wahnsinn vibe
This commit is contained in:
69
backend/apps/ai_core/__init__.py
Normal file
69
backend/apps/ai_core/__init__.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db import get_db
|
||||
from core.di import register_service
|
||||
from core.security import require_admin
|
||||
|
||||
from .indexer import reindex_all, subscribe_indexer
|
||||
from .ollama_client import get_llm
|
||||
from .tools import describe_for_prompt, list_tools
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class QueryIn(BaseModel):
|
||||
query: str
|
||||
source_type: str | None = None
|
||||
limit: int = 10
|
||||
|
||||
|
||||
@router.post("/query")
|
||||
def query_rag(body: QueryIn, db: Session = Depends(get_db)):
|
||||
if not body.query.strip():
|
||||
raise HTTPException(400, "Empty query")
|
||||
emb = get_llm().embed(body.query)
|
||||
stmt = text(
|
||||
"""
|
||||
SELECT source_type, source_id, text, meta,
|
||||
1 - (embedding <=> (:q)::vector) AS score
|
||||
FROM ai_documents
|
||||
{where}
|
||||
ORDER BY embedding <=> (:q)::vector
|
||||
LIMIT :lim
|
||||
""".format(
|
||||
where="WHERE source_type = :st" if body.source_type else ""
|
||||
)
|
||||
)
|
||||
params: dict = {"q": emb, "lim": body.limit}
|
||||
if body.source_type:
|
||||
params["st"] = body.source_type
|
||||
rows = db.execute(stmt, params).mappings().all()
|
||||
return [
|
||||
{
|
||||
"source_type": r["source_type"],
|
||||
"source_id": r["source_id"],
|
||||
"text": r["text"],
|
||||
"meta": r["meta"],
|
||||
"score": float(r["score"]),
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/reindex", dependencies=[Depends(require_admin)])
|
||||
def trigger_reindex():
|
||||
return reindex_all()
|
||||
|
||||
|
||||
@router.get("/tools")
|
||||
def catalog(_: dict = Depends(require_admin)):
|
||||
return describe_for_prompt()
|
||||
|
||||
|
||||
def on_load() -> None:
|
||||
subscribe_indexer()
|
||||
register_service("LLMProvider", get_llm())
|
||||
register_service("ToolRegistry", list_tools)
|
||||
158
backend/apps/ai_core/indexer.py
Normal file
158
backend/apps/ai_core/indexer.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""RAG indexer: subscribes to product/category/setting events and (re)builds embeddings."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db import SessionLocal
|
||||
from core.events import event_bus
|
||||
from core.i18n import pick
|
||||
|
||||
from apps.catalog.models import Category, Product
|
||||
|
||||
from .models import AIDocument
|
||||
from .ollama_client import get_llm
|
||||
|
||||
|
||||
_COLOR_SYNONYMS = {
|
||||
"green": "grün grüner grünes Grün green olive",
|
||||
"blue": "blau blauer blaues Blau blue navy",
|
||||
"black": "schwarz schwarzer schwarzes Schwarz black",
|
||||
"white": "weiß weißer weißes Weiß white blank",
|
||||
"olive": "oliv olivgrün olive green khaki",
|
||||
"red": "rot roter rotes Rot red",
|
||||
"khaki": "khaki beige oliv",
|
||||
"brown": "braun brauner braunes Braun brown",
|
||||
"grey": "grau grauer graues Grau grey gray",
|
||||
}
|
||||
|
||||
|
||||
def _product_text(p: Product, cat: Category | None) -> str:
|
||||
parts = [
|
||||
pick(p.name, "de"),
|
||||
pick(p.name, "en"),
|
||||
pick(p.description, "de"),
|
||||
pick(p.description, "en"),
|
||||
]
|
||||
if cat:
|
||||
parts.append(pick(cat.name, "de"))
|
||||
parts.append(pick(cat.name, "en"))
|
||||
if p.attributes:
|
||||
for k, v in p.attributes.items():
|
||||
parts.append(f"{k}: {v}")
|
||||
if k == "color" and isinstance(v, str) and v in _COLOR_SYNONYMS:
|
||||
parts.append(_COLOR_SYNONYMS[v])
|
||||
return "\n".join([s for s in parts if s])
|
||||
|
||||
|
||||
def _category_text(c: Category) -> str:
|
||||
return "\n".join([pick(c.name, "de"), pick(c.name, "en")])
|
||||
|
||||
|
||||
def _upsert(db: Session, source_type: str, source_id: str, text: str, meta: dict) -> None:
|
||||
if not text.strip():
|
||||
return
|
||||
embedding = get_llm().embed(text)
|
||||
existing = (
|
||||
db.query(AIDocument)
|
||||
.filter_by(source_type=source_type, source_id=source_id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.text = text
|
||||
existing.embedding = embedding
|
||||
existing.meta = meta
|
||||
else:
|
||||
db.add(
|
||||
AIDocument(
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
meta=meta,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _remove(db: Session, source_type: str, source_id: str) -> None:
|
||||
db.execute(
|
||||
delete(AIDocument).where(
|
||||
AIDocument.source_type == source_type,
|
||||
AIDocument.source_id == source_id,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
|
||||
def index_product(db: Session, product_id: int) -> None:
|
||||
p = db.get(Product, product_id)
|
||||
if not p:
|
||||
_remove(db, "product", str(product_id))
|
||||
return
|
||||
cat = db.get(Category, p.category_id) if p.category_id else None
|
||||
text = _product_text(p, cat)
|
||||
meta = {"category_id": p.category_id, "price": float(p.price)}
|
||||
_upsert(db, "product", str(product_id), text, meta)
|
||||
|
||||
|
||||
def index_category(db: Session, category_id: int) -> None:
|
||||
c = db.get(Category, category_id)
|
||||
if not c:
|
||||
_remove(db, "category", str(category_id))
|
||||
return
|
||||
_upsert(db, "category", str(category_id), _category_text(c), {})
|
||||
|
||||
|
||||
def reindex_all() -> dict:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db.execute(delete(AIDocument))
|
||||
db.commit()
|
||||
n_p = 0
|
||||
for p in db.query(Product).filter(Product.active.is_(True)).all():
|
||||
index_product(db, p.id)
|
||||
n_p += 1
|
||||
n_c = 0
|
||||
for c in db.query(Category).all():
|
||||
index_category(db, c.id)
|
||||
n_c += 1
|
||||
return {"products": n_p, "categories": n_c}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# Event subscribers -----------------------------------------------------
|
||||
|
||||
|
||||
def _on_product_event(event_type: str, payload: dict[str, Any], db: Session) -> None:
|
||||
pid = payload.get("id")
|
||||
if not pid:
|
||||
return
|
||||
if event_type == "product.deleted":
|
||||
_remove(db, "product", str(pid))
|
||||
else:
|
||||
try:
|
||||
index_product(db, pid)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"[ai-indexer] product {pid} failed: {e}")
|
||||
|
||||
|
||||
def _on_category_event(event_type: str, payload: dict[str, Any], db: Session) -> None:
|
||||
cid = payload.get("id")
|
||||
if not cid:
|
||||
return
|
||||
if event_type == "category.deleted":
|
||||
_remove(db, "category", str(cid))
|
||||
else:
|
||||
try:
|
||||
index_category(db, cid)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"[ai-indexer] category {cid} failed: {e}")
|
||||
|
||||
|
||||
def subscribe_indexer() -> None:
|
||||
event_bus.subscribe("product.*", _on_product_event)
|
||||
event_bus.subscribe("category.*", _on_category_event)
|
||||
6
backend/apps/ai_core/manifest.yaml
Normal file
6
backend/apps/ai_core/manifest.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
name: ai_core
|
||||
version: 0.1.0
|
||||
depends_on: [core, catalog]
|
||||
conflicts_with: []
|
||||
required: true
|
||||
provides: [LLMProvider, ToolRegistry]
|
||||
34
backend/apps/ai_core/models.py
Normal file
34
backend/apps/ai_core/models.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import JSON, DateTime, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.config import settings
|
||||
from core.db import Base
|
||||
|
||||
|
||||
class AIDocument(Base):
|
||||
__tablename__ = "ai_documents"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
source_type: Mapped[str] = mapped_column(String(64), index=True) # 'product', 'category', 'setting'
|
||||
source_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
text: Mapped[str] = mapped_column(Text)
|
||||
embedding: Mapped[list[float]] = mapped_column(Vector(settings.OLLAMA_EMBED_DIM))
|
||||
meta: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class AIAuditLog(Base):
|
||||
__tablename__ = "ai_audit"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
tool: Mapped[str] = mapped_column(String(128))
|
||||
args: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
result: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
ok: Mapped[bool] = mapped_column()
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
69
backend/apps/ai_core/ollama_client.py
Normal file
69
backend/apps/ai_core/ollama_client.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Thin Ollama client. Interface is kept minimal so Ollama can be swapped later."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from core.config import settings
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
def __init__(self, base_url: str | None = None) -> None:
|
||||
self.base_url = base_url or settings.OLLAMA_URL
|
||||
self._client = httpx.Client(base_url=self.base_url, timeout=600.0)
|
||||
|
||||
def embed(self, text: str, model: str | None = None) -> list[float]:
|
||||
model = model or settings.OLLAMA_EMBED_MODEL
|
||||
r = self._client.post("/api/embeddings", json={"model": model, "prompt": text})
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
return data["embedding"]
|
||||
|
||||
def chat(
|
||||
self,
|
||||
system: str,
|
||||
user: str,
|
||||
model: str | None = None,
|
||||
json_mode: bool = False,
|
||||
) -> str:
|
||||
model = model or settings.OLLAMA_CHAT_MODEL
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
payload = {"model": model, "messages": messages, "stream": False}
|
||||
if json_mode:
|
||||
payload["format"] = "json"
|
||||
r = self._client.post("/api/chat", json=payload)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
return data.get("message", {}).get("content", "")
|
||||
|
||||
def chat_json(self, system: str, user: str, model: str | None = None) -> dict:
|
||||
raw = self.chat(system, user, model=model, json_mode=True)
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
# Try to recover a JSON object/array from the output
|
||||
start = raw.find("{")
|
||||
alt_start = raw.find("[")
|
||||
if alt_start != -1 and (start == -1 or alt_start < start):
|
||||
start = alt_start
|
||||
end = max(raw.rfind("}"), raw.rfind("]"))
|
||||
if start >= 0 and end > start:
|
||||
try:
|
||||
return json.loads(raw[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
_client: OllamaClient | None = None
|
||||
|
||||
|
||||
def get_llm() -> OllamaClient:
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = OllamaClient()
|
||||
return _client
|
||||
6
backend/apps/ai_core/reindex.py
Normal file
6
backend/apps/ai_core/reindex.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Standalone entrypoint: `uv run python -m apps.ai_core.reindex`"""
|
||||
from .indexer import reindex_all
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = reindex_all()
|
||||
print(f"Reindexed: {result}")
|
||||
56
backend/apps/ai_core/tools.py
Normal file
56
backend/apps/ai_core/tools.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Tool Registry: apps register callable tools + JSON Schema for the KI to use.
|
||||
|
||||
KI never runs tools directly — the registry is only a catalog for the planner,
|
||||
and handlers are invoked by the `execute` endpoint after user confirmation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSpec:
|
||||
name: str # e.g. "catalog.product.create"
|
||||
description: str
|
||||
args_schema: dict # JSON Schema
|
||||
handler: Callable[[dict, Session], dict]
|
||||
required_role: str = "admin" # only admin-exposed in AI admin chat
|
||||
examples: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
_tools: dict[str, ToolSpec] = {}
|
||||
|
||||
|
||||
def register_tool(spec: ToolSpec) -> None:
|
||||
_tools[spec.name] = spec
|
||||
|
||||
|
||||
def get_tool(name: str) -> ToolSpec | None:
|
||||
return _tools.get(name)
|
||||
|
||||
|
||||
def list_tools(role: str = "admin") -> list[ToolSpec]:
|
||||
return [t for t in _tools.values() if t.required_role == role or role == "admin"]
|
||||
|
||||
|
||||
def describe_for_prompt(role: str = "admin") -> list[dict[str, Any]]:
|
||||
"""Return a JSON-serializable description of all tools for the LLM prompt."""
|
||||
return [
|
||||
{
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"args_schema": t.args_schema,
|
||||
"examples": t.examples,
|
||||
}
|
||||
for t in list_tools(role)
|
||||
]
|
||||
|
||||
|
||||
def validate_args(spec: ToolSpec, args: dict) -> list[str]:
|
||||
"""Return list of missing required keys (basic check — not full JSON-schema)."""
|
||||
required = spec.args_schema.get("required", [])
|
||||
return [k for k in required if k not in args or args[k] in (None, "")]
|
||||
Reference in New Issue
Block a user