optimize as sales rag
This commit is contained in:
@@ -78,7 +78,12 @@ app = FastAPI()
|
||||
model: Optional[SentenceTransformer] = None
|
||||
chunk_index = None
|
||||
chunk_ids: Optional[List[Any]] = None
|
||||
|
||||
# Sales-RAG signals derived from NDJSON (loaded on startup and reload):
|
||||
# - chunk_doc_map: chunk_id -> document_id
|
||||
# - chunk_pos_map: chunk_id -> chunk_index (position within document, if available)
|
||||
chunk_doc_map: Dict[str, str] = {}
|
||||
chunk_pos_map: Dict[str, int] = {}
|
||||
|
||||
tag_index = None
|
||||
tag_ids: Optional[List[Any]] = None
|
||||
@@ -115,10 +120,32 @@ def _safe_read_json(path: Path) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
|
||||
def load_chunk_doc_map() -> None:
|
||||
global chunk_doc_map
|
||||
def _as_key(value: Any) -> Optional[str]:
|
||||
"""
|
||||
Normalize IDs to string keys for maps. Returns None if unusable.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
v = value.strip()
|
||||
return v if v else None
|
||||
try:
|
||||
v = str(value).strip()
|
||||
return v if v else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def load_chunk_maps_from_ndjson() -> None:
|
||||
"""
|
||||
Builds two maps from index.ndjson:
|
||||
- chunk_id -> document_id
|
||||
- chunk_id -> chunk_index (position inside document, if present)
|
||||
"""
|
||||
global chunk_doc_map, chunk_pos_map
|
||||
|
||||
chunk_doc_map = {}
|
||||
chunk_pos_map = {}
|
||||
|
||||
if not INDEX_NDJSON_PATH.exists():
|
||||
return
|
||||
@@ -126,18 +153,53 @@ def load_chunk_doc_map() -> None:
|
||||
try:
|
||||
with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
chunk_id = row.get("chunk_id")
|
||||
document_id = row.get("document_id")
|
||||
chunk_id_key = _as_key(row.get("chunk_id"))
|
||||
if not chunk_id_key:
|
||||
continue
|
||||
|
||||
document_id = row.get("document_id")
|
||||
doc_id_key = _as_key(document_id)
|
||||
if doc_id_key:
|
||||
chunk_doc_map[chunk_id_key] = doc_id_key
|
||||
|
||||
# chunk_index is optional but very useful for Sales-RAG diversity rules
|
||||
# (e.g. min distance within a doc)
|
||||
ci = row.get("chunk_index")
|
||||
if isinstance(ci, int):
|
||||
chunk_pos_map[chunk_id_key] = ci
|
||||
else:
|
||||
# tolerate numeric strings
|
||||
if isinstance(ci, str):
|
||||
s = ci.strip()
|
||||
if s.isdigit():
|
||||
try:
|
||||
chunk_pos_map[chunk_id_key] = int(s)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(chunk_id, str) and isinstance(document_id, str):
|
||||
chunk_doc_map[chunk_id] = document_id
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load chunk-doc map from ndjson: %s", str(e))
|
||||
logger.warning("Failed to load chunk maps from ndjson: %s", str(e))
|
||||
|
||||
|
||||
def _sanitize_limit(limit: int, default: int = 8, max_limit: int = 200) -> int:
|
||||
try:
|
||||
v = int(limit)
|
||||
except Exception:
|
||||
return default
|
||||
if v <= 0:
|
||||
return default
|
||||
if v > max_limit:
|
||||
return max_limit
|
||||
return v
|
||||
|
||||
|
||||
def load_all() -> None:
|
||||
@@ -175,8 +237,8 @@ def load_all() -> None:
|
||||
chunk_index = None
|
||||
chunk_ids = None
|
||||
|
||||
logger.info("[Reload] Loading chunk-doc map")
|
||||
load_chunk_doc_map()
|
||||
logger.info("[Reload] Loading chunk maps (doc_id + chunk_index)")
|
||||
load_chunk_maps_from_ndjson()
|
||||
|
||||
if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists():
|
||||
logger.info("[Reload] Loading tag index")
|
||||
@@ -199,7 +261,12 @@ def load_all() -> None:
|
||||
|
||||
current_index_version = index_version if isinstance(index_version, int) else None
|
||||
|
||||
logger.info("[Reload] Completed (index_version=%s runtime=%s)", str(current_index_version), str(current_runtime_stamp))
|
||||
logger.info(
|
||||
"[Reload] Completed (index_version=%s runtime=%s embedding_model=%s)",
|
||||
str(current_index_version),
|
||||
str(current_runtime_stamp),
|
||||
str(loaded_embedding_model_name),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -227,12 +294,20 @@ def observer_loop() -> None:
|
||||
new_runtime = v if isinstance(v, str) else None
|
||||
|
||||
if new_version != current_index_version:
|
||||
logger.info("[Observer] index_version changed (%s -> %s) -> Reload", str(current_index_version), str(new_version))
|
||||
logger.info(
|
||||
"[Observer] index_version changed (%s -> %s) -> Reload",
|
||||
str(current_index_version),
|
||||
str(new_version),
|
||||
)
|
||||
load_all()
|
||||
continue
|
||||
|
||||
if new_runtime != current_runtime_stamp:
|
||||
logger.info("[Observer] runtime changed (%s -> %s) -> Reload", str(current_runtime_stamp), str(new_runtime))
|
||||
logger.info(
|
||||
"[Observer] runtime changed (%s -> %s) -> Reload",
|
||||
str(current_runtime_stamp),
|
||||
str(new_runtime),
|
||||
)
|
||||
load_all()
|
||||
|
||||
except Exception as e:
|
||||
@@ -267,6 +342,7 @@ def health():
|
||||
"chunk_index_loaded": chunk_index is not None,
|
||||
"tag_index_loaded": tag_index is not None,
|
||||
"model_loaded": model is not None,
|
||||
"embedding_model": loaded_embedding_model_name,
|
||||
"index_version": current_index_version,
|
||||
"runtime_stamp": current_runtime_stamp,
|
||||
"log_file": str(LOG_FILE),
|
||||
@@ -287,15 +363,33 @@ def search_chunks(req: SearchRequest):
|
||||
if chunk_index is None or chunk_ids is None or model is None:
|
||||
raise HTTPException(status_code=503, detail="Chunk index not available")
|
||||
|
||||
# Safety: clamp limit to prevent abuse / accidental huge queries
|
||||
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
|
||||
|
||||
query = (req.query or "").strip()
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="query must not be empty")
|
||||
|
||||
query_vec = model.encode(
|
||||
[f"query: {req.query}"],
|
||||
[f"query: {query}"],
|
||||
normalize_embeddings=True
|
||||
)
|
||||
query_vec = np.array(query_vec).astype("float32")
|
||||
|
||||
effective_limit = req.limit
|
||||
effective_limit = limit
|
||||
doc_filter: Optional[List[str]] = None
|
||||
if req.doc_ids:
|
||||
effective_limit = max(req.limit * 5, 50)
|
||||
# Normalize incoming doc_ids for reliable matching
|
||||
doc_filter = []
|
||||
for d in req.doc_ids:
|
||||
dk = _as_key(d)
|
||||
if dk:
|
||||
doc_filter.append(dk)
|
||||
|
||||
# When doc filtering is enabled, we fetch a wider pool and filter down.
|
||||
# Keep it bounded to avoid expensive scans on huge indices.
|
||||
effective_limit = max(limit * 5, 50)
|
||||
effective_limit = min(effective_limit, 500)
|
||||
|
||||
scores, indices = chunk_index.search(query_vec, effective_limit)
|
||||
|
||||
@@ -307,19 +401,33 @@ def search_chunks(req: SearchRequest):
|
||||
if idx < 0 or idx >= len(chunk_ids):
|
||||
continue
|
||||
|
||||
chunk_id = chunk_ids[idx]
|
||||
raw_chunk_id = chunk_ids[idx]
|
||||
chunk_id_key = _as_key(raw_chunk_id)
|
||||
if not chunk_id_key:
|
||||
continue
|
||||
|
||||
if req.doc_ids:
|
||||
doc_id = chunk_doc_map.get(chunk_id)
|
||||
if doc_id not in req.doc_ids:
|
||||
# Apply doc filter if requested
|
||||
doc_id = chunk_doc_map.get(chunk_id_key)
|
||||
if doc_filter is not None:
|
||||
if doc_id is None or doc_id not in doc_filter:
|
||||
continue
|
||||
|
||||
results.append({
|
||||
"chunk_id": chunk_id,
|
||||
# Sales-RAG signals:
|
||||
# - document_id (for doc quotas / diversity rules)
|
||||
# - chunk_index (position within doc for distance constraints)
|
||||
payload = {
|
||||
"chunk_id": raw_chunk_id,
|
||||
"score": float(score),
|
||||
})
|
||||
"document_id": doc_id, # may be None if ndjson missing/partial
|
||||
}
|
||||
|
||||
if len(results) >= req.limit:
|
||||
ci = chunk_pos_map.get(chunk_id_key)
|
||||
if isinstance(ci, int):
|
||||
payload["chunk_index"] = ci
|
||||
|
||||
results.append(payload)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
@@ -330,13 +438,19 @@ def search_tags(req: SearchRequest):
|
||||
if tag_index is None or tag_ids is None or model is None:
|
||||
raise HTTPException(status_code=503, detail="Tag index not available")
|
||||
|
||||
limit = _sanitize_limit(req.limit, default=8, max_limit=200)
|
||||
|
||||
query = (req.query or "").strip()
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="query must not be empty")
|
||||
|
||||
query_vec = model.encode(
|
||||
[f"query: {req.query}"],
|
||||
[f"query: {query}"],
|
||||
normalize_embeddings=True
|
||||
)
|
||||
query_vec = np.array(query_vec).astype("float32")
|
||||
|
||||
scores, indices = tag_index.search(query_vec, req.limit)
|
||||
scores, indices = tag_index.search(query_vec, limit)
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user