fix vector python embedding
This commit is contained in:
@@ -1,13 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
# Keep HuggingFace/SentenceTransformer model loading deterministic in CLI jobs.
|
||||
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
|
||||
os.environ.setdefault("HF_HUB_ETAG_TIMEOUT", "10")
|
||||
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "30")
|
||||
|
||||
MODEL_LOAD_TIMEOUT_SECONDS = int(os.environ.get("RETRIEX_EMBEDDING_MODEL_LOAD_TIMEOUT_SECONDS", "60"))
|
||||
|
||||
|
||||
def log_event(event: str, **payload: Any) -> None:
|
||||
print(json.dumps({"event": event, **payload}, ensure_ascii=False), file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
def fail(message: str, code: int) -> None:
|
||||
print(f"ERROR: {message}", file=sys.stderr)
|
||||
print(f"ERROR: {message}", file=sys.stderr, flush=True)
|
||||
sys.exit(code)
|
||||
|
||||
|
||||
@@ -40,27 +53,6 @@ except Exception:
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Load embedding model from index_meta.json (Single Source of Truth)
|
||||
# ---------------------------------------------------------
|
||||
BASE_PATH = Path(__file__).resolve().parents[2]
|
||||
INDEX_META_PATH = BASE_PATH / "var" / "knowledge" / "index_meta.json"
|
||||
|
||||
if not INDEX_META_PATH.exists():
|
||||
fail("index_meta.json not found", 30)
|
||||
|
||||
try:
|
||||
meta = json.loads(INDEX_META_PATH.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
fail("index_meta.json is invalid", 30)
|
||||
|
||||
embedding_model = meta.get("embedding_model")
|
||||
if not isinstance(embedding_model, str) or embedding_model.strip() == "":
|
||||
fail("embedding_model missing in index_meta.json", 31)
|
||||
|
||||
model = SentenceTransformer(embedding_model.strip())
|
||||
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# File checks
|
||||
# ---------------------------------------------------------
|
||||
@@ -90,6 +82,61 @@ def normalize_text(value: Any) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def resolve_embedding_model_from_meta() -> str:
|
||||
# Local model path wins. This avoids implicit network/cache lookup in production.
|
||||
override = os.environ.get("RETRIEX_EMBEDDING_MODEL_PATH", "").strip()
|
||||
if override:
|
||||
return override
|
||||
|
||||
base_path = Path(__file__).resolve().parents[2]
|
||||
index_meta_path = base_path / "var" / "knowledge" / "index_meta.json"
|
||||
|
||||
if not index_meta_path.exists():
|
||||
fail("index_meta.json not found", 30)
|
||||
|
||||
try:
|
||||
meta = json.loads(index_meta_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
fail("index_meta.json is invalid", 30)
|
||||
|
||||
embedding_model = meta.get("embedding_model")
|
||||
if not isinstance(embedding_model, str) or embedding_model.strip() == "":
|
||||
fail("embedding_model missing in index_meta.json", 31)
|
||||
|
||||
return embedding_model.strip()
|
||||
|
||||
|
||||
def load_sentence_transformer(model_name_or_path: str) -> SentenceTransformer:
|
||||
def timeout_handler(_signum: int, _frame: Any) -> None:
|
||||
raise TimeoutError(
|
||||
"Embedding model load timed out. "
|
||||
"Cache the model locally or set RETRIEX_EMBEDDING_MODEL_PATH."
|
||||
)
|
||||
|
||||
log_event(
|
||||
"tag_embedding_model_load_start",
|
||||
model=model_name_or_path,
|
||||
timeout_seconds=MODEL_LOAD_TIMEOUT_SECONDS,
|
||||
hf_hub_disable_xet=os.environ.get("HF_HUB_DISABLE_XET"),
|
||||
)
|
||||
|
||||
previous_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(MODEL_LOAD_TIMEOUT_SECONDS)
|
||||
|
||||
try:
|
||||
loaded_model = SentenceTransformer(model_name_or_path)
|
||||
except TimeoutError as exc:
|
||||
fail(str(exc), 32)
|
||||
except Exception as exc:
|
||||
fail(f"Unable to load embedding model '{model_name_or_path}': {exc}", 33)
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, previous_handler)
|
||||
|
||||
log_event("tag_embedding_model_load_done", model=model_name_or_path)
|
||||
return loaded_model
|
||||
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Streaming read NDJSON
|
||||
# ---------------------------------------------------------
|
||||
@@ -146,33 +193,34 @@ def load_rows(path: Path) -> Tuple[List[str], List[str], Dict[str, int]]:
|
||||
|
||||
|
||||
texts, ids, stats = load_rows(tags_path)
|
||||
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "tag_rows_loaded",
|
||||
**stats,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
log_event("tag_rows_loaded", **stats)
|
||||
|
||||
if not texts:
|
||||
cleanup_outputs()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Load model only after we know that usable tags exist
|
||||
# ---------------------------------------------------------
|
||||
embedding_model = resolve_embedding_model_from_meta()
|
||||
model = load_sentence_transformer(embedding_model)
|
||||
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Build embeddings
|
||||
# ---------------------------------------------------------
|
||||
log_event("tag_embedding_encode_start", rows=len(texts))
|
||||
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
normalize_embeddings=True,
|
||||
show_progress_bar=True,
|
||||
show_progress_bar=False,
|
||||
batch_size=128,
|
||||
)
|
||||
|
||||
log_event("tag_embedding_encode_done", rows=len(texts))
|
||||
|
||||
embeddings = np.array(embeddings, dtype="float32")
|
||||
|
||||
if embeddings.ndim != 2 or embeddings.shape[0] != len(ids) or embeddings.shape[0] == 0:
|
||||
@@ -198,4 +246,12 @@ meta_path.write_text(
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
sys.exit(0)
|
||||
log_event(
|
||||
"tag_vector_index_written",
|
||||
index=str(out_path),
|
||||
meta=str(meta_path),
|
||||
rows=len(ids),
|
||||
dimension=dim,
|
||||
)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
@@ -1,8 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Keep stdout clean for the PHP caller. Diagnostics go to stderr only.
|
||||
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
|
||||
os.environ.setdefault("HF_HUB_ETAG_TIMEOUT", "5")
|
||||
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "10")
|
||||
|
||||
MODEL_LOAD_TIMEOUT_SECONDS = int(os.environ.get("RETRIEX_EMBEDDING_MODEL_LOAD_TIMEOUT_SECONDS", "30"))
|
||||
|
||||
|
||||
def empty() -> None:
|
||||
print("[]")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def debug(message: str) -> None:
|
||||
print(message, file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Positional args (aligned with PHP client exec call)
|
||||
@@ -12,14 +32,9 @@ from pathlib import Path
|
||||
# 3 index_path
|
||||
# 4 meta_path
|
||||
# 5 model
|
||||
#
|
||||
# Example:
|
||||
# python vector_search_tags.py "foo" 8 /path/vector_tags.index /path/vector_tags.index.meta.json all-MiniLM-L6-v2
|
||||
# ---------------------------------------------------------
|
||||
|
||||
if len(sys.argv) < 6:
|
||||
print("[]")
|
||||
sys.exit(0)
|
||||
empty()
|
||||
|
||||
query = sys.argv[1]
|
||||
|
||||
@@ -29,66 +44,83 @@ except Exception:
|
||||
limit = 5
|
||||
|
||||
index_path = Path(sys.argv[3]).resolve()
|
||||
meta_path = Path(sys.argv[4]).resolve()
|
||||
meta_path = Path(sys.argv[4]).resolve()
|
||||
model_name = sys.argv[5]
|
||||
|
||||
model_override = os.environ.get("RETRIEX_EMBEDDING_MODEL_PATH", "").strip()
|
||||
if model_override:
|
||||
model_name = model_override
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Dependency checks
|
||||
# ---------------------------------------------------------
|
||||
try:
|
||||
import faiss
|
||||
except Exception:
|
||||
# keep stdout clean for caller
|
||||
print("[]")
|
||||
sys.exit(0)
|
||||
empty()
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except Exception:
|
||||
print("[]")
|
||||
sys.exit(0)
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
empty()
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# File checks
|
||||
# ---------------------------------------------------------
|
||||
if limit <= 0:
|
||||
print("[]")
|
||||
sys.exit(0)
|
||||
empty()
|
||||
|
||||
if not index_path.is_file() or not meta_path.is_file():
|
||||
# No tag index available => no routing
|
||||
print("[]")
|
||||
sys.exit(0)
|
||||
empty()
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Load model
|
||||
# Load model with timeout
|
||||
# ---------------------------------------------------------
|
||||
model = SentenceTransformer(model_name)
|
||||
def load_model(model_name_or_path: str) -> SentenceTransformer:
|
||||
def timeout_handler(_signum: int, _frame: Any) -> None:
|
||||
raise TimeoutError("tag search embedding model load timed out")
|
||||
|
||||
previous_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(MODEL_LOAD_TIMEOUT_SECONDS)
|
||||
|
||||
try:
|
||||
return SentenceTransformer(model_name_or_path)
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, previous_handler)
|
||||
|
||||
|
||||
try:
|
||||
model = load_model(model_name)
|
||||
except Exception as exc:
|
||||
debug(f"Unable to load tag search embedding model '{model_name}': {exc}")
|
||||
empty()
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Load index + meta
|
||||
# ---------------------------------------------------------
|
||||
index = faiss.read_index(str(index_path))
|
||||
try:
|
||||
index = faiss.read_index(str(index_path))
|
||||
except Exception:
|
||||
empty()
|
||||
|
||||
try:
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
ids = json.load(f)
|
||||
except Exception:
|
||||
print("[]")
|
||||
sys.exit(0)
|
||||
empty()
|
||||
|
||||
if not isinstance(ids, list) or len(ids) == 0:
|
||||
print("[]")
|
||||
sys.exit(0)
|
||||
empty()
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Embed & search
|
||||
# ---------------------------------------------------------
|
||||
qvec = model.encode([query], normalize_embeddings=True)
|
||||
|
||||
scores, idxs = index.search(qvec, limit)
|
||||
try:
|
||||
qvec = model.encode([query], normalize_embeddings=True, show_progress_bar=False)
|
||||
scores, idxs = index.search(qvec, limit)
|
||||
except Exception:
|
||||
empty()
|
||||
|
||||
out = []
|
||||
for score, idx in zip(scores[0], idxs[0]):
|
||||
@@ -100,4 +132,4 @@ for score, idx in zip(scores[0], idxs[0]):
|
||||
})
|
||||
|
||||
print(json.dumps(out))
|
||||
sys.exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import threading
|
||||
@@ -13,6 +14,12 @@ import numpy as np
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Keep HuggingFace/SentenceTransformer model loading deterministic.
|
||||
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
|
||||
os.environ.setdefault("HF_HUB_ETAG_TIMEOUT", "10")
|
||||
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "30")
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
@@ -147,6 +154,15 @@ def _safe_read_json(path: Path) -> Optional[Any]:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def _resolve_embedding_model_name(configured_model_name: str) -> str:
|
||||
# A local model directory avoids implicit network/cache lookups in production.
|
||||
model_override = os.environ.get("RETRIEX_EMBEDDING_MODEL_PATH", "").strip()
|
||||
if model_override:
|
||||
return model_override
|
||||
|
||||
return configured_model_name.strip()
|
||||
|
||||
def _as_key(value: Any) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
@@ -362,10 +378,12 @@ def load_all() -> None:
|
||||
if not embedding_model_name:
|
||||
raise RuntimeError("embedding_model missing in index_meta.json")
|
||||
|
||||
if model is None or embedding_model_name != loaded_embedding_model_name:
|
||||
logger.info("[Reload] Loading embedding model: %s", embedding_model_name)
|
||||
model = SentenceTransformer(embedding_model_name)
|
||||
loaded_embedding_model_name = embedding_model_name
|
||||
resolved_embedding_model_name = _resolve_embedding_model_name(str(embedding_model_name))
|
||||
|
||||
if model is None or resolved_embedding_model_name != loaded_embedding_model_name:
|
||||
logger.info("[Reload] Loading embedding model: %s", resolved_embedding_model_name)
|
||||
model = SentenceTransformer(resolved_embedding_model_name)
|
||||
loaded_embedding_model_name = resolved_embedding_model_name
|
||||
|
||||
runtime = _safe_read_json(INDEX_RUNTIME_PATH)
|
||||
chunk_runtime_stamp, tags_runtime_stamp, tags_index_present = _extract_runtime_state(runtime)
|
||||
|
||||
Reference in New Issue
Block a user