112 lines
3.3 KiB
Python
112 lines
3.3 KiB
Python
import json
|
|
import re
|
|
|
|
from fastapi import APIRouter, Depends
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import text
|
|
from sqlalchemy.orm import Session
|
|
|
|
from core.db import get_db
|
|
from core.i18n import pick
|
|
from core.redis_client import redis_client
|
|
|
|
from apps.ai_core.ollama_client import get_llm
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class SearchIn(BaseModel):
|
|
query: str
|
|
limit: int = 12
|
|
|
|
|
|
_STOP = {
|
|
"ich", "suche", "brauche", "will", "möchte", "eine", "einen", "ein",
|
|
"die", "der", "das", "und", "oder", "mit", "für", "zum", "zur",
|
|
"i", "am", "looking", "for", "want", "need", "a", "an", "the",
|
|
"of", "to", "with", "etwas", "some", "nach",
|
|
}
|
|
|
|
_SYN = {
|
|
"pulli": ["pullover", "sweater"],
|
|
"shirt": ["t-shirt", "tshirt"],
|
|
"hose": ["pants", "jeans"],
|
|
"warme": ["warm"],
|
|
"warm": ["warme"],
|
|
"grüner": ["grün", "green"],
|
|
"grüne": ["grün", "green"],
|
|
"grünes": ["grün", "green"],
|
|
"blauer": ["blau", "blue"],
|
|
"blaue": ["blau", "blue"],
|
|
"blaues": ["blau", "blue"],
|
|
"wandern": ["wander", "hiking"],
|
|
}
|
|
|
|
|
|
def _tokenize(s: str) -> list[str]:
|
|
tokens = [t.lower() for t in re.findall(r"[\wäöüß]+", s, flags=re.UNICODE)]
|
|
expanded: list[str] = []
|
|
for t in tokens:
|
|
if t in _STOP or len(t) < 2:
|
|
continue
|
|
expanded.append(t)
|
|
expanded.extend(_SYN.get(t, []))
|
|
return expanded
|
|
|
|
|
|
def _keyword_score(product: dict, tokens: list[str]) -> float:
|
|
if not tokens:
|
|
return 0.0
|
|
haystack = " ".join([
|
|
pick(product.get("name", {}), "de").lower(),
|
|
pick(product.get("name", {}), "en").lower(),
|
|
pick(product.get("description", {}), "de").lower(),
|
|
pick(product.get("description", {}), "en").lower(),
|
|
" ".join(str(v).lower() for v in (product.get("attributes") or {}).values()),
|
|
product.get("sku", "").lower(),
|
|
])
|
|
hits = sum(1 for t in tokens if t in haystack)
|
|
return hits / len(tokens)
|
|
|
|
|
|
@router.post("/search")
|
|
def ki_search(body: SearchIn, db: Session = Depends(get_db)):
|
|
"""Hybrid product search: embedding similarity + keyword boost."""
|
|
q = body.query.strip()
|
|
if not q:
|
|
return {"query": q, "products": []}
|
|
|
|
emb = get_llm().embed(q)
|
|
# Pull a larger candidate pool, then re-rank with keyword boost
|
|
pool_size = max(body.limit * 3, 20)
|
|
rows = db.execute(
|
|
text(
|
|
"""
|
|
SELECT source_id, 1 - (embedding <=> (:q)::vector) AS score
|
|
FROM ai_documents
|
|
WHERE source_type = 'product'
|
|
ORDER BY embedding <=> (:q)::vector
|
|
LIMIT :lim
|
|
"""
|
|
),
|
|
{"q": emb, "lim": pool_size},
|
|
).mappings().all()
|
|
|
|
tokens = _tokenize(q)
|
|
candidates: list[dict] = []
|
|
for r in rows:
|
|
raw = redis_client.get(f"product:{r['source_id']}")
|
|
if not raw:
|
|
continue
|
|
d = json.loads(raw)
|
|
emb_s = float(r["score"])
|
|
kw_s = _keyword_score(d, tokens)
|
|
# Combined score: 60% embedding, 40% keyword (but keyword zeroing-out boosts ordering)
|
|
d["_score"] = round(0.6 * emb_s + 0.4 * kw_s, 4)
|
|
d["_emb"] = round(emb_s, 4)
|
|
d["_kw"] = round(kw_s, 4)
|
|
candidates.append(d)
|
|
|
|
candidates.sort(key=lambda p: p["_score"], reverse=True)
|
|
return {"query": q, "products": candidates[: body.limit]}
|