diff --git a/go-gpt3/.gitignore b/go-gpt3/.gitignore
index 99b40bf..708f197 100644
--- a/go-gpt3/.gitignore
+++ b/go-gpt3/.gitignore
@@ -15,5 +15,5 @@
# vendor/
# Auth token for tests
-.openai-token
+.gogpt-token
.idea
\ No newline at end of file
diff --git a/go-gpt3/README.md b/go-gpt3/README.md
index e6e352e..4f6a437 100644
--- a/go-gpt3/README.md
+++ b/go-gpt3/README.md
@@ -1,11 +1,11 @@
# Go OpenAI
-[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai)
-[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai)
-[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai)
+[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gogpt)
+[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gogpt)](https://goreportcard.com/report/github.com/sashabaranov/go-gogpt)
+[![codecov](https://codecov.io/gh/sashabaranov/go-gogpt/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-gogpt)
-> **Note**: the repository was recently renamed from `go-gpt3` to `go-openai`
+> **Note**: the repository was recently renamed from `go-gpt3` to `go-gogpt`
-This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support:
+This library provides Go clients for [OpenAI API](https://platform.gogpt.com/). We support:
* ChatGPT
* GPT-3, GPT-4
@@ -14,7 +14,7 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/).
Installation:
```
-go get github.com/sashabaranov/go-openai
+go get github.com/sashabaranov/go-gogpt
```
@@ -26,18 +26,18 @@ package main
import (
"context"
"fmt"
- openai "github.com/sashabaranov/go-openai"
+ gogpt "github.com/sashabaranov/go-gogpt"
)
func main() {
- client := openai.NewClient("your token")
+ client := gogpt.NewClient("your token")
resp, err := client.CreateChatCompletion(
context.Background(),
- openai.ChatCompletionRequest{
- Model: openai.GPT3Dot5Turbo,
- Messages: []openai.ChatCompletionMessage{
+ gogpt.ChatCompletionRequest{
+ Model: gogpt.GPT3Dot5Turbo,
+ Messages: []gogpt.ChatCompletionMessage{
{
- Role: openai.ChatMessageRoleUser,
+ Role: gogpt.ChatMessageRoleUser,
Content: "Hello!",
},
},
@@ -67,15 +67,15 @@ package main
import (
"context"
"fmt"
- openai "github.com/sashabaranov/go-openai"
+ gogpt "github.com/sashabaranov/go-gogpt"
)
func main() {
- c := openai.NewClient("your token")
+ c := gogpt.NewClient("your token")
ctx := context.Background()
- req := openai.CompletionRequest{
- Model: openai.GPT3Ada,
+ req := gogpt.CompletionRequest{
+ Model: gogpt.GPT3Ada,
MaxTokens: 5,
Prompt: "Lorem ipsum",
}
@@ -100,15 +100,15 @@ import (
"context"
"fmt"
"io"
- openai "github.com/sashabaranov/go-openai"
+ gogpt "github.com/sashabaranov/go-gogpt"
)
func main() {
- c := openai.NewClient("your token")
+ c := gogpt.NewClient("your token")
ctx := context.Background()
- req := openai.CompletionRequest{
- Model: openai.GPT3Ada,
+ req := gogpt.CompletionRequest{
+ Model: gogpt.GPT3Ada,
MaxTokens: 5,
Prompt: "Lorem ipsum",
Stream: true,
@@ -149,15 +149,15 @@ import (
"context"
"fmt"
- openai "github.com/sashabaranov/go-openai"
+ gogpt "github.com/sashabaranov/go-gogpt"
)
func main() {
- c := openai.NewClient("your token")
+ c := gogpt.NewClient("your token")
ctx := context.Background()
- req := openai.AudioRequest{
- Model: openai.Whisper1,
+ req := gogpt.AudioRequest{
+ Model: gogpt.Whisper1,
FilePath: "recording.mp3",
}
resp, err := c.CreateTranscription(ctx, req)
@@ -181,20 +181,20 @@ import (
"context"
"encoding/base64"
"fmt"
- openai "github.com/sashabaranov/go-openai"
+ gogpt "github.com/sashabaranov/go-gogpt"
"image/png"
"os"
)
func main() {
- c := openai.NewClient("your token")
+ c := gogpt.NewClient("your token")
ctx := context.Background()
// Sample image by link
- reqUrl := openai.ImageRequest{
+ reqUrl := gogpt.ImageRequest{
Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail",
- Size: openai.CreateImageSize256x256,
- ResponseFormat: openai.CreateImageResponseFormatURL,
+ Size: gogpt.CreateImageSize256x256,
+ ResponseFormat: gogpt.CreateImageResponseFormatURL,
N: 1,
}
@@ -206,10 +206,10 @@ func main() {
fmt.Println(respUrl.Data[0].URL)
// Example image as base64
- reqBase64 := openai.ImageRequest{
+ reqBase64 := gogpt.ImageRequest{
Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine",
- Size: openai.CreateImageSize256x256,
- ResponseFormat: openai.CreateImageResponseFormatB64JSON,
+ Size: gogpt.CreateImageSize256x256,
+ ResponseFormat: gogpt.CreateImageResponseFormatB64JSON,
N: 1,
}
@@ -254,7 +254,7 @@ func main() {
Configuring proxy
```go
-config := openai.DefaultConfig("token")
+config := gogpt.DefaultConfig("token")
proxyUrl, err := url.Parse("http://localhost:{port}")
if err != nil {
panic(err)
@@ -266,10 +266,10 @@ config.HTTPClient = &http.Client{
Transport: transport,
}
-c := openai.NewClientWithConfig(config)
+c := gogpt.NewClientWithConfig(config)
```
-See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig
+See also: https://pkg.go.dev/github.com/sashabaranov/go-gogpt#ClientConfig
@@ -285,12 +285,12 @@ import (
"os"
"strings"
- "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-gogpt"
)
func main() {
- client := openai.NewClient("your token")
- messages := make([]openai.ChatCompletionMessage, 0)
+ client := gogpt.NewClient("your token")
+ messages := make([]gogpt.ChatCompletionMessage, 0)
reader := bufio.NewReader(os.Stdin)
fmt.Println("Conversation")
fmt.Println("---------------------")
@@ -300,15 +300,15 @@ func main() {
text, _ := reader.ReadString('\n')
// convert CRLF to LF
text = strings.Replace(text, "\n", "", -1)
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleUser,
+ messages = append(messages, gogpt.ChatCompletionMessage{
+ Role: gogpt.ChatMessageRoleUser,
Content: text,
})
resp, err := client.CreateChatCompletion(
context.Background(),
- openai.ChatCompletionRequest{
- Model: openai.GPT3Dot5Turbo,
+ gogpt.ChatCompletionRequest{
+ Model: gogpt.GPT3Dot5Turbo,
Messages: messages,
},
)
@@ -319,8 +319,8 @@ func main() {
}
content := resp.Choices[0].Message.Content
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleAssistant,
+ messages = append(messages, gogpt.ChatCompletionMessage{
+ Role: gogpt.ChatMessageRoleAssistant,
Content: content,
})
fmt.Println(content)
diff --git a/go-gpt3/api.go b/go-gpt3/api.go
index 00d6d35..c36b184 100644
--- a/go-gpt3/api.go
+++ b/go-gpt3/api.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"context"
diff --git a/go-gpt3/api_test.go b/go-gpt3/api_test.go
deleted file mode 100644
index 202ec94..0000000
--- a/go-gpt3/api_test.go
+++ /dev/null
@@ -1,182 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
-
- "context"
- "errors"
- "io"
- "os"
- "testing"
-)
-
-func TestAPI(t *testing.T) {
- apiToken := os.Getenv("OPENAI_TOKEN")
- if apiToken == "" {
- t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
- }
-
- var err error
- c := NewClient(apiToken)
- ctx := context.Background()
- _, err = c.ListEngines(ctx)
- if err != nil {
- t.Fatalf("ListEngines error: %v", err)
- }
-
- _, err = c.GetEngine(ctx, "davinci")
- if err != nil {
- t.Fatalf("GetEngine error: %v", err)
- }
-
- fileRes, err := c.ListFiles(ctx)
- if err != nil {
- t.Fatalf("ListFiles error: %v", err)
- }
-
- if len(fileRes.Files) > 0 {
- _, err = c.GetFile(ctx, fileRes.Files[0].ID)
- if err != nil {
- t.Fatalf("GetFile error: %v", err)
- }
- } // else skip
-
- embeddingReq := EmbeddingRequest{
- Input: []string{
- "The food was delicious and the waiter",
- "Other examples of embedding request",
- },
- Model: AdaSearchQuery,
- }
- _, err = c.CreateEmbeddings(ctx, embeddingReq)
- if err != nil {
- t.Fatalf("Embedding error: %v", err)
- }
-
- _, err = c.CreateChatCompletion(
- ctx,
- ChatCompletionRequest{
- Model: GPT3Dot5Turbo,
- Messages: []ChatCompletionMessage{
- {
- Role: ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- },
- )
-
- if err != nil {
- t.Errorf("CreateChatCompletion (without name) returned error: %v", err)
- }
-
- _, err = c.CreateChatCompletion(
- ctx,
- ChatCompletionRequest{
- Model: GPT3Dot5Turbo,
- Messages: []ChatCompletionMessage{
- {
- Role: ChatMessageRoleUser,
- Name: "John_Doe",
- Content: "Hello!",
- },
- },
- },
- )
-
- if err != nil {
- t.Errorf("CreateChatCompletion (with name) returned error: %v", err)
- }
-
- stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
- Prompt: "Ex falso quodlibet",
- Model: GPT3Ada,
- MaxTokens: 5,
- Stream: true,
- })
- if err != nil {
- t.Errorf("CreateCompletionStream returned error: %v", err)
- }
- defer stream.Close()
-
- counter := 0
- for {
- _, err = stream.Recv()
- if err != nil {
- if errors.Is(err, io.EOF) {
- break
- }
- t.Errorf("Stream error: %v", err)
- } else {
- counter++
- }
- }
- if counter == 0 {
- t.Error("Stream did not return any responses")
- }
-}
-
-func TestAPIError(t *testing.T) {
- apiToken := os.Getenv("OPENAI_TOKEN")
- if apiToken == "" {
- t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
- }
-
- var err error
- c := NewClient(apiToken + "_invalid")
- ctx := context.Background()
- _, err = c.ListEngines(ctx)
- if err == nil {
- t.Fatal("ListEngines did not fail")
- }
-
- var apiErr *APIError
- if !errors.As(err, &apiErr) {
- t.Fatalf("Error is not an APIError: %+v", err)
- }
-
- if apiErr.StatusCode != 401 {
- t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
- }
- if *apiErr.Code != "invalid_api_key" {
- t.Fatalf("Unexpected API error code: %s", *apiErr.Code)
- }
- if apiErr.Error() == "" {
- t.Fatal("Empty error message occured")
- }
-}
-
-func TestRequestError(t *testing.T) {
- var err error
-
- config := DefaultConfig("dummy")
- config.BaseURL = "https://httpbin.org/status/418?"
- c := NewClientWithConfig(config)
- ctx := context.Background()
- _, err = c.ListEngines(ctx)
- if err == nil {
- t.Fatal("ListEngines request did not fail")
- }
-
- var reqErr *RequestError
- if !errors.As(err, &reqErr) {
- t.Fatalf("Error is not a RequestError: %+v", err)
- }
-
- if reqErr.StatusCode != 418 {
- t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode)
- }
-
- if reqErr.Unwrap() == nil {
- t.Fatalf("Empty request error occured")
- }
-}
-
-// numTokens Returns the number of GPT-3 encoded tokens in the given text.
-// This function approximates based on the rule of thumb stated by OpenAI:
-// https://beta.openai.com/tokenizer
-//
-// TODO: implement an actual tokenizer for GPT-3 and Codex (once available)
-func numTokens(s string) int {
- return int(float32(len(s)) / 4)
-}
diff --git a/go-gpt3/audio.go b/go-gpt3/audio.go
index 54bd32f..0dc611e 100644
--- a/go-gpt3/audio.go
+++ b/go-gpt3/audio.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"bytes"
diff --git a/go-gpt3/audio_test.go b/go-gpt3/audio_test.go
deleted file mode 100644
index 2a035c9..0000000
--- a/go-gpt3/audio_test.go
+++ /dev/null
@@ -1,143 +0,0 @@
-package openai_test
-
-import (
- "bytes"
- "errors"
- "io"
- "mime"
- "mime/multipart"
- "net/http"
- "os"
- "path/filepath"
- "strings"
-
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "testing"
-)
-
-// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
-func TestAudio(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
- server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
-
- testcases := []struct {
- name string
- createFn func(context.Context, AudioRequest) (AudioResponse, error)
- }{
- {
- "transcribe",
- client.CreateTranscription,
- },
- {
- "translate",
- client.CreateTranslation,
- },
- }
-
- ctx := context.Background()
-
- dir, cleanup := createTestDirectory(t)
- defer cleanup()
-
- for _, tc := range testcases {
- t.Run(tc.name, func(t *testing.T) {
- path := filepath.Join(dir, "fake.mp3")
- createTestFile(t, path)
-
- req := AudioRequest{
- FilePath: path,
- Model: "whisper-3",
- }
- _, err = tc.createFn(ctx, req)
- if err != nil {
- t.Fatalf("audio API error: %v", err)
- }
- })
- }
-}
-
-// createTestFile creates a fake file with "hello" as the content.
-func createTestFile(t *testing.T, path string) {
- file, err := os.Create(path)
- if err != nil {
- t.Fatalf("failed to create file %v", err)
- }
- if _, err = file.WriteString("hello"); err != nil {
- t.Fatalf("failed to write to file %v", err)
- }
- file.Close()
-}
-
-// createTestDirectory creates a temporary folder which will be deleted when cleanup is called.
-func createTestDirectory(t *testing.T) (path string, cleanup func()) {
- t.Helper()
-
- path, err := os.MkdirTemp(os.TempDir(), "")
- if err != nil {
- t.Fatal(err)
- }
-
- return path, func() { os.RemoveAll(path) }
-}
-
-// handleAudioEndpoint Handles the completion endpoint by the test server.
-func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
- var err error
-
- // audio endpoints only accept POST requests
- if r.Method != "POST" {
- http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
- }
-
- mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
- if err != nil {
- http.Error(w, "failed to parse media type", http.StatusBadRequest)
- return
- }
-
- if !strings.HasPrefix(mediaType, "multipart") {
- http.Error(w, "request is not multipart", http.StatusBadRequest)
- }
-
- boundary, ok := params["boundary"]
- if !ok {
- http.Error(w, "no boundary in params", http.StatusBadRequest)
- return
- }
-
- fileData := &bytes.Buffer{}
- mr := multipart.NewReader(r.Body, boundary)
- part, err := mr.NextPart()
- if err != nil && errors.Is(err, io.EOF) {
- http.Error(w, "error accessing file", http.StatusBadRequest)
- return
- }
- if _, err = io.Copy(fileData, part); err != nil {
- http.Error(w, "failed to copy file", http.StatusInternalServerError)
- return
- }
-
- if len(fileData.Bytes()) == 0 {
- w.WriteHeader(http.StatusInternalServerError)
- http.Error(w, "received empty file data", http.StatusBadRequest)
- return
- }
-
- if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
- http.Error(w, "failed to write body", http.StatusInternalServerError)
- return
- }
-}
diff --git a/go-gpt3/chat.go b/go-gpt3/chat.go
index 99edfe8..bd2b6f7 100644
--- a/go-gpt3/chat.go
+++ b/go-gpt3/chat.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"context"
@@ -23,8 +23,8 @@ type ChatCompletionMessage struct {
// This property isn't in the official documentation, but it's in
// the documentation for the official library for python:
- // - https://github.com/openai/openai-python/blob/main/chatml.md
- // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ // - https://github.com/gogpt/gogpt-python/blob/main/chatml.md
+ // - https://github.com/gogpt/gogpt-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Name string `json:"name,omitempty"`
}
diff --git a/go-gpt3/chat_stream.go b/go-gpt3/chat_stream.go
index 26e964c..6fc440b 100644
--- a/go-gpt3/chat_stream.go
+++ b/go-gpt3/chat_stream.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"bufio"
diff --git a/go-gpt3/chat_stream_test.go b/go-gpt3/chat_stream_test.go
deleted file mode 100644
index e3da2da..0000000
--- a/go-gpt3/chat_stream_test.go
+++ /dev/null
@@ -1,153 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "encoding/json"
- "errors"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
-)
-
-func TestCreateChatCompletionStream(t *testing.T) {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
-
- // Send test responses
- dataBytes := []byte{}
- dataBytes = append(dataBytes, []byte("event: message\n")...)
- //nolint:lll
- data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}`
- dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
-
- dataBytes = append(dataBytes, []byte("event: message\n")...)
- //nolint:lll
- data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}`
- dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
-
- dataBytes = append(dataBytes, []byte("event: done\n")...)
- dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
-
- _, err := w.Write(dataBytes)
- if err != nil {
- t.Errorf("Write error: %s", err)
- }
- }))
- defer server.Close()
-
- // Client portion of the test
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = server.URL + "/v1"
- config.HTTPClient.Transport = &tokenRoundTripper{
- test.GetTestToken(),
- http.DefaultTransport,
- }
-
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- request := ChatCompletionRequest{
- MaxTokens: 5,
- Model: GPT3Dot5Turbo,
- Messages: []ChatCompletionMessage{
- {
- Role: ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- Stream: true,
- }
-
- stream, err := client.CreateChatCompletionStream(ctx, request)
- if err != nil {
- t.Errorf("CreateCompletionStream returned error: %v", err)
- }
- defer stream.Close()
-
- expectedResponses := []ChatCompletionStreamResponse{
- {
- ID: "1",
- Object: "completion",
- Created: 1598069254,
- Model: GPT3Dot5Turbo,
- Choices: []ChatCompletionStreamChoice{
- {
- Delta: ChatCompletionStreamChoiceDelta{
- Content: "response1",
- },
- FinishReason: "max_tokens",
- },
- },
- },
- {
- ID: "2",
- Object: "completion",
- Created: 1598069255,
- Model: GPT3Dot5Turbo,
- Choices: []ChatCompletionStreamChoice{
- {
- Delta: ChatCompletionStreamChoiceDelta{
- Content: "response2",
- },
- FinishReason: "max_tokens",
- },
- },
- },
- }
-
- for ix, expectedResponse := range expectedResponses {
- b, _ := json.Marshal(expectedResponse)
- t.Logf("%d: %s", ix, string(b))
-
- receivedResponse, streamErr := stream.Recv()
- if streamErr != nil {
- t.Errorf("stream.Recv() failed: %v", streamErr)
- }
- if !compareChatResponses(expectedResponse, receivedResponse) {
- t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
- }
- }
-
- _, streamErr := stream.Recv()
- if !errors.Is(streamErr, io.EOF) {
- t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
- }
-
- _, streamErr = stream.Recv()
- if !errors.Is(streamErr, io.EOF) {
- t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
- }
-}
-
-// Helper funcs.
-func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
- if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
- return false
- }
- if len(r1.Choices) != len(r2.Choices) {
- return false
- }
- for i := range r1.Choices {
- if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) {
- return false
- }
- }
- return true
-}
-
-func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool {
- if c1.Index != c2.Index {
- return false
- }
- if c1.Delta.Content != c2.Delta.Content {
- return false
- }
- if c1.FinishReason != c2.FinishReason {
- return false
- }
- return true
-}
diff --git a/go-gpt3/chat_test.go b/go-gpt3/chat_test.go
deleted file mode 100644
index 5c03ebf..0000000
--- a/go-gpt3/chat_test.go
+++ /dev/null
@@ -1,132 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strconv"
- "strings"
- "testing"
- "time"
-)
-
-func TestChatCompletionsWrongModel(t *testing.T) {
- config := DefaultConfig("whatever")
- config.BaseURL = "http://localhost/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- req := ChatCompletionRequest{
- MaxTokens: 5,
- Model: "ada",
- Messages: []ChatCompletionMessage{
- {
- Role: ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- }
- _, err := client.CreateChatCompletion(ctx, req)
- if !errors.Is(err, ErrChatCompletionInvalidModel) {
- t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err)
- }
-}
-
-// TestCompletions Tests the completions endpoint of the API using the mocked server.
-func TestChatCompletions(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- req := ChatCompletionRequest{
- MaxTokens: 5,
- Model: GPT3Dot5Turbo,
- Messages: []ChatCompletionMessage{
- {
- Role: ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- }
- _, err = client.CreateChatCompletion(ctx, req)
- if err != nil {
- t.Fatalf("CreateChatCompletion error: %v", err)
- }
-}
-
-// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
-func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
- var err error
- var resBytes []byte
-
- // completions only accepts POST requests
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- }
- var completionReq ChatCompletionRequest
- if completionReq, err = getChatCompletionBody(r); err != nil {
- http.Error(w, "could not read request", http.StatusInternalServerError)
- return
- }
- res := ChatCompletionResponse{
- ID: strconv.Itoa(int(time.Now().Unix())),
- Object: "test-object",
- Created: time.Now().Unix(),
- // would be nice to validate Model during testing, but
- // this may not be possible with how much upkeep
- // would be required / wouldn't make much sense
- Model: completionReq.Model,
- }
- // create completions
- for i := 0; i < completionReq.N; i++ {
- // generate a random string of length completionReq.Length
- completionStr := strings.Repeat("a", completionReq.MaxTokens)
-
- res.Choices = append(res.Choices, ChatCompletionChoice{
- Message: ChatCompletionMessage{
- Role: ChatMessageRoleAssistant,
- Content: completionStr,
- },
- Index: i,
- })
- }
- inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N
- completionTokens := completionReq.MaxTokens * completionReq.N
- res.Usage = Usage{
- PromptTokens: inputTokens,
- CompletionTokens: completionTokens,
- TotalTokens: inputTokens + completionTokens,
- }
- resBytes, _ = json.Marshal(res)
- fmt.Fprintln(w, string(resBytes))
-}
-
-// getChatCompletionBody Returns the body of the request to create a completion.
-func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) {
- completion := ChatCompletionRequest{}
- // read the request body
- reqBody, err := io.ReadAll(r.Body)
- if err != nil {
- return ChatCompletionRequest{}, err
- }
- err = json.Unmarshal(reqBody, &completion)
- if err != nil {
- return ChatCompletionRequest{}, err
- }
- return completion, nil
-}
diff --git a/go-gpt3/common.go b/go-gpt3/common.go
index 3b555a7..9fb0178 100644
--- a/go-gpt3/common.go
+++ b/go-gpt3/common.go
@@ -1,5 +1,5 @@
// common.go defines common types used throughout the OpenAI API.
-package openai
+package gogpt
// Usage Represents the total token usage per request to OpenAI.
type Usage struct {
diff --git a/go-gpt3/completion.go b/go-gpt3/completion.go
index 66b4866..853c057 100644
--- a/go-gpt3/completion.go
+++ b/go-gpt3/completion.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"context"
diff --git a/go-gpt3/completion_test.go b/go-gpt3/completion_test.go
deleted file mode 100644
index 9868eb2..0000000
--- a/go-gpt3/completion_test.go
+++ /dev/null
@@ -1,121 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strconv"
- "strings"
- "testing"
- "time"
-)
-
-func TestCompletionsWrongModel(t *testing.T) {
- config := DefaultConfig("whatever")
- config.BaseURL = "http://localhost/v1"
- client := NewClientWithConfig(config)
-
- _, err := client.CreateCompletion(
- context.Background(),
- CompletionRequest{
- MaxTokens: 5,
- Model: GPT3Dot5Turbo,
- },
- )
- if !errors.Is(err, ErrCompletionUnsupportedModel) {
- t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
- }
-}
-
-// TestCompletions Tests the completions endpoint of the API using the mocked server.
-func TestCompletions(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- req := CompletionRequest{
- MaxTokens: 5,
- Model: "ada",
- }
- req.Prompt = "Lorem ipsum"
- _, err = client.CreateCompletion(ctx, req)
- if err != nil {
- t.Fatalf("CreateCompletion error: %v", err)
- }
-}
-
-// handleCompletionEndpoint Handles the completion endpoint by the test server.
-func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
- var err error
- var resBytes []byte
-
- // completions only accepts POST requests
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- }
- var completionReq CompletionRequest
- if completionReq, err = getCompletionBody(r); err != nil {
- http.Error(w, "could not read request", http.StatusInternalServerError)
- return
- }
- res := CompletionResponse{
- ID: strconv.Itoa(int(time.Now().Unix())),
- Object: "test-object",
- Created: time.Now().Unix(),
- // would be nice to validate Model during testing, but
- // this may not be possible with how much upkeep
- // would be required / wouldn't make much sense
- Model: completionReq.Model,
- }
- // create completions
- for i := 0; i < completionReq.N; i++ {
- // generate a random string of length completionReq.Length
- completionStr := strings.Repeat("a", completionReq.MaxTokens)
- if completionReq.Echo {
- completionStr = completionReq.Prompt + completionStr
- }
- res.Choices = append(res.Choices, CompletionChoice{
- Text: completionStr,
- Index: i,
- })
- }
- inputTokens := numTokens(completionReq.Prompt) * completionReq.N
- completionTokens := completionReq.MaxTokens * completionReq.N
- res.Usage = Usage{
- PromptTokens: inputTokens,
- CompletionTokens: completionTokens,
- TotalTokens: inputTokens + completionTokens,
- }
- resBytes, _ = json.Marshal(res)
- fmt.Fprintln(w, string(resBytes))
-}
-
-// getCompletionBody Returns the body of the request to create a completion.
-func getCompletionBody(r *http.Request) (CompletionRequest, error) {
- completion := CompletionRequest{}
- // read the request body
- reqBody, err := io.ReadAll(r.Body)
- if err != nil {
- return CompletionRequest{}, err
- }
- err = json.Unmarshal(reqBody, &completion)
- if err != nil {
- return CompletionRequest{}, err
- }
- return completion, nil
-}
diff --git a/go-gpt3/config.go b/go-gpt3/config.go
index e09c256..236a2dd 100644
--- a/go-gpt3/config.go
+++ b/go-gpt3/config.go
@@ -1,11 +1,11 @@
-package openai
+package gogpt
import (
"net/http"
)
const (
- apiURLv1 = "https://api.openai.com/v1"
+ apiURLv1 = "https://api.gogpt.com/v1"
defaultEmptyMessagesLimit uint = 300
)
diff --git a/go-gpt3/edits.go b/go-gpt3/edits.go
index 858a8e5..265cfec 100644
--- a/go-gpt3/edits.go
+++ b/go-gpt3/edits.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"context"
diff --git a/go-gpt3/edits_test.go b/go-gpt3/edits_test.go
deleted file mode 100644
index 6a16f7c..0000000
--- a/go-gpt3/edits_test.go
+++ /dev/null
@@ -1,104 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "testing"
- "time"
-)
-
-// TestEdits Tests the edits endpoint of the API using the mocked server.
-func TestEdits(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/edits", handleEditEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- // create an edit request
- model := "ada"
- editReq := EditsRequest{
- Model: &model,
- Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " +
- "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" +
- " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" +
- " ex ea commodo consequat. Duis aute irure dolor in reprehe",
- Instruction: "test instruction",
- N: 3,
- }
- response, err := client.Edits(ctx, editReq)
- if err != nil {
- t.Fatalf("Edits error: %v", err)
- }
- if len(response.Choices) != editReq.N {
- t.Fatalf("edits does not properly return the correct number of choices")
- }
-}
-
-// handleEditEndpoint Handles the edit endpoint by the test server.
-func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
- var err error
- var resBytes []byte
-
- // edits only accepts POST requests
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- }
- var editReq EditsRequest
- editReq, err = getEditBody(r)
- if err != nil {
- http.Error(w, "could not read request", http.StatusInternalServerError)
- return
- }
- // create a response
- res := EditsResponse{
- Object: "test-object",
- Created: time.Now().Unix(),
- }
- // edit and calculate token usage
- editString := "edited by mocked OpenAI server :)"
- inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N
- completionTokens := int(float32(len(editString))/4) * editReq.N
- for i := 0; i < editReq.N; i++ {
- // instruction will be hidden and only seen by OpenAI
- res.Choices = append(res.Choices, EditsChoice{
- Text: editReq.Input + editString,
- Index: i,
- })
- }
- res.Usage = Usage{
- PromptTokens: inputTokens,
- CompletionTokens: completionTokens,
- TotalTokens: inputTokens + completionTokens,
- }
- resBytes, _ = json.Marshal(res)
- fmt.Fprint(w, string(resBytes))
-}
-
-// getEditBody Returns the body of the request to create an edit.
-func getEditBody(r *http.Request) (EditsRequest, error) {
- edit := EditsRequest{}
- // read the request body
- reqBody, err := io.ReadAll(r.Body)
- if err != nil {
- return EditsRequest{}, err
- }
- err = json.Unmarshal(reqBody, &edit)
- if err != nil {
- return EditsRequest{}, err
- }
- return edit, nil
-}
diff --git a/go-gpt3/embeddings.go b/go-gpt3/embeddings.go
index 2deaccc..01bb090 100644
--- a/go-gpt3/embeddings.go
+++ b/go-gpt3/embeddings.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"context"
@@ -130,7 +130,7 @@ type EmbeddingRequest struct {
}
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
-// https://beta.openai.com/docs/api-reference/embeddings/create
+// https://beta.gogpt.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request)
if err != nil {
diff --git a/go-gpt3/embeddings_test.go b/go-gpt3/embeddings_test.go
deleted file mode 100644
index 2aa48c5..0000000
--- a/go-gpt3/embeddings_test.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
-
- "bytes"
- "encoding/json"
- "testing"
-)
-
-func TestEmbedding(t *testing.T) {
- embeddedModels := []EmbeddingModel{
- AdaSimilarity,
- BabbageSimilarity,
- CurieSimilarity,
- DavinciSimilarity,
- AdaSearchDocument,
- AdaSearchQuery,
- BabbageSearchDocument,
- BabbageSearchQuery,
- CurieSearchDocument,
- CurieSearchQuery,
- DavinciSearchDocument,
- DavinciSearchQuery,
- AdaCodeSearchCode,
- AdaCodeSearchText,
- BabbageCodeSearchCode,
- BabbageCodeSearchText,
- }
- for _, model := range embeddedModels {
- embeddingReq := EmbeddingRequest{
- Input: []string{
- "The food was delicious and the waiter",
- "Other examples of embedding request",
- },
- Model: model,
- }
- // marshal embeddingReq to JSON and confirm that the model field equals
- // the AdaSearchQuery type
- marshaled, err := json.Marshal(embeddingReq)
- if err != nil {
- t.Fatalf("Could not marshal embedding request: %v", err)
- }
- if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
- t.Fatalf("Expected embedding request to contain model field")
- }
- }
-}
diff --git a/go-gpt3/engines.go b/go-gpt3/engines.go
index bb6a66c..019de1a 100644
--- a/go-gpt3/engines.go
+++ b/go-gpt3/engines.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"context"
diff --git a/go-gpt3/error.go b/go-gpt3/error.go
index d041da2..927fafd 100644
--- a/go-gpt3/error.go
+++ b/go-gpt3/error.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import "fmt"
diff --git a/go-gpt3/files.go b/go-gpt3/files.go
index ec441c3..3716453 100644
--- a/go-gpt3/files.go
+++ b/go-gpt3/files.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"bytes"
diff --git a/go-gpt3/files_test.go b/go-gpt3/files_test.go
deleted file mode 100644
index 6a78ce1..0000000
--- a/go-gpt3/files_test.go
+++ /dev/null
@@ -1,81 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "strconv"
- "testing"
- "time"
-)
-
-func TestFileUpload(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/files", handleCreateFile)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- req := FileRequest{
- FileName: "test.go",
- FilePath: "api.go",
- Purpose: "fine-tune",
- }
- _, err = client.CreateFile(ctx, req)
- if err != nil {
- t.Fatalf("CreateFile error: %v", err)
- }
-}
-
-// handleCreateFile Handles the images endpoint by the test server.
-func handleCreateFile(w http.ResponseWriter, r *http.Request) {
- var err error
- var resBytes []byte
-
- // edits only accepts POST requests
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- }
- err = r.ParseMultipartForm(1024 * 1024 * 1024)
- if err != nil {
- http.Error(w, "file is more than 1GB", http.StatusInternalServerError)
- return
- }
-
- values := r.Form
- var purpose string
- for key, value := range values {
- if key == "purpose" {
- purpose = value[0]
- }
- }
- file, header, err := r.FormFile("file")
- if err != nil {
- return
- }
- defer file.Close()
-
- var fileReq = File{
- Bytes: int(header.Size),
- ID: strconv.Itoa(int(time.Now().Unix())),
- FileName: header.Filename,
- Purpose: purpose,
- CreatedAt: time.Now().Unix(),
- Object: "test-objecct",
- Owner: "test-owner",
- }
-
- resBytes, _ = json.Marshal(fileReq)
- fmt.Fprint(w, string(resBytes))
-}
diff --git a/go-gpt3/fine_tunes.go b/go-gpt3/fine_tunes.go
index a121867..e48f5ba 100644
--- a/go-gpt3/fine_tunes.go
+++ b/go-gpt3/fine_tunes.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"context"
diff --git a/go-gpt3/fine_tunes_test.go b/go-gpt3/fine_tunes_test.go
deleted file mode 100644
index 1f6f967..0000000
--- a/go-gpt3/fine_tunes_test.go
+++ /dev/null
@@ -1,101 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "testing"
-)
-
-const testFineTuneID = "fine-tune-id"
-
-// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server.
-func TestFineTunes(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler(
- "/v1/fine-tunes",
- func(w http.ResponseWriter, r *http.Request) {
- var resBytes []byte
- if r.Method == http.MethodGet {
- resBytes, _ = json.Marshal(FineTuneList{})
- } else {
- resBytes, _ = json.Marshal(FineTune{})
- }
- fmt.Fprintln(w, string(resBytes))
- },
- )
-
- server.RegisterHandler(
- "/v1/fine-tunes/"+testFineTuneID+"/cancel",
- func(w http.ResponseWriter, r *http.Request) {
- resBytes, _ := json.Marshal(FineTune{})
- fmt.Fprintln(w, string(resBytes))
- },
- )
-
- server.RegisterHandler(
- "/v1/fine-tunes/"+testFineTuneID,
- func(w http.ResponseWriter, r *http.Request) {
- var resBytes []byte
- if r.Method == http.MethodDelete {
- resBytes, _ = json.Marshal(FineTuneDeleteResponse{})
- } else {
- resBytes, _ = json.Marshal(FineTune{})
- }
- fmt.Fprintln(w, string(resBytes))
- },
- )
-
- server.RegisterHandler(
- "/v1/fine-tunes/"+testFineTuneID+"/events",
- func(w http.ResponseWriter, r *http.Request) {
- resBytes, _ := json.Marshal(FineTuneEventList{})
- fmt.Fprintln(w, string(resBytes))
- },
- )
-
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- _, err = client.ListFineTunes(ctx)
- if err != nil {
- t.Fatalf("ListFineTunes error: %v", err)
- }
-
- _, err = client.CreateFineTune(ctx, FineTuneRequest{})
- if err != nil {
- t.Fatalf("CreateFineTune error: %v", err)
- }
-
- _, err = client.CancelFineTune(ctx, testFineTuneID)
- if err != nil {
- t.Fatalf("CancelFineTune error: %v", err)
- }
-
- _, err = client.GetFineTune(ctx, testFineTuneID)
- if err != nil {
- t.Fatalf("GetFineTune error: %v", err)
- }
-
- _, err = client.DeleteFineTune(ctx, testFineTuneID)
- if err != nil {
- t.Fatalf("DeleteFineTune error: %v", err)
- }
-
- _, err = client.ListFineTuneEvents(ctx, testFineTuneID)
- if err != nil {
- t.Fatalf("ListFineTuneEvents error: %v", err)
- }
-}
diff --git a/go-gpt3/go.mod b/go-gpt3/go.mod
index 42cc7b3..2c6dc62 100644
--- a/go-gpt3/go.mod
+++ b/go-gpt3/go.mod
@@ -1,3 +1,3 @@
-module github.com/sashabaranov/go-openai
+module github.com/sashabaranov/go-gogpt
go 1.18
diff --git a/go-gpt3/image.go b/go-gpt3/image.go
index c0dfa64..188eb58 100644
--- a/go-gpt3/image.go
+++ b/go-gpt3/image.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"bytes"
@@ -147,7 +147,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
return
}
writer.Close()
- //https://platform.openai.com/docs/api-reference/images/create-variation
+ //https://platform.gogpt.com/docs/api-reference/images/create-variation
urlSuffix := "/images/variations"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
if err != nil {
diff --git a/go-gpt3/image_test.go b/go-gpt3/image_test.go
deleted file mode 100644
index b7949c8..0000000
--- a/go-gpt3/image_test.go
+++ /dev/null
@@ -1,268 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "os"
- "testing"
- "time"
-)
-
-func TestImages(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- req := ImageRequest{}
- req.Prompt = "Lorem ipsum"
- _, err = client.CreateImage(ctx, req)
- if err != nil {
- t.Fatalf("CreateImage error: %v", err)
- }
-}
-
-// handleImageEndpoint Handles the images endpoint by the test server.
-func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
- var err error
- var resBytes []byte
-
- // imagess only accepts POST requests
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- }
- var imageReq ImageRequest
- if imageReq, err = getImageBody(r); err != nil {
- http.Error(w, "could not read request", http.StatusInternalServerError)
- return
- }
- res := ImageResponse{
- Created: time.Now().Unix(),
- }
- for i := 0; i < imageReq.N; i++ {
- imageData := ImageResponseDataInner{}
- switch imageReq.ResponseFormat {
- case CreateImageResponseFormatURL, "":
- imageData.URL = "https://example.com/image.png"
- case CreateImageResponseFormatB64JSON:
- // This decodes to "{}" in base64.
- imageData.B64JSON = "e30K"
- default:
- http.Error(w, "invalid response format", http.StatusBadRequest)
- return
- }
- res.Data = append(res.Data, imageData)
- }
- resBytes, _ = json.Marshal(res)
- fmt.Fprintln(w, string(resBytes))
-}
-
-// getImageBody Returns the body of the request to create a image.
-func getImageBody(r *http.Request) (ImageRequest, error) {
- image := ImageRequest{}
- // read the request body
- reqBody, err := io.ReadAll(r.Body)
- if err != nil {
- return ImageRequest{}, err
- }
- err = json.Unmarshal(reqBody, &image)
- if err != nil {
- return ImageRequest{}, err
- }
- return image, nil
-}
-
-func TestImageEdit(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- origin, err := os.Create("image.png")
- if err != nil {
- t.Error("open origin file error")
- return
- }
-
- mask, err := os.Create("mask.png")
- if err != nil {
- t.Error("open mask file error")
- return
- }
-
- defer func() {
- mask.Close()
- origin.Close()
- os.Remove("mask.png")
- os.Remove("image.png")
- }()
-
- req := ImageEditRequest{
- Image: origin,
- Mask: mask,
- Prompt: "There is a turtle in the pool",
- N: 3,
- Size: CreateImageSize1024x1024,
- }
- _, err = client.CreateEditImage(ctx, req)
- if err != nil {
- t.Fatalf("CreateImage error: %v", err)
- }
-}
-
-func TestImageEditWithoutMask(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- origin, err := os.Create("image.png")
- if err != nil {
- t.Error("open origin file error")
- return
- }
-
- defer func() {
- origin.Close()
- os.Remove("image.png")
- }()
-
- req := ImageEditRequest{
- Image: origin,
- Prompt: "There is a turtle in the pool",
- N: 3,
- Size: CreateImageSize1024x1024,
- }
- _, err = client.CreateEditImage(ctx, req)
- if err != nil {
- t.Fatalf("CreateImage error: %v", err)
- }
-}
-
-// handleEditImageEndpoint Handles the images endpoint by the test server.
-func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
- var resBytes []byte
-
- // imagess only accepts POST requests
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- }
-
- responses := ImageResponse{
- Created: time.Now().Unix(),
- Data: []ImageResponseDataInner{
- {
- URL: "test-url1",
- B64JSON: "",
- },
- {
- URL: "test-url2",
- B64JSON: "",
- },
- {
- URL: "test-url3",
- B64JSON: "",
- },
- },
- }
-
- resBytes, _ = json.Marshal(responses)
- fmt.Fprintln(w, string(resBytes))
-}
-
-func TestImageVariation(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- origin, err := os.Create("image.png")
- if err != nil {
- t.Error("open origin file error")
- return
- }
-
- defer func() {
- origin.Close()
- os.Remove("image.png")
- }()
-
- req := ImageVariRequest{
- Image: origin,
- N: 3,
- Size: CreateImageSize1024x1024,
- }
- _, err = client.CreateVariImage(ctx, req)
- if err != nil {
- t.Fatalf("CreateImage error: %v", err)
- }
-}
-
-// handleVariateImageEndpoint Handles the images endpoint by the test server.
-func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
- var resBytes []byte
-
- // imagess only accepts POST requests
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- }
-
- responses := ImageResponse{
- Created: time.Now().Unix(),
- Data: []ImageResponseDataInner{
- {
- URL: "test-url1",
- B64JSON: "",
- },
- {
- URL: "test-url2",
- B64JSON: "",
- },
- {
- URL: "test-url3",
- B64JSON: "",
- },
- },
- }
-
- resBytes, _ = json.Marshal(responses)
- fmt.Fprintln(w, string(resBytes))
-}
diff --git a/go-gpt3/marshaller.go b/go-gpt3/marshaller.go
index 308ccd1..651514e 100644
--- a/go-gpt3/marshaller.go
+++ b/go-gpt3/marshaller.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"encoding/json"
diff --git a/go-gpt3/models.go b/go-gpt3/models.go
index 2be91aa..71d3553 100644
--- a/go-gpt3/models.go
+++ b/go-gpt3/models.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"context"
diff --git a/go-gpt3/models_test.go b/go-gpt3/models_test.go
deleted file mode 100644
index c96ece8..0000000
--- a/go-gpt3/models_test.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "testing"
-)
-
-// TestListModels Tests the models endpoint of the API using the mocked server.
-func TestListModels(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/models", handleModelsEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- _, err = client.ListModels(ctx)
- if err != nil {
- t.Fatalf("ListModels error: %v", err)
- }
-}
-
-// handleModelsEndpoint Handles the models endpoint by the test server.
-func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) {
- resBytes, _ := json.Marshal(ModelsList{})
- fmt.Fprintln(w, string(resBytes))
-}
diff --git a/go-gpt3/moderation.go b/go-gpt3/moderation.go
index ff789a6..87fbd57 100644
--- a/go-gpt3/moderation.go
+++ b/go-gpt3/moderation.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"context"
diff --git a/go-gpt3/moderation_test.go b/go-gpt3/moderation_test.go
deleted file mode 100644
index f501245..0000000
--- a/go-gpt3/moderation_test.go
+++ /dev/null
@@ -1,102 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strconv"
- "strings"
- "testing"
- "time"
-)
-
-// TestModeration Tests the moderations endpoint of the API using the mocked server.
-func TestModerations(t *testing.T) {
- server := test.NewTestServer()
- server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
- // create the test server
- var err error
- ts := server.OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- // create an edit request
- model := "text-moderation-stable"
- moderationReq := ModerationRequest{
- Model: &model,
- Input: "I want to kill them.",
- }
- _, err = client.Moderations(ctx, moderationReq)
- if err != nil {
- t.Fatalf("Moderation error: %v", err)
- }
-}
-
-// handleModerationEndpoint Handles the moderation endpoint by the test server.
-func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
- var err error
- var resBytes []byte
-
- // completions only accepts POST requests
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- }
- var moderationReq ModerationRequest
- if moderationReq, err = getModerationBody(r); err != nil {
- http.Error(w, "could not read request", http.StatusInternalServerError)
- return
- }
-
- resCat := ResultCategories{}
- resCatScore := ResultCategoryScores{}
- switch {
- case strings.Contains(moderationReq.Input, "kill"):
- resCat = ResultCategories{Violence: true}
- resCatScore = ResultCategoryScores{Violence: 1}
- case strings.Contains(moderationReq.Input, "hate"):
- resCat = ResultCategories{Hate: true}
- resCatScore = ResultCategoryScores{Hate: 1}
- case strings.Contains(moderationReq.Input, "suicide"):
- resCat = ResultCategories{SelfHarm: true}
- resCatScore = ResultCategoryScores{SelfHarm: 1}
- case strings.Contains(moderationReq.Input, "porn"):
- resCat = ResultCategories{Sexual: true}
- resCatScore = ResultCategoryScores{Sexual: 1}
- }
-
- result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true}
-
- res := ModerationResponse{
- ID: strconv.Itoa(int(time.Now().Unix())),
- Model: *moderationReq.Model,
- }
- res.Results = append(res.Results, result)
-
- resBytes, _ = json.Marshal(res)
- fmt.Fprintln(w, string(resBytes))
-}
-
-// getModerationBody Returns the body of the request to do a moderation.
-func getModerationBody(r *http.Request) (ModerationRequest, error) {
- moderation := ModerationRequest{}
- // read the request body
- reqBody, err := io.ReadAll(r.Body)
- if err != nil {
- return ModerationRequest{}, err
- }
- err = json.Unmarshal(reqBody, &moderation)
- if err != nil {
- return ModerationRequest{}, err
- }
- return moderation, nil
-}
diff --git a/go-gpt3/request_builder.go b/go-gpt3/request_builder.go
index f0cef10..c505d16 100644
--- a/go-gpt3/request_builder.go
+++ b/go-gpt3/request_builder.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"bytes"
diff --git a/go-gpt3/request_builder_test.go b/go-gpt3/request_builder_test.go
deleted file mode 100644
index 533977a..0000000
--- a/go-gpt3/request_builder_test.go
+++ /dev/null
@@ -1,148 +0,0 @@
-package openai //nolint:testpackage // testing private field
-
-import (
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "errors"
- "net/http"
- "testing"
-)
-
-var (
- errTestMarshallerFailed = errors.New("test marshaller failed")
- errTestRequestBuilderFailed = errors.New("test request builder failed")
-)
-
-type (
- failingRequestBuilder struct{}
- failingMarshaller struct{}
-)
-
-func (*failingMarshaller) marshal(value any) ([]byte, error) {
- return []byte{}, errTestMarshallerFailed
-}
-
-func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) {
- return nil, errTestRequestBuilderFailed
-}
-
-func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
- builder := httpRequestBuilder{
- marshaller: &failingMarshaller{},
- }
-
- _, err := builder.build(context.Background(), "", "", struct{}{})
- if !errors.Is(err, errTestMarshallerFailed) {
- t.Fatalf("Did not return error when marshaller failed: %v", err)
- }
-}
-
-func TestClientReturnsRequestBuilderErrors(t *testing.T) {
- var err error
- ts := test.NewTestServer().OpenAITestServer()
- ts.Start()
- defer ts.Close()
-
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = ts.URL + "/v1"
- client := NewClientWithConfig(config)
- client.requestBuilder = &failingRequestBuilder{}
-
- ctx := context.Background()
-
- _, err = client.CreateCompletion(ctx, CompletionRequest{})
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.CreateFineTune(ctx, FineTuneRequest{})
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.ListFineTunes(ctx)
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.CancelFineTune(ctx, "")
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.GetFineTune(ctx, "")
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.DeleteFineTune(ctx, "")
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.ListFineTuneEvents(ctx, "")
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.Moderations(ctx, ModerationRequest{})
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.Edits(ctx, EditsRequest{})
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.CreateImage(ctx, ImageRequest{})
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- err = client.DeleteFile(ctx, "")
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.GetFile(ctx, "")
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.ListFiles(ctx)
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.ListEngines(ctx)
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.GetEngine(ctx, "")
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-
- _, err = client.ListModels(ctx)
- if !errors.Is(err, errTestRequestBuilderFailed) {
- t.Fatalf("Did not return error when request builder failed: %v", err)
- }
-}
diff --git a/go-gpt3/stream.go b/go-gpt3/stream.go
index 0eed4aa..4745b47 100644
--- a/go-gpt3/stream.go
+++ b/go-gpt3/stream.go
@@ -1,4 +1,4 @@
-package openai
+package gogpt
import (
"bufio"
diff --git a/go-gpt3/stream_test.go b/go-gpt3/stream_test.go
deleted file mode 100644
index 8f89e6b..0000000
--- a/go-gpt3/stream_test.go
+++ /dev/null
@@ -1,147 +0,0 @@
-package openai_test
-
-import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test"
-
- "context"
- "errors"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
-)
-
-func TestCreateCompletionStream(t *testing.T) {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
-
- // Send test responses
- dataBytes := []byte{}
- dataBytes = append(dataBytes, []byte("event: message\n")...)
- //nolint:lll
- data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
- dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
-
- dataBytes = append(dataBytes, []byte("event: message\n")...)
- //nolint:lll
- data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}`
- dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
-
- dataBytes = append(dataBytes, []byte("event: done\n")...)
- dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
-
- _, err := w.Write(dataBytes)
- if err != nil {
- t.Errorf("Write error: %s", err)
- }
- }))
- defer server.Close()
-
- // Client portion of the test
- config := DefaultConfig(test.GetTestToken())
- config.BaseURL = server.URL + "/v1"
- config.HTTPClient.Transport = &tokenRoundTripper{
- test.GetTestToken(),
- http.DefaultTransport,
- }
-
- client := NewClientWithConfig(config)
- ctx := context.Background()
-
- request := CompletionRequest{
- Prompt: "Ex falso quodlibet",
- Model: "text-davinci-002",
- MaxTokens: 10,
- Stream: true,
- }
-
- stream, err := client.CreateCompletionStream(ctx, request)
- if err != nil {
- t.Errorf("CreateCompletionStream returned error: %v", err)
- }
- defer stream.Close()
-
- expectedResponses := []CompletionResponse{
- {
- ID: "1",
- Object: "completion",
- Created: 1598069254,
- Model: "text-davinci-002",
- Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}},
- },
- {
- ID: "2",
- Object: "completion",
- Created: 1598069255,
- Model: "text-davinci-002",
- Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
- },
- }
-
- for ix, expectedResponse := range expectedResponses {
- receivedResponse, streamErr := stream.Recv()
- if streamErr != nil {
- t.Errorf("stream.Recv() failed: %v", streamErr)
- }
- if !compareResponses(expectedResponse, receivedResponse) {
- t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
- }
- }
-
- _, streamErr := stream.Recv()
- if !errors.Is(streamErr, io.EOF) {
- t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
- }
-
- _, streamErr = stream.Recv()
- if !errors.Is(streamErr, io.EOF) {
- t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
- }
-}
-
-// A "tokenRoundTripper" is a struct that implements the RoundTripper
-// interface, specifically to handle the authentication token by adding a token
-// to the request header. We need this because the API requires that each
-// request include a valid API token in the headers for authentication and
-// authorization.
-type tokenRoundTripper struct {
- token string
- fallback http.RoundTripper
-}
-
-// RoundTrip takes an *http.Request as input and returns an
-// *http.Response and an error.
-//
-// It is expected to use the provided request to create a connection to an HTTP
-// server and return the response, or an error if one occurred. The returned
-// Response should have its Body closed. If the RoundTrip method returns an
-// error, the Client's Get, Head, Post, and PostForm methods return the same
-// error.
-func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
- req.Header.Set("Authorization", "Bearer "+t.token)
- return t.fallback.RoundTrip(req)
-}
-
-// Helper funcs.
-func compareResponses(r1, r2 CompletionResponse) bool {
- if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
- return false
- }
- if len(r1.Choices) != len(r2.Choices) {
- return false
- }
- for i := range r1.Choices {
- if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) {
- return false
- }
- }
- return true
-}
-
-func compareResponseChoices(c1, c2 CompletionChoice) bool {
- if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason {
- return false
- }
- return true
-}