alpha new hybridretriver line
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Dict
|
||||
|
||||
import numpy as np
|
||||
import faiss
|
||||
@@ -25,6 +25,7 @@ TAG_INDEX_PATH = KNOWLEDGE_DIR / "vector_tags.index"
|
||||
TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json"
|
||||
|
||||
INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json"
|
||||
INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson"
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -36,8 +37,11 @@ app = FastAPI()
|
||||
model: Optional[SentenceTransformer] = None
|
||||
chunk_index = None
|
||||
chunk_ids: Optional[List[Any]] = None
|
||||
chunk_doc_map: Dict[str, str] = {}
|
||||
|
||||
tag_index = None
|
||||
tag_ids: Optional[List[Any]] = None
|
||||
|
||||
loaded_embedding_model_name: Optional[str] = None
|
||||
|
||||
|
||||
@@ -48,12 +52,35 @@ loaded_embedding_model_name: Optional[str] = None
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
limit: int = 8
|
||||
doc_ids: Optional[List[str]] = None # NEW
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Loader
|
||||
# ============================================================
|
||||
|
||||
def load_chunk_doc_map():
|
||||
global chunk_doc_map
|
||||
|
||||
chunk_doc_map = {}
|
||||
|
||||
if not INDEX_NDJSON_PATH.exists():
|
||||
return
|
||||
|
||||
with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
chunk_id = row.get("chunk_id")
|
||||
document_id = row.get("document_id")
|
||||
|
||||
if isinstance(chunk_id, str) and isinstance(document_id, str):
|
||||
chunk_doc_map[chunk_id] = document_id
|
||||
|
||||
|
||||
def load_all():
|
||||
global model, chunk_index, chunk_ids, tag_index, tag_ids, loaded_embedding_model_name
|
||||
|
||||
@@ -81,6 +108,10 @@ def load_all():
|
||||
chunk_index = None
|
||||
chunk_ids = None
|
||||
|
||||
# Load chunk → document map
|
||||
print("[Reload] Loading chunk-doc map")
|
||||
load_chunk_doc_map()
|
||||
|
||||
# Reload tag index
|
||||
if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists():
|
||||
print("[Reload] Loading tag index")
|
||||
@@ -134,20 +165,37 @@ def search_chunks(req: SearchRequest):
|
||||
query_vec = model.encode([req.query], normalize_embeddings=True)
|
||||
query_vec = np.array(query_vec).astype("float32")
|
||||
|
||||
scores, indices = chunk_index.search(query_vec, req.limit)
|
||||
# Wenn doc_ids gesetzt sind → mehr holen, dann filtern
|
||||
effective_limit = req.limit
|
||||
if req.doc_ids:
|
||||
effective_limit = max(req.limit * 5, 50)
|
||||
|
||||
scores, indices = chunk_index.search(query_vec, effective_limit)
|
||||
|
||||
results = []
|
||||
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx == -1:
|
||||
continue
|
||||
if idx < 0 or idx >= len(chunk_ids):
|
||||
continue
|
||||
|
||||
chunk_id = chunk_ids[idx]
|
||||
|
||||
# NEW: doc-scoped filter
|
||||
if req.doc_ids:
|
||||
doc_id = chunk_doc_map.get(chunk_id)
|
||||
if doc_id not in req.doc_ids:
|
||||
continue
|
||||
|
||||
results.append({
|
||||
"chunk_id": chunk_ids[idx],
|
||||
"chunk_id": chunk_id,
|
||||
"score": float(score),
|
||||
})
|
||||
|
||||
if len(results) >= req.limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user