diff --git a/go-gpt3/.github/workflows/pr.yml b/go-gpt3/.github/workflows/pr.yml new file mode 100644 index 0000000..9a1b3ac --- /dev/null +++ b/go-gpt3/.github/workflows/pr.yml @@ -0,0 +1,27 @@ +name: PR sanity check + +on: + - push + - pull_request + +jobs: + prcheck: + name: PR sanity check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Setup Go + uses: actions/setup-go@v2 + with: + go-version: '1.19' + - name: Run vet + run: | + go vet . + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: latest + - name: Run tests + run: go test -race -covermode=atomic -coverprofile=coverage.out -v . + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 diff --git a/go-gpt3/README.md b/go-gpt3/README.md index 7f8b016..e6e352e 100644 --- a/go-gpt3/README.md +++ b/go-gpt3/README.md @@ -1,17 +1,65 @@ -# go-gpt3 -[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gpt3) -[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gpt3)](https://goreportcard.com/report/github.com/sashabaranov/go-gpt3) +# Go OpenAI +[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai) +[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) +[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) +> **Note**: the repository was recently renamed from `go-gpt3` to `go-openai` -[OpenAI ChatGPT and GPT-3](https://platform.openai.com/) API client for Go +This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support: + +* ChatGPT +* GPT-3, GPT-4 +* DALL·E 2 +* Whisper Installation: ``` -go get github.com/sashabaranov/go-gpt3 +go get github.com/sashabaranov/go-openai +``` + + +ChatGPT example usage: + +```go +package main + +import ( + "context" + "fmt" + openai "github.com/sashabaranov/go-openai" +) + +func main() { + client := openai.NewClient("your token") + resp, err := client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + return + } + + fmt.Println(resp.Choices[0].Message.Content) +} + ``` -Example usage: + +Other examples: + +
+GPT-3 completion ```go package main @@ -19,27 +67,30 @@ package main import ( "context" "fmt" - gogpt "github.com/sashabaranov/go-gpt3" + openai "github.com/sashabaranov/go-openai" ) func main() { - c := gogpt.NewClient("your token") + c := openai.NewClient("your token") ctx := context.Background() - req := gogpt.CompletionRequest{ - Model: gogpt.GPT3Ada, + req := openai.CompletionRequest{ + Model: openai.GPT3Ada, MaxTokens: 5, Prompt: "Lorem ipsum", } resp, err := c.CreateCompletion(ctx, req) if err != nil { + fmt.Printf("Completion error: %v\n", err) return } fmt.Println(resp.Choices[0].Text) } ``` +
-Streaming response example: +
+GPT-3 streaming completion ```go package main @@ -49,21 +100,22 @@ import ( "context" "fmt" "io" - gogpt "github.com/sashabaranov/go-gpt3" + openai "github.com/sashabaranov/go-openai" ) func main() { - c := gogpt.NewClient("your token") + c := openai.NewClient("your token") ctx := context.Background() - req := gogpt.CompletionRequest{ - Model: gogpt.GPT3Ada, + req := openai.CompletionRequest{ + Model: openai.GPT3Ada, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, } stream, err := c.CreateCompletionStream(ctx, req) if err != nil { + fmt.Printf("CompletionStream error: %v\n", err) return } defer stream.Close() @@ -85,3 +137,194 @@ func main() { } } ``` +
+ +
+Audio Speech-To-Text + +```go +package main + +import ( + "context" + "fmt" + + openai "github.com/sashabaranov/go-openai" +) + +func main() { + c := openai.NewClient("your token") + ctx := context.Background() + + req := openai.AudioRequest{ + Model: openai.Whisper1, + FilePath: "recording.mp3", + } + resp, err := c.CreateTranscription(ctx, req) + if err != nil { + fmt.Printf("Transcription error: %v\n", err) + return + } + fmt.Println(resp.Text) +} +``` +
+ +
+DALL-E 2 image generation + +```go +package main + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + openai "github.com/sashabaranov/go-openai" + "image/png" + "os" +) + +func main() { + c := openai.NewClient("your token") + ctx := context.Background() + + // Sample image by link + reqUrl := openai.ImageRequest{ + Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail", + Size: openai.CreateImageSize256x256, + ResponseFormat: openai.CreateImageResponseFormatURL, + N: 1, + } + + respUrl, err := c.CreateImage(ctx, reqUrl) + if err != nil { + fmt.Printf("Image creation error: %v\n", err) + return + } + fmt.Println(respUrl.Data[0].URL) + + // Example image as base64 + reqBase64 := openai.ImageRequest{ + Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine", + Size: openai.CreateImageSize256x256, + ResponseFormat: openai.CreateImageResponseFormatB64JSON, + N: 1, + } + + respBase64, err := c.CreateImage(ctx, reqBase64) + if err != nil { + fmt.Printf("Image creation error: %v\n", err) + return + } + + imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON) + if err != nil { + fmt.Printf("Base64 decode error: %v\n", err) + return + } + + r := bytes.NewReader(imgBytes) + imgData, err := png.Decode(r) + if err != nil { + fmt.Printf("PNG decode error: %v\n", err) + return + } + + file, err := os.Create("image.png") + if err != nil { + fmt.Printf("File creation error: %v\n", err) + return + } + defer file.Close() + + if err := png.Encode(file, imgData); err != nil { + fmt.Printf("PNG encode error: %v\n", err) + return + } + + fmt.Println("The image was saved as example.png") +} + +``` +
+ +
+Configuring proxy + +```go +config := openai.DefaultConfig("token") +proxyUrl, err := url.Parse("http://localhost:{port}") +if err != nil { + panic(err) +} +transport := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), +} +config.HTTPClient = &http.Client{ + Transport: transport, +} + +c := openai.NewClientWithConfig(config) +``` + +See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig +
+ +
+ChatGPT support context + +```go +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + "github.com/sashabaranov/go-openai" +) + +func main() { + client := openai.NewClient("your token") + messages := make([]openai.ChatCompletionMessage, 0) + reader := bufio.NewReader(os.Stdin) + fmt.Println("Conversation") + fmt.Println("---------------------") + + for { + fmt.Print("-> ") + text, _ := reader.ReadString('\n') + // convert CRLF to LF + text = strings.Replace(text, "\n", "", -1) + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: text, + }) + + resp, err := client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: messages, + }, + ) + + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + continue + } + + content := resp.Choices[0].Message.Content + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: content, + }) + fmt.Println(content) + } +} +``` +
\ No newline at end of file diff --git a/go-gpt3/answers.go b/go-gpt3/answers.go deleted file mode 100644 index 5a99078..0000000 --- a/go-gpt3/answers.go +++ /dev/null @@ -1,50 +0,0 @@ -package gogpt - -import ( - "bytes" - "context" - "encoding/json" - "net/http" -) - -type AnswerRequest struct { - Documents []string `json:"documents,omitempty"` - File string `json:"file,omitempty"` - Question string `json:"question"` - SearchModel string `json:"search_model,omitempty"` - Model string `json:"model"` - ExamplesContext string `json:"examples_context"` - Examples [][]string `json:"examples"` - MaxTokens int `json:"max_tokens,omitempty"` - Stop []string `json:"stop,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` -} - -type AnswerResponse struct { - Answers []string `json:"answers"` - Completion string `json:"completion"` - Model string `json:"model"` - Object string `json:"object"` - SearchModel string `json:"search_model"` - SelectedDocuments []struct { - Document int `json:"document"` - Text string `json:"text"` - } `json:"selected_documents"` -} - -// Search — perform a semantic search api call over a list of documents. -func (c *Client) Answers(ctx context.Context, request AnswerRequest) (response AnswerResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/answers"), bytes.NewBuffer(reqBytes)) - if err != nil { - return - } - - err = c.sendRequest(req, &response) - return -} diff --git a/go-gpt3/api.go b/go-gpt3/api.go index 715c9dd..00d6d35 100644 --- a/go-gpt3/api.go +++ b/go-gpt3/api.go @@ -1,6 +1,7 @@ -package gogpt +package openai import ( + "context" "encoding/json" "fmt" "net/http" @@ -9,17 +10,22 @@ import ( // Client is OpenAI GPT-3 API client. type Client struct { config ClientConfig + + requestBuilder requestBuilder } // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) - return &Client{config} + return NewClientWithConfig(config) } // NewClientWithConfig creates new OpenAI API client for specified config. func NewClientWithConfig(config ClientConfig) *Client { - return &Client{config} + return &Client{ + config: config, + requestBuilder: newRequestBuilder(), + } } // NewOrgClient creates new OpenAI API client for specified Organization ID. @@ -28,7 +34,7 @@ func NewClientWithConfig(config ClientConfig) *Client { func NewOrgClient(authToken, org string) *Client { config := DefaultConfig(authToken) config.OrgID = org - return &Client{config} + return NewClientWithConfig(config) } func (c *Client) sendRequest(req *http.Request, v interface{}) error { @@ -68,7 +74,7 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { } if v != nil { - if err = json.NewDecoder(res.Body).Decode(&v); err != nil { + if err = json.NewDecoder(res.Body).Decode(v); err != nil { return err } } @@ -79,3 +85,22 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { func (c *Client) fullURL(suffix string) string { return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } + +func (c *Client) newStreamRequest( + ctx context.Context, + method string, + urlSuffix string, + body any) (*http.Request, error) { + req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + + return req, nil +} diff --git a/go-gpt3/api_test.go b/go-gpt3/api_test.go new file mode 100644 index 0000000..202ec94 --- /dev/null +++ b/go-gpt3/api_test.go @@ -0,0 +1,182 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + + "context" + "errors" + "io" + "os" + "testing" +) + +func TestAPI(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := NewClient(apiToken) + ctx := context.Background() + _, err = c.ListEngines(ctx) + if err != nil { + t.Fatalf("ListEngines error: %v", err) + } + + _, err = c.GetEngine(ctx, "davinci") + if err != nil { + t.Fatalf("GetEngine error: %v", err) + } + + fileRes, err := c.ListFiles(ctx) + if err != nil { + t.Fatalf("ListFiles error: %v", err) + } + + if len(fileRes.Files) > 0 { + _, err = c.GetFile(ctx, fileRes.Files[0].ID) + if err != nil { + t.Fatalf("GetFile error: %v", err) + } + } // else skip + + embeddingReq := EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: AdaSearchQuery, + } + _, err = c.CreateEmbeddings(ctx, embeddingReq) + if err != nil { + t.Fatalf("Embedding error: %v", err) + } + + _, err = c.CreateChatCompletion( + ctx, + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + + if err != nil { + t.Errorf("CreateChatCompletion (without name) returned error: %v", err) + } + + _, err = c.CreateChatCompletion( + ctx, + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Name: "John_Doe", + Content: "Hello!", + }, + }, + }, + ) + + if err != nil { + t.Errorf("CreateChatCompletion (with name) returned error: %v", err) + } + + stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: GPT3Ada, + MaxTokens: 5, + Stream: true, + }) + if err != nil { + t.Errorf("CreateCompletionStream returned error: %v", err) + } + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + +func TestAPIError(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := NewClient(apiToken + "_invalid") + ctx := context.Background() + _, err = c.ListEngines(ctx) + if err == nil { + t.Fatal("ListEngines did not fail") + } + + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("Error is not an APIError: %+v", err) + } + + if apiErr.StatusCode != 401 { + t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode) + } + if *apiErr.Code != "invalid_api_key" { + t.Fatalf("Unexpected API error code: %s", *apiErr.Code) + } + if apiErr.Error() == "" { + t.Fatal("Empty error message occured") + } +} + +func TestRequestError(t *testing.T) { + var err error + + config := DefaultConfig("dummy") + config.BaseURL = "https://httpbin.org/status/418?" + c := NewClientWithConfig(config) + ctx := context.Background() + _, err = c.ListEngines(ctx) + if err == nil { + t.Fatal("ListEngines request did not fail") + } + + var reqErr *RequestError + if !errors.As(err, &reqErr) { + t.Fatalf("Error is not a RequestError: %+v", err) + } + + if reqErr.StatusCode != 418 { + t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode) + } + + if reqErr.Unwrap() == nil { + t.Fatalf("Empty request error occured") + } +} + +// numTokens Returns the number of GPT-3 encoded tokens in the given text. +// This function approximates based on the rule of thumb stated by OpenAI: +// https://beta.openai.com/tokenizer +// +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) +func numTokens(s string) int { + return int(float32(len(s)) / 4) +} diff --git a/go-gpt3/audio.go b/go-gpt3/audio.go index 0dc611e..54bd32f 100644 --- a/go-gpt3/audio.go +++ b/go-gpt3/audio.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bytes" diff --git a/go-gpt3/audio_test.go b/go-gpt3/audio_test.go new file mode 100644 index 0000000..2a035c9 --- /dev/null +++ b/go-gpt3/audio_test.go @@ -0,0 +1,143 @@ +package openai_test + +import ( + "bytes" + "errors" + "io" + "mime" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "strings" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "testing" +) + +// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. +func TestAudio(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) + server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + + testcases := []struct { + name string + createFn func(context.Context, AudioRequest) (AudioResponse, error) + }{ + { + "transcribe", + client.CreateTranscription, + }, + { + "translate", + client.CreateTranslation, + }, + } + + ctx := context.Background() + + dir, cleanup := createTestDirectory(t) + defer cleanup() + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(dir, "fake.mp3") + createTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + Model: "whisper-3", + } + _, err = tc.createFn(ctx, req) + if err != nil { + t.Fatalf("audio API error: %v", err) + } + }) + } +} + +// createTestFile creates a fake file with "hello" as the content. +func createTestFile(t *testing.T, path string) { + file, err := os.Create(path) + if err != nil { + t.Fatalf("failed to create file %v", err) + } + if _, err = file.WriteString("hello"); err != nil { + t.Fatalf("failed to write to file %v", err) + } + file.Close() +} + +// createTestDirectory creates a temporary folder which will be deleted when cleanup is called. +func createTestDirectory(t *testing.T) (path string, cleanup func()) { + t.Helper() + + path, err := os.MkdirTemp(os.TempDir(), "") + if err != nil { + t.Fatal(err) + } + + return path, func() { os.RemoveAll(path) } +} + +// handleAudioEndpoint Handles the completion endpoint by the test server. +func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + + // audio endpoints only accept POST requests + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } + + mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, "failed to parse media type", http.StatusBadRequest) + return + } + + if !strings.HasPrefix(mediaType, "multipart") { + http.Error(w, "request is not multipart", http.StatusBadRequest) + } + + boundary, ok := params["boundary"] + if !ok { + http.Error(w, "no boundary in params", http.StatusBadRequest) + return + } + + fileData := &bytes.Buffer{} + mr := multipart.NewReader(r.Body, boundary) + part, err := mr.NextPart() + if err != nil && errors.Is(err, io.EOF) { + http.Error(w, "error accessing file", http.StatusBadRequest) + return + } + if _, err = io.Copy(fileData, part); err != nil { + http.Error(w, "failed to copy file", http.StatusInternalServerError) + return + } + + if len(fileData.Bytes()) == 0 { + w.WriteHeader(http.StatusInternalServerError) + http.Error(w, "received empty file data", http.StatusBadRequest) + return + } + + if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { + http.Error(w, "failed to write body", http.StatusInternalServerError) + return + } +} diff --git a/go-gpt3/chat.go b/go-gpt3/chat.go index 7ac1121..99edfe8 100644 --- a/go-gpt3/chat.go +++ b/go-gpt3/chat.go @@ -1,15 +1,18 @@ -package gogpt +package openai import ( - "bufio" - "bytes" "context" - "encoding/json" "errors" - "fmt" "net/http" ) +// Chat message role defined by the OpenAI API. +const ( + ChatMessageRoleSystem = "system" + ChatMessageRoleUser = "user" + ChatMessageRoleAssistant = "assistant" +) + var ( ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") ) @@ -17,6 +20,12 @@ var ( type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content"` + + // This property isn't in the official documentation, but it's in + // the documentation for the official library for python: + // - https://github.com/openai/openai-python/blob/main/chatml.md + // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + Name string `json:"name,omitempty"` } // ChatCompletionRequest represents a request structure for chat completion API. @@ -41,24 +50,6 @@ type ChatCompletionChoice struct { FinishReason string `json:"finish_reason"` } -type ChatCompletionChoiceForStream struct { - Index int `json:"index"` - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason string `json:"finish_reason"` -} - -// ChatCompletionResponseForStream represents a response structure for chat completion API. -type ChatCompletionResponseForStream struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoiceForStream `json:"choices"` - Usage Usage `json:"usage"` -} - // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { ID string `json:"id"` @@ -69,25 +60,21 @@ type ChatCompletionResponse struct { Usage Usage `json:"usage"` } -// CreateChatCompletion — API call to Creates a completion for the chat message. +// CreateChatCompletion — API call to Create a completion for the chat message. func (c *Client) CreateChatCompletion( ctx context.Context, request ChatCompletionRequest, ) (response ChatCompletionResponse, err error) { model := request.Model - if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo { + switch model { + case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K: + default: err = ErrChatCompletionInvalidModel return } - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - urlSuffix := "/chat/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } @@ -95,44 +82,3 @@ func (c *Client) CreateChatCompletion( err = c.sendRequest(req, &response) return } - -// CreateChatCompletionStream — API call to create a completion w/ streaming -func (c *Client) CreateChatCompletionStream( - ctx context.Context, - request ChatCompletionRequest, -) (stream *CompletionStream, err error) { - model := request.Model - if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo { - err = ErrChatCompletionInvalidModel - return - } - request.Stream = true - reqBytes, err := json.Marshal(request) - if err != nil { - return - } - - urlSuffix := "/chat/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) - if err != nil { - return - } - - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() - if err != nil { - return - } - - stream = &CompletionStream{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - - reader: bufio.NewReader(resp.Body), - response: resp, - } - return -} diff --git a/go-gpt3/chat_stream.go b/go-gpt3/chat_stream.go new file mode 100644 index 0000000..26e964c --- /dev/null +++ b/go-gpt3/chat_stream.go @@ -0,0 +1,106 @@ +package openai + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "io" + "net/http" +) + +type ChatCompletionStreamChoiceDelta struct { + Content string `json:"content"` +} + +type ChatCompletionStreamChoice struct { + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` +} + +// ChatCompletionStream +// Note: Perhaps it is more elegant to abstract Stream using generics. +type ChatCompletionStream struct { + emptyMessagesLimit uint + isFinished bool + + reader *bufio.Reader + response *http.Response +} + +func (stream *ChatCompletionStream) Recv() (response ChatCompletionStreamResponse, err error) { + if stream.isFinished { + err = io.EOF + return + } + + var emptyMessagesCount uint + +waitForData: + line, err := stream.reader.ReadBytes('\n') + if err != nil { + return + } + + var headerData = []byte("data: ") + line = bytes.TrimSpace(line) + if !bytes.HasPrefix(line, headerData) { + emptyMessagesCount++ + if emptyMessagesCount > stream.emptyMessagesLimit { + err = ErrTooManyEmptyStreamMessages + return + } + + goto waitForData + } + + line = bytes.TrimPrefix(line, headerData) + if string(line) == "[DONE]" { + stream.isFinished = true + err = io.EOF + return + } + + err = json.Unmarshal(line, &response) + return +} + +func (stream *ChatCompletionStream) Close() { + stream.response.Body.Close() +} + +// CreateChatCompletionStream — API call to create a chat completion w/ streaming +// support. It sets whether to stream back partial progress. If set, tokens will be +// sent as data-only server-sent events as they become available, with the +// stream terminated by a data: [DONE] message. +func (c *Client) CreateChatCompletionStream( + ctx context.Context, + request ChatCompletionRequest, +) (stream *ChatCompletionStream, err error) { + request.Stream = true + req, err := c.newStreamRequest(ctx, "POST", "/chat/completions", request) + if err != nil { + return + } + + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + if err != nil { + return + } + + stream = &ChatCompletionStream{ + emptyMessagesLimit: c.config.EmptyMessagesLimit, + reader: bufio.NewReader(resp.Body), + response: resp, + } + return +} diff --git a/go-gpt3/chat_stream_test.go b/go-gpt3/chat_stream_test.go new file mode 100644 index 0000000..e3da2da --- /dev/null +++ b/go-gpt3/chat_stream_test.go @@ -0,0 +1,153 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestCreateChatCompletionStream(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + if err != nil { + t.Errorf("Write error: %s", err) + } + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + } + + stream, err := client.CreateChatCompletionStream(ctx, request) + if err != nil { + t.Errorf("CreateCompletionStream returned error: %v", err) + } + defer stream.Close() + + expectedResponses := []ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: GPT3Dot5Turbo, + Choices: []ChatCompletionStreamChoice{ + { + Delta: ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: GPT3Dot5Turbo, + Choices: []ChatCompletionStreamChoice{ + { + Delta: ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + if streamErr != nil { + t.Errorf("stream.Recv() failed: %v", streamErr) + } + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) + } +} + +// Helper funcs. +func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { + if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { + return false + } + if len(r1.Choices) != len(r2.Choices) { + return false + } + for i := range r1.Choices { + if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) { + return false + } + } + return true +} + +func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { + if c1.Index != c2.Index { + return false + } + if c1.Delta.Content != c2.Delta.Content { + return false + } + if c1.FinishReason != c2.FinishReason { + return false + } + return true +} diff --git a/go-gpt3/chat_test.go b/go-gpt3/chat_test.go new file mode 100644 index 0000000..5c03ebf --- /dev/null +++ b/go-gpt3/chat_test.go @@ -0,0 +1,132 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "testing" + "time" +) + +func TestChatCompletionsWrongModel(t *testing.T) { + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + req := ChatCompletionRequest{ + MaxTokens: 5, + Model: "ada", + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } + _, err := client.CreateChatCompletion(ctx, req) + if !errors.Is(err, ErrChatCompletionInvalidModel) { + t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err) + } +} + +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestChatCompletions(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + req := ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } + _, err = client.CreateChatCompletion(ctx, req) + if err != nil { + t.Fatalf("CreateChatCompletion error: %v", err) + } +} + +// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. +func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq ChatCompletionRequest + if completionReq, err = getChatCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + for i := 0; i < completionReq.N; i++ { + // generate a random string of length completionReq.Length + completionStr := strings.Repeat("a", completionReq.MaxTokens) + + res.Choices = append(res.Choices, ChatCompletionChoice{ + Message: ChatCompletionMessage{ + Role: ChatMessageRoleAssistant, + Content: completionStr, + }, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N + completionTokens := completionReq.MaxTokens * completionReq.N + res.Usage = Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getChatCompletionBody Returns the body of the request to create a completion. +func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { + completion := ChatCompletionRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return ChatCompletionRequest{}, err + } + err = json.Unmarshal(reqBody, &completion) + if err != nil { + return ChatCompletionRequest{}, err + } + return completion, nil +} diff --git a/go-gpt3/common.go b/go-gpt3/common.go index 9fb0178..3b555a7 100644 --- a/go-gpt3/common.go +++ b/go-gpt3/common.go @@ -1,5 +1,5 @@ // common.go defines common types used throughout the OpenAI API. -package gogpt +package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { diff --git a/go-gpt3/completion.go b/go-gpt3/completion.go index 74762c9..66b4866 100644 --- a/go-gpt3/completion.go +++ b/go-gpt3/completion.go @@ -1,17 +1,24 @@ -package gogpt +package openai import ( - "bytes" "context" - "encoding/json" + "errors" "net/http" ) +var ( + ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll +) + // GPT3 Defines the models provided by OpenAI to use when generating // completions from OpenAI. // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( + GPT432K0314 = "gpt-4-32k-0314" + GPT432K = "gpt-4-32k" + GPT40314 = "gpt-4-0314" + GPT4 = "gpt-4" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3TextDavinci003 = "text-davinci-003" @@ -92,14 +99,13 @@ func (c *Client) CreateCompletion( ctx context.Context, request CompletionRequest, ) (response CompletionResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { + if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo { + err = ErrCompletionUnsupportedModel return } urlSuffix := "/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } diff --git a/go-gpt3/completion_test.go b/go-gpt3/completion_test.go new file mode 100644 index 0000000..9868eb2 --- /dev/null +++ b/go-gpt3/completion_test.go @@ -0,0 +1,121 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "testing" + "time" +) + +func TestCompletionsWrongModel(t *testing.T) { + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + CompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + }, + ) + if !errors.Is(err, ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) + } +} + +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestCompletions(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + req := CompletionRequest{ + MaxTokens: 5, + Model: "ada", + } + req.Prompt = "Lorem ipsum" + _, err = client.CreateCompletion(ctx, req) + if err != nil { + t.Fatalf("CreateCompletion error: %v", err) + } +} + +// handleCompletionEndpoint Handles the completion endpoint by the test server. +func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq CompletionRequest + if completionReq, err = getCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := CompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + for i := 0; i < completionReq.N; i++ { + // generate a random string of length completionReq.Length + completionStr := strings.Repeat("a", completionReq.MaxTokens) + if completionReq.Echo { + completionStr = completionReq.Prompt + completionStr + } + res.Choices = append(res.Choices, CompletionChoice{ + Text: completionStr, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Prompt) * completionReq.N + completionTokens := completionReq.MaxTokens * completionReq.N + res.Usage = Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getCompletionBody Returns the body of the request to create a completion. +func getCompletionBody(r *http.Request) (CompletionRequest, error) { + completion := CompletionRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return CompletionRequest{}, err + } + err = json.Unmarshal(reqBody, &completion) + if err != nil { + return CompletionRequest{}, err + } + return completion, nil +} diff --git a/go-gpt3/config.go b/go-gpt3/config.go index 8e5f211..e09c256 100644 --- a/go-gpt3/config.go +++ b/go-gpt3/config.go @@ -1,11 +1,11 @@ -package gogpt +package openai import ( "net/http" ) const ( - apiURLv1 = "http://chatgpt.zhiyinos.cn/" + apiURLv1 = "https://api.openai.com/v1" defaultEmptyMessagesLimit uint = 300 ) diff --git a/go-gpt3/edits.go b/go-gpt3/edits.go index 8cfc21c..858a8e5 100644 --- a/go-gpt3/edits.go +++ b/go-gpt3/edits.go @@ -1,9 +1,7 @@ -package gogpt +package openai import ( - "bytes" "context" - "encoding/json" "net/http" ) @@ -33,13 +31,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/edits"), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request) if err != nil { return } diff --git a/go-gpt3/edits_test.go b/go-gpt3/edits_test.go new file mode 100644 index 0000000..6a16f7c --- /dev/null +++ b/go-gpt3/edits_test.go @@ -0,0 +1,104 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" +) + +// TestEdits Tests the edits endpoint of the API using the mocked server. +func TestEdits(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/edits", handleEditEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + // create an edit request + model := "ada" + editReq := EditsRequest{ + Model: &model, + Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + + " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + + " ex ea commodo consequat. Duis aute irure dolor in reprehe", + Instruction: "test instruction", + N: 3, + } + response, err := client.Edits(ctx, editReq) + if err != nil { + t.Fatalf("Edits error: %v", err) + } + if len(response.Choices) != editReq.N { + t.Fatalf("edits does not properly return the correct number of choices") + } +} + +// handleEditEndpoint Handles the edit endpoint by the test server. +func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // edits only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var editReq EditsRequest + editReq, err = getEditBody(r) + if err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + // create a response + res := EditsResponse{ + Object: "test-object", + Created: time.Now().Unix(), + } + // edit and calculate token usage + editString := "edited by mocked OpenAI server :)" + inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N + completionTokens := int(float32(len(editString))/4) * editReq.N + for i := 0; i < editReq.N; i++ { + // instruction will be hidden and only seen by OpenAI + res.Choices = append(res.Choices, EditsChoice{ + Text: editReq.Input + editString, + Index: i, + }) + } + res.Usage = Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + fmt.Fprint(w, string(resBytes)) +} + +// getEditBody Returns the body of the request to create an edit. +func getEditBody(r *http.Request) (EditsRequest, error) { + edit := EditsRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return EditsRequest{}, err + } + err = json.Unmarshal(reqBody, &edit) + if err != nil { + return EditsRequest{}, err + } + return edit, nil +} diff --git a/go-gpt3/embeddings.go b/go-gpt3/embeddings.go index bfdb802..2deaccc 100644 --- a/go-gpt3/embeddings.go +++ b/go-gpt3/embeddings.go @@ -1,9 +1,7 @@ -package gogpt +package openai import ( - "bytes" "context" - "encoding/json" "net/http" ) @@ -103,7 +101,7 @@ var stringToEnum = map[string]EmbeddingModel{ // then their vector representations should also be similar. type Embedding struct { Object string `json:"object"` - Embedding []float64 `json:"embedding"` + Embedding []float32 `json:"embedding"` Index int `json:"index"` } @@ -134,14 +132,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - - urlSuffix := "/embeddings" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) if err != nil { return } diff --git a/go-gpt3/embeddings_test.go b/go-gpt3/embeddings_test.go new file mode 100644 index 0000000..2aa48c5 --- /dev/null +++ b/go-gpt3/embeddings_test.go @@ -0,0 +1,48 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + + "bytes" + "encoding/json" + "testing" +) + +func TestEmbedding(t *testing.T) { + embeddedModels := []EmbeddingModel{ + AdaSimilarity, + BabbageSimilarity, + CurieSimilarity, + DavinciSimilarity, + AdaSearchDocument, + AdaSearchQuery, + BabbageSearchDocument, + BabbageSearchQuery, + CurieSearchDocument, + CurieSearchQuery, + DavinciSearchDocument, + DavinciSearchQuery, + AdaCodeSearchCode, + AdaCodeSearchText, + BabbageCodeSearchCode, + BabbageCodeSearchText, + } + for _, model := range embeddedModels { + embeddingReq := EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + } + // marshal embeddingReq to JSON and confirm that the model field equals + // the AdaSearchQuery type + marshaled, err := json.Marshal(embeddingReq) + if err != nil { + t.Fatalf("Could not marshal embedding request: %v", err) + } + if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + } +} diff --git a/go-gpt3/engines.go b/go-gpt3/engines.go index 2805f15..bb6a66c 100644 --- a/go-gpt3/engines.go +++ b/go-gpt3/engines.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "context" @@ -22,7 +22,7 @@ type EnginesList struct { // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/engines"), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil) if err != nil { return } @@ -38,7 +38,7 @@ func (c *Client) GetEngine( engineID string, ) (engine Engine, err error) { urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } diff --git a/go-gpt3/error.go b/go-gpt3/error.go index 927fafd..d041da2 100644 --- a/go-gpt3/error.go +++ b/go-gpt3/error.go @@ -1,4 +1,4 @@ -package gogpt +package openai import "fmt" diff --git a/go-gpt3/files.go b/go-gpt3/files.go index 385bb52..ec441c3 100644 --- a/go-gpt3/files.go +++ b/go-gpt3/files.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bytes" @@ -112,7 +112,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) + req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) if err != nil { return } @@ -124,7 +124,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/files"), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil) if err != nil { return } @@ -137,7 +137,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { // such as the file name and purpose. func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } diff --git a/go-gpt3/files_test.go b/go-gpt3/files_test.go new file mode 100644 index 0000000..6a78ce1 --- /dev/null +++ b/go-gpt3/files_test.go @@ -0,0 +1,81 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "testing" + "time" +) + +func TestFileUpload(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files", handleCreateFile) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + req := FileRequest{ + FileName: "test.go", + FilePath: "api.go", + Purpose: "fine-tune", + } + _, err = client.CreateFile(ctx, req) + if err != nil { + t.Fatalf("CreateFile error: %v", err) + } +} + +// handleCreateFile Handles the images endpoint by the test server. +func handleCreateFile(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // edits only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + err = r.ParseMultipartForm(1024 * 1024 * 1024) + if err != nil { + http.Error(w, "file is more than 1GB", http.StatusInternalServerError) + return + } + + values := r.Form + var purpose string + for key, value := range values { + if key == "purpose" { + purpose = value[0] + } + } + file, header, err := r.FormFile("file") + if err != nil { + return + } + defer file.Close() + + var fileReq = File{ + Bytes: int(header.Size), + ID: strconv.Itoa(int(time.Now().Unix())), + FileName: header.Filename, + Purpose: purpose, + CreatedAt: time.Now().Unix(), + Object: "test-objecct", + Owner: "test-owner", + } + + resBytes, _ = json.Marshal(fileReq) + fmt.Fprint(w, string(resBytes)) +} diff --git a/go-gpt3/fine_tunes.go b/go-gpt3/fine_tunes.go new file mode 100644 index 0000000..a121867 --- /dev/null +++ b/go-gpt3/fine_tunes.go @@ -0,0 +1,130 @@ +package openai + +import ( + "context" + "fmt" + "net/http" +) + +type FineTuneRequest struct { + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + Model string `json:"model,omitempty"` + Epochs int `json:"n_epochs,omitempty"` + BatchSize int `json:"batch_size,omitempty"` + LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"` + PromptLossRate float32 `json:"prompt_loss_rate,omitempty"` + ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"` + ClassificationClasses int `json:"classification_n_classes,omitempty"` + ClassificationPositiveClass string `json:"classification_positive_class,omitempty"` + ClassificationBetas []float32 `json:"classification_betas,omitempty"` + Suffix string `json:"suffix,omitempty"` +} + +type FineTune struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + CreatedAt int64 `json:"created_at"` + FineTuneEventList []FineTuneEvent `json:"events,omitempty"` + FineTunedModel string `json:"fine_tuned_model"` + HyperParams FineTuneHyperParams `json:"hyperparams"` + OrganizationID string `json:"organization_id"` + ResultFiles []File `json:"result_files"` + Status string `json:"status"` + ValidationFiles []File `json:"validation_files"` + TrainingFiles []File `json:"training_files"` + UpdatedAt int64 `json:"updated_at"` +} + +type FineTuneEvent struct { + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` +} + +type FineTuneHyperParams struct { + BatchSize int `json:"batch_size"` + LearningRateMultiplier float64 `json:"learning_rate_multiplier"` + Epochs int `json:"n_epochs"` + PromptLossWeight float64 `json:"prompt_loss_weight"` +} + +type FineTuneList struct { + Object string `json:"object"` + Data []FineTune `json:"data"` +} +type FineTuneEventList struct { + Object string `json:"object"` + Data []FineTuneEvent `json:"data"` +} + +type FineTuneDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + +func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { + urlSuffix := "/fine-tunes" + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelFineTune cancel a fine-tune job. +func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { + urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { + req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/go-gpt3/fine_tunes_test.go b/go-gpt3/fine_tunes_test.go new file mode 100644 index 0000000..1f6f967 --- /dev/null +++ b/go-gpt3/fine_tunes_test.go @@ -0,0 +1,101 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "fmt" + "net/http" + "testing" +) + +const testFineTuneID = "fine-tune-id" + +// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. +func TestFineTunes(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler( + "/v1/fine-tunes", + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + if r.Method == http.MethodGet { + resBytes, _ = json.Marshal(FineTuneList{}) + } else { + resBytes, _ = json.Marshal(FineTune{}) + } + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine-tunes/"+testFineTuneID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTune{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine-tunes/"+testFineTuneID, + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + if r.Method == http.MethodDelete { + resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) + } else { + resBytes, _ = json.Marshal(FineTune{}) + } + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine-tunes/"+testFineTuneID+"/events", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuneEventList{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.ListFineTunes(ctx) + if err != nil { + t.Fatalf("ListFineTunes error: %v", err) + } + + _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + if err != nil { + t.Fatalf("CreateFineTune error: %v", err) + } + + _, err = client.CancelFineTune(ctx, testFineTuneID) + if err != nil { + t.Fatalf("CancelFineTune error: %v", err) + } + + _, err = client.GetFineTune(ctx, testFineTuneID) + if err != nil { + t.Fatalf("GetFineTune error: %v", err) + } + + _, err = client.DeleteFineTune(ctx, testFineTuneID) + if err != nil { + t.Fatalf("DeleteFineTune error: %v", err) + } + + _, err = client.ListFineTuneEvents(ctx, testFineTuneID) + if err != nil { + t.Fatalf("ListFineTuneEvents error: %v", err) + } +} diff --git a/go-gpt3/go.mod b/go-gpt3/go.mod new file mode 100644 index 0000000..42cc7b3 --- /dev/null +++ b/go-gpt3/go.mod @@ -0,0 +1,3 @@ +module github.com/sashabaranov/go-openai + +go 1.18 diff --git a/go-gpt3/image.go b/go-gpt3/image.go index 07ecaa7..c0dfa64 100644 --- a/go-gpt3/image.go +++ b/go-gpt3/image.go @@ -1,9 +1,8 @@ -package gogpt +package openai import ( "bytes" "context" - "encoding/json" "io" "mime/multipart" "net/http" @@ -46,14 +45,8 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - urlSuffix := "/images/generations" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } @@ -86,20 +79,65 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - // mask - mask, err := writer.CreateFormFile("mask", request.Mask.Name()) + // mask, it is optional + if request.Mask != nil { + mask, err2 := writer.CreateFormFile("mask", request.Mask.Name()) + if err2 != nil { + return + } + _, err = io.Copy(mask, request.Mask) + if err != nil { + return + } + } + + err = writer.WriteField("prompt", request.Prompt) if err != nil { return } - _, err = io.Copy(mask, request.Mask) + err = writer.WriteField("n", strconv.Itoa(request.N)) + if err != nil { + return + } + err = writer.WriteField("size", request.Size) + if err != nil { + return + } + writer.Close() + urlSuffix := "/images/edits" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) if err != nil { return } - err = writer.WriteField("prompt", request.Prompt) + req.Header.Set("Content-Type", writer.FormDataContentType()) + err = c.sendRequest(req, &response) + return +} + +// ImageVariRequest represents the request structure for the image API. +type ImageVariRequest struct { + Image *os.File `json:"image,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. +// Use abbreviations(vari for variation) because ci-lint has a single-line length limit ... +func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) (response ImageResponse, err error) { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // image + image, err := writer.CreateFormFile("image", request.Image.Name()) if err != nil { return } + _, err = io.Copy(image, request.Image) + if err != nil { + return + } + err = writer.WriteField("n", strconv.Itoa(request.N)) if err != nil { return @@ -109,7 +147,8 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } writer.Close() - urlSuffix := "/images/edits" + //https://platform.openai.com/docs/api-reference/images/create-variation + urlSuffix := "/images/variations" req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) if err != nil { return diff --git a/go-gpt3/image_test.go b/go-gpt3/image_test.go new file mode 100644 index 0000000..b7949c8 --- /dev/null +++ b/go-gpt3/image_test.go @@ -0,0 +1,268 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "testing" + "time" +) + +func TestImages(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/images/generations", handleImageEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + req := ImageRequest{} + req.Prompt = "Lorem ipsum" + _, err = client.CreateImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + +// handleImageEndpoint Handles the images endpoint by the test server. +func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var imageReq ImageRequest + if imageReq, err = getImageBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := ImageResponse{ + Created: time.Now().Unix(), + } + for i := 0; i < imageReq.N; i++ { + imageData := ImageResponseDataInner{} + switch imageReq.ResponseFormat { + case CreateImageResponseFormatURL, "": + imageData.URL = "https://example.com/image.png" + case CreateImageResponseFormatB64JSON: + // This decodes to "{}" in base64. + imageData.B64JSON = "e30K" + default: + http.Error(w, "invalid response format", http.StatusBadRequest) + return + } + res.Data = append(res.Data, imageData) + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getImageBody Returns the body of the request to create a image. +func getImageBody(r *http.Request) (ImageRequest, error) { + image := ImageRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return ImageRequest{}, err + } + err = json.Unmarshal(reqBody, &image) + if err != nil { + return ImageRequest{}, err + } + return image, nil +} + +func TestImageEdit(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + mask, err := os.Create("mask.png") + if err != nil { + t.Error("open mask file error") + return + } + + defer func() { + mask.Close() + origin.Close() + os.Remove("mask.png") + os.Remove("image.png") + }() + + req := ImageEditRequest{ + Image: origin, + Mask: mask, + Prompt: "There is a turtle in the pool", + N: 3, + Size: CreateImageSize1024x1024, + } + _, err = client.CreateEditImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + +func TestImageEditWithoutMask(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + defer func() { + origin.Close() + os.Remove("image.png") + }() + + req := ImageEditRequest{ + Image: origin, + Prompt: "There is a turtle in the pool", + N: 3, + Size: CreateImageSize1024x1024, + } + _, err = client.CreateEditImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + +// handleEditImageEndpoint Handles the images endpoint by the test server. +func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := ImageResponse{ + Created: time.Now().Unix(), + Data: []ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +} + +func TestImageVariation(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + defer func() { + origin.Close() + os.Remove("image.png") + }() + + req := ImageVariRequest{ + Image: origin, + N: 3, + Size: CreateImageSize1024x1024, + } + _, err = client.CreateVariImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + +// handleVariateImageEndpoint Handles the images endpoint by the test server. +func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := ImageResponse{ + Created: time.Now().Unix(), + Data: []ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +} diff --git a/go-gpt3/marshaller.go b/go-gpt3/marshaller.go new file mode 100644 index 0000000..308ccd1 --- /dev/null +++ b/go-gpt3/marshaller.go @@ -0,0 +1,15 @@ +package openai + +import ( + "encoding/json" +) + +type marshaller interface { + marshal(value any) ([]byte, error) +} + +type jsonMarshaller struct{} + +func (jm *jsonMarshaller) marshal(value any) ([]byte, error) { + return json.Marshal(value) +} diff --git a/go-gpt3/models.go b/go-gpt3/models.go index c18e502..2be91aa 100644 --- a/go-gpt3/models.go +++ b/go-gpt3/models.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "context" @@ -7,7 +7,7 @@ import ( // Model struct represents an OpenAPI model. type Model struct { - CreatedAt int64 `json:"created_at"` + CreatedAt int64 `json:"created"` ID string `json:"id"` Object string `json:"object"` OwnedBy string `json:"owned_by"` @@ -18,7 +18,7 @@ type Model struct { // Permission struct represents an OpenAPI permission. type Permission struct { - CreatedAt int64 `json:"created_at"` + CreatedAt int64 `json:"created"` ID string `json:"id"` Object string `json:"object"` AllowCreateEngine bool `json:"allow_create_engine"` @@ -40,7 +40,7 @@ type ModelsList struct { // ListModels Lists the currently available models, // and provides basic information about each model such as the model id and parent. func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/models"), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil) if err != nil { return } diff --git a/go-gpt3/models_test.go b/go-gpt3/models_test.go new file mode 100644 index 0000000..c96ece8 --- /dev/null +++ b/go-gpt3/models_test.go @@ -0,0 +1,39 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestListModels Tests the models endpoint of the API using the mocked server. +func TestListModels(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/models", handleModelsEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.ListModels(ctx) + if err != nil { + t.Fatalf("ListModels error: %v", err) + } +} + +// handleModelsEndpoint Handles the models endpoint by the test server. +func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(ModelsList{}) + fmt.Fprintln(w, string(resBytes)) +} diff --git a/go-gpt3/moderation.go b/go-gpt3/moderation.go index 1849e10..ff789a6 100644 --- a/go-gpt3/moderation.go +++ b/go-gpt3/moderation.go @@ -1,9 +1,7 @@ -package gogpt +package openai import ( - "bytes" "context" - "encoding/json" "net/http" ) @@ -52,13 +50,7 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/moderations"), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request) if err != nil { return } diff --git a/go-gpt3/moderation_test.go b/go-gpt3/moderation_test.go new file mode 100644 index 0000000..f501245 --- /dev/null +++ b/go-gpt3/moderation_test.go @@ -0,0 +1,102 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "testing" + "time" +) + +// TestModeration Tests the moderations endpoint of the API using the mocked server. +func TestModerations(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/moderations", handleModerationEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + // create an edit request + model := "text-moderation-stable" + moderationReq := ModerationRequest{ + Model: &model, + Input: "I want to kill them.", + } + _, err = client.Moderations(ctx, moderationReq) + if err != nil { + t.Fatalf("Moderation error: %v", err) + } +} + +// handleModerationEndpoint Handles the moderation endpoint by the test server. +func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var moderationReq ModerationRequest + if moderationReq, err = getModerationBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + + resCat := ResultCategories{} + resCatScore := ResultCategoryScores{} + switch { + case strings.Contains(moderationReq.Input, "kill"): + resCat = ResultCategories{Violence: true} + resCatScore = ResultCategoryScores{Violence: 1} + case strings.Contains(moderationReq.Input, "hate"): + resCat = ResultCategories{Hate: true} + resCatScore = ResultCategoryScores{Hate: 1} + case strings.Contains(moderationReq.Input, "suicide"): + resCat = ResultCategories{SelfHarm: true} + resCatScore = ResultCategoryScores{SelfHarm: 1} + case strings.Contains(moderationReq.Input, "porn"): + resCat = ResultCategories{Sexual: true} + resCatScore = ResultCategoryScores{Sexual: 1} + } + + result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} + + res := ModerationResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Model: *moderationReq.Model, + } + res.Results = append(res.Results, result) + + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getModerationBody Returns the body of the request to do a moderation. +func getModerationBody(r *http.Request) (ModerationRequest, error) { + moderation := ModerationRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return ModerationRequest{}, err + } + err = json.Unmarshal(reqBody, &moderation) + if err != nil { + return ModerationRequest{}, err + } + return moderation, nil +} diff --git a/go-gpt3/request_builder.go b/go-gpt3/request_builder.go new file mode 100644 index 0000000..f0cef10 --- /dev/null +++ b/go-gpt3/request_builder.go @@ -0,0 +1,40 @@ +package openai + +import ( + "bytes" + "context" + "net/http" +) + +type requestBuilder interface { + build(ctx context.Context, method, url string, request any) (*http.Request, error) +} + +type httpRequestBuilder struct { + marshaller marshaller +} + +func newRequestBuilder() *httpRequestBuilder { + return &httpRequestBuilder{ + marshaller: &jsonMarshaller{}, + } +} + +func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) { + if request == nil { + return http.NewRequestWithContext(ctx, method, url, nil) + } + + var reqBytes []byte + reqBytes, err := b.marshaller.marshal(request) + if err != nil { + return nil, err + } + + return http.NewRequestWithContext( + ctx, + method, + url, + bytes.NewBuffer(reqBytes), + ) +} diff --git a/go-gpt3/request_builder_test.go b/go-gpt3/request_builder_test.go new file mode 100644 index 0000000..533977a --- /dev/null +++ b/go-gpt3/request_builder_test.go @@ -0,0 +1,148 @@ +package openai //nolint:testpackage // testing private field + +import ( + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "errors" + "net/http" + "testing" +) + +var ( + errTestMarshallerFailed = errors.New("test marshaller failed") + errTestRequestBuilderFailed = errors.New("test request builder failed") +) + +type ( + failingRequestBuilder struct{} + failingMarshaller struct{} +) + +func (*failingMarshaller) marshal(value any) ([]byte, error) { + return []byte{}, errTestMarshallerFailed +} + +func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) { + return nil, errTestRequestBuilderFailed +} + +func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { + builder := httpRequestBuilder{ + marshaller: &failingMarshaller{}, + } + + _, err := builder.build(context.Background(), "", "", struct{}{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } +} + +func TestClientReturnsRequestBuilderErrors(t *testing.T) { + var err error + ts := test.NewTestServer().OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + ctx := context.Background() + + _, err = client.CreateCompletion(ctx, CompletionRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFineTunes(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CancelFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.DeleteFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFineTuneEvents(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.Moderations(ctx, ModerationRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.Edits(ctx, EditsRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateImage(ctx, ImageRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + err = client.DeleteFile(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetFile(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFiles(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListEngines(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetEngine(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListModels(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } +} diff --git a/go-gpt3/stream.go b/go-gpt3/stream.go index 8f35d2e..0eed4aa 100644 --- a/go-gpt3/stream.go +++ b/go-gpt3/stream.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bufio" @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "net/http" ) @@ -60,43 +59,6 @@ waitForData: return } -func (stream *CompletionStream) ChatRecv() (response ChatCompletionResponseForStream, err error) { - if stream.isFinished { - err = io.EOF - return - } - - var emptyMessagesCount uint - -waitForData: - line, err := stream.reader.ReadBytes('\n') - if err != nil { - return - } - - var headerData = []byte("data: ") - line = bytes.TrimSpace(line) - if !bytes.HasPrefix(line, headerData) { - emptyMessagesCount++ - if emptyMessagesCount > stream.emptyMessagesLimit { - err = ErrTooManyEmptyStreamMessages - return - } - - goto waitForData - } - - line = bytes.TrimPrefix(line, headerData) - if string(line) == "[DONE]" { - stream.isFinished = true - err = io.EOF - return - } - - err = json.Unmarshal(line, &response) - return -} - func (stream *CompletionStream) Close() { stream.response.Body.Close() } @@ -110,18 +72,7 @@ func (c *Client) CreateCompletionStream( request CompletionRequest, ) (stream *CompletionStream, err error) { request.Stream = true - reqBytes, err := json.Marshal(request) - if err != nil { - return - } - - urlSuffix := "/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + req, err := c.newStreamRequest(ctx, "POST", "/completions", request) if err != nil { return } diff --git a/go-gpt3/stream_test.go b/go-gpt3/stream_test.go new file mode 100644 index 0000000..8f89e6b --- /dev/null +++ b/go-gpt3/stream_test.go @@ -0,0 +1,147 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestCreateCompletionStream(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + if err != nil { + t.Errorf("Write error: %s", err) + } + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + if err != nil { + t.Errorf("CreateCompletionStream returned error: %v", err) + } + defer stream.Close() + + expectedResponses := []CompletionResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: "text-davinci-002", + Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: "text-davinci-002", + Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, + }, + } + + for ix, expectedResponse := range expectedResponses { + receivedResponse, streamErr := stream.Recv() + if streamErr != nil { + t.Errorf("stream.Recv() failed: %v", streamErr) + } + if !compareResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) + } +} + +// A "tokenRoundTripper" is a struct that implements the RoundTripper +// interface, specifically to handle the authentication token by adding a token +// to the request header. We need this because the API requires that each +// request include a valid API token in the headers for authentication and +// authorization. +type tokenRoundTripper struct { + token string + fallback http.RoundTripper +} + +// RoundTrip takes an *http.Request as input and returns an +// *http.Response and an error. +// +// It is expected to use the provided request to create a connection to an HTTP +// server and return the response, or an error if one occurred. The returned +// Response should have its Body closed. If the RoundTrip method returns an +// error, the Client's Get, Head, Post, and PostForm methods return the same +// error. +func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+t.token) + return t.fallback.RoundTrip(req) +} + +// Helper funcs. +func compareResponses(r1, r2 CompletionResponse) bool { + if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { + return false + } + if len(r1.Choices) != len(r2.Choices) { + return false + } + for i := range r1.Choices { + if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) { + return false + } + } + return true +} + +func compareResponseChoices(c1, c2 CompletionChoice) bool { + if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { + return false + } + return true +}