add multi model
This commit is contained in:
@@ -39,6 +39,7 @@ final readonly class AgentRunner
|
||||
private LoggerInterface $agentLogger,
|
||||
private AgentRunnerConfig $agentRunnerConfig,
|
||||
private LanguageCleanupConfig $languageCleanupConfig,
|
||||
private array $llmCallModels,
|
||||
private bool $debug,
|
||||
private bool $logPrompt,
|
||||
private bool $logContext,
|
||||
@@ -46,6 +47,18 @@ final readonly class AgentRunner
|
||||
$this->systemMsgOn = true;
|
||||
}
|
||||
|
||||
private function llmCallModel(string $callName): ?string
|
||||
{
|
||||
$modelName = $this->llmCallModels[$callName] ?? null;
|
||||
if (!is_string($modelName)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
$modelName = trim($modelName);
|
||||
|
||||
return $modelName !== '' ? $modelName : null;
|
||||
}
|
||||
|
||||
public function run(string $prompt, string $userId, bool $forceFullContext = false, string $requestContextHint = ''): Generator
|
||||
{
|
||||
$originalPrompt = trim($prompt);
|
||||
@@ -973,7 +986,7 @@ final readonly class AgentRunner
|
||||
$this->thinkSuppressor->reset();
|
||||
|
||||
try {
|
||||
foreach ($this->ollamaClient->stream($normalizationPrompt) as $token) {
|
||||
foreach ($this->ollamaClient->stream($normalizationPrompt, $this->llmCallModel('input_normalization')) as $token) {
|
||||
if (!is_string($token)) {
|
||||
continue;
|
||||
}
|
||||
@@ -1539,7 +1552,7 @@ final readonly class AgentRunner
|
||||
$this->thinkSuppressor->reset();
|
||||
|
||||
try {
|
||||
foreach ($this->ollamaClient->stream($shopPrompt) as $token) {
|
||||
foreach ($this->ollamaClient->stream($shopPrompt, $this->llmCallModel('shop_query_optimization')) as $token) {
|
||||
if (!is_string($token)) {
|
||||
continue;
|
||||
}
|
||||
@@ -4655,7 +4668,7 @@ final readonly class AgentRunner
|
||||
$thinkingNoticeShown = true;
|
||||
|
||||
try {
|
||||
foreach ($this->ollamaClient->stream($finalPrompt) as $token) {
|
||||
foreach ($this->ollamaClient->stream($finalPrompt, $this->llmCallModel('final_answer')) as $token) {
|
||||
if (!is_string($token)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ final readonly class RetriexEffectiveConfigProvider
|
||||
'llm' => [
|
||||
'timeout_seconds' => $this->param('retriex.llm.timeout_seconds'),
|
||||
'num_predict' => $this->param('retriex.llm.num_predict'),
|
||||
'call_models' => $this->param('retriex.llm.call_models'),
|
||||
],
|
||||
'retrieval' => $this->retrievalConfig(),
|
||||
'prompt' => $this->promptConfig(),
|
||||
@@ -85,6 +86,7 @@ final readonly class RetriexEffectiveConfigProvider
|
||||
$this->validateRuntime($config['runtime'], $errors, $warnings);
|
||||
$this->validateIndex($config['index'], $errors, $warnings);
|
||||
$this->validateModel($config['model_generation'], $errors, $warnings);
|
||||
$this->validateLlm($config['llm'], $errors, $warnings);
|
||||
$this->validateRetrieval($config['retrieval'], $errors, $warnings);
|
||||
$this->validatePrompt($config['prompt'], $errors, $warnings);
|
||||
$this->validateAgent($config['agent'], $errors, $warnings);
|
||||
@@ -1714,6 +1716,46 @@ final readonly class RetriexEffectiveConfigProvider
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array<string, mixed> $llm
|
||||
* @param list<string> $errors
|
||||
* @param list<string> $warnings
|
||||
*/
|
||||
private function validateLlm(array $llm, array &$errors, array &$warnings): void
|
||||
{
|
||||
$callModels = $llm['call_models'] ?? [];
|
||||
if (!is_array($callModels)) {
|
||||
$errors[] = 'llm.call_models must be a map.';
|
||||
return;
|
||||
}
|
||||
|
||||
$knownCalls = [
|
||||
'input_normalization',
|
||||
'shop_query_optimization',
|
||||
'final_answer',
|
||||
];
|
||||
|
||||
foreach ($callModels as $callName => $modelName) {
|
||||
if (!is_string($callName) || trim($callName) === '') {
|
||||
$errors[] = 'llm.call_models contains an invalid call name.';
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!in_array($callName, $knownCalls, true)) {
|
||||
$warnings[] = 'llm.call_models contains an unknown call name: ' . $callName . '.';
|
||||
}
|
||||
|
||||
if ($modelName !== null && !is_string($modelName)) {
|
||||
$errors[] = 'llm.call_models.' . $callName . ' must be null or a string model name.';
|
||||
continue;
|
||||
}
|
||||
|
||||
if (is_string($modelName) && trim($modelName) === '') {
|
||||
$warnings[] = 'llm.call_models.' . $callName . ' is empty and will use the default model.';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array<string, mixed> $retrieval
|
||||
* @param list<string> $errors
|
||||
|
||||
@@ -17,7 +17,7 @@ final class OllamaClient
|
||||
private const LOW_SPEED_LIMIT_BYTES = 1;
|
||||
private const LOW_SPEED_TIME_SECONDS = 45;
|
||||
private ?ModelGenerationConfig $cachedConfig = null;
|
||||
private $config = null;
|
||||
private ?ModelGenerationConfig $config = null;
|
||||
|
||||
public function __construct(
|
||||
private string $apiUrl,
|
||||
@@ -29,33 +29,35 @@ final class OllamaClient
|
||||
/**
|
||||
* Public Streaming API
|
||||
*/
|
||||
public function stream(string $prompt): Generator
|
||||
public function stream(string $prompt, ?string $modelName = null): Generator
|
||||
{
|
||||
$this->config = $this->getConfig();
|
||||
|
||||
if ($this->config->isStream()) {
|
||||
yield from $this->streamInternal($prompt);
|
||||
yield from $this->streamInternal($prompt, $modelName);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback: Blocking generate → Generator-kompatibel ausgeben
|
||||
yield $this->generateInternal($prompt);
|
||||
// Fallback: Blocking generate with Generator-compatible output
|
||||
yield $this->generateInternal($prompt, $modelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Public Blocking API
|
||||
*/
|
||||
public function generate(string $prompt): string
|
||||
public function generate(string $prompt, ?string $modelName = null): string
|
||||
{
|
||||
return $this->generateInternal($prompt);
|
||||
$this->config = $this->getConfig();
|
||||
|
||||
return $this->generateInternal($prompt, $modelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal streaming transport
|
||||
*/
|
||||
private function streamInternal(string $prompt): Generator
|
||||
private function streamInternal(string $prompt, ?string $modelName = null): Generator
|
||||
{
|
||||
$payload = $this->buildPayload($prompt, true);
|
||||
$payload = $this->buildPayload($prompt, true, $modelName);
|
||||
|
||||
$buffer = '';
|
||||
$done = false;
|
||||
@@ -137,9 +139,9 @@ final class OllamaClient
|
||||
/**
|
||||
* Internal blocking transport
|
||||
*/
|
||||
private function generateInternal(string $prompt): string
|
||||
private function generateInternal(string $prompt, ?string $modelName = null): string
|
||||
{
|
||||
$payload = $this->buildPayload($prompt, false);
|
||||
$payload = $this->buildPayload($prompt, false, $modelName);
|
||||
|
||||
$ch = curl_init($this->apiUrl);
|
||||
if ($ch === false) {
|
||||
@@ -173,10 +175,18 @@ final class OllamaClient
|
||||
/**
|
||||
* Central Payload Builder (DRY)
|
||||
*/
|
||||
private function buildPayload(string $prompt, bool $stream): string
|
||||
private function buildPayload(string $prompt, bool $stream, ?string $modelName = null): string
|
||||
{
|
||||
$config = $this->getConfig();
|
||||
$this->config = $config;
|
||||
|
||||
$effectiveModelName = trim((string) $modelName);
|
||||
if ($effectiveModelName === '') {
|
||||
$effectiveModelName = $config->getModelName();
|
||||
}
|
||||
|
||||
return json_encode([
|
||||
'model' => $this->config->getModelName(),
|
||||
'model' => $effectiveModelName,
|
||||
'prompt' => $prompt,
|
||||
'stream' => $stream,
|
||||
'options' => $this->buildOptions()
|
||||
|
||||
Reference in New Issue
Block a user