#!/usr/bin/env python3 import json import logging from logging.handlers import RotatingFileHandler import threading import time from pathlib import Path from typing import Any, List, Optional, Dict import numpy as np import faiss from fastapi import FastAPI, HTTPException from pydantic import BaseModel from sentence_transformers import SentenceTransformer # ============================================================ # Paths # ============================================================ BASE_PATH = Path(__file__).resolve().parents[2] KNOWLEDGE_DIR = BASE_PATH / "var" / "knowledge" LOG_DIR = BASE_PATH / "var" / "log" LOG_FILE = LOG_DIR / "vector_service.log" CHUNK_INDEX_PATH = KNOWLEDGE_DIR / "vector.index" CHUNK_MAP_PATH = KNOWLEDGE_DIR / "vector.index.meta.json" TAG_INDEX_PATH = KNOWLEDGE_DIR / "vector_tags.index" TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json" INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json" INDEX_RUNTIME_PATH = KNOWLEDGE_DIR / "index_runtime.json" INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson" # ============================================================ # Logging # ============================================================ logger = logging.getLogger("vector_service") logger.setLevel(logging.INFO) def setup_logging() -> None: LOG_DIR.mkdir(parents=True, exist_ok=True) fmt = logging.Formatter( fmt="%(asctime)s %(levelname)s %(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z", ) # Rotating file file_handler = RotatingFileHandler( str(LOG_FILE), maxBytes=10 * 1024 * 1024, # 10MB backupCount=5, encoding="utf-8", ) file_handler.setFormatter(fmt) file_handler.setLevel(logging.INFO) # Console (stdout) stream_handler = logging.StreamHandler() stream_handler.setFormatter(fmt) stream_handler.setLevel(logging.INFO) # avoid duplicate handlers if uvicorn reloads workers if not any(isinstance(h, RotatingFileHandler) for h in logger.handlers): logger.addHandler(file_handler) if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers): logger.addHandler(stream_handler) # ============================================================ # FastAPI # ============================================================ app = FastAPI() model: Optional[SentenceTransformer] = None chunk_index = None chunk_ids: Optional[List[Any]] = None chunk_doc_map: Dict[str, str] = {} tag_index = None tag_ids: Optional[List[Any]] = None loaded_embedding_model_name: Optional[str] = None current_index_version: Optional[int] = None current_runtime_stamp: Optional[str] = None reload_lock = threading.Lock() # ============================================================ # Models # ============================================================ class SearchRequest(BaseModel): query: str limit: int = 8 doc_ids: Optional[List[str]] = None # ============================================================ # Loader # ============================================================ def _safe_read_json(path: Path) -> Optional[dict]: try: if not path.exists(): return None return json.loads(path.read_text(encoding="utf-8")) except Exception as e: logger.warning("Failed to read json %s: %s", str(path), str(e)) return None def load_chunk_doc_map() -> None: global chunk_doc_map chunk_doc_map = {} if not INDEX_NDJSON_PATH.exists(): return try: with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as f: for line in f: try: row = json.loads(line) except Exception: continue chunk_id = row.get("chunk_id") document_id = row.get("document_id") if isinstance(chunk_id, str) and isinstance(document_id, str): chunk_doc_map[chunk_id] = document_id except Exception as e: logger.warning("Failed to load chunk-doc map from ndjson: %s", str(e)) def load_all() -> None: """ Reload everything deterministically (model + indices + maps), guarded by reload_lock (thread-safe). """ global model, chunk_index, chunk_ids global tag_index, tag_ids global loaded_embedding_model_name global current_index_version global current_runtime_stamp with reload_lock: meta = _safe_read_json(INDEX_META_PATH) if not isinstance(meta, dict): raise RuntimeError("index_meta.json not found or invalid") embedding_model_name = meta.get("embedding_model") index_version = meta.get("index_version") if not embedding_model_name: raise RuntimeError("embedding_model missing in index_meta.json") # Reload model if needed if model is None or embedding_model_name != loaded_embedding_model_name: logger.info("[Reload] Loading embedding model: %s", embedding_model_name) model = SentenceTransformer(embedding_model_name) loaded_embedding_model_name = embedding_model_name # Reload chunk index if CHUNK_INDEX_PATH.exists() and CHUNK_MAP_PATH.exists(): logger.info("[Reload] Loading chunk index") chunk_index = faiss.read_index(str(CHUNK_INDEX_PATH)) chunk_ids = _safe_read_json(CHUNK_MAP_PATH) or None if not isinstance(chunk_ids, list): chunk_index = None chunk_ids = None logger.warning("[Reload] chunk_ids meta invalid -> chunk index disabled") else: chunk_index = None chunk_ids = None # Load chunk → document map logger.info("[Reload] Loading chunk-doc map") load_chunk_doc_map() # Reload tag index if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists(): logger.info("[Reload] Loading tag index") tag_index = faiss.read_index(str(TAG_INDEX_PATH)) tag_ids = _safe_read_json(TAG_MAP_PATH) or None if not isinstance(tag_ids, list): tag_index = None tag_ids = None logger.warning("[Reload] tag_ids meta invalid -> tag index disabled") else: tag_index = None tag_ids = None # Runtime stamp (commit marker for tags+chunks) runtime = _safe_read_json(INDEX_RUNTIME_PATH) if isinstance(runtime, dict): v = runtime.get("last_rebuild_at") current_runtime_stamp = v if isinstance(v, str) else None else: current_runtime_stamp = None current_index_version = index_version if isinstance(index_version, int) else None logger.info("[Reload] Completed (index_version=%s runtime=%s)", str(current_index_version), str(current_runtime_stamp)) # ============================================================ # Observer (Enterprise Auto Reload) # ============================================================ def observer_loop() -> None: global current_index_version global current_runtime_stamp while True: time.sleep(2) try: meta = _safe_read_json(INDEX_META_PATH) if not isinstance(meta, dict): continue new_version = meta.get("index_version") if isinstance(meta.get("index_version"), int) else None runtime = _safe_read_json(INDEX_RUNTIME_PATH) new_runtime = None if isinstance(runtime, dict): v = runtime.get("last_rebuild_at") new_runtime = v if isinstance(v, str) else None # Structure change (embedding, dim, scoring_version, etc.) -> reload if new_version != current_index_version: logger.info("[Observer] index_version changed (%s -> %s) -> Reload", str(current_index_version), str(new_version)) load_all() continue # Content change (chunks OR tags) -> reload if new_runtime != current_runtime_stamp: logger.info("[Observer] runtime changed (%s -> %s) -> Reload", str(current_runtime_stamp), str(new_runtime)) load_all() except Exception as e: logger.error("[Observer ERROR] %s", str(e)) # ============================================================ # Startup # ============================================================ @app.on_event("startup") def startup_event(): setup_logging() logger.info("[VectorService] Startup") load_all() t = threading.Thread(target=observer_loop, daemon=True) t.start() logger.info("[VectorService] Ready (log=%s)", str(LOG_FILE)) # ============================================================ # Endpoints # ============================================================ @app.get("/health") def health(): return { "status": "ok", "chunk_index_loaded": chunk_index is not None, "tag_index_loaded": tag_index is not None, "model_loaded": model is not None, "index_version": current_index_version, "runtime_stamp": current_runtime_stamp, "log_file": str(LOG_FILE), } @app.post("/reload") def reload(): """ Manual reload endpoint (kept for compatibility with mto:agent:vector:control --reload). Auto-reload still runs via observer_loop. """ try: load_all() return {"status": "reloaded"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/search-chunks") def search_chunks(req: SearchRequest): if chunk_index is None or chunk_ids is None or model is None: raise HTTPException(status_code=503, detail="Chunk index not available") query_vec = model.encode([req.query], normalize_embeddings=True) query_vec = np.array(query_vec).astype("float32") effective_limit = req.limit if req.doc_ids: effective_limit = max(req.limit * 5, 50) scores, indices = chunk_index.search(query_vec, effective_limit) results = [] for score, idx in zip(scores[0], indices[0]): if idx == -1: continue if idx < 0 or idx >= len(chunk_ids): continue chunk_id = chunk_ids[idx] if req.doc_ids: doc_id = chunk_doc_map.get(chunk_id) if doc_id not in req.doc_ids: continue results.append({ "chunk_id": chunk_id, "score": float(score), }) if len(results) >= req.limit: break return results @app.post("/search-tags") def search_tags(req: SearchRequest): if tag_index is None or tag_ids is None or model is None: raise HTTPException(status_code=503, detail="Tag index not available") query_vec = model.encode([req.query], normalize_embeddings=True) query_vec = np.array(query_vec).astype("float32") scores, indices = tag_index.search(query_vec, req.limit) results = [] for score, idx in zip(scores[0], indices[0]): if idx == -1: continue if idx < 0 or idx >= len(tag_ids): continue results.append({ "tag_id": tag_ids[idx], "score": float(score), }) return results