Improve Templates (#183)

* Improve Templates

* Fix test case

* Update AI GenerateTemplate

* update openai client and GPT completer

* composer.lock

* Update types and list json with script

* Template changes

* fix on draft template

* Finish opnform templates

---------

Co-authored-by: Forms Dev <chirag+new@notionforms.io>
Co-authored-by: Julien Nahum <julien@nahum.net>
This commit is contained in:
Chirag Chhatrala
2023-09-08 16:30:28 +05:30
committed by GitHub
parent d93eca7410
commit 8e47b49e9a
36 changed files with 3130 additions and 1381 deletions

View File

@@ -14,7 +14,7 @@ use OpenAI\Exceptions\ErrorException;
*/
class GptCompleter
{
const AI_MODEL = 'gpt-3.5-turbo';
const AI_MODEL = 'gpt-4';
protected Client $openAi;
protected mixed $result;
@@ -22,19 +22,32 @@ class GptCompleter
protected ?string $systemMessage;
protected int $tokenUsed = 0;
protected bool $useStreaming = false;
public function __construct(string $apiKey, protected int $retries = 2)
public function __construct(string $apiKey, protected int $retries = 2, protected string $model = self::AI_MODEL)
{
$this->openAi = \OpenAI::client($apiKey);
}
public function setAiModel(string $model)
{
$this->model = $model;
return $this;
}
public function setSystemMessage(string $systemMessage): self
{
$this->systemMessage = $systemMessage;
return $this;
}
public function completeChat(array $messages, int $maxTokens = 512, float $temperature = 0.81): self
public function useStreaming(): self
{
$this->useStreaming = true;
return $this;
}
public function completeChat(array $messages, int $maxTokens = 4096, float $temperature = 0.81): self
{
$this->computeChatCompletion($messages, $maxTokens, $temperature)
->queryCompletion();
@@ -56,30 +69,40 @@ class GptCompleter
public function getArray(): array
{
$payload = Str::of($this->result)->trim();
if ($payload->contains('```json')) {
$payload = $payload->after('```json')->before('```');
} else if ($payload->contains('```')) {
$payload = $payload->after('```')->before('```');
}
$payload = $payload->toString();
$exception = null;
for ($i = 0; $i < $this->retries; $i++) {
$payload = Str::of($this->result)->trim();
if ($payload->contains('```json')) {
$payload = $payload->after('```json')->before('```');
} else if ($payload->contains('```')) {
$payload = $payload->after('```')->before('```');
}
$payload = $payload->toString();
$exception = null;
try {
$payload = (new JsonFixer)->fix($payload);
return json_decode($payload, true);
$newPayload = (new JsonFixer)->fix($payload);
return json_decode($newPayload, true);
} catch (\Aws\Exception\InvalidJsonException $e) {
$exception = $e;
Log::warning("Invalid JSON, retrying:");
Log::warning($payload);
Log::warning(json_encode($this->completionInput));
$this->queryCompletion();
}
}
throw $exception;
}
public function getHtml(): string
{
$payload = Str::of($this->result)->trim();
if ($payload->contains('```html')) {
$payload = $payload->after('```html')->before('```');
} else if ($payload->contains('```')) {
$payload = $payload->after('```')->before('```');
}
return $payload->toString();
}
public function getString(): string
{
return trim($this->result);
@@ -90,7 +113,7 @@ class GptCompleter
return $this->tokenUsed;
}
protected function computeChatCompletion(array $messages, int $maxTokens = 512, float $temperature = 0.81): self
protected function computeChatCompletion(array $messages, int $maxTokens = 4096, float $temperature = 0.81): self
{
if (isset($this->systemMessage) && $messages[0]['role'] !== 'system') {
$messages = array_merge([[
@@ -100,7 +123,7 @@ class GptCompleter
}
$completionInput = [
'model' => self::AI_MODEL,
'model' => $this->model,
'messages' => $messages,
'max_tokens' => $maxTokens,
'temperature' => $temperature
@@ -110,7 +133,12 @@ class GptCompleter
return $this;
}
protected function queryCompletion(): self {
protected function queryCompletion(): self
{
if ($this->useStreaming) {
return $this->queryStreamedCompletion();
}
try {
Log::debug("Open AI query: " . json_encode($this->completionInput));
$response = $this->openAi->chat()->create($this->completionInput);
@@ -123,4 +151,19 @@ class GptCompleter
$this->result = $response->choices[0]->message->content;
return $this;
}
protected function queryStreamedCompletion(): self
{
Log::debug("Open AI query: " . json_encode($this->completionInput));
$this->result = '';
$response = $this->openAi->chat()->createStreamed($this->completionInput);
foreach ($response as $chunk) {
$choice = $chunk->choices[0];
if (is_null($choice->delta->role)) {
$this->result .= $choice->delta->content;
}
}
return $this;
}
}

View File

@@ -95,6 +95,7 @@ class JsonFixer
*/
public function fix($json)
{
$json = preg_replace('/(?<!\\\\)(?:\\\\{2})*\p{C}+/u', '', $json);
list($head, $json, $tail) = $this->trim($json);
if (empty($json) || $this->isValid($json)) {
@@ -124,7 +125,7 @@ class JsonFixer
protected function isValid($json)
{
\json_decode($json);
\json_decode($json,true,512,JSON_INVALID_UTF8_SUBSTITUTE);
return \JSON_ERROR_NONE === \json_last_error();
}
@@ -265,6 +266,10 @@ class JsonFixer
return $json;
}
\Log::debug('Broken json received: ', [
'json' => $json
]);
throw new InvalidJsonException(
\sprintf('Could not fix JSON (tried padding `%s`)', \substr($tmpJson, $length), $json)
);