#!/usr/bin/env python3 import json import logging from logging.handlers import RotatingFileHandler import threading import time from pathlib import Path from typing import Any, List, Optional, Dict import numpy as np import faiss from fastapi import FastAPI, HTTPException from pydantic import BaseModel from sentence_transformers import SentenceTransformer # ============================================================ # Paths # ============================================================ BASE_PATH = Path(__file__).resolve().parents[2] KNOWLEDGE_DIR = BASE_PATH / "var" / "knowledge" LOG_DIR = BASE_PATH / "var" / "log" LOG_FILE = LOG_DIR / "vector_service.log" CHUNK_INDEX_PATH = KNOWLEDGE_DIR / "vector.index" CHUNK_MAP_PATH = KNOWLEDGE_DIR / "vector.index.meta.json" TAG_INDEX_PATH = KNOWLEDGE_DIR / "vector_tags.index" TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json" INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json" INDEX_RUNTIME_PATH = KNOWLEDGE_DIR / "index_runtime.json" INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson" # ============================================================ # Logging # ============================================================ logger = logging.getLogger("vector_service") logger.setLevel(logging.INFO) def setup_logging() -> None: LOG_DIR.mkdir(parents=True, exist_ok=True) fmt = logging.Formatter( fmt="%(asctime)s %(levelname)s %(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z", ) file_handler = RotatingFileHandler( str(LOG_FILE), maxBytes=10 * 1024 * 1024, backupCount=5, encoding="utf-8", ) file_handler.setFormatter(fmt) file_handler.setLevel(logging.INFO) stream_handler = logging.StreamHandler() stream_handler.setFormatter(fmt) stream_handler.setLevel(logging.INFO) if not any(isinstance(h, RotatingFileHandler) for h in logger.handlers): logger.addHandler(file_handler) if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers): logger.addHandler(stream_handler) # ============================================================ # FastAPI # ============================================================ app = FastAPI() model: Optional[SentenceTransformer] = None chunk_index = None chunk_ids: Optional[List[Any]] = None # Sales-RAG signals derived from NDJSON (loaded on startup and reload): # - chunk_doc_map: chunk_id -> document_id # - chunk_pos_map: chunk_id -> chunk_index (position within document, if available) chunk_doc_map: Dict[str, str] = {} chunk_pos_map: Dict[str, int] = {} tag_index = None tag_ids: Optional[List[Any]] = None loaded_embedding_model_name: Optional[str] = None current_index_version: Optional[int] = None current_runtime_stamp: Optional[str] = None reload_lock = threading.Lock() # ============================================================ # Models # ============================================================ class SearchRequest(BaseModel): query: str limit: int = 8 doc_ids: Optional[List[str]] = None # ============================================================ # Loader # ============================================================ def _safe_read_json(path: Path) -> Optional[dict]: try: if not path.exists(): return None return json.loads(path.read_text(encoding="utf-8")) except Exception as e: logger.warning("Failed to read json %s: %s", str(path), str(e)) return None def _as_key(value: Any) -> Optional[str]: """ Normalize IDs to string keys for maps. Returns None if unusable. """ if value is None: return None if isinstance(value, str): v = value.strip() return v if v else None try: v = str(value).strip() return v if v else None except Exception: return None def load_chunk_maps_from_ndjson() -> None: """ Builds two maps from index.ndjson: - chunk_id -> document_id - chunk_id -> chunk_index (position inside document, if present) """ global chunk_doc_map, chunk_pos_map chunk_doc_map = {} chunk_pos_map = {} if not INDEX_NDJSON_PATH.exists(): return try: with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: row = json.loads(line) except Exception: continue chunk_id_key = _as_key(row.get("chunk_id")) if not chunk_id_key: continue document_id = row.get("document_id") doc_id_key = _as_key(document_id) if doc_id_key: chunk_doc_map[chunk_id_key] = doc_id_key # chunk_index is optional but very useful for Sales-RAG diversity rules # (e.g. min distance within a doc) ci = row.get("chunk_index") if isinstance(ci, int): chunk_pos_map[chunk_id_key] = ci else: # tolerate numeric strings if isinstance(ci, str): s = ci.strip() if s.isdigit(): try: chunk_pos_map[chunk_id_key] = int(s) except Exception: pass except Exception as e: logger.warning("Failed to load chunk maps from ndjson: %s", str(e)) def _sanitize_limit(limit: int, default: int = 8, max_limit: int = 200) -> int: try: v = int(limit) except Exception: return default if v <= 0: return default if v > max_limit: return max_limit return v def load_all() -> None: global model, chunk_index, chunk_ids global tag_index, tag_ids global loaded_embedding_model_name global current_index_version global current_runtime_stamp with reload_lock: meta = _safe_read_json(INDEX_META_PATH) if not isinstance(meta, dict): raise RuntimeError("index_meta.json not found or invalid") embedding_model_name = meta.get("embedding_model") index_version = meta.get("index_version") if not embedding_model_name: raise RuntimeError("embedding_model missing in index_meta.json") if model is None or embedding_model_name != loaded_embedding_model_name: logger.info("[Reload] Loading embedding model: %s", embedding_model_name) model = SentenceTransformer(embedding_model_name) loaded_embedding_model_name = embedding_model_name if CHUNK_INDEX_PATH.exists() and CHUNK_MAP_PATH.exists(): logger.info("[Reload] Loading chunk index") chunk_index = faiss.read_index(str(CHUNK_INDEX_PATH)) chunk_ids = _safe_read_json(CHUNK_MAP_PATH) or None if not isinstance(chunk_ids, list): chunk_index = None chunk_ids = None logger.warning("[Reload] chunk_ids meta invalid -> chunk index disabled") else: chunk_index = None chunk_ids = None logger.info("[Reload] Loading chunk maps (doc_id + chunk_index)") load_chunk_maps_from_ndjson() if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists(): logger.info("[Reload] Loading tag index") tag_index = faiss.read_index(str(TAG_INDEX_PATH)) tag_ids = _safe_read_json(TAG_MAP_PATH) or None if not isinstance(tag_ids, list): tag_index = None tag_ids = None logger.warning("[Reload] tag_ids meta invalid -> tag index disabled") else: tag_index = None tag_ids = None runtime = _safe_read_json(INDEX_RUNTIME_PATH) if isinstance(runtime, dict): v = runtime.get("last_rebuild_at") current_runtime_stamp = v if isinstance(v, str) else None else: current_runtime_stamp = None current_index_version = index_version if isinstance(index_version, int) else None logger.info( "[Reload] Completed (index_version=%s runtime=%s embedding_model=%s)", str(current_index_version), str(current_runtime_stamp), str(loaded_embedding_model_name), ) # ============================================================ # Observer (Enterprise Auto Reload) # ============================================================ def observer_loop() -> None: global current_index_version global current_runtime_stamp while True: time.sleep(2) try: meta = _safe_read_json(INDEX_META_PATH) if not isinstance(meta, dict): continue new_version = meta.get("index_version") if isinstance(meta.get("index_version"), int) else None runtime = _safe_read_json(INDEX_RUNTIME_PATH) new_runtime = None if isinstance(runtime, dict): v = runtime.get("last_rebuild_at") new_runtime = v if isinstance(v, str) else None if new_version != current_index_version: logger.info( "[Observer] index_version changed (%s -> %s) -> Reload", str(current_index_version), str(new_version), ) load_all() continue if new_runtime != current_runtime_stamp: logger.info( "[Observer] runtime changed (%s -> %s) -> Reload", str(current_runtime_stamp), str(new_runtime), ) load_all() except Exception as e: logger.error("[Observer ERROR] %s", str(e)) # ============================================================ # Startup # ============================================================ @app.on_event("startup") def startup_event(): setup_logging() logger.info("[VectorService] Startup") load_all() t = threading.Thread(target=observer_loop, daemon=True) t.start() logger.info("[VectorService] Ready (log=%s)", str(LOG_FILE)) # ============================================================ # Endpoints # ============================================================ @app.get("/health") def health(): return { "status": "ok", "chunk_index_loaded": chunk_index is not None, "tag_index_loaded": tag_index is not None, "model_loaded": model is not None, "embedding_model": loaded_embedding_model_name, "index_version": current_index_version, "runtime_stamp": current_runtime_stamp, "log_file": str(LOG_FILE), } @app.post("/reload") def reload(): try: load_all() return {"status": "reloaded"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/search-chunks") def search_chunks(req: SearchRequest): if chunk_index is None or chunk_ids is None or model is None: raise HTTPException(status_code=503, detail="Chunk index not available") # Safety: clamp limit to prevent abuse / accidental huge queries limit = _sanitize_limit(req.limit, default=8, max_limit=200) query = (req.query or "").strip() if not query: raise HTTPException(status_code=400, detail="query must not be empty") query_vec = model.encode( [f"query: {query}"], normalize_embeddings=True ) query_vec = np.array(query_vec).astype("float32") effective_limit = limit doc_filter: Optional[List[str]] = None if req.doc_ids: # Normalize incoming doc_ids for reliable matching doc_filter = [] for d in req.doc_ids: dk = _as_key(d) if dk: doc_filter.append(dk) # When doc filtering is enabled, we fetch a wider pool and filter down. # Keep it bounded to avoid expensive scans on huge indices. effective_limit = max(limit * 5, 50) effective_limit = min(effective_limit, 500) scores, indices = chunk_index.search(query_vec, effective_limit) results = [] for score, idx in zip(scores[0], indices[0]): if idx == -1: continue if idx < 0 or idx >= len(chunk_ids): continue raw_chunk_id = chunk_ids[idx] chunk_id_key = _as_key(raw_chunk_id) if not chunk_id_key: continue # Apply doc filter if requested doc_id = chunk_doc_map.get(chunk_id_key) if doc_filter is not None: if doc_id is None or doc_id not in doc_filter: continue # Sales-RAG signals: # - document_id (for doc quotas / diversity rules) # - chunk_index (position within doc for distance constraints) payload = { "chunk_id": raw_chunk_id, "score": float(score), "document_id": doc_id, # may be None if ndjson missing/partial } ci = chunk_pos_map.get(chunk_id_key) if isinstance(ci, int): payload["chunk_index"] = ci results.append(payload) if len(results) >= limit: break return results @app.post("/search-tags") def search_tags(req: SearchRequest): if tag_index is None or tag_ids is None or model is None: raise HTTPException(status_code=503, detail="Tag index not available") limit = _sanitize_limit(req.limit, default=8, max_limit=200) query = (req.query or "").strip() if not query: raise HTTPException(status_code=400, detail="query must not be empty") query_vec = model.encode( [f"query: {query}"], normalize_embeddings=True ) query_vec = np.array(query_vec).astype("float32") scores, indices = tag_index.search(query_vec, limit) results = [] for score, idx in zip(scores[0], indices[0]): if idx == -1: continue if idx < 0 or idx >= len(tag_ids): continue results.append({ "tag_id": tag_ids[idx], "score": float(score), }) return results