#!/usr/bin/env python3 import json import os import logging from logging.handlers import RotatingFileHandler import threading import time from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import faiss import numpy as np from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse from pydantic import BaseModel # Keep HuggingFace/SentenceTransformer model loading deterministic. os.environ.setdefault("HF_HUB_DISABLE_XET", "1") os.environ.setdefault("HF_HUB_ETAG_TIMEOUT", "10") os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "30") from sentence_transformers import SentenceTransformer # ============================================================ # Service Stamp (to verify you are running THIS file) # ============================================================ SERVICE_STAMP = "vector_service.py@2026-04-20T00:00+02:00" # ============================================================ # Paths # ============================================================ BASE_PATH = Path(__file__).resolve().parents[2] KNOWLEDGE_DIR = BASE_PATH / "var" / "knowledge" LOG_DIR = BASE_PATH / "var" / "log" LOG_FILE = LOG_DIR / "vector_service.log" CHUNK_INDEX_PATH = KNOWLEDGE_DIR / "vector.index" CHUNK_MAP_PATH = KNOWLEDGE_DIR / "vector.index.meta.json" TAG_INDEX_PATH = KNOWLEDGE_DIR / "vector_tags.index" TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json" INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json" INDEX_RUNTIME_PATH = KNOWLEDGE_DIR / "index_runtime.json" INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson" TAGS_NDJSON_PATH = KNOWLEDGE_DIR / "tags.ndjson" # ============================================================ # Logging # ============================================================ logger = logging.getLogger("vector_service") logger.setLevel(logging.INFO) # ============================================================ # App State # ============================================================ app = FastAPI() model: Optional[SentenceTransformer] = None chunk_index = None chunk_ids: Optional[List[Any]] = None chunk_doc_map: Dict[str, str] = {} chunk_pos_map: Dict[str, int] = {} tag_index = None tag_ids: Optional[List[Any]] = None # tag_id -> {"label": "...", "tag_type": "..."} tag_meta_map: Dict[str, Dict[str, str]] = {} loaded_embedding_model_name: Optional[str] = None current_index_version: Optional[int] = None current_chunk_runtime_stamp: Optional[str] = None current_tags_runtime_stamp: Optional[str] = None current_tags_index_present: Optional[bool] = None reload_lock = threading.Lock() # ============================================================ # Models # ============================================================ class SearchRequest(BaseModel): query: str limit: int = 8 doc_ids: Optional[List[str]] = None # ============================================================ # Helpers # ============================================================ def setup_logging() -> None: LOG_DIR.mkdir(parents=True, exist_ok=True) fmt = logging.Formatter( fmt="%(asctime)s %(levelname)s %(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z", ) file_handler = RotatingFileHandler( str(LOG_FILE), maxBytes=10 * 1024 * 1024, backupCount=5, encoding="utf-8", ) file_handler.setFormatter(fmt) file_handler.setLevel(logging.INFO) stream_handler = logging.StreamHandler() stream_handler.setFormatter(fmt) stream_handler.setLevel(logging.INFO) if not any(isinstance(h, RotatingFileHandler) for h in logger.handlers): logger.addHandler(file_handler) if not any(type(h) is logging.StreamHandler for h in logger.handlers): logger.addHandler(stream_handler) uvicorn_error = logging.getLogger("uvicorn.error") uvicorn_access = logging.getLogger("uvicorn.access") uvicorn_error.setLevel(logging.INFO) uvicorn_access.setLevel(logging.INFO) if not any(isinstance(h, RotatingFileHandler) for h in uvicorn_error.handlers): uvicorn_error.addHandler(file_handler) if not any(type(h) is logging.StreamHandler for h in uvicorn_error.handlers): uvicorn_error.addHandler(stream_handler) if not any(isinstance(h, RotatingFileHandler) for h in uvicorn_access.handlers): uvicorn_access.addHandler(file_handler) if not any(type(h) is logging.StreamHandler for h in uvicorn_access.handlers): uvicorn_access.addHandler(stream_handler) def _safe_read_json(path: Path) -> Optional[Any]: try: if not path.exists(): return None return json.loads(path.read_text(encoding="utf-8")) except Exception as exc: logger.warning("Failed to read json %s: %s", str(path), str(exc)) return None def _resolve_embedding_model_name(configured_model_name: str) -> str: # A local model directory avoids implicit network/cache lookups in production. model_override = os.environ.get("RETRIEX_EMBEDDING_MODEL_PATH", "").strip() if model_override: return model_override return configured_model_name.strip() def _as_key(value: Any) -> Optional[str]: if value is None: return None if isinstance(value, str): value = value.strip() return value or None try: value = str(value).strip() return value or None except Exception: return None def _sanitize_limit(limit: int, default: int = 8, max_limit: int = 200) -> int: try: value = int(limit) except Exception: return default if value <= 0: return default if value > max_limit: return max_limit return value def _normalize_meta_list(value: Any) -> Optional[List[Any]]: """ Accepts: - list: ok - dict like {"0": "...", "1": "..."}: convert to list sorted by numeric key Returns None if invalid. """ if isinstance(value, list): return value if isinstance(value, dict): try: keys = sorted(int(key) for key in value.keys()) return [value[str(i)] for i in keys] except Exception: return None return None def _normalize_tag_type(value: Any) -> str: normalized = _as_key(value) if normalized is None: return "generic" normalized = normalized.lower() if normalized in {"generic", "catalog_entity", "sales_signal"}: return normalized return "generic" def _extract_runtime_state(runtime: Any) -> Tuple[Optional[str], Optional[str], Optional[bool]]: if not isinstance(runtime, dict): return None, None, None chunk_runtime = runtime.get("last_rebuild_at") tags_runtime = runtime.get("last_tags_rebuild_at") tags_index_present = runtime.get("tags_index_present") if not isinstance(chunk_runtime, str): chunk_runtime = None if not isinstance(tags_runtime, str): tags_runtime = None if not isinstance(tags_index_present, bool): tags_index_present = None return chunk_runtime, tags_runtime, tags_index_present def _validate_index_alignment(index_obj: Any, ids: Optional[List[Any]], label: str) -> Tuple[Any, Optional[List[Any]]]: if index_obj is None or ids is None: return None, None try: index_count = int(index_obj.ntotal) except Exception: logger.warning("[Reload] %s index has no ntotal -> disabled", label) return None, None if index_count != len(ids): logger.warning( "[Reload] %s meta/index mismatch (ids=%s index=%s) -> disabled", label, len(ids), index_count, ) return None, None return index_obj, ids def load_chunk_maps_from_ndjson() -> None: global chunk_doc_map, chunk_pos_map chunk_doc_map = {} chunk_pos_map = {} if not INDEX_NDJSON_PATH.exists(): return try: with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue try: row = json.loads(line) except Exception: continue chunk_id_key = _as_key(row.get("chunk_id")) if not chunk_id_key: continue doc_id_key = _as_key(row.get("document_id")) if doc_id_key: chunk_doc_map[chunk_id_key] = doc_id_key chunk_index_value = row.get("chunk_index") if isinstance(chunk_index_value, int): chunk_pos_map[chunk_id_key] = chunk_index_value elif isinstance(chunk_index_value, str): stripped = chunk_index_value.strip() if stripped.isdigit(): try: chunk_pos_map[chunk_id_key] = int(stripped) except Exception: pass except Exception as exc: logger.warning("Failed to load chunk maps from ndjson: %s", str(exc)) def load_tag_meta_from_tags_ndjson() -> None: """ Loads minimal tag metadata from tags.ndjson to enrich /search-tags results. Expected line format: { "tag_id": "...", "text": "LABEL\\nSLUG\\noptional description", "type": "catalog_entity|generic|sales_signal", "document_ids": ["..."] } Only tags with at least one exported document id are kept. """ global tag_meta_map tag_meta_map = {} if not TAGS_NDJSON_PATH.exists(): logger.info("[Reload] tags.ndjson missing -> tag_meta_map empty (%s)", str(TAGS_NDJSON_PATH)) return try: with TAGS_NDJSON_PATH.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue try: row = json.loads(line) except Exception: continue tag_id = _as_key(row.get("tag_id")) if not tag_id: continue document_ids = row.get("document_ids") if isinstance(document_ids, list) and len(document_ids) == 0: continue tag_type = _normalize_tag_type(row.get("type")) label = "" text_value = row.get("text") if isinstance(text_value, str) and text_value.strip(): first_line = text_value.splitlines()[0].strip() if text_value.splitlines() else "" label = first_line tag_meta_map[tag_id] = { "label": label, "tag_type": tag_type, } except Exception as exc: logger.warning("Failed to load tag meta from tags.ndjson: %s", str(exc)) tag_meta_map = {} def load_all() -> None: global model, chunk_index, chunk_ids global tag_index, tag_ids global loaded_embedding_model_name global current_index_version global current_chunk_runtime_stamp, current_tags_runtime_stamp, current_tags_index_present with reload_lock: meta = _safe_read_json(INDEX_META_PATH) if not isinstance(meta, dict): raise RuntimeError("index_meta.json not found or invalid") embedding_model_name = meta.get("embedding_model") index_version = meta.get("index_version") if not embedding_model_name: raise RuntimeError("embedding_model missing in index_meta.json") resolved_embedding_model_name = _resolve_embedding_model_name(str(embedding_model_name)) if model is None or resolved_embedding_model_name != loaded_embedding_model_name: logger.info("[Reload] Loading embedding model: %s", resolved_embedding_model_name) model = SentenceTransformer(resolved_embedding_model_name) loaded_embedding_model_name = resolved_embedding_model_name runtime = _safe_read_json(INDEX_RUNTIME_PATH) chunk_runtime_stamp, tags_runtime_stamp, tags_index_present = _extract_runtime_state(runtime) # Chunks if CHUNK_INDEX_PATH.exists() and CHUNK_MAP_PATH.exists(): logger.info("[Reload] Loading chunk index") loaded_chunk_index = faiss.read_index(str(CHUNK_INDEX_PATH)) raw_chunk_meta = _safe_read_json(CHUNK_MAP_PATH) loaded_chunk_ids = _normalize_meta_list(raw_chunk_meta) if loaded_chunk_ids is None: chunk_index = None chunk_ids = None logger.warning("[Reload] chunk_ids meta invalid -> chunk index disabled") else: chunk_index, chunk_ids = _validate_index_alignment(loaded_chunk_index, loaded_chunk_ids, "chunk") else: chunk_index = None chunk_ids = None logger.info("[Reload] Loading chunk maps (doc_id + chunk_index)") load_chunk_maps_from_ndjson() # Tags should_load_tag_index = tags_index_present is not False if should_load_tag_index and TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists(): logger.info("[Reload] Loading tag index") loaded_tag_index = faiss.read_index(str(TAG_INDEX_PATH)) raw_tag_meta = _safe_read_json(TAG_MAP_PATH) loaded_tag_ids = _normalize_meta_list(raw_tag_meta) if loaded_tag_ids is None: tag_index = None tag_ids = None logger.warning("[Reload] tag_ids meta invalid -> tag index disabled") else: tag_index, tag_ids = _validate_index_alignment(loaded_tag_index, loaded_tag_ids, "tag") else: tag_index = None tag_ids = None if tags_index_present is False: logger.info("[Reload] Runtime marks tags index as absent -> tag index disabled") logger.info("[Reload] Loading tag meta from tags.ndjson") load_tag_meta_from_tags_ndjson() current_index_version = index_version if isinstance(index_version, int) else None current_chunk_runtime_stamp = chunk_runtime_stamp current_tags_runtime_stamp = tags_runtime_stamp current_tags_index_present = tags_index_present logger.info( "[Reload] Completed (index_version=%s chunk_runtime=%s tags_runtime=%s tags_index_present=%s embedding_model=%s tag_meta=%s stamp=%s file=%s)", str(current_index_version), str(current_chunk_runtime_stamp), str(current_tags_runtime_stamp), str(current_tags_index_present), str(loaded_embedding_model_name), str(len(tag_meta_map)), SERVICE_STAMP, str(Path(__file__).resolve()), ) # ============================================================ # Observer # ============================================================ def observer_loop() -> None: global current_index_version global current_chunk_runtime_stamp, current_tags_runtime_stamp, current_tags_index_present while True: time.sleep(2) try: meta = _safe_read_json(INDEX_META_PATH) if not isinstance(meta, dict): continue new_version = meta.get("index_version") if isinstance(meta.get("index_version"), int) else None runtime = _safe_read_json(INDEX_RUNTIME_PATH) new_chunk_runtime, new_tags_runtime, new_tags_index_present = _extract_runtime_state(runtime) if new_version != current_index_version: logger.info( "[Observer] index_version changed (%s -> %s) -> Reload", str(current_index_version), str(new_version), ) load_all() continue if new_chunk_runtime != current_chunk_runtime_stamp: logger.info( "[Observer] chunk runtime changed (%s -> %s) -> Reload", str(current_chunk_runtime_stamp), str(new_chunk_runtime), ) load_all() continue if new_tags_runtime != current_tags_runtime_stamp: logger.info( "[Observer] tags runtime changed (%s -> %s) -> Reload", str(current_tags_runtime_stamp), str(new_tags_runtime), ) load_all() continue if new_tags_index_present != current_tags_index_present: logger.info( "[Observer] tags_index_present changed (%s -> %s) -> Reload", str(current_tags_index_present), str(new_tags_index_present), ) load_all() except Exception as exc: logger.exception("[Observer ERROR] %s", str(exc)) # ============================================================ # Global Exception Handler # ============================================================ @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): logger.exception("UNHANDLED_EXCEPTION path=%s method=%s", request.url.path, request.method) return JSONResponse( status_code=500, content={ "error": "Internal Server Error", "detail": str(exc), "path": request.url.path, "stamp": SERVICE_STAMP, }, ) # ============================================================ # Startup # ============================================================ @app.on_event("startup") def startup_event() -> None: setup_logging() logger.info("[VectorService] Startup stamp=%s file=%s", SERVICE_STAMP, str(Path(__file__).resolve())) load_all() observer = threading.Thread(target=observer_loop, daemon=True) observer.start() logger.info("[VectorService] Ready (log=%s)", str(LOG_FILE)) # ============================================================ # Endpoints # ============================================================ @app.get("/health") def health() -> Dict[str, Any]: return { "status": "ok", "stamp": SERVICE_STAMP, "file": str(Path(__file__).resolve()), "chunk_index_loaded": chunk_index is not None, "tag_index_loaded": tag_index is not None, "model_loaded": model is not None, "embedding_model": loaded_embedding_model_name, "index_version": current_index_version, "chunk_runtime_stamp": current_chunk_runtime_stamp, "tags_runtime_stamp": current_tags_runtime_stamp, "tags_index_present": current_tags_index_present, "tag_meta_type": type(tag_ids).__name__ if tag_ids is not None else None, "tag_meta_len": len(tag_ids) if isinstance(tag_ids, list) else None, "chunk_meta_type": type(chunk_ids).__name__ if chunk_ids is not None else None, "chunk_meta_len": len(chunk_ids) if isinstance(chunk_ids, list) else None, "tag_meta_map_len": len(tag_meta_map), "tags_ndjson_path": str(TAGS_NDJSON_PATH), "log_file": str(LOG_FILE), } @app.post("/reload") def reload() -> Dict[str, str]: try: load_all() return {"status": "reloaded", "stamp": SERVICE_STAMP} except Exception as exc: logger.exception("reload failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/search-chunks") def search_chunks(req: SearchRequest) -> List[Dict[str, Any]]: if chunk_index is None or chunk_ids is None or model is None: raise HTTPException(status_code=503, detail="Chunk index not available") try: limit = _sanitize_limit(req.limit, default=8, max_limit=200) query = (req.query or "").strip() if not query: raise HTTPException(status_code=400, detail="query must not be empty") query_vec = model.encode([f"query: {query}"], normalize_embeddings=True) query_vec = np.array(query_vec).astype("float32") effective_limit = limit doc_filter: Optional[List[str]] = None if req.doc_ids: doc_filter = [] for document_id in req.doc_ids: document_key = _as_key(document_id) if document_key: doc_filter.append(document_key) effective_limit = max(limit * 5, 50) effective_limit = min(effective_limit, 500) scores, indices = chunk_index.search(query_vec, effective_limit) results: List[Dict[str, Any]] = [] for score, idx in zip(scores[0], indices[0]): if idx == -1: continue if idx < 0 or idx >= len(chunk_ids): continue raw_chunk_id = chunk_ids[idx] chunk_id_key = _as_key(raw_chunk_id) if not chunk_id_key: continue document_id = chunk_doc_map.get(chunk_id_key) if doc_filter is not None: if document_id is None or document_id not in doc_filter: continue payload: Dict[str, Any] = { "chunk_id": raw_chunk_id, "score": float(score), "document_id": document_id, } chunk_position = chunk_pos_map.get(chunk_id_key) if isinstance(chunk_position, int): payload["chunk_index"] = chunk_position results.append(payload) if len(results) >= limit: break return results except HTTPException: raise except Exception as exc: logger.exception("search-chunks failure") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/search-tags") def search_tags(req: SearchRequest) -> List[Dict[str, Any]]: if tag_index is None or tag_ids is None or model is None: raise HTTPException(status_code=503, detail="Tag index not available") try: limit = _sanitize_limit(req.limit, default=8, max_limit=200) query = (req.query or "").strip() if not query: raise HTTPException(status_code=400, detail="query must not be empty") query_vec = model.encode([f"query: {query}"], normalize_embeddings=True) query_vec = np.array(query_vec).astype("float32") if query_vec.ndim != 2: raise RuntimeError(f"Invalid embedding shape: {query_vec.shape}") if query_vec.shape[1] != tag_index.d: raise RuntimeError(f"Embedding dimension mismatch (vec={query_vec.shape[1]}, index={tag_index.d})") scores, indices = tag_index.search(query_vec, limit) results: List[Dict[str, Any]] = [] seen_tag_ids = set() for score, idx in zip(scores[0], indices[0]): if idx == -1: continue if idx < 0 or idx >= len(tag_ids): continue raw_tag_id = tag_ids[idx] tag_id_key = _as_key(raw_tag_id) if not tag_id_key or tag_id_key in seen_tag_ids: continue payload: Dict[str, Any] = { "tag_id": raw_tag_id, "score": float(score), } meta = tag_meta_map.get(tag_id_key) if isinstance(meta, dict): label = meta.get("label") tag_type = meta.get("tag_type") if isinstance(label, str): payload["label"] = label.strip() payload["tag_type"] = _normalize_tag_type(tag_type) else: payload["label"] = "" payload["tag_type"] = "generic" results.append(payload) seen_tag_ids.add(tag_id_key) if len(results) >= limit: break return results except HTTPException: raise except Exception as exc: logger.exception("search-tags failure") raise HTTPException(status_code=500, detail=str(exc))