#!/usr/bin/env python3 import json import logging from logging.handlers import RotatingFileHandler import threading import time from pathlib import Path from typing import Any, List, Optional, Dict import numpy as np import faiss from fastapi import FastAPI, HTTPException from pydantic import BaseModel from sentence_transformers import SentenceTransformer # ============================================================ # Paths # ============================================================ BASE_PATH = Path(__file__).resolve().parents[2] KNOWLEDGE_DIR = BASE_PATH / "var" / "knowledge" LOG_DIR = BASE_PATH / "var" / "log" LOG_FILE = LOG_DIR / "vector_service.log" CHUNK_INDEX_PATH = KNOWLEDGE_DIR / "vector.index" CHUNK_MAP_PATH = KNOWLEDGE_DIR / "vector.index.meta.json" TAG_INDEX_PATH = KNOWLEDGE_DIR / "vector_tags.index" TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json" INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json" INDEX_RUNTIME_PATH = KNOWLEDGE_DIR / "index_runtime.json" INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson" # ============================================================ # Logging # ============================================================ logger = logging.getLogger("vector_service") logger.setLevel(logging.INFO) def setup_logging() -> None: LOG_DIR.mkdir(parents=True, exist_ok=True) fmt = logging.Formatter( fmt="%(asctime)s %(levelname)s %(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z", ) file_handler = RotatingFileHandler( str(LOG_FILE), maxBytes=10 * 1024 * 1024, backupCount=5, encoding="utf-8", ) file_handler.setFormatter(fmt) file_handler.setLevel(logging.INFO) stream_handler = logging.StreamHandler() stream_handler.setFormatter(fmt) stream_handler.setLevel(logging.INFO) if not any(isinstance(h, RotatingFileHandler) for h in logger.handlers): logger.addHandler(file_handler) if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers): logger.addHandler(stream_handler) # ============================================================ # FastAPI # ============================================================ app = FastAPI() model: Optional[SentenceTransformer] = None chunk_index = None chunk_ids: Optional[List[Any]] = None chunk_doc_map: Dict[str, str] = {} tag_index = None tag_ids: Optional[List[Any]] = None loaded_embedding_model_name: Optional[str] = None current_index_version: Optional[int] = None current_runtime_stamp: Optional[str] = None reload_lock = threading.Lock() # ============================================================ # Models # ============================================================ class SearchRequest(BaseModel): query: str limit: int = 8 doc_ids: Optional[List[str]] = None # ============================================================ # Loader # ============================================================ def _safe_read_json(path: Path) -> Optional[dict]: try: if not path.exists(): return None return json.loads(path.read_text(encoding="utf-8")) except Exception as e: logger.warning("Failed to read json %s: %s", str(path), str(e)) return None def load_chunk_doc_map() -> None: global chunk_doc_map chunk_doc_map = {} if not INDEX_NDJSON_PATH.exists(): return try: with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as f: for line in f: try: row = json.loads(line) except Exception: continue chunk_id = row.get("chunk_id") document_id = row.get("document_id") if isinstance(chunk_id, str) and isinstance(document_id, str): chunk_doc_map[chunk_id] = document_id except Exception as e: logger.warning("Failed to load chunk-doc map from ndjson: %s", str(e)) def load_all() -> None: global model, chunk_index, chunk_ids global tag_index, tag_ids global loaded_embedding_model_name global current_index_version global current_runtime_stamp with reload_lock: meta = _safe_read_json(INDEX_META_PATH) if not isinstance(meta, dict): raise RuntimeError("index_meta.json not found or invalid") embedding_model_name = meta.get("embedding_model") index_version = meta.get("index_version") if not embedding_model_name: raise RuntimeError("embedding_model missing in index_meta.json") if model is None or embedding_model_name != loaded_embedding_model_name: logger.info("[Reload] Loading embedding model: %s", embedding_model_name) model = SentenceTransformer(embedding_model_name) loaded_embedding_model_name = embedding_model_name if CHUNK_INDEX_PATH.exists() and CHUNK_MAP_PATH.exists(): logger.info("[Reload] Loading chunk index") chunk_index = faiss.read_index(str(CHUNK_INDEX_PATH)) chunk_ids = _safe_read_json(CHUNK_MAP_PATH) or None if not isinstance(chunk_ids, list): chunk_index = None chunk_ids = None logger.warning("[Reload] chunk_ids meta invalid -> chunk index disabled") else: chunk_index = None chunk_ids = None logger.info("[Reload] Loading chunk-doc map") load_chunk_doc_map() if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists(): logger.info("[Reload] Loading tag index") tag_index = faiss.read_index(str(TAG_INDEX_PATH)) tag_ids = _safe_read_json(TAG_MAP_PATH) or None if not isinstance(tag_ids, list): tag_index = None tag_ids = None logger.warning("[Reload] tag_ids meta invalid -> tag index disabled") else: tag_index = None tag_ids = None runtime = _safe_read_json(INDEX_RUNTIME_PATH) if isinstance(runtime, dict): v = runtime.get("last_rebuild_at") current_runtime_stamp = v if isinstance(v, str) else None else: current_runtime_stamp = None current_index_version = index_version if isinstance(index_version, int) else None logger.info("[Reload] Completed (index_version=%s runtime=%s)", str(current_index_version), str(current_runtime_stamp)) # ============================================================ # Observer (Enterprise Auto Reload) # ============================================================ def observer_loop() -> None: global current_index_version global current_runtime_stamp while True: time.sleep(2) try: meta = _safe_read_json(INDEX_META_PATH) if not isinstance(meta, dict): continue new_version = meta.get("index_version") if isinstance(meta.get("index_version"), int) else None runtime = _safe_read_json(INDEX_RUNTIME_PATH) new_runtime = None if isinstance(runtime, dict): v = runtime.get("last_rebuild_at") new_runtime = v if isinstance(v, str) else None if new_version != current_index_version: logger.info("[Observer] index_version changed (%s -> %s) -> Reload", str(current_index_version), str(new_version)) load_all() continue if new_runtime != current_runtime_stamp: logger.info("[Observer] runtime changed (%s -> %s) -> Reload", str(current_runtime_stamp), str(new_runtime)) load_all() except Exception as e: logger.error("[Observer ERROR] %s", str(e)) # ============================================================ # Startup # ============================================================ @app.on_event("startup") def startup_event(): setup_logging() logger.info("[VectorService] Startup") load_all() t = threading.Thread(target=observer_loop, daemon=True) t.start() logger.info("[VectorService] Ready (log=%s)", str(LOG_FILE)) # ============================================================ # Endpoints # ============================================================ @app.get("/health") def health(): return { "status": "ok", "chunk_index_loaded": chunk_index is not None, "tag_index_loaded": tag_index is not None, "model_loaded": model is not None, "index_version": current_index_version, "runtime_stamp": current_runtime_stamp, "log_file": str(LOG_FILE), } @app.post("/reload") def reload(): try: load_all() return {"status": "reloaded"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/search-chunks") def search_chunks(req: SearchRequest): if chunk_index is None or chunk_ids is None or model is None: raise HTTPException(status_code=503, detail="Chunk index not available") query_vec = model.encode( [f"query: {req.query}"], normalize_embeddings=True ) query_vec = np.array(query_vec).astype("float32") effective_limit = req.limit if req.doc_ids: effective_limit = max(req.limit * 5, 50) scores, indices = chunk_index.search(query_vec, effective_limit) results = [] for score, idx in zip(scores[0], indices[0]): if idx == -1: continue if idx < 0 or idx >= len(chunk_ids): continue chunk_id = chunk_ids[idx] if req.doc_ids: doc_id = chunk_doc_map.get(chunk_id) if doc_id not in req.doc_ids: continue results.append({ "chunk_id": chunk_id, "score": float(score), }) if len(results) >= req.limit: break return results @app.post("/search-tags") def search_tags(req: SearchRequest): if tag_index is None or tag_ids is None or model is None: raise HTTPException(status_code=503, detail="Tag index not available") query_vec = model.encode( [f"query: {req.query}"], normalize_embeddings=True ) query_vec = np.array(query_vec).astype("float32") scores, indices = tag_index.search(query_vec, req.limit) results = [] for score, idx in zip(scores[0], indices[0]): if idx == -1: continue if idx < 0 or idx >= len(tag_ids): continue results.append({ "tag_id": tag_ids[idx], "score": float(score), }) return results