179 lines
5.1 KiB
PHP
179 lines
5.1 KiB
PHP
<?php
|
|
|
|
declare(strict_types=1);
|
|
|
|
namespace App\Command;
|
|
|
|
use App\Eval\AgentEvalRunner;
|
|
use App\Eval\Dto\EvalCase;
|
|
use App\Eval\Dto\EvalResult;
|
|
use App\Eval\EvalCaseLoader;
|
|
use App\Eval\EvalReportWriter;
|
|
use Symfony\Component\Console\Attribute\AsCommand;
|
|
use Symfony\Component\Console\Command\Command;
|
|
use Symfony\Component\Console\Input\InputArgument;
|
|
use Symfony\Component\Console\Input\InputInterface;
|
|
use Symfony\Component\Console\Input\InputOption;
|
|
use Symfony\Component\Console\Output\OutputInterface;
|
|
use Symfony\Component\Console\Style\SymfonyStyle;
|
|
|
|
#[AsCommand(
|
|
name: 'mto:agent:eval:run',
|
|
description: 'Run versioned eval cases for RetrieX'
|
|
)]
|
|
final class AgentEvalRunCommand extends Command
|
|
{
|
|
public function __construct(
|
|
private readonly EvalCaseLoader $loader,
|
|
private readonly AgentEvalRunner $runner,
|
|
private readonly EvalReportWriter $reportWriter,
|
|
) {
|
|
parent::__construct();
|
|
}
|
|
|
|
protected function configure(): void
|
|
{
|
|
$this
|
|
->addArgument(
|
|
'type',
|
|
InputArgument::OPTIONAL,
|
|
'Eval type to run (retrieval, shop_query, followup, answer_guard)',
|
|
'retrieval'
|
|
)
|
|
->addOption(
|
|
'case',
|
|
null,
|
|
InputOption::VALUE_OPTIONAL,
|
|
'Run only a single case by id'
|
|
)
|
|
->addOption(
|
|
'json',
|
|
null,
|
|
InputOption::VALUE_NONE,
|
|
'Print the full report as JSON'
|
|
)
|
|
->addOption(
|
|
'no-write',
|
|
null,
|
|
InputOption::VALUE_NONE,
|
|
'Do not write the report file'
|
|
);
|
|
}
|
|
|
|
protected function execute(InputInterface $input, OutputInterface $output): int
|
|
{
|
|
$io = new SymfonyStyle($input, $output);
|
|
|
|
$type = trim((string) $input->getArgument('type'));
|
|
$caseId = trim((string) $input->getOption('case'));
|
|
$asJson = (bool) $input->getOption('json');
|
|
$noWrite = (bool) $input->getOption('no-write');
|
|
|
|
try {
|
|
$cases = $this->loader->load($type);
|
|
} catch (\Throwable $e) {
|
|
$io->error($e->getMessage());
|
|
|
|
return Command::FAILURE;
|
|
}
|
|
|
|
if ($caseId !== '') {
|
|
$cases = array_values(array_filter(
|
|
$cases,
|
|
static fn (EvalCase $case): bool => $case->id === $caseId
|
|
));
|
|
}
|
|
|
|
if ($cases === []) {
|
|
$io->warning('No eval cases selected.');
|
|
|
|
return Command::SUCCESS;
|
|
}
|
|
|
|
try {
|
|
$results = $this->runner->runAll($cases);
|
|
} catch (\Throwable $e) {
|
|
$io->error($e->getMessage());
|
|
|
|
return Command::FAILURE;
|
|
}
|
|
|
|
$passed = count(array_filter(
|
|
$results,
|
|
static fn (EvalResult $result): bool => $result->passed
|
|
));
|
|
$failed = count($results) - $passed;
|
|
|
|
$report = [
|
|
'type' => $type,
|
|
'case_filter' => $caseId !== '' ? $caseId : null,
|
|
'total' => count($results),
|
|
'passed' => $passed,
|
|
'failed' => $failed,
|
|
'generated_at' => (new \DateTimeImmutable())->format(\DateTimeInterface::ATOM),
|
|
'results' => array_map(
|
|
static fn (EvalResult $result): array => $result->toArray(),
|
|
$results
|
|
),
|
|
];
|
|
|
|
$writtenPath = null;
|
|
|
|
if (!$noWrite) {
|
|
try {
|
|
$writtenPath = $this->reportWriter->write($report);
|
|
} catch (\Throwable $e) {
|
|
$io->error($e->getMessage());
|
|
|
|
return Command::FAILURE;
|
|
}
|
|
}
|
|
|
|
if ($asJson) {
|
|
$jsonReport = $report;
|
|
|
|
if ($writtenPath !== null) {
|
|
$jsonReport['written_to'] = $writtenPath;
|
|
}
|
|
|
|
$json = json_encode(
|
|
$jsonReport,
|
|
JSON_PRETTY_PRINT | JSON_UNESCAPED_SLASHES | JSON_UNESCAPED_UNICODE
|
|
);
|
|
|
|
if (!is_string($json)) {
|
|
$io->error('json_encode failed.');
|
|
|
|
return Command::FAILURE;
|
|
}
|
|
|
|
$output->writeln($json);
|
|
|
|
return $failed > 0 ? Command::FAILURE : Command::SUCCESS;
|
|
}
|
|
|
|
$io->title('RetrieX Eval Run');
|
|
$io->definitionList(
|
|
['type' => $type],
|
|
['total' => (string) count($results)],
|
|
['passed' => (string) $passed],
|
|
['failed' => (string) $failed],
|
|
['report_file' => $writtenPath ?? 'disabled (--no-write)']
|
|
);
|
|
|
|
foreach ($results as $result) {
|
|
if ($result->passed) {
|
|
$io->writeln(sprintf('<info>PASS</info> %s', $result->caseId));
|
|
continue;
|
|
}
|
|
|
|
$io->writeln(sprintf('<error>FAIL</error> %s', $result->caseId));
|
|
|
|
foreach ($result->failures as $failure) {
|
|
$io->writeln(' - ' . $failure);
|
|
}
|
|
}
|
|
|
|
return $failed > 0 ? Command::FAILURE : Command::SUCCESS;
|
|
}
|
|
} |