mirror of https://github.com/nextcloud/server.git
parent
27e1c86652
commit
ebc7631544
@ -0,0 +1,269 @@
|
||||
<?php
|
||||
/**
|
||||
* Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net>
|
||||
* This file is licensed under the Affero General Public License version 3 or
|
||||
* later.
|
||||
* See the COPYING-README file.
|
||||
*/
|
||||
|
||||
namespace Test\LanguageModel;
|
||||
|
||||
use OC\AppFramework\Bootstrap\Coordinator;
|
||||
use OC\LanguageModel\Db\TaskMapper;
|
||||
use OC\LanguageModel\LanguageModelManager;
|
||||
use OC\LanguageModel\TaskBackgroundJob;
|
||||
use OCP\BackgroundJob\IJobList;
|
||||
use OCP\Common\Exception\NotFoundException;
|
||||
use OCP\EventDispatcher\IEventDispatcher;
|
||||
use OCP\IServerContainer;
|
||||
use OCP\LanguageModel\Events\TaskFailedEvent;
|
||||
use OCP\LanguageModel\Events\TaskSuccessfulEvent;
|
||||
use OCP\LanguageModel\FreePromptTask;
|
||||
use OCP\LanguageModel\HeadlineTask;
|
||||
use OCP\LanguageModel\IHeadlineProvider;
|
||||
use OCP\LanguageModel\ILanguageModelManager;
|
||||
use OCP\LanguageModel\ILanguageModelProvider;
|
||||
use OCP\LanguageModel\ILanguageModelTask;
|
||||
use OCP\LanguageModel\ISummaryProvider;
|
||||
use OCP\LanguageModel\SummaryTask;
|
||||
use OCP\LanguageModel\TopicsTask;
|
||||
use OCP\PreConditionNotMetException;
|
||||
use Psr\Log\LoggerInterface;
|
||||
use Test\BackgroundJob\DummyJobList;
|
||||
|
||||
class TestVanillaLanguageModelProvider implements ILanguageModelProvider {
|
||||
public bool $ran = false;
|
||||
|
||||
public function getName(): string {
|
||||
return 'TEST Vanilla LLM Provider';
|
||||
}
|
||||
|
||||
public function prompt(string $prompt): string {
|
||||
$this->ran = true;
|
||||
return $prompt . ' Free Prompt';
|
||||
}
|
||||
}
|
||||
|
||||
class TestFailingLanguageModelProvider implements ILanguageModelProvider {
|
||||
public bool $ran = false;
|
||||
|
||||
public function getName(): string {
|
||||
return 'TEST Vanilla LLM Provider';
|
||||
}
|
||||
|
||||
public function prompt(string $prompt): string {
|
||||
$this->ran = true;
|
||||
throw new \Exception('ERROR');
|
||||
}
|
||||
}
|
||||
|
||||
class TestFullLanguageModelProvider implements ILanguageModelProvider, ISummaryProvider, IHeadlineProvider {
|
||||
public function getName(): string {
|
||||
return 'TEST Full LLM Provider';
|
||||
}
|
||||
|
||||
public function prompt(string $prompt): string {
|
||||
return $prompt . ' Free Prompt';
|
||||
}
|
||||
|
||||
public function findHeadline(string $text): string {
|
||||
return $text . ' Headline';
|
||||
}
|
||||
|
||||
public function summarize(string $text): string {
|
||||
return $text. ' Summarize';
|
||||
}
|
||||
}
|
||||
|
||||
class LanguageModelManagerTest extends \Test\TestCase {
|
||||
private ILanguageModelManager $languageModelManager;
|
||||
private Coordinator $coordinator;
|
||||
|
||||
protected function setUp(): void {
|
||||
parent::setUp();
|
||||
|
||||
$this->languageModelManager = new LanguageModelManager(
|
||||
\OC::$server->get(IServerContainer::class),
|
||||
$this->coordinator = \OC::$server->get(Coordinator::class),
|
||||
\OC::$server->get(LoggerInterface::class),
|
||||
\OC::$server->get(IJobList::class),
|
||||
\OC::$server->get(TaskMapper::class),
|
||||
);
|
||||
}
|
||||
|
||||
public function testShouldNotHaveAnyProviders() {
|
||||
$this->assertCount(0, $this->languageModelManager->getAvailableTasks());
|
||||
$this->assertCount(0, $this->languageModelManager->getAvailableTaskTypes());
|
||||
$this->assertFalse($this->languageModelManager->hasProviders());
|
||||
$this->expectException(PreConditionNotMetException::class);
|
||||
$this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null));
|
||||
}
|
||||
|
||||
public function testProviderShouldBeRegisteredAndRun() {
|
||||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class);
|
||||
$this->assertCount(1, $this->languageModelManager->getAvailableTasks());
|
||||
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
|
||||
$this->assertTrue($this->languageModelManager->hasProviders());
|
||||
$this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null)));
|
||||
|
||||
// Summaries are not implemented by the vanilla provider, only free prompt
|
||||
$this->expectException(PreConditionNotMetException::class);
|
||||
$this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null));
|
||||
}
|
||||
|
||||
public function testProviderShouldBeRegisteredAndScheduled() {
|
||||
// register provider
|
||||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class);
|
||||
$this->assertCount(1, $this->languageModelManager->getAvailableTasks());
|
||||
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
|
||||
$this->assertTrue($this->languageModelManager->hasProviders());
|
||||
|
||||
// create task object
|
||||
$task = new FreePromptTask('Hello', 'test', null);
|
||||
$this->assertNull($task->getId());
|
||||
$this->assertNull($task->getOutput());
|
||||
|
||||
// schedule works
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus());
|
||||
$this->languageModelManager->scheduleTask($task);
|
||||
|
||||
// Task object is up-to-date
|
||||
$this->assertNotNull($task->getId());
|
||||
$this->assertNull($task->getOutput());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus());
|
||||
|
||||
// Task object retrieved from db is up-to-date
|
||||
$task2 = $this->languageModelManager->getTask($task->getId());
|
||||
$this->assertEquals($task->getId(), $task2->getId());
|
||||
$this->assertEquals('Hello', $task2->getInput());
|
||||
$this->assertNull($task2->getOutput());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());
|
||||
|
||||
/** @var IEventDispatcher $eventDispatcher */
|
||||
$eventDispatcher = \OC::$server->get(IEventDispatcher::class);
|
||||
$successfulEventFired = false;
|
||||
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
|
||||
$successfulEventFired = true;
|
||||
$t = $event->getTask();
|
||||
$this->assertEquals($task->getId(), $t->getId());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus());
|
||||
$this->assertEquals('Hello Free Prompt', $t->getOutput());
|
||||
});
|
||||
$failedEventFired = false;
|
||||
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
|
||||
$failedEventFired = true;
|
||||
$t = $event->getTask();
|
||||
$this->assertEquals($task->getId(), $t->getId());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus());
|
||||
$this->assertEquals('ERROR', $event->getErrorMessage());
|
||||
});
|
||||
|
||||
// run background job
|
||||
/** @var TaskBackgroundJob $bgJob */
|
||||
$bgJob = \OC::$server->get(TaskBackgroundJob::class);
|
||||
$bgJob->setArgument(['taskId' => $task->getId()]);
|
||||
$bgJob->start(new DummyJobList());
|
||||
$provider = \OC::$server->get(TestVanillaLanguageModelProvider::class);
|
||||
$this->assertTrue($provider->ran);
|
||||
$this->assertTrue($successfulEventFired);
|
||||
$this->assertFalse($failedEventFired);
|
||||
|
||||
// Task object retrieved from db is up-to-date
|
||||
$task3 = $this->languageModelManager->getTask($task->getId());
|
||||
$this->assertEquals($task->getId(), $task3->getId());
|
||||
$this->assertEquals('Hello', $task3->getInput());
|
||||
$this->assertEquals('Hello Free Prompt', $task3->getOutput());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task2->getStatus());
|
||||
}
|
||||
|
||||
public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() {
|
||||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class);
|
||||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFullLanguageModelProvider::class);
|
||||
$this->assertCount(3, $this->languageModelManager->getAvailableTasks());
|
||||
$this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes());
|
||||
$this->assertTrue($this->languageModelManager->hasProviders());
|
||||
|
||||
// Try free prompt again
|
||||
$this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null)));
|
||||
|
||||
// Try headline task
|
||||
$this->assertEquals('Hello Headline', $this->languageModelManager->runTask(new HeadlineTask('Hello', 'test', null)));
|
||||
|
||||
// Try summary task
|
||||
$this->assertEquals('Hello Summarize', $this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null)));
|
||||
|
||||
// Topics are not implemented by both the vanilla provider and the full provider
|
||||
$this->expectException(PreConditionNotMetException::class);
|
||||
$this->languageModelManager->runTask(new TopicsTask('Hello', 'test', null));
|
||||
}
|
||||
|
||||
public function testNonexistentTask() {
|
||||
$this->expectException(NotFoundException::class);
|
||||
$this->languageModelManager->getTask(98765432456);
|
||||
}
|
||||
|
||||
public function testTaskFailure() {
|
||||
// register provider
|
||||
$this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFailingLanguageModelProvider::class);
|
||||
$this->assertCount(1, $this->languageModelManager->getAvailableTasks());
|
||||
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
|
||||
$this->assertTrue($this->languageModelManager->hasProviders());
|
||||
|
||||
// create task object
|
||||
$task = new FreePromptTask('Hello', 'test', null);
|
||||
$this->assertNull($task->getId());
|
||||
$this->assertNull($task->getOutput());
|
||||
|
||||
// schedule works
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus());
|
||||
$this->languageModelManager->scheduleTask($task);
|
||||
|
||||
// Task object is up-to-date
|
||||
$this->assertNotNull($task->getId());
|
||||
$this->assertNull($task->getOutput());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus());
|
||||
|
||||
// Task object retrieved from db is up-to-date
|
||||
$task2 = $this->languageModelManager->getTask($task->getId());
|
||||
$this->assertEquals($task->getId(), $task2->getId());
|
||||
$this->assertEquals('Hello', $task2->getInput());
|
||||
$this->assertNull($task2->getOutput());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());
|
||||
|
||||
/** @var IEventDispatcher $eventDispatcher */
|
||||
$eventDispatcher = \OC::$server->get(IEventDispatcher::class);
|
||||
$successfulEventFired = false;
|
||||
$eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
|
||||
$successfulEventFired = true;
|
||||
$t = $event->getTask();
|
||||
$this->assertEquals($task->getId(), $t->getId());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $t->getStatus());
|
||||
$this->assertEquals('Hello Free Prompt', $t->getOutput());
|
||||
});
|
||||
$failedEventFired = false;
|
||||
$eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
|
||||
$failedEventFired = true;
|
||||
$t = $event->getTask();
|
||||
$this->assertEquals($task->getId(), $t->getId());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $t->getStatus());
|
||||
$this->assertEquals('ERROR', $event->getErrorMessage());
|
||||
});
|
||||
|
||||
// run background job
|
||||
/** @var TaskBackgroundJob $bgJob */
|
||||
$bgJob = \OC::$server->get(TaskBackgroundJob::class);
|
||||
$bgJob->setArgument(['taskId' => $task->getId()]);
|
||||
$bgJob->start(new DummyJobList());
|
||||
$provider = \OC::$server->get(TestFailingLanguageModelProvider::class);
|
||||
$this->assertTrue($provider->ran);
|
||||
$this->assertTrue($failedEventFired);
|
||||
$this->assertFalse($successfulEventFired);
|
||||
|
||||
// Task object retrieved from db is up-to-date
|
||||
$task3 = $this->languageModelManager->getTask($task->getId());
|
||||
$this->assertEquals($task->getId(), $task3->getId());
|
||||
$this->assertEquals('Hello', $task3->getInput());
|
||||
$this->assertNull($task3->getOutput());
|
||||
$this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task2->getStatus());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue