diff --git a/go-gpt3/.gitignore b/go-gpt3/.gitignore index 99b40bf..708f197 100644 --- a/go-gpt3/.gitignore +++ b/go-gpt3/.gitignore @@ -15,5 +15,5 @@ # vendor/ # Auth token for tests -.openai-token +.gogpt-token .idea \ No newline at end of file diff --git a/go-gpt3/README.md b/go-gpt3/README.md index e6e352e..4f6a437 100644 --- a/go-gpt3/README.md +++ b/go-gpt3/README.md @@ -1,11 +1,11 @@ # Go OpenAI -[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai) -[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) -[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) +[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gogpt) +[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gogpt)](https://goreportcard.com/report/github.com/sashabaranov/go-gogpt) +[![codecov](https://codecov.io/gh/sashabaranov/go-gogpt/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-gogpt) -> **Note**: the repository was recently renamed from `go-gpt3` to `go-openai` +> **Note**: the repository was recently renamed from `go-gpt3` to `go-gogpt` -This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support: +This library provides Go clients for [OpenAI API](https://platform.gogpt.com/). We support: * ChatGPT * GPT-3, GPT-4 @@ -14,7 +14,7 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/). Installation: ``` -go get github.com/sashabaranov/go-openai +go get github.com/sashabaranov/go-gogpt ``` @@ -26,18 +26,18 @@ package main import ( "context" "fmt" - openai "github.com/sashabaranov/go-openai" + gogpt "github.com/sashabaranov/go-gogpt" ) func main() { - client := openai.NewClient("your token") + client := gogpt.NewClient("your token") resp, err := client.CreateChatCompletion( context.Background(), - openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, - Messages: []openai.ChatCompletionMessage{ + gogpt.ChatCompletionRequest{ + Model: gogpt.GPT3Dot5Turbo, + Messages: []gogpt.ChatCompletionMessage{ { - Role: openai.ChatMessageRoleUser, + Role: gogpt.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -67,15 +67,15 @@ package main import ( "context" "fmt" - openai "github.com/sashabaranov/go-openai" + gogpt "github.com/sashabaranov/go-gogpt" ) func main() { - c := openai.NewClient("your token") + c := gogpt.NewClient("your token") ctx := context.Background() - req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + req := gogpt.CompletionRequest{ + Model: gogpt.GPT3Ada, MaxTokens: 5, Prompt: "Lorem ipsum", } @@ -100,15 +100,15 @@ import ( "context" "fmt" "io" - openai "github.com/sashabaranov/go-openai" + gogpt "github.com/sashabaranov/go-gogpt" ) func main() { - c := openai.NewClient("your token") + c := gogpt.NewClient("your token") ctx := context.Background() - req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + req := gogpt.CompletionRequest{ + Model: gogpt.GPT3Ada, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, @@ -149,15 +149,15 @@ import ( "context" "fmt" - openai "github.com/sashabaranov/go-openai" + gogpt "github.com/sashabaranov/go-gogpt" ) func main() { - c := openai.NewClient("your token") + c := gogpt.NewClient("your token") ctx := context.Background() - req := openai.AudioRequest{ - Model: openai.Whisper1, + req := gogpt.AudioRequest{ + Model: gogpt.Whisper1, FilePath: "recording.mp3", } resp, err := c.CreateTranscription(ctx, req) @@ -181,20 +181,20 @@ import ( "context" "encoding/base64" "fmt" - openai "github.com/sashabaranov/go-openai" + gogpt "github.com/sashabaranov/go-gogpt" "image/png" "os" ) func main() { - c := openai.NewClient("your token") + c := gogpt.NewClient("your token") ctx := context.Background() // Sample image by link - reqUrl := openai.ImageRequest{ + reqUrl := gogpt.ImageRequest{ Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail", - Size: openai.CreateImageSize256x256, - ResponseFormat: openai.CreateImageResponseFormatURL, + Size: gogpt.CreateImageSize256x256, + ResponseFormat: gogpt.CreateImageResponseFormatURL, N: 1, } @@ -206,10 +206,10 @@ func main() { fmt.Println(respUrl.Data[0].URL) // Example image as base64 - reqBase64 := openai.ImageRequest{ + reqBase64 := gogpt.ImageRequest{ Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine", - Size: openai.CreateImageSize256x256, - ResponseFormat: openai.CreateImageResponseFormatB64JSON, + Size: gogpt.CreateImageSize256x256, + ResponseFormat: gogpt.CreateImageResponseFormatB64JSON, N: 1, } @@ -254,7 +254,7 @@ func main() { Configuring proxy ```go -config := openai.DefaultConfig("token") +config := gogpt.DefaultConfig("token") proxyUrl, err := url.Parse("http://localhost:{port}") if err != nil { panic(err) @@ -266,10 +266,10 @@ config.HTTPClient = &http.Client{ Transport: transport, } -c := openai.NewClientWithConfig(config) +c := gogpt.NewClientWithConfig(config) ``` -See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig +See also: https://pkg.go.dev/github.com/sashabaranov/go-gogpt#ClientConfig
@@ -285,12 +285,12 @@ import ( "os" "strings" - "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-gogpt" ) func main() { - client := openai.NewClient("your token") - messages := make([]openai.ChatCompletionMessage, 0) + client := gogpt.NewClient("your token") + messages := make([]gogpt.ChatCompletionMessage, 0) reader := bufio.NewReader(os.Stdin) fmt.Println("Conversation") fmt.Println("---------------------") @@ -300,15 +300,15 @@ func main() { text, _ := reader.ReadString('\n') // convert CRLF to LF text = strings.Replace(text, "\n", "", -1) - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleUser, + messages = append(messages, gogpt.ChatCompletionMessage{ + Role: gogpt.ChatMessageRoleUser, Content: text, }) resp, err := client.CreateChatCompletion( context.Background(), - openai.ChatCompletionRequest{ - Model: openai.GPT3Dot5Turbo, + gogpt.ChatCompletionRequest{ + Model: gogpt.GPT3Dot5Turbo, Messages: messages, }, ) @@ -319,8 +319,8 @@ func main() { } content := resp.Choices[0].Message.Content - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, + messages = append(messages, gogpt.ChatCompletionMessage{ + Role: gogpt.ChatMessageRoleAssistant, Content: content, }) fmt.Println(content) diff --git a/go-gpt3/api.go b/go-gpt3/api.go index 00d6d35..c36b184 100644 --- a/go-gpt3/api.go +++ b/go-gpt3/api.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "context" diff --git a/go-gpt3/api_test.go b/go-gpt3/api_test.go deleted file mode 100644 index 202ec94..0000000 --- a/go-gpt3/api_test.go +++ /dev/null @@ -1,182 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - - "context" - "errors" - "io" - "os" - "testing" -) - -func TestAPI(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken) - ctx := context.Background() - _, err = c.ListEngines(ctx) - if err != nil { - t.Fatalf("ListEngines error: %v", err) - } - - _, err = c.GetEngine(ctx, "davinci") - if err != nil { - t.Fatalf("GetEngine error: %v", err) - } - - fileRes, err := c.ListFiles(ctx) - if err != nil { - t.Fatalf("ListFiles error: %v", err) - } - - if len(fileRes.Files) > 0 { - _, err = c.GetFile(ctx, fileRes.Files[0].ID) - if err != nil { - t.Fatalf("GetFile error: %v", err) - } - } // else skip - - embeddingReq := EmbeddingRequest{ - Input: []string{ - "The food was delicious and the waiter", - "Other examples of embedding request", - }, - Model: AdaSearchQuery, - } - _, err = c.CreateEmbeddings(ctx, embeddingReq) - if err != nil { - t.Fatalf("Embedding error: %v", err) - } - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Content: "Hello!", - }, - }, - }, - ) - - if err != nil { - t.Errorf("CreateChatCompletion (without name) returned error: %v", err) - } - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Name: "John_Doe", - Content: "Hello!", - }, - }, - }, - ) - - if err != nil { - t.Errorf("CreateChatCompletion (with name) returned error: %v", err) - } - - stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - if err != nil { - t.Errorf("CreateCompletionStream returned error: %v", err) - } - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } -} - -func TestAPIError(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken + "_invalid") - ctx := context.Background() - _, err = c.ListEngines(ctx) - if err == nil { - t.Fatal("ListEngines did not fail") - } - - var apiErr *APIError - if !errors.As(err, &apiErr) { - t.Fatalf("Error is not an APIError: %+v", err) - } - - if apiErr.StatusCode != 401 { - t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode) - } - if *apiErr.Code != "invalid_api_key" { - t.Fatalf("Unexpected API error code: %s", *apiErr.Code) - } - if apiErr.Error() == "" { - t.Fatal("Empty error message occured") - } -} - -func TestRequestError(t *testing.T) { - var err error - - config := DefaultConfig("dummy") - config.BaseURL = "https://httpbin.org/status/418?" - c := NewClientWithConfig(config) - ctx := context.Background() - _, err = c.ListEngines(ctx) - if err == nil { - t.Fatal("ListEngines request did not fail") - } - - var reqErr *RequestError - if !errors.As(err, &reqErr) { - t.Fatalf("Error is not a RequestError: %+v", err) - } - - if reqErr.StatusCode != 418 { - t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode) - } - - if reqErr.Unwrap() == nil { - t.Fatalf("Empty request error occured") - } -} - -// numTokens Returns the number of GPT-3 encoded tokens in the given text. -// This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer -// -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) -func numTokens(s string) int { - return int(float32(len(s)) / 4) -} diff --git a/go-gpt3/audio.go b/go-gpt3/audio.go index 54bd32f..0dc611e 100644 --- a/go-gpt3/audio.go +++ b/go-gpt3/audio.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "bytes" diff --git a/go-gpt3/audio_test.go b/go-gpt3/audio_test.go deleted file mode 100644 index 2a035c9..0000000 --- a/go-gpt3/audio_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package openai_test - -import ( - "bytes" - "errors" - "io" - "mime" - "mime/multipart" - "net/http" - "os" - "path/filepath" - "strings" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "testing" -) - -// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. -func TestAudio(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) - server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - testcases := []struct { - name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) - }{ - { - "transcribe", - client.CreateTranscription, - }, - { - "translate", - client.CreateTranslation, - }, - } - - ctx := context.Background() - - dir, cleanup := createTestDirectory(t) - defer cleanup() - - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") - createTestFile(t, path) - - req := AudioRequest{ - FilePath: path, - Model: "whisper-3", - } - _, err = tc.createFn(ctx, req) - if err != nil { - t.Fatalf("audio API error: %v", err) - } - }) - } -} - -// createTestFile creates a fake file with "hello" as the content. -func createTestFile(t *testing.T, path string) { - file, err := os.Create(path) - if err != nil { - t.Fatalf("failed to create file %v", err) - } - if _, err = file.WriteString("hello"); err != nil { - t.Fatalf("failed to write to file %v", err) - } - file.Close() -} - -// createTestDirectory creates a temporary folder which will be deleted when cleanup is called. -func createTestDirectory(t *testing.T) (path string, cleanup func()) { - t.Helper() - - path, err := os.MkdirTemp(os.TempDir(), "") - if err != nil { - t.Fatal(err) - } - - return path, func() { os.RemoveAll(path) } -} - -// handleAudioEndpoint Handles the completion endpoint by the test server. -func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - - // audio endpoints only accept POST requests - if r.Method != "POST" { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } - - mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) - if err != nil { - http.Error(w, "failed to parse media type", http.StatusBadRequest) - return - } - - if !strings.HasPrefix(mediaType, "multipart") { - http.Error(w, "request is not multipart", http.StatusBadRequest) - } - - boundary, ok := params["boundary"] - if !ok { - http.Error(w, "no boundary in params", http.StatusBadRequest) - return - } - - fileData := &bytes.Buffer{} - mr := multipart.NewReader(r.Body, boundary) - part, err := mr.NextPart() - if err != nil && errors.Is(err, io.EOF) { - http.Error(w, "error accessing file", http.StatusBadRequest) - return - } - if _, err = io.Copy(fileData, part); err != nil { - http.Error(w, "failed to copy file", http.StatusInternalServerError) - return - } - - if len(fileData.Bytes()) == 0 { - w.WriteHeader(http.StatusInternalServerError) - http.Error(w, "received empty file data", http.StatusBadRequest) - return - } - - if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { - http.Error(w, "failed to write body", http.StatusInternalServerError) - return - } -} diff --git a/go-gpt3/chat.go b/go-gpt3/chat.go index 99edfe8..bd2b6f7 100644 --- a/go-gpt3/chat.go +++ b/go-gpt3/chat.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "context" @@ -23,8 +23,8 @@ type ChatCompletionMessage struct { // This property isn't in the official documentation, but it's in // the documentation for the official library for python: - // - https://github.com/openai/openai-python/blob/main/chatml.md - // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // - https://github.com/gogpt/gogpt-python/blob/main/chatml.md + // - https://github.com/gogpt/gogpt-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` } diff --git a/go-gpt3/chat_stream.go b/go-gpt3/chat_stream.go index 26e964c..6fc440b 100644 --- a/go-gpt3/chat_stream.go +++ b/go-gpt3/chat_stream.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "bufio" diff --git a/go-gpt3/chat_stream_test.go b/go-gpt3/chat_stream_test.go deleted file mode 100644 index e3da2da..0000000 --- a/go-gpt3/chat_stream_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "encoding/json" - "errors" - "io" - "net/http" - "net/http/httptest" - "testing" -) - -func TestCreateChatCompletionStream(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - - // Send test responses - dataBytes := []byte{} - dataBytes = append(dataBytes, []byte("event: message\n")...) - //nolint:lll - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - dataBytes = append(dataBytes, []byte("event: message\n")...) - //nolint:lll - data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - dataBytes = append(dataBytes, []byte("event: done\n")...) - dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) - - _, err := w.Write(dataBytes) - if err != nil { - t.Errorf("Write error: %s", err) - } - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() - - request := ChatCompletionRequest{ - MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Content: "Hello!", - }, - }, - Stream: true, - } - - stream, err := client.CreateChatCompletionStream(ctx, request) - if err != nil { - t.Errorf("CreateCompletionStream returned error: %v", err) - } - defer stream.Close() - - expectedResponses := []ChatCompletionStreamResponse{ - { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ - { - Delta: ChatCompletionStreamChoiceDelta{ - Content: "response1", - }, - FinishReason: "max_tokens", - }, - }, - }, - { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ - { - Delta: ChatCompletionStreamChoiceDelta{ - Content: "response2", - }, - FinishReason: "max_tokens", - }, - }, - }, - } - - for ix, expectedResponse := range expectedResponses { - b, _ := json.Marshal(expectedResponse) - t.Logf("%d: %s", ix, string(b)) - - receivedResponse, streamErr := stream.Recv() - if streamErr != nil { - t.Errorf("stream.Recv() failed: %v", streamErr) - } - if !compareChatResponses(expectedResponse, receivedResponse) { - t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) - } - } - - _, streamErr := stream.Recv() - if !errors.Is(streamErr, io.EOF) { - t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) - } - - _, streamErr = stream.Recv() - if !errors.Is(streamErr, io.EOF) { - t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) - } -} - -// Helper funcs. -func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { - if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { - return false - } - if len(r1.Choices) != len(r2.Choices) { - return false - } - for i := range r1.Choices { - if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) { - return false - } - } - return true -} - -func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { - if c1.Index != c2.Index { - return false - } - if c1.Delta.Content != c2.Delta.Content { - return false - } - if c1.FinishReason != c2.FinishReason { - return false - } - return true -} diff --git a/go-gpt3/chat_test.go b/go-gpt3/chat_test.go deleted file mode 100644 index 5c03ebf..0000000 --- a/go-gpt3/chat_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "testing" - "time" -) - -func TestChatCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := ChatCompletionRequest{ - MaxTokens: 5, - Model: "ada", - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Content: "Hello!", - }, - }, - } - _, err := client.CreateChatCompletion(ctx, req) - if !errors.Is(err, ErrChatCompletionInvalidModel) { - t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err) - } -} - -// TestCompletions Tests the completions endpoint of the API using the mocked server. -func TestChatCompletions(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := ChatCompletionRequest{ - MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Content: "Hello!", - }, - }, - } - _, err = client.CreateChatCompletion(ctx, req) - if err != nil { - t.Fatalf("CreateChatCompletion error: %v", err) - } -} - -// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. -func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // completions only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var completionReq ChatCompletionRequest - if completionReq, err = getChatCompletionBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - res := ChatCompletionResponse{ - ID: strconv.Itoa(int(time.Now().Unix())), - Object: "test-object", - Created: time.Now().Unix(), - // would be nice to validate Model during testing, but - // this may not be possible with how much upkeep - // would be required / wouldn't make much sense - Model: completionReq.Model, - } - // create completions - for i := 0; i < completionReq.N; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleAssistant, - Content: completionStr, - }, - Index: i, - }) - } - inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N - res.Usage = Usage{ - PromptTokens: inputTokens, - CompletionTokens: completionTokens, - TotalTokens: inputTokens + completionTokens, - } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// getChatCompletionBody Returns the body of the request to create a completion. -func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { - completion := ChatCompletionRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return ChatCompletionRequest{}, err - } - err = json.Unmarshal(reqBody, &completion) - if err != nil { - return ChatCompletionRequest{}, err - } - return completion, nil -} diff --git a/go-gpt3/common.go b/go-gpt3/common.go index 3b555a7..9fb0178 100644 --- a/go-gpt3/common.go +++ b/go-gpt3/common.go @@ -1,5 +1,5 @@ // common.go defines common types used throughout the OpenAI API. -package openai +package gogpt // Usage Represents the total token usage per request to OpenAI. type Usage struct { diff --git a/go-gpt3/completion.go b/go-gpt3/completion.go index 66b4866..853c057 100644 --- a/go-gpt3/completion.go +++ b/go-gpt3/completion.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "context" diff --git a/go-gpt3/completion_test.go b/go-gpt3/completion_test.go deleted file mode 100644 index 9868eb2..0000000 --- a/go-gpt3/completion_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "testing" - "time" -) - -func TestCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) - - _, err := client.CreateCompletion( - context.Background(), - CompletionRequest{ - MaxTokens: 5, - Model: GPT3Dot5Turbo, - }, - ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { - t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) - } -} - -// TestCompletions Tests the completions endpoint of the API using the mocked server. -func TestCompletions(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/completions", handleCompletionEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := CompletionRequest{ - MaxTokens: 5, - Model: "ada", - } - req.Prompt = "Lorem ipsum" - _, err = client.CreateCompletion(ctx, req) - if err != nil { - t.Fatalf("CreateCompletion error: %v", err) - } -} - -// handleCompletionEndpoint Handles the completion endpoint by the test server. -func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // completions only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var completionReq CompletionRequest - if completionReq, err = getCompletionBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - res := CompletionResponse{ - ID: strconv.Itoa(int(time.Now().Unix())), - Object: "test-object", - Created: time.Now().Unix(), - // would be nice to validate Model during testing, but - // this may not be possible with how much upkeep - // would be required / wouldn't make much sense - Model: completionReq.Model, - } - // create completions - for i := 0; i < completionReq.N; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - if completionReq.Echo { - completionStr = completionReq.Prompt + completionStr - } - res.Choices = append(res.Choices, CompletionChoice{ - Text: completionStr, - Index: i, - }) - } - inputTokens := numTokens(completionReq.Prompt) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N - res.Usage = Usage{ - PromptTokens: inputTokens, - CompletionTokens: completionTokens, - TotalTokens: inputTokens + completionTokens, - } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return CompletionRequest{}, err - } - err = json.Unmarshal(reqBody, &completion) - if err != nil { - return CompletionRequest{}, err - } - return completion, nil -} diff --git a/go-gpt3/config.go b/go-gpt3/config.go index e09c256..236a2dd 100644 --- a/go-gpt3/config.go +++ b/go-gpt3/config.go @@ -1,11 +1,11 @@ -package openai +package gogpt import ( "net/http" ) const ( - apiURLv1 = "https://api.openai.com/v1" + apiURLv1 = "https://api.gogpt.com/v1" defaultEmptyMessagesLimit uint = 300 ) diff --git a/go-gpt3/edits.go b/go-gpt3/edits.go index 858a8e5..265cfec 100644 --- a/go-gpt3/edits.go +++ b/go-gpt3/edits.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "context" diff --git a/go-gpt3/edits_test.go b/go-gpt3/edits_test.go deleted file mode 100644 index 6a16f7c..0000000 --- a/go-gpt3/edits_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "testing" - "time" -) - -// TestEdits Tests the edits endpoint of the API using the mocked server. -func TestEdits(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/edits", handleEditEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - // create an edit request - model := "ada" - editReq := EditsRequest{ - Model: &model, - Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + - "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + - " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + - " ex ea commodo consequat. Duis aute irure dolor in reprehe", - Instruction: "test instruction", - N: 3, - } - response, err := client.Edits(ctx, editReq) - if err != nil { - t.Fatalf("Edits error: %v", err) - } - if len(response.Choices) != editReq.N { - t.Fatalf("edits does not properly return the correct number of choices") - } -} - -// handleEditEndpoint Handles the edit endpoint by the test server. -func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // edits only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var editReq EditsRequest - editReq, err = getEditBody(r) - if err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - // create a response - res := EditsResponse{ - Object: "test-object", - Created: time.Now().Unix(), - } - // edit and calculate token usage - editString := "edited by mocked OpenAI server :)" - inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N - completionTokens := int(float32(len(editString))/4) * editReq.N - for i := 0; i < editReq.N; i++ { - // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, EditsChoice{ - Text: editReq.Input + editString, - Index: i, - }) - } - res.Usage = Usage{ - PromptTokens: inputTokens, - CompletionTokens: completionTokens, - TotalTokens: inputTokens + completionTokens, - } - resBytes, _ = json.Marshal(res) - fmt.Fprint(w, string(resBytes)) -} - -// getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (EditsRequest, error) { - edit := EditsRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return EditsRequest{}, err - } - err = json.Unmarshal(reqBody, &edit) - if err != nil { - return EditsRequest{}, err - } - return edit, nil -} diff --git a/go-gpt3/embeddings.go b/go-gpt3/embeddings.go index 2deaccc..01bb090 100644 --- a/go-gpt3/embeddings.go +++ b/go-gpt3/embeddings.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "context" @@ -130,7 +130,7 @@ type EmbeddingRequest struct { } // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. -// https://beta.openai.com/docs/api-reference/embeddings/create +// https://beta.gogpt.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) if err != nil { diff --git a/go-gpt3/embeddings_test.go b/go-gpt3/embeddings_test.go deleted file mode 100644 index 2aa48c5..0000000 --- a/go-gpt3/embeddings_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - - "bytes" - "encoding/json" - "testing" -) - -func TestEmbedding(t *testing.T) { - embeddedModels := []EmbeddingModel{ - AdaSimilarity, - BabbageSimilarity, - CurieSimilarity, - DavinciSimilarity, - AdaSearchDocument, - AdaSearchQuery, - BabbageSearchDocument, - BabbageSearchQuery, - CurieSearchDocument, - CurieSearchQuery, - DavinciSearchDocument, - DavinciSearchQuery, - AdaCodeSearchCode, - AdaCodeSearchText, - BabbageCodeSearchCode, - BabbageCodeSearchText, - } - for _, model := range embeddedModels { - embeddingReq := EmbeddingRequest{ - Input: []string{ - "The food was delicious and the waiter", - "Other examples of embedding request", - }, - Model: model, - } - // marshal embeddingReq to JSON and confirm that the model field equals - // the AdaSearchQuery type - marshaled, err := json.Marshal(embeddingReq) - if err != nil { - t.Fatalf("Could not marshal embedding request: %v", err) - } - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { - t.Fatalf("Expected embedding request to contain model field") - } - } -} diff --git a/go-gpt3/engines.go b/go-gpt3/engines.go index bb6a66c..019de1a 100644 --- a/go-gpt3/engines.go +++ b/go-gpt3/engines.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "context" diff --git a/go-gpt3/error.go b/go-gpt3/error.go index d041da2..927fafd 100644 --- a/go-gpt3/error.go +++ b/go-gpt3/error.go @@ -1,4 +1,4 @@ -package openai +package gogpt import "fmt" diff --git a/go-gpt3/files.go b/go-gpt3/files.go index ec441c3..3716453 100644 --- a/go-gpt3/files.go +++ b/go-gpt3/files.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "bytes" diff --git a/go-gpt3/files_test.go b/go-gpt3/files_test.go deleted file mode 100644 index 6a78ce1..0000000 --- a/go-gpt3/files_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "encoding/json" - "fmt" - "net/http" - "strconv" - "testing" - "time" -) - -func TestFileUpload(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files", handleCreateFile) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := FileRequest{ - FileName: "test.go", - FilePath: "api.go", - Purpose: "fine-tune", - } - _, err = client.CreateFile(ctx, req) - if err != nil { - t.Fatalf("CreateFile error: %v", err) - } -} - -// handleCreateFile Handles the images endpoint by the test server. -func handleCreateFile(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // edits only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - err = r.ParseMultipartForm(1024 * 1024 * 1024) - if err != nil { - http.Error(w, "file is more than 1GB", http.StatusInternalServerError) - return - } - - values := r.Form - var purpose string - for key, value := range values { - if key == "purpose" { - purpose = value[0] - } - } - file, header, err := r.FormFile("file") - if err != nil { - return - } - defer file.Close() - - var fileReq = File{ - Bytes: int(header.Size), - ID: strconv.Itoa(int(time.Now().Unix())), - FileName: header.Filename, - Purpose: purpose, - CreatedAt: time.Now().Unix(), - Object: "test-objecct", - Owner: "test-owner", - } - - resBytes, _ = json.Marshal(fileReq) - fmt.Fprint(w, string(resBytes)) -} diff --git a/go-gpt3/fine_tunes.go b/go-gpt3/fine_tunes.go index a121867..e48f5ba 100644 --- a/go-gpt3/fine_tunes.go +++ b/go-gpt3/fine_tunes.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "context" diff --git a/go-gpt3/fine_tunes_test.go b/go-gpt3/fine_tunes_test.go deleted file mode 100644 index 1f6f967..0000000 --- a/go-gpt3/fine_tunes_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "encoding/json" - "fmt" - "net/http" - "testing" -) - -const testFineTuneID = "fine-tune-id" - -// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. -func TestFineTunes(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler( - "/v1/fine-tunes", - func(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - if r.Method == http.MethodGet { - resBytes, _ = json.Marshal(FineTuneList{}) - } else { - resBytes, _ = json.Marshal(FineTune{}) - } - fmt.Fprintln(w, string(resBytes)) - }, - ) - - server.RegisterHandler( - "/v1/fine-tunes/"+testFineTuneID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTune{}) - fmt.Fprintln(w, string(resBytes)) - }, - ) - - server.RegisterHandler( - "/v1/fine-tunes/"+testFineTuneID, - func(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - if r.Method == http.MethodDelete { - resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) - } else { - resBytes, _ = json.Marshal(FineTune{}) - } - fmt.Fprintln(w, string(resBytes)) - }, - ) - - server.RegisterHandler( - "/v1/fine-tunes/"+testFineTuneID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuneEventList{}) - fmt.Fprintln(w, string(resBytes)) - }, - ) - - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.ListFineTunes(ctx) - if err != nil { - t.Fatalf("ListFineTunes error: %v", err) - } - - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) - if err != nil { - t.Fatalf("CreateFineTune error: %v", err) - } - - _, err = client.CancelFineTune(ctx, testFineTuneID) - if err != nil { - t.Fatalf("CancelFineTune error: %v", err) - } - - _, err = client.GetFineTune(ctx, testFineTuneID) - if err != nil { - t.Fatalf("GetFineTune error: %v", err) - } - - _, err = client.DeleteFineTune(ctx, testFineTuneID) - if err != nil { - t.Fatalf("DeleteFineTune error: %v", err) - } - - _, err = client.ListFineTuneEvents(ctx, testFineTuneID) - if err != nil { - t.Fatalf("ListFineTuneEvents error: %v", err) - } -} diff --git a/go-gpt3/go.mod b/go-gpt3/go.mod index 42cc7b3..2c6dc62 100644 --- a/go-gpt3/go.mod +++ b/go-gpt3/go.mod @@ -1,3 +1,3 @@ -module github.com/sashabaranov/go-openai +module github.com/sashabaranov/go-gogpt go 1.18 diff --git a/go-gpt3/image.go b/go-gpt3/image.go index c0dfa64..188eb58 100644 --- a/go-gpt3/image.go +++ b/go-gpt3/image.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "bytes" @@ -147,7 +147,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } writer.Close() - //https://platform.openai.com/docs/api-reference/images/create-variation + //https://platform.gogpt.com/docs/api-reference/images/create-variation urlSuffix := "/images/variations" req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) if err != nil { diff --git a/go-gpt3/image_test.go b/go-gpt3/image_test.go deleted file mode 100644 index b7949c8..0000000 --- a/go-gpt3/image_test.go +++ /dev/null @@ -1,268 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "testing" - "time" -) - -func TestImages(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/generations", handleImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := ImageRequest{} - req.Prompt = "Lorem ipsum" - _, err = client.CreateImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } -} - -// handleImageEndpoint Handles the images endpoint by the test server. -func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var imageReq ImageRequest - if imageReq, err = getImageBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - res := ImageResponse{ - Created: time.Now().Unix(), - } - for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} - switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": - imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: - // This decodes to "{}" in base64. - imageData.B64JSON = "e30K" - default: - http.Error(w, "invalid response format", http.StatusBadRequest) - return - } - res.Data = append(res.Data, imageData) - } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return ImageRequest{}, err - } - err = json.Unmarshal(reqBody, &image) - if err != nil { - return ImageRequest{}, err - } - return image, nil -} - -func TestImageEdit(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - mask, err := os.Create("mask.png") - if err != nil { - t.Error("open mask file error") - return - } - - defer func() { - mask.Close() - origin.Close() - os.Remove("mask.png") - os.Remove("image.png") - }() - - req := ImageEditRequest{ - Image: origin, - Mask: mask, - Prompt: "There is a turtle in the pool", - N: 3, - Size: CreateImageSize1024x1024, - } - _, err = client.CreateEditImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } -} - -func TestImageEditWithoutMask(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - defer func() { - origin.Close() - os.Remove("image.png") - }() - - req := ImageEditRequest{ - Image: origin, - Prompt: "There is a turtle in the pool", - N: 3, - Size: CreateImageSize1024x1024, - } - _, err = client.CreateEditImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } -} - -// handleEditImageEndpoint Handles the images endpoint by the test server. -func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - - responses := ImageResponse{ - Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ - { - URL: "test-url1", - B64JSON: "", - }, - { - URL: "test-url2", - B64JSON: "", - }, - { - URL: "test-url3", - B64JSON: "", - }, - }, - } - - resBytes, _ = json.Marshal(responses) - fmt.Fprintln(w, string(resBytes)) -} - -func TestImageVariation(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - defer func() { - origin.Close() - os.Remove("image.png") - }() - - req := ImageVariRequest{ - Image: origin, - N: 3, - Size: CreateImageSize1024x1024, - } - _, err = client.CreateVariImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } -} - -// handleVariateImageEndpoint Handles the images endpoint by the test server. -func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - - responses := ImageResponse{ - Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ - { - URL: "test-url1", - B64JSON: "", - }, - { - URL: "test-url2", - B64JSON: "", - }, - { - URL: "test-url3", - B64JSON: "", - }, - }, - } - - resBytes, _ = json.Marshal(responses) - fmt.Fprintln(w, string(resBytes)) -} diff --git a/go-gpt3/marshaller.go b/go-gpt3/marshaller.go index 308ccd1..651514e 100644 --- a/go-gpt3/marshaller.go +++ b/go-gpt3/marshaller.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "encoding/json" diff --git a/go-gpt3/models.go b/go-gpt3/models.go index 2be91aa..71d3553 100644 --- a/go-gpt3/models.go +++ b/go-gpt3/models.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "context" diff --git a/go-gpt3/models_test.go b/go-gpt3/models_test.go deleted file mode 100644 index c96ece8..0000000 --- a/go-gpt3/models_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "encoding/json" - "fmt" - "net/http" - "testing" -) - -// TestListModels Tests the models endpoint of the API using the mocked server. -func TestListModels(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/models", handleModelsEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.ListModels(ctx) - if err != nil { - t.Fatalf("ListModels error: %v", err) - } -} - -// handleModelsEndpoint Handles the models endpoint by the test server. -func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(ModelsList{}) - fmt.Fprintln(w, string(resBytes)) -} diff --git a/go-gpt3/moderation.go b/go-gpt3/moderation.go index ff789a6..87fbd57 100644 --- a/go-gpt3/moderation.go +++ b/go-gpt3/moderation.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "context" diff --git a/go-gpt3/moderation_test.go b/go-gpt3/moderation_test.go deleted file mode 100644 index f501245..0000000 --- a/go-gpt3/moderation_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "testing" - "time" -) - -// TestModeration Tests the moderations endpoint of the API using the mocked server. -func TestModerations(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/moderations", handleModerationEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - // create an edit request - model := "text-moderation-stable" - moderationReq := ModerationRequest{ - Model: &model, - Input: "I want to kill them.", - } - _, err = client.Moderations(ctx, moderationReq) - if err != nil { - t.Fatalf("Moderation error: %v", err) - } -} - -// handleModerationEndpoint Handles the moderation endpoint by the test server. -func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // completions only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var moderationReq ModerationRequest - if moderationReq, err = getModerationBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - - resCat := ResultCategories{} - resCatScore := ResultCategoryScores{} - switch { - case strings.Contains(moderationReq.Input, "kill"): - resCat = ResultCategories{Violence: true} - resCatScore = ResultCategoryScores{Violence: 1} - case strings.Contains(moderationReq.Input, "hate"): - resCat = ResultCategories{Hate: true} - resCatScore = ResultCategoryScores{Hate: 1} - case strings.Contains(moderationReq.Input, "suicide"): - resCat = ResultCategories{SelfHarm: true} - resCatScore = ResultCategoryScores{SelfHarm: 1} - case strings.Contains(moderationReq.Input, "porn"): - resCat = ResultCategories{Sexual: true} - resCatScore = ResultCategoryScores{Sexual: 1} - } - - result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} - - res := ModerationResponse{ - ID: strconv.Itoa(int(time.Now().Unix())), - Model: *moderationReq.Model, - } - res.Results = append(res.Results, result) - - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// getModerationBody Returns the body of the request to do a moderation. -func getModerationBody(r *http.Request) (ModerationRequest, error) { - moderation := ModerationRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return ModerationRequest{}, err - } - err = json.Unmarshal(reqBody, &moderation) - if err != nil { - return ModerationRequest{}, err - } - return moderation, nil -} diff --git a/go-gpt3/request_builder.go b/go-gpt3/request_builder.go index f0cef10..c505d16 100644 --- a/go-gpt3/request_builder.go +++ b/go-gpt3/request_builder.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "bytes" diff --git a/go-gpt3/request_builder_test.go b/go-gpt3/request_builder_test.go deleted file mode 100644 index 533977a..0000000 --- a/go-gpt3/request_builder_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "errors" - "net/http" - "testing" -) - -var ( - errTestMarshallerFailed = errors.New("test marshaller failed") - errTestRequestBuilderFailed = errors.New("test request builder failed") -) - -type ( - failingRequestBuilder struct{} - failingMarshaller struct{} -) - -func (*failingMarshaller) marshal(value any) ([]byte, error) { - return []byte{}, errTestMarshallerFailed -} - -func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) { - return nil, errTestRequestBuilderFailed -} - -func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { - builder := httpRequestBuilder{ - marshaller: &failingMarshaller{}, - } - - _, err := builder.build(context.Background(), "", "", struct{}{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } -} - -func TestClientReturnsRequestBuilderErrors(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - client.requestBuilder = &failingRequestBuilder{} - - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTunes(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CancelFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.DeleteFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTuneEvents(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Moderations(ctx, ModerationRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Edits(ctx, EditsRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateImage(ctx, ImageRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - err = client.DeleteFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFiles(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListEngines(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetEngine(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListModels(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } -} diff --git a/go-gpt3/stream.go b/go-gpt3/stream.go index 0eed4aa..4745b47 100644 --- a/go-gpt3/stream.go +++ b/go-gpt3/stream.go @@ -1,4 +1,4 @@ -package openai +package gogpt import ( "bufio" diff --git a/go-gpt3/stream_test.go b/go-gpt3/stream_test.go deleted file mode 100644 index 8f89e6b..0000000 --- a/go-gpt3/stream_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package openai_test - -import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "errors" - "io" - "net/http" - "net/http/httptest" - "testing" -) - -func TestCreateCompletionStream(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - - // Send test responses - dataBytes := []byte{} - dataBytes = append(dataBytes, []byte("event: message\n")...) - //nolint:lll - data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - dataBytes = append(dataBytes, []byte("event: message\n")...) - //nolint:lll - data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - dataBytes = append(dataBytes, []byte("event: done\n")...) - dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) - - _, err := w.Write(dataBytes) - if err != nil { - t.Errorf("Write error: %s", err) - } - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() - - request := CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: "text-davinci-002", - MaxTokens: 10, - Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) - if err != nil { - t.Errorf("CreateCompletionStream returned error: %v", err) - } - defer stream.Close() - - expectedResponses := []CompletionResponse{ - { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, - }, - { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, - }, - } - - for ix, expectedResponse := range expectedResponses { - receivedResponse, streamErr := stream.Recv() - if streamErr != nil { - t.Errorf("stream.Recv() failed: %v", streamErr) - } - if !compareResponses(expectedResponse, receivedResponse) { - t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) - } - } - - _, streamErr := stream.Recv() - if !errors.Is(streamErr, io.EOF) { - t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) - } - - _, streamErr = stream.Recv() - if !errors.Is(streamErr, io.EOF) { - t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) - } -} - -// A "tokenRoundTripper" is a struct that implements the RoundTripper -// interface, specifically to handle the authentication token by adding a token -// to the request header. We need this because the API requires that each -// request include a valid API token in the headers for authentication and -// authorization. -type tokenRoundTripper struct { - token string - fallback http.RoundTripper -} - -// RoundTrip takes an *http.Request as input and returns an -// *http.Response and an error. -// -// It is expected to use the provided request to create a connection to an HTTP -// server and return the response, or an error if one occurred. The returned -// Response should have its Body closed. If the RoundTrip method returns an -// error, the Client's Get, Head, Post, and PostForm methods return the same -// error. -func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("Authorization", "Bearer "+t.token) - return t.fallback.RoundTrip(req) -} - -// Helper funcs. -func compareResponses(r1, r2 CompletionResponse) bool { - if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { - return false - } - if len(r1.Choices) != len(r2.Choices) { - return false - } - for i := range r1.Choices { - if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) { - return false - } - } - return true -} - -func compareResponseChoices(c1, c2 CompletionChoice) bool { - if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { - return false - } - return true -}