696 lines
23 KiB
Python
696 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import json
|
|
import logging
|
|
from logging.handlers import RotatingFileHandler
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import faiss
|
|
import numpy as np
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
# ============================================================
|
|
# Service Stamp (to verify you are running THIS file)
|
|
# ============================================================
|
|
|
|
SERVICE_STAMP = "vector_service.py@2026-04-20T00:00+02:00"
|
|
|
|
|
|
# ============================================================
|
|
# Paths
|
|
# ============================================================
|
|
|
|
BASE_PATH = Path(__file__).resolve().parents[2]
|
|
KNOWLEDGE_DIR = BASE_PATH / "var" / "knowledge"
|
|
LOG_DIR = BASE_PATH / "var" / "log"
|
|
LOG_FILE = LOG_DIR / "vector_service.log"
|
|
|
|
CHUNK_INDEX_PATH = KNOWLEDGE_DIR / "vector.index"
|
|
CHUNK_MAP_PATH = KNOWLEDGE_DIR / "vector.index.meta.json"
|
|
|
|
TAG_INDEX_PATH = KNOWLEDGE_DIR / "vector_tags.index"
|
|
TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json"
|
|
|
|
INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json"
|
|
INDEX_RUNTIME_PATH = KNOWLEDGE_DIR / "index_runtime.json"
|
|
INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson"
|
|
TAGS_NDJSON_PATH = KNOWLEDGE_DIR / "tags.ndjson"
|
|
|
|
|
|
# ============================================================
|
|
# Logging
|
|
# ============================================================
|
|
|
|
logger = logging.getLogger("vector_service")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
# ============================================================
|
|
# App State
|
|
# ============================================================
|
|
|
|
app = FastAPI()
|
|
|
|
model: Optional[SentenceTransformer] = None
|
|
chunk_index = None
|
|
chunk_ids: Optional[List[Any]] = None
|
|
|
|
chunk_doc_map: Dict[str, str] = {}
|
|
chunk_pos_map: Dict[str, int] = {}
|
|
|
|
tag_index = None
|
|
tag_ids: Optional[List[Any]] = None
|
|
|
|
# tag_id -> {"label": "...", "tag_type": "..."}
|
|
tag_meta_map: Dict[str, Dict[str, str]] = {}
|
|
|
|
loaded_embedding_model_name: Optional[str] = None
|
|
current_index_version: Optional[int] = None
|
|
current_chunk_runtime_stamp: Optional[str] = None
|
|
current_tags_runtime_stamp: Optional[str] = None
|
|
current_tags_index_present: Optional[bool] = None
|
|
|
|
reload_lock = threading.Lock()
|
|
|
|
|
|
# ============================================================
|
|
# Models
|
|
# ============================================================
|
|
|
|
class SearchRequest(BaseModel):
|
|
query: str
|
|
limit: int = 8
|
|
doc_ids: Optional[List[str]] = None
|
|
|
|
|
|
# ============================================================
|
|
# Helpers
|
|
# ============================================================
|
|
|
|
def setup_logging() -> None:
|
|
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
fmt = logging.Formatter(
|
|
fmt="%(asctime)s %(levelname)s %(message)s",
|
|
datefmt="%Y-%m-%dT%H:%M:%S%z",
|
|
)
|
|
|
|
file_handler = RotatingFileHandler(
|
|
str(LOG_FILE),
|
|
maxBytes=10 * 1024 * 1024,
|
|
backupCount=5,
|
|
encoding="utf-8",
|
|
)
|
|
file_handler.setFormatter(fmt)
|
|
file_handler.setLevel(logging.INFO)
|
|
|
|
stream_handler = logging.StreamHandler()
|
|
stream_handler.setFormatter(fmt)
|
|
stream_handler.setLevel(logging.INFO)
|
|
|
|
if not any(isinstance(h, RotatingFileHandler) for h in logger.handlers):
|
|
logger.addHandler(file_handler)
|
|
if not any(type(h) is logging.StreamHandler for h in logger.handlers):
|
|
logger.addHandler(stream_handler)
|
|
|
|
uvicorn_error = logging.getLogger("uvicorn.error")
|
|
uvicorn_access = logging.getLogger("uvicorn.access")
|
|
|
|
uvicorn_error.setLevel(logging.INFO)
|
|
uvicorn_access.setLevel(logging.INFO)
|
|
|
|
if not any(isinstance(h, RotatingFileHandler) for h in uvicorn_error.handlers):
|
|
uvicorn_error.addHandler(file_handler)
|
|
if not any(type(h) is logging.StreamHandler for h in uvicorn_error.handlers):
|
|
uvicorn_error.addHandler(stream_handler)
|
|
|
|
if not any(isinstance(h, RotatingFileHandler) for h in uvicorn_access.handlers):
|
|
uvicorn_access.addHandler(file_handler)
|
|
if not any(type(h) is logging.StreamHandler for h in uvicorn_access.handlers):
|
|
uvicorn_access.addHandler(stream_handler)
|
|
|
|
|
|
def _safe_read_json(path: Path) -> Optional[Any]:
|
|
try:
|
|
if not path.exists():
|
|
return None
|
|
return json.loads(path.read_text(encoding="utf-8"))
|
|
except Exception as exc:
|
|
logger.warning("Failed to read json %s: %s", str(path), str(exc))
|
|
return None
|
|
|
|
|
|
def _as_key(value: Any) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
value = value.strip()
|
|
return value or None
|
|
try:
|
|
value = str(value).strip()
|
|
return value or None
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _sanitize_limit(limit: int, default: int = 8, max_limit: int = 200) -> int:
|
|
try:
|
|
value = int(limit)
|
|
except Exception:
|
|
return default
|
|
if value <= 0:
|
|
return default
|
|
if value > max_limit:
|
|
return max_limit
|
|
return value
|
|
|
|
|
|
def _normalize_meta_list(value: Any) -> Optional[List[Any]]:
|
|
"""
|
|
Accepts:
|
|
- list: ok
|
|
- dict like {"0": "...", "1": "..."}: convert to list sorted by numeric key
|
|
Returns None if invalid.
|
|
"""
|
|
if isinstance(value, list):
|
|
return value
|
|
|
|
if isinstance(value, dict):
|
|
try:
|
|
keys = sorted(int(key) for key in value.keys())
|
|
return [value[str(i)] for i in keys]
|
|
except Exception:
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
def _normalize_tag_type(value: Any) -> str:
|
|
normalized = _as_key(value)
|
|
if normalized is None:
|
|
return "generic"
|
|
|
|
normalized = normalized.lower()
|
|
if normalized in {"generic", "catalog_entity", "sales_signal"}:
|
|
return normalized
|
|
|
|
return "generic"
|
|
|
|
|
|
def _extract_runtime_state(runtime: Any) -> Tuple[Optional[str], Optional[str], Optional[bool]]:
|
|
if not isinstance(runtime, dict):
|
|
return None, None, None
|
|
|
|
chunk_runtime = runtime.get("last_rebuild_at")
|
|
tags_runtime = runtime.get("last_tags_rebuild_at")
|
|
tags_index_present = runtime.get("tags_index_present")
|
|
|
|
if not isinstance(chunk_runtime, str):
|
|
chunk_runtime = None
|
|
if not isinstance(tags_runtime, str):
|
|
tags_runtime = None
|
|
if not isinstance(tags_index_present, bool):
|
|
tags_index_present = None
|
|
|
|
return chunk_runtime, tags_runtime, tags_index_present
|
|
|
|
|
|
def _validate_index_alignment(index_obj: Any, ids: Optional[List[Any]], label: str) -> Tuple[Any, Optional[List[Any]]]:
|
|
if index_obj is None or ids is None:
|
|
return None, None
|
|
|
|
try:
|
|
index_count = int(index_obj.ntotal)
|
|
except Exception:
|
|
logger.warning("[Reload] %s index has no ntotal -> disabled", label)
|
|
return None, None
|
|
|
|
if index_count != len(ids):
|
|
logger.warning(
|
|
"[Reload] %s meta/index mismatch (ids=%s index=%s) -> disabled",
|
|
label,
|
|
len(ids),
|
|
index_count,
|
|
)
|
|
return None, None
|
|
|
|
return index_obj, ids
|
|
|
|
|
|
def load_chunk_maps_from_ndjson() -> None:
|
|
global chunk_doc_map, chunk_pos_map
|
|
|
|
chunk_doc_map = {}
|
|
chunk_pos_map = {}
|
|
|
|
if not INDEX_NDJSON_PATH.exists():
|
|
return
|
|
|
|
try:
|
|
with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as handle:
|
|
for line in handle:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
row = json.loads(line)
|
|
except Exception:
|
|
continue
|
|
|
|
chunk_id_key = _as_key(row.get("chunk_id"))
|
|
if not chunk_id_key:
|
|
continue
|
|
|
|
doc_id_key = _as_key(row.get("document_id"))
|
|
if doc_id_key:
|
|
chunk_doc_map[chunk_id_key] = doc_id_key
|
|
|
|
chunk_index_value = row.get("chunk_index")
|
|
if isinstance(chunk_index_value, int):
|
|
chunk_pos_map[chunk_id_key] = chunk_index_value
|
|
elif isinstance(chunk_index_value, str):
|
|
stripped = chunk_index_value.strip()
|
|
if stripped.isdigit():
|
|
try:
|
|
chunk_pos_map[chunk_id_key] = int(stripped)
|
|
except Exception:
|
|
pass
|
|
except Exception as exc:
|
|
logger.warning("Failed to load chunk maps from ndjson: %s", str(exc))
|
|
|
|
|
|
def load_tag_meta_from_tags_ndjson() -> None:
|
|
"""
|
|
Loads minimal tag metadata from tags.ndjson to enrich /search-tags results.
|
|
Expected line format:
|
|
{
|
|
"tag_id": "...",
|
|
"text": "LABEL\\nSLUG\\noptional description",
|
|
"type": "catalog_entity|generic|sales_signal",
|
|
"document_ids": ["..."]
|
|
}
|
|
|
|
Only tags with at least one exported document id are kept.
|
|
"""
|
|
global tag_meta_map
|
|
|
|
tag_meta_map = {}
|
|
|
|
if not TAGS_NDJSON_PATH.exists():
|
|
logger.info("[Reload] tags.ndjson missing -> tag_meta_map empty (%s)", str(TAGS_NDJSON_PATH))
|
|
return
|
|
|
|
try:
|
|
with TAGS_NDJSON_PATH.open("r", encoding="utf-8") as handle:
|
|
for line in handle:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
row = json.loads(line)
|
|
except Exception:
|
|
continue
|
|
|
|
tag_id = _as_key(row.get("tag_id"))
|
|
if not tag_id:
|
|
continue
|
|
|
|
document_ids = row.get("document_ids")
|
|
if isinstance(document_ids, list) and len(document_ids) == 0:
|
|
continue
|
|
|
|
tag_type = _normalize_tag_type(row.get("type"))
|
|
label = ""
|
|
|
|
text_value = row.get("text")
|
|
if isinstance(text_value, str) and text_value.strip():
|
|
first_line = text_value.splitlines()[0].strip() if text_value.splitlines() else ""
|
|
label = first_line
|
|
|
|
tag_meta_map[tag_id] = {
|
|
"label": label,
|
|
"tag_type": tag_type,
|
|
}
|
|
except Exception as exc:
|
|
logger.warning("Failed to load tag meta from tags.ndjson: %s", str(exc))
|
|
tag_meta_map = {}
|
|
|
|
|
|
def load_all() -> None:
|
|
global model, chunk_index, chunk_ids
|
|
global tag_index, tag_ids
|
|
global loaded_embedding_model_name
|
|
global current_index_version
|
|
global current_chunk_runtime_stamp, current_tags_runtime_stamp, current_tags_index_present
|
|
|
|
with reload_lock:
|
|
meta = _safe_read_json(INDEX_META_PATH)
|
|
if not isinstance(meta, dict):
|
|
raise RuntimeError("index_meta.json not found or invalid")
|
|
|
|
embedding_model_name = meta.get("embedding_model")
|
|
index_version = meta.get("index_version")
|
|
|
|
if not embedding_model_name:
|
|
raise RuntimeError("embedding_model missing in index_meta.json")
|
|
|
|
if model is None or embedding_model_name != loaded_embedding_model_name:
|
|
logger.info("[Reload] Loading embedding model: %s", embedding_model_name)
|
|
model = SentenceTransformer(embedding_model_name)
|
|
loaded_embedding_model_name = embedding_model_name
|
|
|
|
runtime = _safe_read_json(INDEX_RUNTIME_PATH)
|
|
chunk_runtime_stamp, tags_runtime_stamp, tags_index_present = _extract_runtime_state(runtime)
|
|
|
|
# Chunks
|
|
if CHUNK_INDEX_PATH.exists() and CHUNK_MAP_PATH.exists():
|
|
logger.info("[Reload] Loading chunk index")
|
|
loaded_chunk_index = faiss.read_index(str(CHUNK_INDEX_PATH))
|
|
raw_chunk_meta = _safe_read_json(CHUNK_MAP_PATH)
|
|
loaded_chunk_ids = _normalize_meta_list(raw_chunk_meta)
|
|
if loaded_chunk_ids is None:
|
|
chunk_index = None
|
|
chunk_ids = None
|
|
logger.warning("[Reload] chunk_ids meta invalid -> chunk index disabled")
|
|
else:
|
|
chunk_index, chunk_ids = _validate_index_alignment(loaded_chunk_index, loaded_chunk_ids, "chunk")
|
|
else:
|
|
chunk_index = None
|
|
chunk_ids = None
|
|
|
|
logger.info("[Reload] Loading chunk maps (doc_id + chunk_index)")
|
|
load_chunk_maps_from_ndjson()
|
|
|
|
# Tags
|
|
should_load_tag_index = tags_index_present is not False
|
|
if should_load_tag_index and TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists():
|
|
logger.info("[Reload] Loading tag index")
|
|
loaded_tag_index = faiss.read_index(str(TAG_INDEX_PATH))
|
|
raw_tag_meta = _safe_read_json(TAG_MAP_PATH)
|
|
loaded_tag_ids = _normalize_meta_list(raw_tag_meta)
|
|
if loaded_tag_ids is None:
|
|
tag_index = None
|
|
tag_ids = None
|
|
logger.warning("[Reload] tag_ids meta invalid -> tag index disabled")
|
|
else:
|
|
tag_index, tag_ids = _validate_index_alignment(loaded_tag_index, loaded_tag_ids, "tag")
|
|
else:
|
|
tag_index = None
|
|
tag_ids = None
|
|
if tags_index_present is False:
|
|
logger.info("[Reload] Runtime marks tags index as absent -> tag index disabled")
|
|
|
|
logger.info("[Reload] Loading tag meta from tags.ndjson")
|
|
load_tag_meta_from_tags_ndjson()
|
|
|
|
current_index_version = index_version if isinstance(index_version, int) else None
|
|
current_chunk_runtime_stamp = chunk_runtime_stamp
|
|
current_tags_runtime_stamp = tags_runtime_stamp
|
|
current_tags_index_present = tags_index_present
|
|
|
|
logger.info(
|
|
"[Reload] Completed (index_version=%s chunk_runtime=%s tags_runtime=%s tags_index_present=%s embedding_model=%s tag_meta=%s stamp=%s file=%s)",
|
|
str(current_index_version),
|
|
str(current_chunk_runtime_stamp),
|
|
str(current_tags_runtime_stamp),
|
|
str(current_tags_index_present),
|
|
str(loaded_embedding_model_name),
|
|
str(len(tag_meta_map)),
|
|
SERVICE_STAMP,
|
|
str(Path(__file__).resolve()),
|
|
)
|
|
|
|
|
|
# ============================================================
|
|
# Observer
|
|
# ============================================================
|
|
|
|
def observer_loop() -> None:
|
|
global current_index_version
|
|
global current_chunk_runtime_stamp, current_tags_runtime_stamp, current_tags_index_present
|
|
|
|
while True:
|
|
time.sleep(2)
|
|
|
|
try:
|
|
meta = _safe_read_json(INDEX_META_PATH)
|
|
if not isinstance(meta, dict):
|
|
continue
|
|
|
|
new_version = meta.get("index_version") if isinstance(meta.get("index_version"), int) else None
|
|
runtime = _safe_read_json(INDEX_RUNTIME_PATH)
|
|
new_chunk_runtime, new_tags_runtime, new_tags_index_present = _extract_runtime_state(runtime)
|
|
|
|
if new_version != current_index_version:
|
|
logger.info(
|
|
"[Observer] index_version changed (%s -> %s) -> Reload",
|
|
str(current_index_version),
|
|
str(new_version),
|
|
)
|
|
load_all()
|
|
continue
|
|
|
|
if new_chunk_runtime != current_chunk_runtime_stamp:
|
|
logger.info(
|
|
"[Observer] chunk runtime changed (%s -> %s) -> Reload",
|
|
str(current_chunk_runtime_stamp),
|
|
str(new_chunk_runtime),
|
|
)
|
|
load_all()
|
|
continue
|
|
|
|
if new_tags_runtime != current_tags_runtime_stamp:
|
|
logger.info(
|
|
"[Observer] tags runtime changed (%s -> %s) -> Reload",
|
|
str(current_tags_runtime_stamp),
|
|
str(new_tags_runtime),
|
|
)
|
|
load_all()
|
|
continue
|
|
|
|
if new_tags_index_present != current_tags_index_present:
|
|
logger.info(
|
|
"[Observer] tags_index_present changed (%s -> %s) -> Reload",
|
|
str(current_tags_index_present),
|
|
str(new_tags_index_present),
|
|
)
|
|
load_all()
|
|
|
|
except Exception as exc:
|
|
logger.exception("[Observer ERROR] %s", str(exc))
|
|
|
|
|
|
# ============================================================
|
|
# Global Exception Handler
|
|
# ============================================================
|
|
|
|
@app.exception_handler(Exception)
|
|
async def unhandled_exception_handler(request: Request, exc: Exception):
|
|
logger.exception("UNHANDLED_EXCEPTION path=%s method=%s", request.url.path, request.method)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"error": "Internal Server Error",
|
|
"detail": str(exc),
|
|
"path": request.url.path,
|
|
"stamp": SERVICE_STAMP,
|
|
},
|
|
)
|
|
|
|
|
|
# ============================================================
|
|
# Startup
|
|
# ============================================================
|
|
|
|
@app.on_event("startup")
|
|
def startup_event() -> None:
|
|
setup_logging()
|
|
logger.info("[VectorService] Startup stamp=%s file=%s", SERVICE_STAMP, str(Path(__file__).resolve()))
|
|
load_all()
|
|
observer = threading.Thread(target=observer_loop, daemon=True)
|
|
observer.start()
|
|
logger.info("[VectorService] Ready (log=%s)", str(LOG_FILE))
|
|
|
|
|
|
# ============================================================
|
|
# Endpoints
|
|
# ============================================================
|
|
|
|
@app.get("/health")
|
|
def health() -> Dict[str, Any]:
|
|
return {
|
|
"status": "ok",
|
|
"stamp": SERVICE_STAMP,
|
|
"file": str(Path(__file__).resolve()),
|
|
"chunk_index_loaded": chunk_index is not None,
|
|
"tag_index_loaded": tag_index is not None,
|
|
"model_loaded": model is not None,
|
|
"embedding_model": loaded_embedding_model_name,
|
|
"index_version": current_index_version,
|
|
"chunk_runtime_stamp": current_chunk_runtime_stamp,
|
|
"tags_runtime_stamp": current_tags_runtime_stamp,
|
|
"tags_index_present": current_tags_index_present,
|
|
"tag_meta_type": type(tag_ids).__name__ if tag_ids is not None else None,
|
|
"tag_meta_len": len(tag_ids) if isinstance(tag_ids, list) else None,
|
|
"chunk_meta_type": type(chunk_ids).__name__ if chunk_ids is not None else None,
|
|
"chunk_meta_len": len(chunk_ids) if isinstance(chunk_ids, list) else None,
|
|
"tag_meta_map_len": len(tag_meta_map),
|
|
"tags_ndjson_path": str(TAGS_NDJSON_PATH),
|
|
"log_file": str(LOG_FILE),
|
|
}
|
|
|
|
|
|
@app.post("/reload")
|
|
def reload() -> Dict[str, str]:
|
|
try:
|
|
load_all()
|
|
return {"status": "reloaded", "stamp": SERVICE_STAMP}
|
|
except Exception as exc:
|
|
logger.exception("reload failed")
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@app.post("/search-chunks")
|
|
def search_chunks(req: SearchRequest) -> List[Dict[str, Any]]:
|
|
if chunk_index is None or chunk_ids is None or model is None:
|
|
raise HTTPException(status_code=503, detail="Chunk index not available")
|
|
|
|
try:
|
|
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
|
|
|
|
query = (req.query or "").strip()
|
|
if not query:
|
|
raise HTTPException(status_code=400, detail="query must not be empty")
|
|
|
|
query_vec = model.encode([f"query: {query}"], normalize_embeddings=True)
|
|
query_vec = np.array(query_vec).astype("float32")
|
|
|
|
effective_limit = limit
|
|
doc_filter: Optional[List[str]] = None
|
|
if req.doc_ids:
|
|
doc_filter = []
|
|
for document_id in req.doc_ids:
|
|
document_key = _as_key(document_id)
|
|
if document_key:
|
|
doc_filter.append(document_key)
|
|
effective_limit = max(limit * 5, 50)
|
|
effective_limit = min(effective_limit, 500)
|
|
|
|
scores, indices = chunk_index.search(query_vec, effective_limit)
|
|
|
|
results: List[Dict[str, Any]] = []
|
|
for score, idx in zip(scores[0], indices[0]):
|
|
if idx == -1:
|
|
continue
|
|
if idx < 0 or idx >= len(chunk_ids):
|
|
continue
|
|
|
|
raw_chunk_id = chunk_ids[idx]
|
|
chunk_id_key = _as_key(raw_chunk_id)
|
|
if not chunk_id_key:
|
|
continue
|
|
|
|
document_id = chunk_doc_map.get(chunk_id_key)
|
|
if doc_filter is not None:
|
|
if document_id is None or document_id not in doc_filter:
|
|
continue
|
|
|
|
payload: Dict[str, Any] = {
|
|
"chunk_id": raw_chunk_id,
|
|
"score": float(score),
|
|
"document_id": document_id,
|
|
}
|
|
|
|
chunk_position = chunk_pos_map.get(chunk_id_key)
|
|
if isinstance(chunk_position, int):
|
|
payload["chunk_index"] = chunk_position
|
|
|
|
results.append(payload)
|
|
|
|
if len(results) >= limit:
|
|
break
|
|
|
|
return results
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.exception("search-chunks failure")
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@app.post("/search-tags")
|
|
def search_tags(req: SearchRequest) -> List[Dict[str, Any]]:
|
|
if tag_index is None or tag_ids is None or model is None:
|
|
raise HTTPException(status_code=503, detail="Tag index not available")
|
|
|
|
try:
|
|
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
|
|
|
|
query = (req.query or "").strip()
|
|
if not query:
|
|
raise HTTPException(status_code=400, detail="query must not be empty")
|
|
|
|
query_vec = model.encode([f"query: {query}"], normalize_embeddings=True)
|
|
query_vec = np.array(query_vec).astype("float32")
|
|
|
|
if query_vec.ndim != 2:
|
|
raise RuntimeError(f"Invalid embedding shape: {query_vec.shape}")
|
|
|
|
if query_vec.shape[1] != tag_index.d:
|
|
raise RuntimeError(f"Embedding dimension mismatch (vec={query_vec.shape[1]}, index={tag_index.d})")
|
|
|
|
scores, indices = tag_index.search(query_vec, limit)
|
|
|
|
results: List[Dict[str, Any]] = []
|
|
seen_tag_ids = set()
|
|
|
|
for score, idx in zip(scores[0], indices[0]):
|
|
if idx == -1:
|
|
continue
|
|
if idx < 0 or idx >= len(tag_ids):
|
|
continue
|
|
|
|
raw_tag_id = tag_ids[idx]
|
|
tag_id_key = _as_key(raw_tag_id)
|
|
if not tag_id_key or tag_id_key in seen_tag_ids:
|
|
continue
|
|
|
|
payload: Dict[str, Any] = {
|
|
"tag_id": raw_tag_id,
|
|
"score": float(score),
|
|
}
|
|
|
|
meta = tag_meta_map.get(tag_id_key)
|
|
if isinstance(meta, dict):
|
|
label = meta.get("label")
|
|
tag_type = meta.get("tag_type")
|
|
|
|
if isinstance(label, str):
|
|
payload["label"] = label.strip()
|
|
payload["tag_type"] = _normalize_tag_type(tag_type)
|
|
else:
|
|
payload["label"] = ""
|
|
payload["tag_type"] = "generic"
|
|
|
|
results.append(payload)
|
|
seen_tag_ids.add(tag_id_key)
|
|
|
|
if len(results) >= limit:
|
|
break
|
|
|
|
return results
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.exception("search-tags failure")
|
|
raise HTTPException(status_code=500, detail=str(exc)) |