add catalog mode
This commit is contained in:
@@ -5,24 +5,19 @@ import json
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Positional args (aligned with PHP builder exec call)
|
||||
# ---------------------------------------------------------
|
||||
# Positional args
|
||||
# 1 tags.ndjson
|
||||
# 2 out_index_path (can be .tmp)
|
||||
# 3 model
|
||||
# Example:
|
||||
# python vector_ingest_tags.py /var/knowledge/tags.ndjson /var/knowledge/vector_tags.index.tmp all-MiniLM-L6-v2
|
||||
# ---------------------------------------------------------
|
||||
|
||||
if len(sys.argv) < 4:
|
||||
print("ERROR: usage: vector_ingest_tags.py <tags.ndjson> <out.index> <model>", file=sys.stderr)
|
||||
if len(sys.argv) < 3:
|
||||
print("ERROR: usage: vector_ingest_tags.py <tags.ndjson> <out.index>", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
tags_path = Path(sys.argv[1]).resolve()
|
||||
out_path = Path(sys.argv[2]).resolve()
|
||||
model_name = sys.argv[3]
|
||||
|
||||
meta_path = Path(str(out_path) + ".meta.json") # vector_tags.index(.tmp).meta.json
|
||||
meta_path = Path(str(out_path) + ".meta.json")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Dependency checks
|
||||
@@ -43,6 +38,25 @@ import numpy as np
|
||||
import faiss
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Load embedding model from index_meta.json (Single Source of Truth)
|
||||
# ---------------------------------------------------------
|
||||
BASE_PATH = Path(__file__).resolve().parents[2]
|
||||
INDEX_META_PATH = BASE_PATH / "var" / "knowledge" / "index_meta.json"
|
||||
|
||||
if not INDEX_META_PATH.exists():
|
||||
print("ERROR: index_meta.json not found", file=sys.stderr)
|
||||
sys.exit(30)
|
||||
|
||||
meta = json.loads(INDEX_META_PATH.read_text(encoding="utf-8"))
|
||||
embedding_model = meta.get("embedding_model")
|
||||
|
||||
if not embedding_model:
|
||||
print("ERROR: embedding_model missing in index_meta.json", file=sys.stderr)
|
||||
sys.exit(31)
|
||||
|
||||
model = SentenceTransformer(embedding_model)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# File checks
|
||||
# ---------------------------------------------------------
|
||||
@@ -50,14 +64,8 @@ if not tags_path.is_file():
|
||||
print(f"ERROR: tags.ndjson not found at {tags_path}", file=sys.stderr)
|
||||
sys.exit(20)
|
||||
|
||||
# Ensure output directory exists
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Load model
|
||||
# ---------------------------------------------------------
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Streaming read NDJSON
|
||||
# ---------------------------------------------------------
|
||||
@@ -85,13 +93,9 @@ with open(tags_path, "r", encoding="utf-8") as f:
|
||||
if len(text) > 4000:
|
||||
text = text[:4000]
|
||||
|
||||
# -------------------------------------------------
|
||||
# E5 requires "passage:" prefix for indexed texts
|
||||
# -------------------------------------------------
|
||||
texts.append(f"passage: {text}")
|
||||
ids.append(str(tag_id))
|
||||
|
||||
# If empty: remove outputs (tmp) and exit success
|
||||
if not texts:
|
||||
if out_path.exists():
|
||||
out_path.unlink()
|
||||
@@ -112,17 +116,11 @@ embeddings = model.encode(
|
||||
embeddings = np.array(embeddings).astype("float32")
|
||||
dim = embeddings.shape[1]
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Build FAISS index
|
||||
# ---------------------------------------------------------
|
||||
index = faiss.IndexFlatIP(dim)
|
||||
index.add(embeddings)
|
||||
|
||||
faiss.write_index(index, str(out_path))
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Write ID mapping meta
|
||||
# ---------------------------------------------------------
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(ids, f)
|
||||
|
||||
|
||||
@@ -10,11 +10,19 @@ from typing import Any, List, Optional, Dict
|
||||
|
||||
import numpy as np
|
||||
import faiss
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Service Stamp (to verify you are running THIS file)
|
||||
# ============================================================
|
||||
|
||||
SERVICE_STAMP = "vector_service.py@2026-02-28T10:20+01:00"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Paths
|
||||
# ============================================================
|
||||
@@ -42,6 +50,7 @@ INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson"
|
||||
logger = logging.getLogger("vector_service")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -68,6 +77,23 @@ def setup_logging() -> None:
|
||||
if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
# Capture uvicorn logs in the same file as well (critical for hidden 500s)
|
||||
uvicorn_error = logging.getLogger("uvicorn.error")
|
||||
uvicorn_access = logging.getLogger("uvicorn.access")
|
||||
|
||||
uvicorn_error.setLevel(logging.INFO)
|
||||
uvicorn_access.setLevel(logging.INFO)
|
||||
|
||||
if not any(isinstance(h, RotatingFileHandler) for h in uvicorn_error.handlers):
|
||||
uvicorn_error.addHandler(file_handler)
|
||||
if not any(isinstance(h, logging.StreamHandler) for h in uvicorn_error.handlers):
|
||||
uvicorn_error.addHandler(stream_handler)
|
||||
|
||||
if not any(isinstance(h, RotatingFileHandler) for h in uvicorn_access.handlers):
|
||||
uvicorn_access.addHandler(file_handler)
|
||||
if not any(isinstance(h, logging.StreamHandler) for h in uvicorn_access.handlers):
|
||||
uvicorn_access.addHandler(stream_handler)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# FastAPI
|
||||
@@ -79,9 +105,6 @@ model: Optional[SentenceTransformer] = None
|
||||
chunk_index = None
|
||||
chunk_ids: Optional[List[Any]] = None
|
||||
|
||||
# Sales-RAG signals derived from NDJSON (loaded on startup and reload):
|
||||
# - chunk_doc_map: chunk_id -> document_id
|
||||
# - chunk_pos_map: chunk_id -> chunk_index (position within document, if available)
|
||||
chunk_doc_map: Dict[str, str] = {}
|
||||
chunk_pos_map: Dict[str, int] = {}
|
||||
|
||||
@@ -89,7 +112,6 @@ tag_index = None
|
||||
tag_ids: Optional[List[Any]] = None
|
||||
|
||||
loaded_embedding_model_name: Optional[str] = None
|
||||
|
||||
current_index_version: Optional[int] = None
|
||||
current_runtime_stamp: Optional[str] = None
|
||||
|
||||
@@ -107,10 +129,10 @@ class SearchRequest(BaseModel):
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Loader
|
||||
# Loader Helpers
|
||||
# ============================================================
|
||||
|
||||
def _safe_read_json(path: Path) -> Optional[dict]:
|
||||
def _safe_read_json(path: Path) -> Optional[Any]:
|
||||
try:
|
||||
if not path.exists():
|
||||
return None
|
||||
@@ -121,9 +143,6 @@ def _safe_read_json(path: Path) -> Optional[dict]:
|
||||
|
||||
|
||||
def _as_key(value: Any) -> Optional[str]:
|
||||
"""
|
||||
Normalize IDs to string keys for maps. Returns None if unusable.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
@@ -136,12 +155,19 @@ def _as_key(value: Any) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_limit(limit: int, default: int = 8, max_limit: int = 200) -> int:
|
||||
try:
|
||||
v = int(limit)
|
||||
except Exception:
|
||||
return default
|
||||
if v <= 0:
|
||||
return default
|
||||
if v > max_limit:
|
||||
return max_limit
|
||||
return v
|
||||
|
||||
|
||||
def load_chunk_maps_from_ndjson() -> None:
|
||||
"""
|
||||
Builds two maps from index.ndjson:
|
||||
- chunk_id -> document_id
|
||||
- chunk_id -> chunk_index (position inside document, if present)
|
||||
"""
|
||||
global chunk_doc_map, chunk_pos_map
|
||||
|
||||
chunk_doc_map = {}
|
||||
@@ -156,7 +182,6 @@ def load_chunk_maps_from_ndjson() -> None:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except Exception:
|
||||
@@ -166,40 +191,43 @@ def load_chunk_maps_from_ndjson() -> None:
|
||||
if not chunk_id_key:
|
||||
continue
|
||||
|
||||
document_id = row.get("document_id")
|
||||
doc_id_key = _as_key(document_id)
|
||||
doc_id_key = _as_key(row.get("document_id"))
|
||||
if doc_id_key:
|
||||
chunk_doc_map[chunk_id_key] = doc_id_key
|
||||
|
||||
# chunk_index is optional but very useful for Sales-RAG diversity rules
|
||||
# (e.g. min distance within a doc)
|
||||
ci = row.get("chunk_index")
|
||||
if isinstance(ci, int):
|
||||
chunk_pos_map[chunk_id_key] = ci
|
||||
else:
|
||||
# tolerate numeric strings
|
||||
if isinstance(ci, str):
|
||||
s = ci.strip()
|
||||
if s.isdigit():
|
||||
try:
|
||||
chunk_pos_map[chunk_id_key] = int(s)
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(ci, str):
|
||||
s = ci.strip()
|
||||
if s.isdigit():
|
||||
try:
|
||||
chunk_pos_map[chunk_id_key] = int(s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load chunk maps from ndjson: %s", str(e))
|
||||
|
||||
|
||||
def _sanitize_limit(limit: int, default: int = 8, max_limit: int = 200) -> int:
|
||||
try:
|
||||
v = int(limit)
|
||||
except Exception:
|
||||
return default
|
||||
if v <= 0:
|
||||
return default
|
||||
if v > max_limit:
|
||||
return max_limit
|
||||
return v
|
||||
def _normalize_meta_list(value: Any) -> Optional[List[Any]]:
|
||||
"""
|
||||
Accepts:
|
||||
- list: ok
|
||||
- dict like {"0": "...", "1": "..."}: convert to list sorted by numeric key
|
||||
Returns None if invalid.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
|
||||
if isinstance(value, dict):
|
||||
try:
|
||||
keys = sorted(int(k) for k in value.keys())
|
||||
return [value[str(i)] for i in keys]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def load_all() -> None:
|
||||
@@ -225,13 +253,14 @@ def load_all() -> None:
|
||||
model = SentenceTransformer(embedding_model_name)
|
||||
loaded_embedding_model_name = embedding_model_name
|
||||
|
||||
# Chunks
|
||||
if CHUNK_INDEX_PATH.exists() and CHUNK_MAP_PATH.exists():
|
||||
logger.info("[Reload] Loading chunk index")
|
||||
chunk_index = faiss.read_index(str(CHUNK_INDEX_PATH))
|
||||
chunk_ids = _safe_read_json(CHUNK_MAP_PATH) or None
|
||||
if not isinstance(chunk_ids, list):
|
||||
raw = _safe_read_json(CHUNK_MAP_PATH)
|
||||
chunk_ids = _normalize_meta_list(raw)
|
||||
if chunk_ids is None:
|
||||
chunk_index = None
|
||||
chunk_ids = None
|
||||
logger.warning("[Reload] chunk_ids meta invalid -> chunk index disabled")
|
||||
else:
|
||||
chunk_index = None
|
||||
@@ -240,13 +269,14 @@ def load_all() -> None:
|
||||
logger.info("[Reload] Loading chunk maps (doc_id + chunk_index)")
|
||||
load_chunk_maps_from_ndjson()
|
||||
|
||||
# Tags
|
||||
if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists():
|
||||
logger.info("[Reload] Loading tag index")
|
||||
tag_index = faiss.read_index(str(TAG_INDEX_PATH))
|
||||
tag_ids = _safe_read_json(TAG_MAP_PATH) or None
|
||||
if not isinstance(tag_ids, list):
|
||||
raw = _safe_read_json(TAG_MAP_PATH)
|
||||
tag_ids = _normalize_meta_list(raw)
|
||||
if tag_ids is None:
|
||||
tag_index = None
|
||||
tag_ids = None
|
||||
logger.warning("[Reload] tag_ids meta invalid -> tag index disabled")
|
||||
else:
|
||||
tag_index = None
|
||||
@@ -262,15 +292,17 @@ def load_all() -> None:
|
||||
current_index_version = index_version if isinstance(index_version, int) else None
|
||||
|
||||
logger.info(
|
||||
"[Reload] Completed (index_version=%s runtime=%s embedding_model=%s)",
|
||||
"[Reload] Completed (index_version=%s runtime=%s embedding_model=%s stamp=%s file=%s)",
|
||||
str(current_index_version),
|
||||
str(current_runtime_stamp),
|
||||
str(loaded_embedding_model_name),
|
||||
SERVICE_STAMP,
|
||||
str(Path(__file__).resolve()),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Observer (Enterprise Auto Reload)
|
||||
# Observer
|
||||
# ============================================================
|
||||
|
||||
def observer_loop() -> None:
|
||||
@@ -294,24 +326,34 @@ def observer_loop() -> None:
|
||||
new_runtime = v if isinstance(v, str) else None
|
||||
|
||||
if new_version != current_index_version:
|
||||
logger.info(
|
||||
"[Observer] index_version changed (%s -> %s) -> Reload",
|
||||
str(current_index_version),
|
||||
str(new_version),
|
||||
)
|
||||
logger.info("[Observer] index_version changed (%s -> %s) -> Reload", str(current_index_version), str(new_version))
|
||||
load_all()
|
||||
continue
|
||||
|
||||
if new_runtime != current_runtime_stamp:
|
||||
logger.info(
|
||||
"[Observer] runtime changed (%s -> %s) -> Reload",
|
||||
str(current_runtime_stamp),
|
||||
str(new_runtime),
|
||||
)
|
||||
logger.info("[Observer] runtime changed (%s -> %s) -> Reload", str(current_runtime_stamp), str(new_runtime))
|
||||
load_all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Observer ERROR] %s", str(e))
|
||||
logger.exception("[Observer ERROR] %s", str(e))
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Global Exception Handler (forces JSON + logs)
|
||||
# ============================================================
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
logger.exception("UNHANDLED_EXCEPTION path=%s method=%s", request.url.path, request.method)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "Internal Server Error",
|
||||
"detail": str(exc),
|
||||
"path": request.url.path,
|
||||
"stamp": SERVICE_STAMP,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -321,13 +363,10 @@ def observer_loop() -> None:
|
||||
@app.on_event("startup")
|
||||
def startup_event():
|
||||
setup_logging()
|
||||
logger.info("[VectorService] Startup")
|
||||
|
||||
logger.info("[VectorService] Startup stamp=%s file=%s", SERVICE_STAMP, str(Path(__file__).resolve()))
|
||||
load_all()
|
||||
|
||||
t = threading.Thread(target=observer_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
logger.info("[VectorService] Ready (log=%s)", str(LOG_FILE))
|
||||
|
||||
|
||||
@@ -339,12 +378,18 @@ def startup_event():
|
||||
def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"stamp": SERVICE_STAMP,
|
||||
"file": str(Path(__file__).resolve()),
|
||||
"chunk_index_loaded": chunk_index is not None,
|
||||
"tag_index_loaded": tag_index is not None,
|
||||
"model_loaded": model is not None,
|
||||
"embedding_model": loaded_embedding_model_name,
|
||||
"index_version": current_index_version,
|
||||
"runtime_stamp": current_runtime_stamp,
|
||||
"tag_meta_type": type(tag_ids).__name__ if tag_ids is not None else None,
|
||||
"tag_meta_len": len(tag_ids) if isinstance(tag_ids, list) else None,
|
||||
"chunk_meta_type": type(chunk_ids).__name__ if chunk_ids is not None else None,
|
||||
"chunk_meta_len": len(chunk_ids) if isinstance(chunk_ids, list) else None,
|
||||
"log_file": str(LOG_FILE),
|
||||
}
|
||||
|
||||
@@ -353,8 +398,9 @@ def health():
|
||||
def reload():
|
||||
try:
|
||||
load_all()
|
||||
return {"status": "reloaded"}
|
||||
return {"status": "reloaded", "stamp": SERVICE_STAMP}
|
||||
except Exception as e:
|
||||
logger.exception("reload failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@@ -363,74 +409,68 @@ def search_chunks(req: SearchRequest):
|
||||
if chunk_index is None or chunk_ids is None or model is None:
|
||||
raise HTTPException(status_code=503, detail="Chunk index not available")
|
||||
|
||||
# Safety: clamp limit to prevent abuse / accidental huge queries
|
||||
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
|
||||
try:
|
||||
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
|
||||
|
||||
query = (req.query or "").strip()
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="query must not be empty")
|
||||
query = (req.query or "").strip()
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="query must not be empty")
|
||||
|
||||
query_vec = model.encode(
|
||||
[f"query: {query}"],
|
||||
normalize_embeddings=True
|
||||
)
|
||||
query_vec = np.array(query_vec).astype("float32")
|
||||
query_vec = model.encode([f"query: {query}"], normalize_embeddings=True)
|
||||
query_vec = np.array(query_vec).astype("float32")
|
||||
|
||||
effective_limit = limit
|
||||
doc_filter: Optional[List[str]] = None
|
||||
if req.doc_ids:
|
||||
# Normalize incoming doc_ids for reliable matching
|
||||
doc_filter = []
|
||||
for d in req.doc_ids:
|
||||
dk = _as_key(d)
|
||||
if dk:
|
||||
doc_filter.append(dk)
|
||||
effective_limit = limit
|
||||
doc_filter: Optional[List[str]] = None
|
||||
if req.doc_ids:
|
||||
doc_filter = []
|
||||
for d in req.doc_ids:
|
||||
dk = _as_key(d)
|
||||
if dk:
|
||||
doc_filter.append(dk)
|
||||
effective_limit = max(limit * 5, 50)
|
||||
effective_limit = min(effective_limit, 500)
|
||||
|
||||
# When doc filtering is enabled, we fetch a wider pool and filter down.
|
||||
# Keep it bounded to avoid expensive scans on huge indices.
|
||||
effective_limit = max(limit * 5, 50)
|
||||
effective_limit = min(effective_limit, 500)
|
||||
scores, indices = chunk_index.search(query_vec, effective_limit)
|
||||
|
||||
scores, indices = chunk_index.search(query_vec, effective_limit)
|
||||
|
||||
results = []
|
||||
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx == -1:
|
||||
continue
|
||||
if idx < 0 or idx >= len(chunk_ids):
|
||||
continue
|
||||
|
||||
raw_chunk_id = chunk_ids[idx]
|
||||
chunk_id_key = _as_key(raw_chunk_id)
|
||||
if not chunk_id_key:
|
||||
continue
|
||||
|
||||
# Apply doc filter if requested
|
||||
doc_id = chunk_doc_map.get(chunk_id_key)
|
||||
if doc_filter is not None:
|
||||
if doc_id is None or doc_id not in doc_filter:
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx == -1:
|
||||
continue
|
||||
if idx < 0 or idx >= len(chunk_ids):
|
||||
continue
|
||||
|
||||
# Sales-RAG signals:
|
||||
# - document_id (for doc quotas / diversity rules)
|
||||
# - chunk_index (position within doc for distance constraints)
|
||||
payload = {
|
||||
"chunk_id": raw_chunk_id,
|
||||
"score": float(score),
|
||||
"document_id": doc_id, # may be None if ndjson missing/partial
|
||||
}
|
||||
raw_chunk_id = chunk_ids[idx]
|
||||
chunk_id_key = _as_key(raw_chunk_id)
|
||||
if not chunk_id_key:
|
||||
continue
|
||||
|
||||
ci = chunk_pos_map.get(chunk_id_key)
|
||||
if isinstance(ci, int):
|
||||
payload["chunk_index"] = ci
|
||||
doc_id = chunk_doc_map.get(chunk_id_key)
|
||||
if doc_filter is not None:
|
||||
if doc_id is None or doc_id not in doc_filter:
|
||||
continue
|
||||
|
||||
results.append(payload)
|
||||
payload = {
|
||||
"chunk_id": raw_chunk_id,
|
||||
"score": float(score),
|
||||
"document_id": doc_id,
|
||||
}
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
ci = chunk_pos_map.get(chunk_id_key)
|
||||
if isinstance(ci, int):
|
||||
payload["chunk_index"] = ci
|
||||
|
||||
return results
|
||||
results.append(payload)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("search-chunks failure")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/search-tags")
|
||||
@@ -438,31 +478,36 @@ def search_tags(req: SearchRequest):
|
||||
if tag_index is None or tag_ids is None or model is None:
|
||||
raise HTTPException(status_code=503, detail="Tag index not available")
|
||||
|
||||
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
|
||||
try:
|
||||
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
|
||||
|
||||
query = (req.query or "").strip()
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="query must not be empty")
|
||||
query = (req.query or "").strip()
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="query must not be empty")
|
||||
|
||||
query_vec = model.encode(
|
||||
[f"query: {query}"],
|
||||
normalize_embeddings=True
|
||||
)
|
||||
query_vec = np.array(query_vec).astype("float32")
|
||||
query_vec = model.encode([f"query: {query}"], normalize_embeddings=True)
|
||||
query_vec = np.array(query_vec).astype("float32")
|
||||
|
||||
scores, indices = tag_index.search(query_vec, limit)
|
||||
if query_vec.ndim != 2:
|
||||
raise RuntimeError(f"Invalid embedding shape: {query_vec.shape}")
|
||||
|
||||
results = []
|
||||
if query_vec.shape[1] != tag_index.d:
|
||||
raise RuntimeError(f"Embedding dimension mismatch (vec={query_vec.shape[1]}, index={tag_index.d})")
|
||||
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx == -1:
|
||||
continue
|
||||
if idx < 0 or idx >= len(tag_ids):
|
||||
continue
|
||||
scores, indices = tag_index.search(query_vec, limit)
|
||||
|
||||
results.append({
|
||||
"tag_id": tag_ids[idx],
|
||||
"score": float(score),
|
||||
})
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx == -1:
|
||||
continue
|
||||
if idx < 0 or idx >= len(tag_ids):
|
||||
continue
|
||||
results.append({"tag_id": tag_ids[idx], "score": float(score)})
|
||||
|
||||
return results
|
||||
return results
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("search-tags failure")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user