fix vector python embedding

This commit is contained in:
team 1
2026-04-24 10:43:20 +02:00
parent 63b7011567
commit 4a8ffc5875
6 changed files with 233 additions and 174 deletions

View File

@@ -1,13 +1,26 @@
#!/usr/bin/env python3
import json
import os
import signal
import sys
from pathlib import Path
from typing import Any, Dict, List, Tuple
# Keep HuggingFace/SentenceTransformer model loading deterministic in CLI jobs.
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
os.environ.setdefault("HF_HUB_ETAG_TIMEOUT", "10")
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "30")
MODEL_LOAD_TIMEOUT_SECONDS = int(os.environ.get("RETRIEX_EMBEDDING_MODEL_LOAD_TIMEOUT_SECONDS", "60"))
def log_event(event: str, **payload: Any) -> None:
print(json.dumps({"event": event, **payload}, ensure_ascii=False), file=sys.stderr, flush=True)
def fail(message: str, code: int) -> None:
print(f"ERROR: {message}", file=sys.stderr)
print(f"ERROR: {message}", file=sys.stderr, flush=True)
sys.exit(code)
@@ -40,27 +53,6 @@ except Exception:
import numpy as np
# ---------------------------------------------------------
# Load embedding model from index_meta.json (Single Source of Truth)
# ---------------------------------------------------------
BASE_PATH = Path(__file__).resolve().parents[2]
INDEX_META_PATH = BASE_PATH / "var" / "knowledge" / "index_meta.json"
if not INDEX_META_PATH.exists():
fail("index_meta.json not found", 30)
try:
meta = json.loads(INDEX_META_PATH.read_text(encoding="utf-8"))
except Exception:
fail("index_meta.json is invalid", 30)
embedding_model = meta.get("embedding_model")
if not isinstance(embedding_model, str) or embedding_model.strip() == "":
fail("embedding_model missing in index_meta.json", 31)
model = SentenceTransformer(embedding_model.strip())
# ---------------------------------------------------------
# File checks
# ---------------------------------------------------------
@@ -90,6 +82,61 @@ def normalize_text(value: Any) -> str:
return text
def resolve_embedding_model_from_meta() -> str:
# Local model path wins. This avoids implicit network/cache lookup in production.
override = os.environ.get("RETRIEX_EMBEDDING_MODEL_PATH", "").strip()
if override:
return override
base_path = Path(__file__).resolve().parents[2]
index_meta_path = base_path / "var" / "knowledge" / "index_meta.json"
if not index_meta_path.exists():
fail("index_meta.json not found", 30)
try:
meta = json.loads(index_meta_path.read_text(encoding="utf-8"))
except Exception:
fail("index_meta.json is invalid", 30)
embedding_model = meta.get("embedding_model")
if not isinstance(embedding_model, str) or embedding_model.strip() == "":
fail("embedding_model missing in index_meta.json", 31)
return embedding_model.strip()
def load_sentence_transformer(model_name_or_path: str) -> SentenceTransformer:
def timeout_handler(_signum: int, _frame: Any) -> None:
raise TimeoutError(
"Embedding model load timed out. "
"Cache the model locally or set RETRIEX_EMBEDDING_MODEL_PATH."
)
log_event(
"tag_embedding_model_load_start",
model=model_name_or_path,
timeout_seconds=MODEL_LOAD_TIMEOUT_SECONDS,
hf_hub_disable_xet=os.environ.get("HF_HUB_DISABLE_XET"),
)
previous_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(MODEL_LOAD_TIMEOUT_SECONDS)
try:
loaded_model = SentenceTransformer(model_name_or_path)
except TimeoutError as exc:
fail(str(exc), 32)
except Exception as exc:
fail(f"Unable to load embedding model '{model_name_or_path}': {exc}", 33)
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, previous_handler)
log_event("tag_embedding_model_load_done", model=model_name_or_path)
return loaded_model
# ---------------------------------------------------------
# Streaming read NDJSON
# ---------------------------------------------------------
@@ -146,33 +193,34 @@ def load_rows(path: Path) -> Tuple[List[str], List[str], Dict[str, int]]:
texts, ids, stats = load_rows(tags_path)
print(
json.dumps(
{
"event": "tag_rows_loaded",
**stats,
},
ensure_ascii=False,
),
file=sys.stderr,
)
log_event("tag_rows_loaded", **stats)
if not texts:
cleanup_outputs()
sys.exit(0)
# ---------------------------------------------------------
# Load model only after we know that usable tags exist
# ---------------------------------------------------------
embedding_model = resolve_embedding_model_from_meta()
model = load_sentence_transformer(embedding_model)
# ---------------------------------------------------------
# Build embeddings
# ---------------------------------------------------------
log_event("tag_embedding_encode_start", rows=len(texts))
embeddings = model.encode(
texts,
normalize_embeddings=True,
show_progress_bar=True,
show_progress_bar=False,
batch_size=128,
)
log_event("tag_embedding_encode_done", rows=len(texts))
embeddings = np.array(embeddings, dtype="float32")
if embeddings.ndim != 2 or embeddings.shape[0] != len(ids) or embeddings.shape[0] == 0:
@@ -198,4 +246,12 @@ meta_path.write_text(
encoding="utf-8",
)
sys.exit(0)
log_event(
"tag_vector_index_written",
index=str(out_path),
meta=str(meta_path),
rows=len(ids),
dimension=dim,
)
sys.exit(0)

View File

