73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from core.db import get_db
|
|
from core.security import require_admin
|
|
|
|
from apps.ai_core.models import AIAuditLog
|
|
from apps.ai_core.tools import get_tool, validate_args
|
|
|
|
from .planner import build_plan
|
|
from .tool_defs import register_all
|
|
|
|
router = APIRouter(dependencies=[Depends(require_admin)])
|
|
|
|
|
|
class PlanIn(BaseModel):
|
|
prompt: str
|
|
|
|
|
|
class ExecuteCardIn(BaseModel):
|
|
tool: str
|
|
args: dict = {}
|
|
|
|
|
|
class ExecuteIn(BaseModel):
|
|
cards: list[ExecuteCardIn]
|
|
|
|
|
|
@router.post("/plan")
|
|
def plan_endpoint(body: PlanIn):
|
|
if not body.prompt.strip():
|
|
raise HTTPException(400, "Empty prompt")
|
|
cards = build_plan(body.prompt)
|
|
return {"cards": cards}
|
|
|
|
|
|
@router.post("/execute")
|
|
def execute_endpoint(
|
|
body: ExecuteIn,
|
|
claims: dict = Depends(require_admin),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
user_id = int(claims["sub"])
|
|
results = []
|
|
for card in body.cards:
|
|
spec = get_tool(card.tool)
|
|
if not spec:
|
|
results.append({"tool": card.tool, "ok": False, "error": "unknown tool"})
|
|
db.add(AIAuditLog(user_id=user_id, tool=card.tool, args=card.args, result={"error": "unknown tool"}, ok=False))
|
|
db.commit()
|
|
continue
|
|
missing = validate_args(spec, card.args)
|
|
if missing:
|
|
results.append({"tool": card.tool, "ok": False, "error": f"missing: {missing}"})
|
|
db.add(AIAuditLog(user_id=user_id, tool=card.tool, args=card.args, result={"missing": missing}, ok=False))
|
|
db.commit()
|
|
continue
|
|
try:
|
|
res = spec.handler(card.args, db)
|
|
results.append({"tool": card.tool, "ok": True, "result": res})
|
|
db.add(AIAuditLog(user_id=user_id, tool=card.tool, args=card.args, result=res, ok=True))
|
|
db.commit()
|
|
except Exception as e: # noqa: BLE001
|
|
results.append({"tool": card.tool, "ok": False, "error": str(e)})
|
|
db.add(AIAuditLog(user_id=user_id, tool=card.tool, args=card.args, result={"error": str(e)}, ok=False))
|
|
db.commit()
|
|
return {"results": results}
|
|
|
|
|
|
def on_load() -> None:
|
|
register_all()
|