alpha new hybridretriver line
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import faiss
|
import faiss
|
||||||
@@ -25,6 +25,7 @@ TAG_INDEX_PATH = KNOWLEDGE_DIR / "vector_tags.index"
|
|||||||
TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json"
|
TAG_MAP_PATH = KNOWLEDGE_DIR / "vector_tags.index.meta.json"
|
||||||
|
|
||||||
INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json"
|
INDEX_META_PATH = KNOWLEDGE_DIR / "index_meta.json"
|
||||||
|
INDEX_NDJSON_PATH = KNOWLEDGE_DIR / "index.ndjson"
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -36,8 +37,11 @@ app = FastAPI()
|
|||||||
model: Optional[SentenceTransformer] = None
|
model: Optional[SentenceTransformer] = None
|
||||||
chunk_index = None
|
chunk_index = None
|
||||||
chunk_ids: Optional[List[Any]] = None
|
chunk_ids: Optional[List[Any]] = None
|
||||||
|
chunk_doc_map: Dict[str, str] = {}
|
||||||
|
|
||||||
tag_index = None
|
tag_index = None
|
||||||
tag_ids: Optional[List[Any]] = None
|
tag_ids: Optional[List[Any]] = None
|
||||||
|
|
||||||
loaded_embedding_model_name: Optional[str] = None
|
loaded_embedding_model_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -48,12 +52,35 @@ loaded_embedding_model_name: Optional[str] = None
|
|||||||
class SearchRequest(BaseModel):
|
class SearchRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
limit: int = 8
|
limit: int = 8
|
||||||
|
doc_ids: Optional[List[str]] = None # NEW
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Loader
|
# Loader
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
|
def load_chunk_doc_map():
|
||||||
|
global chunk_doc_map
|
||||||
|
|
||||||
|
chunk_doc_map = {}
|
||||||
|
|
||||||
|
if not INDEX_NDJSON_PATH.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
with INDEX_NDJSON_PATH.open("r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
row = json.loads(line)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_id = row.get("chunk_id")
|
||||||
|
document_id = row.get("document_id")
|
||||||
|
|
||||||
|
if isinstance(chunk_id, str) and isinstance(document_id, str):
|
||||||
|
chunk_doc_map[chunk_id] = document_id
|
||||||
|
|
||||||
|
|
||||||
def load_all():
|
def load_all():
|
||||||
global model, chunk_index, chunk_ids, tag_index, tag_ids, loaded_embedding_model_name
|
global model, chunk_index, chunk_ids, tag_index, tag_ids, loaded_embedding_model_name
|
||||||
|
|
||||||
@@ -81,6 +108,10 @@ def load_all():
|
|||||||
chunk_index = None
|
chunk_index = None
|
||||||
chunk_ids = None
|
chunk_ids = None
|
||||||
|
|
||||||
|
# Load chunk → document map
|
||||||
|
print("[Reload] Loading chunk-doc map")
|
||||||
|
load_chunk_doc_map()
|
||||||
|
|
||||||
# Reload tag index
|
# Reload tag index
|
||||||
if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists():
|
if TAG_INDEX_PATH.exists() and TAG_MAP_PATH.exists():
|
||||||
print("[Reload] Loading tag index")
|
print("[Reload] Loading tag index")
|
||||||
@@ -134,20 +165,37 @@ def search_chunks(req: SearchRequest):
|
|||||||
query_vec = model.encode([req.query], normalize_embeddings=True)
|
query_vec = model.encode([req.query], normalize_embeddings=True)
|
||||||
query_vec = np.array(query_vec).astype("float32")
|
query_vec = np.array(query_vec).astype("float32")
|
||||||
|
|
||||||
scores, indices = chunk_index.search(query_vec, req.limit)
|
# Wenn doc_ids gesetzt sind → mehr holen, dann filtern
|
||||||
|
effective_limit = req.limit
|
||||||
|
if req.doc_ids:
|
||||||
|
effective_limit = max(req.limit * 5, 50)
|
||||||
|
|
||||||
|
scores, indices = chunk_index.search(query_vec, effective_limit)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for score, idx in zip(scores[0], indices[0]):
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
if idx == -1:
|
if idx == -1:
|
||||||
continue
|
continue
|
||||||
if idx < 0 or idx >= len(chunk_ids):
|
if idx < 0 or idx >= len(chunk_ids):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
chunk_id = chunk_ids[idx]
|
||||||
|
|
||||||
|
# NEW: doc-scoped filter
|
||||||
|
if req.doc_ids:
|
||||||
|
doc_id = chunk_doc_map.get(chunk_id)
|
||||||
|
if doc_id not in req.doc_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"chunk_id": chunk_ids[idx],
|
"chunk_id": chunk_id,
|
||||||
"score": float(score),
|
"score": float(score),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if len(results) >= req.limit:
|
||||||
|
break
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ use App\Vector\VectorSearchClient;
|
|||||||
|
|
||||||
final class NdjsonHybridRetriever implements RetrieverInterface
|
final class NdjsonHybridRetriever implements RetrieverInterface
|
||||||
{
|
{
|
||||||
private const VECTOR_SCORE_THRESHOLD = 0.25;
|
private const VECTOR_SCORE_THRESHOLD = 0.22;
|
||||||
private const VECTOR_TOPK_MULTIPLIER_WHEN_ROUTED = 10;
|
private const VECTOR_TOPK_MULTIPLIER_WHEN_ROUTED = 3;
|
||||||
|
|
||||||
private const HARD_MAX_CHUNKS = 200;
|
private const HARD_MAX_CHUNKS = 200;
|
||||||
private const HARD_MAX_VECTORK = 200;
|
private const HARD_MAX_VECTORK = 200;
|
||||||
@@ -49,9 +49,9 @@ final class NdjsonHybridRetriever implements RetrieverInterface
|
|||||||
|
|
||||||
$isListQuery = $this->isListQuery($prompt);
|
$isListQuery = $this->isListQuery($prompt);
|
||||||
|
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
// 1) Tag Routing
|
// 1) Tag Routing
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
$candidateDocIds = $this->tagRouting->route($prompt);
|
$candidateDocIds = $this->tagRouting->route($prompt);
|
||||||
$candidateSet = null;
|
$candidateSet = null;
|
||||||
|
|
||||||
@@ -59,19 +59,40 @@ final class NdjsonHybridRetriever implements RetrieverInterface
|
|||||||
$candidateSet = array_fill_keys($candidateDocIds, true);
|
$candidateSet = array_fill_keys($candidateDocIds, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
// 2) Vector Search
|
// 2) TopK bestimmen
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
$topK = $vectorTopKBase;
|
$topK = $vectorTopKBase;
|
||||||
|
|
||||||
|
if ($isListQuery) {
|
||||||
|
$topK = max($vectorTopKBase * 3, 80);
|
||||||
|
}
|
||||||
|
|
||||||
if ($candidateSet !== null) {
|
if ($candidateSet !== null) {
|
||||||
$topK = min(
|
$topK = min(
|
||||||
max($vectorTopKBase * self::VECTOR_TOPK_MULTIPLIER_WHEN_ROUTED, $vectorTopKBase),
|
max($topK * self::VECTOR_TOPK_MULTIPLIER_WHEN_ROUTED, $topK),
|
||||||
self::HARD_MAX_VECTORK
|
self::HARD_MAX_VECTORK
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
$hits = $this->vectorClient->search($prompt, $topK);
|
// -------------------------------------------------
|
||||||
|
// 3) Vector Search (Scoped wenn möglich)
|
||||||
|
// -------------------------------------------------
|
||||||
|
if ($candidateSet !== null) {
|
||||||
|
$hits = $this->vectorClient->searchScoped(
|
||||||
|
$prompt,
|
||||||
|
$topK,
|
||||||
|
array_keys($candidateSet)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wenn scoped nichts liefert → global fallback
|
||||||
|
if ($hits === []) {
|
||||||
|
$hits = $this->vectorClient->search($prompt, $vectorTopKBase);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
$hits = $this->vectorClient->search($prompt, $topK);
|
||||||
|
}
|
||||||
|
|
||||||
if ($hits === []) {
|
if ($hits === []) {
|
||||||
return $candidateSet !== null
|
return $candidateSet !== null
|
||||||
@@ -79,9 +100,9 @@ final class NdjsonHybridRetriever implements RetrieverInterface
|
|||||||
: [];
|
: [];
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
// 3) Chunk-IDs + Lookup einmalig
|
// 4) ChunkIds + Lookup
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
$chunkIds = [];
|
$chunkIds = [];
|
||||||
|
|
||||||
foreach ($hits as $hit) {
|
foreach ($hits as $hit) {
|
||||||
@@ -104,9 +125,9 @@ final class NdjsonHybridRetriever implements RetrieverInterface
|
|||||||
|
|
||||||
$rows = $this->lookup->findByChunkIds($chunkIds);
|
$rows = $this->lookup->findByChunkIds($chunkIds);
|
||||||
|
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
// 4) Listen-Modus → Dokument-Ranking
|
// 5) Listenmodus → Dokument-Ranking
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
if ($isListQuery && $candidateSet !== null) {
|
if ($isListQuery && $candidateSet !== null) {
|
||||||
|
|
||||||
$rankedDocIds = $this->rankDocumentsFromHits($hits, $rows, $candidateSet);
|
$rankedDocIds = $this->rankDocumentsFromHits($hits, $rows, $candidateSet);
|
||||||
@@ -120,9 +141,9 @@ final class NdjsonHybridRetriever implements RetrieverInterface
|
|||||||
return $this->collectBestChunkPerDocument($topDocIds, $hits, $rows);
|
return $this->collectBestChunkPerDocument($topDocIds, $hits, $rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
// 5) Normaler Chunk-Modus
|
// 6) Normaler Chunk-Modus
|
||||||
// -------------------------------
|
// -------------------------------------------------
|
||||||
return $this->collectTexts($chunkIds, $rows, $limit);
|
return $this->collectTexts($chunkIds, $rows, $limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,18 +174,13 @@ final class NdjsonHybridRetriever implements RetrieverInterface
|
|||||||
$documentScores = [];
|
$documentScores = [];
|
||||||
|
|
||||||
foreach ($hits as $hit) {
|
foreach ($hits as $hit) {
|
||||||
if (!isset($hit['chunk_id'], $hit['score'])) {
|
$chunkId = (string)($hit['chunk_id'] ?? '');
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
$chunkId = (string)$hit['chunk_id'];
|
|
||||||
|
|
||||||
if (!isset($rows[$chunkId])) {
|
if (!isset($rows[$chunkId])) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
$row = $rows[$chunkId];
|
$docId = $rows[$chunkId]['document_id'] ?? null;
|
||||||
$docId = $row['document_id'] ?? null;
|
|
||||||
|
|
||||||
if (!is_string($docId) || !isset($candidateSet[$docId])) {
|
if (!is_string($docId) || !isset($candidateSet[$docId])) {
|
||||||
continue;
|
continue;
|
||||||
@@ -203,25 +219,19 @@ final class NdjsonHybridRetriever implements RetrieverInterface
|
|||||||
$bestText = null;
|
$bestText = null;
|
||||||
|
|
||||||
foreach ($hits as $hit) {
|
foreach ($hits as $hit) {
|
||||||
if (!isset($hit['chunk_id'], $hit['score'])) {
|
$chunkId = (string)($hit['chunk_id'] ?? '');
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
$chunkId = (string)$hit['chunk_id'];
|
|
||||||
|
|
||||||
if (!isset($rows[$chunkId])) {
|
if (!isset($rows[$chunkId])) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
$row = $rows[$chunkId];
|
if (($rows[$chunkId]['document_id'] ?? null) !== $docId) {
|
||||||
|
|
||||||
if (($row['document_id'] ?? null) !== $docId) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((float)$hit['score'] > $bestScore) {
|
if ((float)$hit['score'] > $bestScore) {
|
||||||
$bestScore = (float)$hit['score'];
|
$bestScore = (float)$hit['score'];
|
||||||
$bestText = $row['text'] ?? null;
|
$bestText = $rows[$chunkId]['text'] ?? null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,17 +25,47 @@ final class VectorSearchClient
|
|||||||
$this->agentLogger = $agentLogger;
|
$this->agentLogger = $agentLogger;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Standard global search
|
||||||
|
*/
|
||||||
public function search(string $query, int $limit = 5): array
|
public function search(string $query, int $limit = 5): array
|
||||||
|
{
|
||||||
|
return $this->executeSearch([
|
||||||
|
'query' => $query,
|
||||||
|
'limit' => $limit,
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scoped search: nur innerhalb bestimmter Dokumente
|
||||||
|
*/
|
||||||
|
public function searchScoped(
|
||||||
|
string $query,
|
||||||
|
int $limit,
|
||||||
|
array $docIds
|
||||||
|
): array {
|
||||||
|
if ($docIds === []) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
return $this->executeSearch([
|
||||||
|
'query' => $query,
|
||||||
|
'limit' => $limit,
|
||||||
|
'doc_ids' => array_values($docIds),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gemeinsame HTTP-Logik (keine Duplikation)
|
||||||
|
*/
|
||||||
|
private function executeSearch(array $payload): array
|
||||||
{
|
{
|
||||||
try {
|
try {
|
||||||
$response = $this->http->request(
|
$response = $this->http->request(
|
||||||
'POST',
|
'POST',
|
||||||
$this->serviceUrl . '/search-chunks',
|
$this->serviceUrl . '/search-chunks',
|
||||||
[
|
[
|
||||||
'json' => [
|
'json' => $payload,
|
||||||
'query' => $query,
|
|
||||||
'limit' => $limit,
|
|
||||||
],
|
|
||||||
'timeout' => 10,
|
'timeout' => 10,
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user