Files
MtoRagSystem/python/vector/vector_service.py
2026-02-26 18:36:57 +01:00

354 lines
10 KiB
Python

#!/usr/bin/env python3
import json
import logging
from logging.handlers import RotatingFileHandler
import threading
import time
from pathlib import Path
from typing import Any, List, Optional, Dict
import numpy as np
import faiss
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
# ============================================================
# Paths
# ============================================================
BASE_PATH = Path(__file__).resolve().parents[2]
KNOWLEDGE_DIR = BASE_PATH / "var" / "knowledge"
LOG_DIR = BASE_PATH / "var" / "log"
LOG_FILE = LOG_DIR / "vector_service.log"
CHUNK_INDEX_PATH = KNOWLEDGE_DIR / "vector.index"
CHUNK_MAP_PATH = KNOWLEDGE_DIR / "vector.index.meta.json"
TAG_INDEX_PATH = KNOWLEDGE_DIR / "vector_tags.index"
TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json"
INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json"
INDEX_RUNTIME_PATH = KNOWLEDGE_DIR / "index_runtime.json"
INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson"
# ============================================================
# Logging
# ============================================================
logger = logging.getLogger("vector_service")
logger.setLevel(logging.INFO)
def setup_logging() -> None:
LOG_DIR.mkdir(parents=True, exist_ok=True)
fmt = logging.Formatter(
fmt="%(asctime)s %(levelname)s %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S%z",
)
file_handler = RotatingFileHandler(
str(LOG_FILE),
maxBytes=10 * 1024 * 1024,
backupCount=5,
encoding="utf-8",
)
file_handler.setFormatter(fmt)
file_handler.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(fmt)
stream_handler.setLevel(logging.INFO)
if not any(isinstance(h, RotatingFileHandler) for h in logger.handlers):
logger.addHandler(file_handler)
if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
logger.addHandler(stream_handler)
# ============================================================
# FastAPI
# ============================================================
app = FastAPI()
model: Optional[SentenceTransformer] = None
chunk_index = None
chunk_ids: Optional[List[Any]] = None
chunk_doc_map: Dict[str, str] = {}
tag_index = None
tag_ids: Optional[List[Any]] = None
loaded_embedding_model_name: Optional[str] = None
current_index_version: Optional[int] = None
current_runtime_stamp: Optional[str] = None
reload_lock = threading.Lock()
# ============================================================
# Models
# ============================================================
class SearchRequest(BaseModel):
query: str
limit: int = 8
doc_ids: Optional[List[str]] = None
# ============================================================
# Loader
# ============================================================
def _safe_read_json(path: Path) -> Optional[dict]:
try:
if not path.exists():
return None
return json.loads(path.read_text(encoding="utf-8"))
except Exception as e:
logger.warning("Failed to read json %s: %s", str(path), str(e))
return None
def load_chunk_doc_map() -> None:
global chunk_doc_map
chunk_doc_map = {}
if not INDEX_NDJSON_PATH.exists():
return
try:
with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as f:
for line in f:
try:
row = json.loads(line)
except Exception:
continue
chunk_id = row.get("chunk_id")
document_id = row.get("document_id")
if isinstance(chunk_id, str) and isinstance(document_id, str):
chunk_doc_map[chunk_id] = document_id
except Exception as e:
logger.warning("Failed to load chunk-doc map from ndjson: %s", str(e))
def load_all() -> None:
global model, chunk_index, chunk_ids
global tag_index, tag_ids
global loaded_embedding_model_name
global current_index_version
global current_runtime_stamp
with reload_lock:
meta = _safe_read_json(INDEX_META_PATH)
if not isinstance(meta, dict):
raise RuntimeError("index_meta.json not found or invalid")
embedding_model_name = meta.get("embedding_model")
index_version = meta.get("index_version")
if not embedding_model_name:
raise RuntimeError("embedding_model missing in index_meta.json")
if model is None or embedding_model_name != loaded_embedding_model_name:
logger.info("[Reload] Loading embedding model: %s", embedding_model_name)
model = SentenceTransformer(embedding_model_name)
loaded_embedding_model_name = embedding_model_name
if CHUNK_INDEX_PATH.exists() and CHUNK_MAP_PATH.exists():
logger.info("[Reload] Loading chunk index")
chunk_index = faiss.read_index(str(CHUNK_INDEX_PATH))
chunk_ids = _safe_read_json(CHUNK_MAP_PATH) or None
if not isinstance(chunk_ids, list):
chunk_index = None
chunk_ids = None
logger.warning("[Reload] chunk_ids meta invalid -> chunk index disabled")
else:
chunk_index = None
chunk_ids = None
logger.info("[Reload] Loading chunk-doc map")
load_chunk_doc_map()
if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists():
logger.info("[Reload] Loading tag index")
tag_index = faiss.read_index(str(TAG_INDEX_PATH))
tag_ids = _safe_read_json(TAG_MAP_PATH) or None
if not isinstance(tag_ids, list):
tag_index = None
tag_ids = None
logger.warning("[Reload] tag_ids meta invalid -> tag index disabled")
else:
tag_index = None
tag_ids = None
runtime = _safe_read_json(INDEX_RUNTIME_PATH)
if isinstance(runtime, dict):
v = runtime.get("last_rebuild_at")
current_runtime_stamp = v if isinstance(v, str) else None
else:
current_runtime_stamp = None
current_index_version = index_version if isinstance(index_version, int) else None
logger.info("[Reload] Completed (index_version=%s runtime=%s)", str(current_index_version), str(current_runtime_stamp))
# ============================================================
# Observer (Enterprise Auto Reload)
# ============================================================
def observer_loop() -> None:
global current_index_version
global current_runtime_stamp
while True:
time.sleep(2)
try:
meta = _safe_read_json(INDEX_META_PATH)
if not isinstance(meta, dict):
continue
new_version = meta.get("index_version") if isinstance(meta.get("index_version"), int) else None
runtime = _safe_read_json(INDEX_RUNTIME_PATH)
new_runtime = None
if isinstance(runtime, dict):
v = runtime.get("last_rebuild_at")
new_runtime = v if isinstance(v, str) else None
if new_version != current_index_version:
logger.info("[Observer] index_version changed (%s -> %s) -> Reload", str(current_index_version), str(new_version))
load_all()
continue
if new_runtime != current_runtime_stamp:
logger.info("[Observer] runtime changed (%s -> %s) -> Reload", str(current_runtime_stamp), str(new_runtime))
load_all()
except Exception as e:
logger.error("[Observer ERROR] %s", str(e))
# ============================================================
# Startup
# ============================================================
@app.on_event("startup")
def startup_event():
setup_logging()
logger.info("[VectorService] Startup")
load_all()
t = threading.Thread(target=observer_loop, daemon=True)
t.start()
logger.info("[VectorService] Ready (log=%s)", str(LOG_FILE))
# ============================================================
# Endpoints
# ============================================================
@app.get("/health")
def health():
return {
"status": "ok",
"chunk_index_loaded": chunk_index is not None,
"tag_index_loaded": tag_index is not None,
"model_loaded": model is not None,
"index_version": current_index_version,
"runtime_stamp": current_runtime_stamp,
"log_file": str(LOG_FILE),
}
@app.post("/reload")
def reload():
try:
load_all()
return {"status": "reloaded"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/search-chunks")
def search_chunks(req: SearchRequest):
if chunk_index is None or chunk_ids is None or model is None:
raise HTTPException(status_code=503, detail="Chunk index not available")
query_vec = model.encode(
[f"query: {req.query}"],
normalize_embeddings=True
)
query_vec = np.array(query_vec).astype("float32")
effective_limit = req.limit
if req.doc_ids:
effective_limit = max(req.limit * 5, 50)
scores, indices = chunk_index.search(query_vec, effective_limit)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1:
continue
if idx < 0 or idx >= len(chunk_ids):
continue
chunk_id = chunk_ids[idx]
if req.doc_ids:
doc_id = chunk_doc_map.get(chunk_id)
if doc_id not in req.doc_ids:
continue
results.append({
"chunk_id": chunk_id,
"score": float(score),
})
if len(results) >= req.limit:
break
return results
@app.post("/search-tags")
def search_tags(req: SearchRequest):
if tag_index is None or tag_ids is None or model is None:
raise HTTPException(status_code=503, detail="Tag index not available")
query_vec = model.encode(
[f"query: {req.query}"],
normalize_embeddings=True
)
query_vec = np.array(query_vec).astype("float32")
scores, indices = tag_index.search(query_vec, req.limit)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1:
continue
if idx < 0 or idx >= len(tag_ids):
continue
results.append({
"tag_id": tag_ids[idx],
"score": float(score),
})
return results