optimize as sales rag

This commit is contained in:
team2
2026-02-27 21:03:59 +01:00
parent efa9b17c2f
commit 3a5804e44c
6 changed files with 541 additions and 213 deletions

View File

@@ -78,7 +78,12 @@ app = FastAPI()
model: Optional[SentenceTransformer] = None
chunk_index = None
chunk_ids: Optional[List[Any]] = None
# Sales-RAG signals derived from NDJSON (loaded on startup and reload):
# - chunk_doc_map: chunk_id -> document_id
# - chunk_pos_map: chunk_id -> chunk_index (position within document, if available)
chunk_doc_map: Dict[str, str] = {}
chunk_pos_map: Dict[str, int] = {}
tag_index = None
tag_ids: Optional[List[Any]] = None
@@ -115,10 +120,32 @@ def _safe_read_json(path: Path) -> Optional[dict]:
return None
def load_chunk_doc_map() -> None:
global chunk_doc_map
def _as_key(value: Any) -> Optional[str]:
"""
Normalize IDs to string keys for maps. Returns None if unusable.
"""
if value is None:
return None
if isinstance(value, str):
v = value.strip()
return v if v else None
try:
v = str(value).strip()
return v if v else None
except Exception:
return None
def load_chunk_maps_from_ndjson() -> None:
"""
Builds two maps from index.ndjson:
- chunk_id -> document_id
- chunk_id -> chunk_index (position inside document, if present)
"""
global chunk_doc_map, chunk_pos_map
chunk_doc_map = {}
chunk_pos_map = {}
if not INDEX_NDJSON_PATH.exists():
return
@@ -126,18 +153,53 @@ def load_chunk_doc_map() -> None:
try:
with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
except Exception:
continue
chunk_id = row.get("chunk_id")
document_id = row.get("document_id")
chunk_id_key = _as_key(row.get("chunk_id"))
if not chunk_id_key:
continue
document_id = row.get("document_id")
doc_id_key = _as_key(document_id)
if doc_id_key:
chunk_doc_map[chunk_id_key] = doc_id_key
# chunk_index is optional but very useful for Sales-RAG diversity rules
# (e.g. min distance within a doc)
ci = row.get("chunk_index")
if isinstance(ci, int):
chunk_pos_map[chunk_id_key] = ci
else:
# tolerate numeric strings
if isinstance(ci, str):
s = ci.strip()
if s.isdigit():
try:
chunk_pos_map[chunk_id_key] = int(s)
except Exception:
pass
if isinstance(chunk_id, str) and isinstance(document_id, str):
chunk_doc_map[chunk_id] = document_id
except Exception as e:
logger.warning("Failed to load chunk-doc map from ndjson: %s", str(e))
logger.warning("Failed to load chunk maps from ndjson: %s", str(e))
def _sanitize_limit(limit: int, default: int = 8, max_limit: int = 200) -> int:
try:
v = int(limit)
except Exception:
return default
if v <= 0:
return default
if v > max_limit:
return max_limit
return v
def load_all() -> None:
@@ -175,8 +237,8 @@ def load_all() -> None:
chunk_index = None
chunk_ids = None
logger.info("[Reload] Loading chunk-doc map")
load_chunk_doc_map()
logger.info("[Reload] Loading chunk maps (doc_id + chunk_index)")
load_chunk_maps_from_ndjson()
if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists():
logger.info("[Reload] Loading tag index")
@@ -199,7 +261,12 @@ def load_all() -> None:
current_index_version = index_version if isinstance(index_version, int) else None
logger.info("[Reload] Completed (index_version=%s runtime=%s)", str(current_index_version), str(current_runtime_stamp))
logger.info(
"[Reload] Completed (index_version=%s runtime=%s embedding_model=%s)",
str(current_index_version),
str(current_runtime_stamp),
str(loaded_embedding_model_name),
)
# ============================================================
@@ -227,12 +294,20 @@ def observer_loop() -> None:
new_runtime = v if isinstance(v, str) else None
if new_version != current_index_version:
logger.info("[Observer] index_version changed (%s -> %s) -> Reload", str(current_index_version), str(new_version))
logger.info(
"[Observer] index_version changed (%s -> %s) -> Reload",
str(current_index_version),
str(new_version),
)
load_all()
continue
if new_runtime != current_runtime_stamp:
logger.info("[Observer] runtime changed (%s -> %s) -> Reload", str(current_runtime_stamp), str(new_runtime))
logger.info(
"[Observer] runtime changed (%s -> %s) -> Reload",
str(current_runtime_stamp),
str(new_runtime),
)
load_all()
except Exception as e:
@@ -267,6 +342,7 @@ def health():
"chunk_index_loaded": chunk_index is not None,
"tag_index_loaded": tag_index is not None,
"model_loaded": model is not None,
"embedding_model": loaded_embedding_model_name,
"index_version": current_index_version,
"runtime_stamp": current_runtime_stamp,
"log_file": str(LOG_FILE),
@@ -287,15 +363,33 @@ def search_chunks(req: SearchRequest):
if chunk_index is None or chunk_ids is None or model is None:
raise HTTPException(status_code=503, detail="Chunk index not available")
# Safety: clamp limit to prevent abuse / accidental huge queries
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
query = (req.query or "").strip()
if not query:
raise HTTPException(status_code=400, detail="query must not be empty")
query_vec = model.encode(
[f"query: {req.query}"],
[f"query: {query}"],
normalize_embeddings=True
)
query_vec = np.array(query_vec).astype("float32")
effective_limit = req.limit
effective_limit = limit
doc_filter: Optional[List[str]] = None
if req.doc_ids:
effective_limit = max(req.limit * 5, 50)
# Normalize incoming doc_ids for reliable matching
doc_filter = []
for d in req.doc_ids:
dk = _as_key(d)
if dk:
doc_filter.append(dk)
# When doc filtering is enabled, we fetch a wider pool and filter down.
# Keep it bounded to avoid expensive scans on huge indices.
effective_limit = max(limit * 5, 50)
effective_limit = min(effective_limit, 500)
scores, indices = chunk_index.search(query_vec, effective_limit)
@@ -307,19 +401,33 @@ def search_chunks(req: SearchRequest):
if idx < 0 or idx >= len(chunk_ids):
continue
chunk_id = chunk_ids[idx]
raw_chunk_id = chunk_ids[idx]
chunk_id_key = _as_key(raw_chunk_id)
if not chunk_id_key:
continue
if req.doc_ids:
doc_id = chunk_doc_map.get(chunk_id)
if doc_id not in req.doc_ids:
# Apply doc filter if requested
doc_id = chunk_doc_map.get(chunk_id_key)
if doc_filter is not None:
if doc_id is None or doc_id not in doc_filter:
continue
results.append({
"chunk_id": chunk_id,
# Sales-RAG signals:
# - document_id (for doc quotas / diversity rules)
# - chunk_index (position within doc for distance constraints)
payload = {
"chunk_id": raw_chunk_id,
"score": float(score),
})
"document_id": doc_id, # may be None if ndjson missing/partial
}
if len(results) >= req.limit:
ci = chunk_pos_map.get(chunk_id_key)
if isinstance(ci, int):
payload["chunk_index"] = ci
results.append(payload)
if len(results) >= limit:
break
return results
@@ -330,13 +438,19 @@ def search_tags(req: SearchRequest):
if tag_index is None or tag_ids is None or model is None:
raise HTTPException(status_code=503, detail="Tag index not available")
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
query = (req.query or "").strip()
if not query:
raise HTTPException(status_code=400, detail="query must not be empty")
query_vec = model.encode(
[f"query: {req.query}"],
[f"query: {query}"],
normalize_embeddings=True
)
query_vec = np.array(query_vec).astype("float32")
scores, indices = tag_index.search(query_vec, req.limit)
scores, indices = tag_index.search(query_vec, limit)
results = []