@@ -1,8 +1,28 @@
#!/usr/bin/env python3
import sys
import json
import os
import signal
import sys
from pathlib import Path
from typing import Any
# Keep stdout clean for the PHP caller. Diagnostics go to stderr only.
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
os.environ.setdefault("HF_HUB_ETAG_TIMEOUT", "5")
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "10")
MODEL_LOAD_TIMEOUT_SECONDS = int(os.environ.get("RETRIEX_EMBEDDING_MODEL_LOAD_TIMEOUT_SECONDS", "30"))
def empty() -> None:
print("[]")
sys.exit(0)
def debug(message: str) -> None:
print(message, file=sys.stderr, flush=True)
# ---------------------------------------------------------
# Positional args (aligned with PHP client exec call)
@@ -12,14 +32,9 @@ from pathlib import Path
# 3 index_path
# 4 meta_path
# 5 model
#
# Example:
# python vector_search_tags.py "foo" 8 /path/vector_tags.index /path/vector_tags.index.meta.json all-MiniLM-L6-v2
# ---------------------------------------------------------
if len(sys.argv) < 6:
print("[]")
sys.exit(0)
empty()
query = sys.argv[1]
@@ -29,66 +44,83 @@ except Exception:
limit = 5
index_path = Path(sys.argv[3]).resolve()
meta_path = Path(sys.argv[4]).resolve()
meta_path = Path(sys.argv[4]).resolve()
model_name = sys.argv[5]
model_override = os.environ.get("RETRIEX_EMBEDDING_MODEL_PATH", "").strip()
if model_override:
model_name = model_override
# ---------------------------------------------------------
# Dependency checks
# ---------------------------------------------------------
try:
import faiss
except Exception:
# keep stdout clean for caller
print("[]")
sys.exit(0)
empty()
try:
from sentence_transformers import SentenceTransformer
except Exception:
print("[]")
sys.exit(0)
from sentence_transformers import SentenceTransformer
empty()
# ---------------------------------------------------------
# File checks
# ---------------------------------------------------------
if limit <= 0:
print("[]")
sys.exit(0)
empty()
if not index_path.is_file() or not meta_path.is_file():
# No tag index available => no routing
print("[]")
sys.exit(0)
empty()
# ---------------------------------------------------------
# Load model
# Load model with timeout
# ---------------------------------------------------------
model = SentenceTransformer(model_name)
def load_model(model_name_or_path: str) -> SentenceTransformer:
def timeout_handler(_signum: int, _frame: Any) -> None:
raise TimeoutError("tag search embedding model load timed out")
previous_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(MODEL_LOAD_TIMEOUT_SECONDS)
try:
return SentenceTransformer(model_name_or_path)
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, previous_handler)
try:
model = load_model(model_name)
except Exception as exc:
debug(f"Unable to load tag search embedding model '{model_name}': {exc}")
empty()
# ---------------------------------------------------------
# Load index + meta
# ---------------------------------------------------------
index = faiss.read_index(str(index_path))
try:
index = faiss.read_index(str(index_path))
except Exception:
empty()
try:
with open(meta_path, "r", encoding="utf-8") as f:
ids = json.load(f)
except Exception:
print("[]")
sys.exit(0)
empty()
if not isinstance(ids, list) or len(ids) == 0:
print("[]")
sys.exit(0)
empty()
# ---------------------------------------------------------
# Embed & search
# ---------------------------------------------------------
qvec = model.encode([query], normalize_embeddings=True)
scores, idxs = index.search(qvec, limit)
try:
qvec = model.encode([query], normalize_embeddings=True, show_progress_bar=False)
scores, idxs = index.search(qvec, limit)
except Exception:
empty()
out = []
for score, idx in zip(scores[0], idxs[0]):
@@ -100,4 +132,4 @@ for score, idx in zip(scores[0], idxs[0]):
})
print(json.dumps(out))
sys.exit(0)
sys.exit(0)

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import json
import os
import logging
from logging.handlers import RotatingFileHandler
import threading
@@ -13,6 +14,12 @@ import numpy as np
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# Keep HuggingFace/SentenceTransformer model loading deterministic.
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
os.environ.setdefault("HF_HUB_ETAG_TIMEOUT", "10")
os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "30")
from sentence_transformers import SentenceTransformer
@@ -147,6 +154,15 @@ def _safe_read_json(path: Path) -> Optional[Any]:
return None
def _resolve_embedding_model_name(configured_model_name: str) -> str:
# A local model directory avoids implicit network/cache lookups in production.
model_override = os.environ.get("RETRIEX_EMBEDDING_MODEL_PATH", "").strip()
if model_override:
return model_override
return configured_model_name.strip()
def _as_key(value: Any) -> Optional[str]:
if value is None:
return None
@@ -362,10 +378,12 @@ def load_all() -> None:
if not embedding_model_name:
raise RuntimeError("embedding_model missing in index_meta.json")
if model is None or embedding_model_name != loaded_embedding_model_name:
logger.info("[Reload] Loading embedding model: %s", embedding_model_name)
model = SentenceTransformer(embedding_model_name)
loaded_embedding_model_name = embedding_model_name
resolved_embedding_model_name = _resolve_embedding_model_name(str(embedding_model_name))
if model is None or resolved_embedding_model_name != loaded_embedding_model_name:
logger.info("[Reload] Loading embedding model: %s", resolved_embedding_model_name)
model = SentenceTransformer(resolved_embedding_model_name)
loaded_embedding_model_name = resolved_embedding_model_name
runtime = _safe_read_json(INDEX_RUNTIME_PATH)
chunk_runtime_stamp, tags_runtime_stamp, tags_index_present = _extract_runtime_state(runtime)