224 lines
6.1 KiB
Python
224 lines
6.1 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any, List, Optional, Dict
|
|
|
|
import numpy as np
|
|
import faiss
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
# ============================================================
|
|
# Paths
|
|
# ============================================================
|
|
|
|
BASE_PATH = Path(__file__).resolve().parents[2]
|
|
KNOWLEDGE_DIR = BASE_PATH / "var" / "knowledge"
|
|
|
|
CHUNK_INDEX_PATH = KNOWLEDGE_DIR / "vector.index"
|
|
CHUNK_MAP_PATH = KNOWLEDGE_DIR / "vector.index.meta.json"
|
|
|
|
TAG_INDEX_PATH = KNOWLEDGE_DIR / "vector_tags.index"
|
|
TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json"
|
|
|
|
INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json"
|
|
INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson"
|
|
|
|
|
|
# ============================================================
|
|
# FastAPI
|
|
# ============================================================
|
|
|
|
app = FastAPI()
|
|
|
|
model: Optional[SentenceTransformer] = None
|
|
chunk_index = None
|
|
chunk_ids: Optional[List[Any]] = None
|
|
chunk_doc_map: Dict[str, str] = {}
|
|
|
|
tag_index = None
|
|
tag_ids: Optional[List[Any]] = None
|
|
|
|
loaded_embedding_model_name: Optional[str] = None
|
|
|
|
|
|
# ============================================================
|
|
# Models
|
|
# ============================================================
|
|
|
|
class SearchRequest(BaseModel):
|
|
query: str
|
|
limit: int = 8
|
|
doc_ids: Optional[List[str]] = None # NEW
|
|
|
|
|
|
# ============================================================
|
|
# Loader
|
|
# ============================================================
|
|
|
|
def load_chunk_doc_map():
|
|
global chunk_doc_map
|
|
|
|
chunk_doc_map = {}
|
|
|
|
if not INDEX_NDJSON_PATH.exists():
|
|
return
|
|
|
|
with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as f:
|
|
for line in f:
|
|
try:
|
|
row = json.loads(line)
|
|
except Exception:
|
|
continue
|
|
|
|
chunk_id = row.get("chunk_id")
|
|
document_id = row.get("document_id")
|
|
|
|
if isinstance(chunk_id, str) and isinstance(document_id, str):
|
|
chunk_doc_map[chunk_id] = document_id
|
|
|
|
|
|
def load_all():
|
|
global model, chunk_index, chunk_ids, tag_index, tag_ids, loaded_embedding_model_name
|
|
|
|
if not INDEX_META_PATH.exists():
|
|
raise RuntimeError("index_meta.json not found")
|
|
|
|
meta = json.loads(INDEX_META_PATH.read_text())
|
|
embedding_model_name = meta.get("embedding_model")
|
|
|
|
if not embedding_model_name:
|
|
raise RuntimeError("embedding_model missing in index_meta.json")
|
|
|
|
# Reload model only if changed
|
|
if model is None or embedding_model_name != loaded_embedding_model_name:
|
|
print(f"[Reload] Loading embedding model: {embedding_model_name}")
|
|
model = SentenceTransformer(embedding_model_name)
|
|
loaded_embedding_model_name = embedding_model_name
|
|
|
|
# Reload chunk index
|
|
if CHUNK_INDEX_PATH.exists() and CHUNK_MAP_PATH.exists():
|
|
print("[Reload] Loading chunk index")
|
|
chunk_index = faiss.read_index(str(CHUNK_INDEX_PATH))
|
|
chunk_ids = json.loads(CHUNK_MAP_PATH.read_text())
|
|
else:
|
|
chunk_index = None
|
|
chunk_ids = None
|
|
|
|
# Load chunk → document map
|
|
print("[Reload] Loading chunk-doc map")
|
|
load_chunk_doc_map()
|
|
|
|
# Reload tag index
|
|
if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists():
|
|
print("[Reload] Loading tag index")
|
|
tag_index = faiss.read_index(str(TAG_INDEX_PATH))
|
|
tag_ids = json.loads(TAG_MAP_PATH.read_text())
|
|
else:
|
|
tag_index = None
|
|
tag_ids = None
|
|
|
|
print("[Reload] Completed")
|
|
|
|
|
|
# ============================================================
|
|
# Startup
|
|
# ============================================================
|
|
|
|
@app.on_event("startup")
|
|
def startup_event():
|
|
load_all()
|
|
print("[VectorService] Ready")
|
|
|
|
|
|
# ============================================================
|
|
# Endpoints
|
|
# ============================================================
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {
|
|
"status": "ok",
|
|
"chunk_index_loaded": chunk_index is not None,
|
|
"tag_index_loaded": tag_index is not None,
|
|
"model_loaded": model is not None,
|
|
}
|
|
|
|
|
|
@app.post("/reload")
|
|
def reload():
|
|
try:
|
|
load_all()
|
|
return {"status": "reloaded"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/search-chunks")
|
|
def search_chunks(req: SearchRequest):
|
|
if chunk_index is None or chunk_ids is None:
|
|
raise HTTPException(status_code=503, detail="Chunk index not available")
|
|
|
|
query_vec = model.encode([req.query], normalize_embeddings=True)
|
|
query_vec = np.array(query_vec).astype("float32")
|
|
|
|
# Wenn doc_ids gesetzt sind → mehr holen, dann filtern
|
|
effective_limit = req.limit
|
|
if req.doc_ids:
|
|
effective_limit = max(req.limit * 5, 50)
|
|
|
|
scores, indices = chunk_index.search(query_vec, effective_limit)
|
|
|
|
results = []
|
|
|
|
for score, idx in zip(scores[0], indices[0]):
|
|
if idx == -1:
|
|
continue
|
|
if idx < 0 or idx >= len(chunk_ids):
|
|
continue
|
|
|
|
chunk_id = chunk_ids[idx]
|
|
|
|
# NEW: doc-scoped filter
|
|
if req.doc_ids:
|
|
doc_id = chunk_doc_map.get(chunk_id)
|
|
if doc_id not in req.doc_ids:
|
|
continue
|
|
|
|
results.append({
|
|
"chunk_id": chunk_id,
|
|
"score": float(score),
|
|
})
|
|
|
|
if len(results) >= req.limit:
|
|
break
|
|
|
|
return results
|
|
|
|
|
|
@app.post("/search-tags")
|
|
def search_tags(req: SearchRequest):
|
|
if tag_index is None or tag_ids is None:
|
|
raise HTTPException(status_code=503, detail="Tag index not available")
|
|
|
|
query_vec = model.encode([req.query], normalize_embeddings=True)
|
|
query_vec = np.array(query_vec).astype("float32")
|
|
|
|
scores, indices = tag_index.search(query_vec, req.limit)
|
|
|
|
results = []
|
|
for score, idx in zip(scores[0], indices[0]):
|
|
if idx == -1:
|
|
continue
|
|
if idx < 0 or idx >= len(tag_ids):
|
|
continue
|
|
|
|
results.append({
|
|
"tag_id": tag_ids[idx],
|
|
"score": float(score),
|
|
})
|
|
|
|
return results |