wahnsinn vibe
This commit is contained in:
111
backend/apps/ai_shop/__init__.py
Normal file
111
backend/apps/ai_shop/__init__.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db import get_db
|
||||
from core.i18n import pick
|
||||
from core.redis_client import redis_client
|
||||
|
||||
from apps.ai_core.ollama_client import get_llm
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SearchIn(BaseModel):
|
||||
query: str
|
||||
limit: int = 12
|
||||
|
||||
|
||||
_STOP = {
|
||||
"ich", "suche", "brauche", "will", "möchte", "eine", "einen", "ein",
|
||||
"die", "der", "das", "und", "oder", "mit", "für", "zum", "zur",
|
||||
"i", "am", "looking", "for", "want", "need", "a", "an", "the",
|
||||
"of", "to", "with", "etwas", "some", "nach",
|
||||
}
|
||||
|
||||
_SYN = {
|
||||
"pulli": ["pullover", "sweater"],
|
||||
"shirt": ["t-shirt", "tshirt"],
|
||||
"hose": ["pants", "jeans"],
|
||||
"warme": ["warm"],
|
||||
"warm": ["warme"],
|
||||
"grüner": ["grün", "green"],
|
||||
"grüne": ["grün", "green"],
|
||||
"grünes": ["grün", "green"],
|
||||
"blauer": ["blau", "blue"],
|
||||
"blaue": ["blau", "blue"],
|
||||
"blaues": ["blau", "blue"],
|
||||
"wandern": ["wander", "hiking"],
|
||||
}
|
||||
|
||||
|
||||
def _tokenize(s: str) -> list[str]:
|
||||
tokens = [t.lower() for t in re.findall(r"[\wäöüß]+", s, flags=re.UNICODE)]
|
||||
expanded: list[str] = []
|
||||
for t in tokens:
|
||||
if t in _STOP or len(t) < 2:
|
||||
continue
|
||||
expanded.append(t)
|
||||
expanded.extend(_SYN.get(t, []))
|
||||
return expanded
|
||||
|
||||
|
||||
def _keyword_score(product: dict, tokens: list[str]) -> float:
|
||||
if not tokens:
|
||||
return 0.0
|
||||
haystack = " ".join([
|
||||
pick(product.get("name", {}), "de").lower(),
|
||||
pick(product.get("name", {}), "en").lower(),
|
||||
pick(product.get("description", {}), "de").lower(),
|
||||
pick(product.get("description", {}), "en").lower(),
|
||||
" ".join(str(v).lower() for v in (product.get("attributes") or {}).values()),
|
||||
product.get("sku", "").lower(),
|
||||
])
|
||||
hits = sum(1 for t in tokens if t in haystack)
|
||||
return hits / len(tokens)
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
def ki_search(body: SearchIn, db: Session = Depends(get_db)):
|
||||
"""Hybrid product search: embedding similarity + keyword boost."""
|
||||
q = body.query.strip()
|
||||
if not q:
|
||||
return {"query": q, "products": []}
|
||||
|
||||
emb = get_llm().embed(q)
|
||||
# Pull a larger candidate pool, then re-rank with keyword boost
|
||||
pool_size = max(body.limit * 3, 20)
|
||||
rows = db.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT source_id, 1 - (embedding <=> (:q)::vector) AS score
|
||||
FROM ai_documents
|
||||
WHERE source_type = 'product'
|
||||
ORDER BY embedding <=> (:q)::vector
|
||||
LIMIT :lim
|
||||
"""
|
||||
),
|
||||
{"q": emb, "lim": pool_size},
|
||||
).mappings().all()
|
||||
|
||||
tokens = _tokenize(q)
|
||||
candidates: list[dict] = []
|
||||
for r in rows:
|
||||
raw = redis_client.get(f"product:{r['source_id']}")
|
||||
if not raw:
|
||||
continue
|
||||
d = json.loads(raw)
|
||||
emb_s = float(r["score"])
|
||||
kw_s = _keyword_score(d, tokens)
|
||||
# Combined score: 60% embedding, 40% keyword (but keyword zeroing-out boosts ordering)
|
||||
d["_score"] = round(0.6 * emb_s + 0.4 * kw_s, 4)
|
||||
d["_emb"] = round(emb_s, 4)
|
||||
d["_kw"] = round(kw_s, 4)
|
||||
candidates.append(d)
|
||||
|
||||
candidates.sort(key=lambda p: p["_score"], reverse=True)
|
||||
return {"query": q, "products": candidates[: body.limit]}
|
||||
Reference in New Issue
Block a user