@@ -15,5 +15,5 @@ | |||
# vendor/ | |||
# Auth token for tests | |||
.openai-token | |||
.gogpt-token | |||
.idea |
@@ -1,11 +1,11 @@ | |||
# Go OpenAI | |||
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai) | |||
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) | |||
[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) | |||
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gogpt) | |||
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gogpt)](https://goreportcard.com/report/github.com/sashabaranov/go-gogpt) | |||
[![codecov](https://codecov.io/gh/sashabaranov/go-gogpt/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-gogpt) | |||
> **Note**: the repository was recently renamed from `go-gpt3` to `go-openai` | |||
> **Note**: the repository was recently renamed from `go-gpt3` to `go-gogpt` | |||
This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support: | |||
This library provides Go clients for [OpenAI API](https://platform.gogpt.com/). We support: | |||
* ChatGPT | |||
* GPT-3, GPT-4 | |||
@@ -14,7 +14,7 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/). | |||
Installation: | |||
``` | |||
go get github.com/sashabaranov/go-openai | |||
go get github.com/sashabaranov/go-gogpt | |||
``` | |||
@@ -26,18 +26,18 @@ package main | |||
import ( | |||
"context" | |||
"fmt" | |||
openai "github.com/sashabaranov/go-openai" | |||
gogpt "github.com/sashabaranov/go-gogpt" | |||
) | |||
func main() { | |||
client := openai.NewClient("your token") | |||
client := gogpt.NewClient("your token") | |||
resp, err := client.CreateChatCompletion( | |||
context.Background(), | |||
openai.ChatCompletionRequest{ | |||
Model: openai.GPT3Dot5Turbo, | |||
Messages: []openai.ChatCompletionMessage{ | |||
gogpt.ChatCompletionRequest{ | |||
Model: gogpt.GPT3Dot5Turbo, | |||
Messages: []gogpt.ChatCompletionMessage{ | |||
{ | |||
Role: openai.ChatMessageRoleUser, | |||
Role: gogpt.ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
@@ -67,15 +67,15 @@ package main | |||
import ( | |||
"context" | |||
"fmt" | |||
openai "github.com/sashabaranov/go-openai" | |||
gogpt "github.com/sashabaranov/go-gogpt" | |||
) | |||
func main() { | |||
c := openai.NewClient("your token") | |||
c := gogpt.NewClient("your token") | |||
ctx := context.Background() | |||
req := openai.CompletionRequest{ | |||
Model: openai.GPT3Ada, | |||
req := gogpt.CompletionRequest{ | |||
Model: gogpt.GPT3Ada, | |||
MaxTokens: 5, | |||
Prompt: "Lorem ipsum", | |||
} | |||
@@ -100,15 +100,15 @@ import ( | |||
"context" | |||
"fmt" | |||
"io" | |||
openai "github.com/sashabaranov/go-openai" | |||
gogpt "github.com/sashabaranov/go-gogpt" | |||
) | |||
func main() { | |||
c := openai.NewClient("your token") | |||
c := gogpt.NewClient("your token") | |||
ctx := context.Background() | |||
req := openai.CompletionRequest{ | |||
Model: openai.GPT3Ada, | |||
req := gogpt.CompletionRequest{ | |||
Model: gogpt.GPT3Ada, | |||
MaxTokens: 5, | |||
Prompt: "Lorem ipsum", | |||
Stream: true, | |||
@@ -149,15 +149,15 @@ import ( | |||
"context" | |||
"fmt" | |||
openai "github.com/sashabaranov/go-openai" | |||
gogpt "github.com/sashabaranov/go-gogpt" | |||
) | |||
func main() { | |||
c := openai.NewClient("your token") | |||
c := gogpt.NewClient("your token") | |||
ctx := context.Background() | |||
req := openai.AudioRequest{ | |||
Model: openai.Whisper1, | |||
req := gogpt.AudioRequest{ | |||
Model: gogpt.Whisper1, | |||
FilePath: "recording.mp3", | |||
} | |||
resp, err := c.CreateTranscription(ctx, req) | |||
@@ -181,20 +181,20 @@ import ( | |||
"context" | |||
"encoding/base64" | |||
"fmt" | |||
openai "github.com/sashabaranov/go-openai" | |||
gogpt "github.com/sashabaranov/go-gogpt" | |||
"image/png" | |||
"os" | |||
) | |||
func main() { | |||
c := openai.NewClient("your token") | |||
c := gogpt.NewClient("your token") | |||
ctx := context.Background() | |||
// Sample image by link | |||
reqUrl := openai.ImageRequest{ | |||
reqUrl := gogpt.ImageRequest{ | |||
Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail", | |||
Size: openai.CreateImageSize256x256, | |||
ResponseFormat: openai.CreateImageResponseFormatURL, | |||
Size: gogpt.CreateImageSize256x256, | |||
ResponseFormat: gogpt.CreateImageResponseFormatURL, | |||
N: 1, | |||
} | |||
@@ -206,10 +206,10 @@ func main() { | |||
fmt.Println(respUrl.Data[0].URL) | |||
// Example image as base64 | |||
reqBase64 := openai.ImageRequest{ | |||
reqBase64 := gogpt.ImageRequest{ | |||
Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine", | |||
Size: openai.CreateImageSize256x256, | |||
ResponseFormat: openai.CreateImageResponseFormatB64JSON, | |||
Size: gogpt.CreateImageSize256x256, | |||
ResponseFormat: gogpt.CreateImageResponseFormatB64JSON, | |||
N: 1, | |||
} | |||
@@ -254,7 +254,7 @@ func main() { | |||
<summary>Configuring proxy</summary> | |||
```go | |||
config := openai.DefaultConfig("token") | |||
config := gogpt.DefaultConfig("token") | |||
proxyUrl, err := url.Parse("http://localhost:{port}") | |||
if err != nil { | |||
panic(err) | |||
@@ -266,10 +266,10 @@ config.HTTPClient = &http.Client{ | |||
Transport: transport, | |||
} | |||
c := openai.NewClientWithConfig(config) | |||
c := gogpt.NewClientWithConfig(config) | |||
``` | |||
See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig | |||
See also: https://pkg.go.dev/github.com/sashabaranov/go-gogpt#ClientConfig | |||
</details> | |||
<details> | |||
@@ -285,12 +285,12 @@ import ( | |||
"os" | |||
"strings" | |||
"github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-gogpt" | |||
) | |||
func main() { | |||
client := openai.NewClient("your token") | |||
messages := make([]openai.ChatCompletionMessage, 0) | |||
client := gogpt.NewClient("your token") | |||
messages := make([]gogpt.ChatCompletionMessage, 0) | |||
reader := bufio.NewReader(os.Stdin) | |||
fmt.Println("Conversation") | |||
fmt.Println("---------------------") | |||
@@ -300,15 +300,15 @@ func main() { | |||
text, _ := reader.ReadString('\n') | |||
// convert CRLF to LF | |||
text = strings.Replace(text, "\n", "", -1) | |||
messages = append(messages, openai.ChatCompletionMessage{ | |||
Role: openai.ChatMessageRoleUser, | |||
messages = append(messages, gogpt.ChatCompletionMessage{ | |||
Role: gogpt.ChatMessageRoleUser, | |||
Content: text, | |||
}) | |||
resp, err := client.CreateChatCompletion( | |||
context.Background(), | |||
openai.ChatCompletionRequest{ | |||
Model: openai.GPT3Dot5Turbo, | |||
gogpt.ChatCompletionRequest{ | |||
Model: gogpt.GPT3Dot5Turbo, | |||
Messages: messages, | |||
}, | |||
) | |||
@@ -319,8 +319,8 @@ func main() { | |||
} | |||
content := resp.Choices[0].Message.Content | |||
messages = append(messages, openai.ChatCompletionMessage{ | |||
Role: openai.ChatMessageRoleAssistant, | |||
messages = append(messages, gogpt.ChatCompletionMessage{ | |||
Role: gogpt.ChatMessageRoleAssistant, | |||
Content: content, | |||
}) | |||
fmt.Println(content) | |||
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"context" | |||
@@ -1,182 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"context" | |||
"errors" | |||
"io" | |||
"os" | |||
"testing" | |||
) | |||
func TestAPI(t *testing.T) { | |||
apiToken := os.Getenv("OPENAI_TOKEN") | |||
if apiToken == "" { | |||
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") | |||
} | |||
var err error | |||
c := NewClient(apiToken) | |||
ctx := context.Background() | |||
_, err = c.ListEngines(ctx) | |||
if err != nil { | |||
t.Fatalf("ListEngines error: %v", err) | |||
} | |||
_, err = c.GetEngine(ctx, "davinci") | |||
if err != nil { | |||
t.Fatalf("GetEngine error: %v", err) | |||
} | |||
fileRes, err := c.ListFiles(ctx) | |||
if err != nil { | |||
t.Fatalf("ListFiles error: %v", err) | |||
} | |||
if len(fileRes.Files) > 0 { | |||
_, err = c.GetFile(ctx, fileRes.Files[0].ID) | |||
if err != nil { | |||
t.Fatalf("GetFile error: %v", err) | |||
} | |||
} // else skip | |||
embeddingReq := EmbeddingRequest{ | |||
Input: []string{ | |||
"The food was delicious and the waiter", | |||
"Other examples of embedding request", | |||
}, | |||
Model: AdaSearchQuery, | |||
} | |||
_, err = c.CreateEmbeddings(ctx, embeddingReq) | |||
if err != nil { | |||
t.Fatalf("Embedding error: %v", err) | |||
} | |||
_, err = c.CreateChatCompletion( | |||
ctx, | |||
ChatCompletionRequest{ | |||
Model: GPT3Dot5Turbo, | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
}, | |||
) | |||
if err != nil { | |||
t.Errorf("CreateChatCompletion (without name) returned error: %v", err) | |||
} | |||
_, err = c.CreateChatCompletion( | |||
ctx, | |||
ChatCompletionRequest{ | |||
Model: GPT3Dot5Turbo, | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Name: "John_Doe", | |||
Content: "Hello!", | |||
}, | |||
}, | |||
}, | |||
) | |||
if err != nil { | |||
t.Errorf("CreateChatCompletion (with name) returned error: %v", err) | |||
} | |||
stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ | |||
Prompt: "Ex falso quodlibet", | |||
Model: GPT3Ada, | |||
MaxTokens: 5, | |||
Stream: true, | |||
}) | |||
if err != nil { | |||
t.Errorf("CreateCompletionStream returned error: %v", err) | |||
} | |||
defer stream.Close() | |||
counter := 0 | |||
for { | |||
_, err = stream.Recv() | |||
if err != nil { | |||
if errors.Is(err, io.EOF) { | |||
break | |||
} | |||
t.Errorf("Stream error: %v", err) | |||
} else { | |||
counter++ | |||
} | |||
} | |||
if counter == 0 { | |||
t.Error("Stream did not return any responses") | |||
} | |||
} | |||
func TestAPIError(t *testing.T) { | |||
apiToken := os.Getenv("OPENAI_TOKEN") | |||
if apiToken == "" { | |||
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") | |||
} | |||
var err error | |||
c := NewClient(apiToken + "_invalid") | |||
ctx := context.Background() | |||
_, err = c.ListEngines(ctx) | |||
if err == nil { | |||
t.Fatal("ListEngines did not fail") | |||
} | |||
var apiErr *APIError | |||
if !errors.As(err, &apiErr) { | |||
t.Fatalf("Error is not an APIError: %+v", err) | |||
} | |||
if apiErr.StatusCode != 401 { | |||
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode) | |||
} | |||
if *apiErr.Code != "invalid_api_key" { | |||
t.Fatalf("Unexpected API error code: %s", *apiErr.Code) | |||
} | |||
if apiErr.Error() == "" { | |||
t.Fatal("Empty error message occured") | |||
} | |||
} | |||
func TestRequestError(t *testing.T) { | |||
var err error | |||
config := DefaultConfig("dummy") | |||
config.BaseURL = "https://httpbin.org/status/418?" | |||
c := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
_, err = c.ListEngines(ctx) | |||
if err == nil { | |||
t.Fatal("ListEngines request did not fail") | |||
} | |||
var reqErr *RequestError | |||
if !errors.As(err, &reqErr) { | |||
t.Fatalf("Error is not a RequestError: %+v", err) | |||
} | |||
if reqErr.StatusCode != 418 { | |||
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode) | |||
} | |||
if reqErr.Unwrap() == nil { | |||
t.Fatalf("Empty request error occured") | |||
} | |||
} | |||
// numTokens Returns the number of GPT-3 encoded tokens in the given text. | |||
// This function approximates based on the rule of thumb stated by OpenAI: | |||
// https://beta.openai.com/tokenizer | |||
// | |||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) | |||
func numTokens(s string) int { | |||
return int(float32(len(s)) / 4) | |||
} |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"bytes" | |||
@@ -1,143 +0,0 @@ | |||
package openai_test | |||
import ( | |||
"bytes" | |||
"errors" | |||
"io" | |||
"mime" | |||
"mime/multipart" | |||
"net/http" | |||
"os" | |||
"path/filepath" | |||
"strings" | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"testing" | |||
) | |||
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. | |||
func TestAudio(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) | |||
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
testcases := []struct { | |||
name string | |||
createFn func(context.Context, AudioRequest) (AudioResponse, error) | |||
}{ | |||
{ | |||
"transcribe", | |||
client.CreateTranscription, | |||
}, | |||
{ | |||
"translate", | |||
client.CreateTranslation, | |||
}, | |||
} | |||
ctx := context.Background() | |||
dir, cleanup := createTestDirectory(t) | |||
defer cleanup() | |||
for _, tc := range testcases { | |||
t.Run(tc.name, func(t *testing.T) { | |||
path := filepath.Join(dir, "fake.mp3") | |||
createTestFile(t, path) | |||
req := AudioRequest{ | |||
FilePath: path, | |||
Model: "whisper-3", | |||
} | |||
_, err = tc.createFn(ctx, req) | |||
if err != nil { | |||
t.Fatalf("audio API error: %v", err) | |||
} | |||
}) | |||
} | |||
} | |||
// createTestFile creates a fake file with "hello" as the content. | |||
func createTestFile(t *testing.T, path string) { | |||
file, err := os.Create(path) | |||
if err != nil { | |||
t.Fatalf("failed to create file %v", err) | |||
} | |||
if _, err = file.WriteString("hello"); err != nil { | |||
t.Fatalf("failed to write to file %v", err) | |||
} | |||
file.Close() | |||
} | |||
// createTestDirectory creates a temporary folder which will be deleted when cleanup is called. | |||
func createTestDirectory(t *testing.T) (path string, cleanup func()) { | |||
t.Helper() | |||
path, err := os.MkdirTemp(os.TempDir(), "") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
return path, func() { os.RemoveAll(path) } | |||
} | |||
// handleAudioEndpoint Handles the completion endpoint by the test server. | |||
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
// audio endpoints only accept POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) | |||
if err != nil { | |||
http.Error(w, "failed to parse media type", http.StatusBadRequest) | |||
return | |||
} | |||
if !strings.HasPrefix(mediaType, "multipart") { | |||
http.Error(w, "request is not multipart", http.StatusBadRequest) | |||
} | |||
boundary, ok := params["boundary"] | |||
if !ok { | |||
http.Error(w, "no boundary in params", http.StatusBadRequest) | |||
return | |||
} | |||
fileData := &bytes.Buffer{} | |||
mr := multipart.NewReader(r.Body, boundary) | |||
part, err := mr.NextPart() | |||
if err != nil && errors.Is(err, io.EOF) { | |||
http.Error(w, "error accessing file", http.StatusBadRequest) | |||
return | |||
} | |||
if _, err = io.Copy(fileData, part); err != nil { | |||
http.Error(w, "failed to copy file", http.StatusInternalServerError) | |||
return | |||
} | |||
if len(fileData.Bytes()) == 0 { | |||
w.WriteHeader(http.StatusInternalServerError) | |||
http.Error(w, "received empty file data", http.StatusBadRequest) | |||
return | |||
} | |||
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { | |||
http.Error(w, "failed to write body", http.StatusInternalServerError) | |||
return | |||
} | |||
} |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"context" | |||
@@ -23,8 +23,8 @@ type ChatCompletionMessage struct { | |||
// This property isn't in the official documentation, but it's in | |||
// the documentation for the official library for python: | |||
// - https://github.com/openai/openai-python/blob/main/chatml.md | |||
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||
// - https://github.com/gogpt/gogpt-python/blob/main/chatml.md | |||
// - https://github.com/gogpt/gogpt-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||
Name string `json:"name,omitempty"` | |||
} | |||
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"bufio" | |||
@@ -1,153 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"io" | |||
"net/http" | |||
"net/http/httptest" | |||
"testing" | |||
) | |||
func TestCreateChatCompletionStream(t *testing.T) { | |||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
w.Header().Set("Content-Type", "text/event-stream") | |||
// Send test responses | |||
dataBytes := []byte{} | |||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||
//nolint:lll | |||
data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` | |||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||
//nolint:lll | |||
data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` | |||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||
dataBytes = append(dataBytes, []byte("event: done\n")...) | |||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) | |||
_, err := w.Write(dataBytes) | |||
if err != nil { | |||
t.Errorf("Write error: %s", err) | |||
} | |||
})) | |||
defer server.Close() | |||
// Client portion of the test | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = server.URL + "/v1" | |||
config.HTTPClient.Transport = &tokenRoundTripper{ | |||
test.GetTestToken(), | |||
http.DefaultTransport, | |||
} | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
request := ChatCompletionRequest{ | |||
MaxTokens: 5, | |||
Model: GPT3Dot5Turbo, | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
Stream: true, | |||
} | |||
stream, err := client.CreateChatCompletionStream(ctx, request) | |||
if err != nil { | |||
t.Errorf("CreateCompletionStream returned error: %v", err) | |||
} | |||
defer stream.Close() | |||
expectedResponses := []ChatCompletionStreamResponse{ | |||
{ | |||
ID: "1", | |||
Object: "completion", | |||
Created: 1598069254, | |||
Model: GPT3Dot5Turbo, | |||
Choices: []ChatCompletionStreamChoice{ | |||
{ | |||
Delta: ChatCompletionStreamChoiceDelta{ | |||
Content: "response1", | |||
}, | |||
FinishReason: "max_tokens", | |||
}, | |||
}, | |||
}, | |||
{ | |||
ID: "2", | |||
Object: "completion", | |||
Created: 1598069255, | |||
Model: GPT3Dot5Turbo, | |||
Choices: []ChatCompletionStreamChoice{ | |||
{ | |||
Delta: ChatCompletionStreamChoiceDelta{ | |||
Content: "response2", | |||
}, | |||
FinishReason: "max_tokens", | |||
}, | |||
}, | |||
}, | |||
} | |||
for ix, expectedResponse := range expectedResponses { | |||
b, _ := json.Marshal(expectedResponse) | |||
t.Logf("%d: %s", ix, string(b)) | |||
receivedResponse, streamErr := stream.Recv() | |||
if streamErr != nil { | |||
t.Errorf("stream.Recv() failed: %v", streamErr) | |||
} | |||
if !compareChatResponses(expectedResponse, receivedResponse) { | |||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) | |||
} | |||
} | |||
_, streamErr := stream.Recv() | |||
if !errors.Is(streamErr, io.EOF) { | |||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) | |||
} | |||
_, streamErr = stream.Recv() | |||
if !errors.Is(streamErr, io.EOF) { | |||
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) | |||
} | |||
} | |||
// Helper funcs. | |||
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { | |||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { | |||
return false | |||
} | |||
if len(r1.Choices) != len(r2.Choices) { | |||
return false | |||
} | |||
for i := range r1.Choices { | |||
if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) { | |||
return false | |||
} | |||
} | |||
return true | |||
} | |||
func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { | |||
if c1.Index != c2.Index { | |||
return false | |||
} | |||
if c1.Delta.Content != c2.Delta.Content { | |||
return false | |||
} | |||
if c1.FinishReason != c2.FinishReason { | |||
return false | |||
} | |||
return true | |||
} |
@@ -1,132 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"strconv" | |||
"strings" | |||
"testing" | |||
"time" | |||
) | |||
func TestChatCompletionsWrongModel(t *testing.T) { | |||
config := DefaultConfig("whatever") | |||
config.BaseURL = "http://localhost/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := ChatCompletionRequest{ | |||
MaxTokens: 5, | |||
Model: "ada", | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
} | |||
_, err := client.CreateChatCompletion(ctx, req) | |||
if !errors.Is(err, ErrChatCompletionInvalidModel) { | |||
t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err) | |||
} | |||
} | |||
// TestCompletions Tests the completions endpoint of the API using the mocked server. | |||
func TestChatCompletions(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := ChatCompletionRequest{ | |||
MaxTokens: 5, | |||
Model: GPT3Dot5Turbo, | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
} | |||
_, err = client.CreateChatCompletion(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateChatCompletion error: %v", err) | |||
} | |||
} | |||
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. | |||
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// completions only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var completionReq ChatCompletionRequest | |||
if completionReq, err = getChatCompletionBody(r); err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
res := ChatCompletionResponse{ | |||
ID: strconv.Itoa(int(time.Now().Unix())), | |||
Object: "test-object", | |||
Created: time.Now().Unix(), | |||
// would be nice to validate Model during testing, but | |||
// this may not be possible with how much upkeep | |||
// would be required / wouldn't make much sense | |||
Model: completionReq.Model, | |||
} | |||
// create completions | |||
for i := 0; i < completionReq.N; i++ { | |||
// generate a random string of length completionReq.Length | |||
completionStr := strings.Repeat("a", completionReq.MaxTokens) | |||
res.Choices = append(res.Choices, ChatCompletionChoice{ | |||
Message: ChatCompletionMessage{ | |||
Role: ChatMessageRoleAssistant, | |||
Content: completionStr, | |||
}, | |||
Index: i, | |||
}) | |||
} | |||
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N | |||
completionTokens := completionReq.MaxTokens * completionReq.N | |||
res.Usage = Usage{ | |||
PromptTokens: inputTokens, | |||
CompletionTokens: completionTokens, | |||
TotalTokens: inputTokens + completionTokens, | |||
} | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
// getChatCompletionBody Returns the body of the request to create a completion. | |||
func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { | |||
completion := ChatCompletionRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return ChatCompletionRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &completion) | |||
if err != nil { | |||
return ChatCompletionRequest{}, err | |||
} | |||
return completion, nil | |||
} |
@@ -1,5 +1,5 @@ | |||
// common.go defines common types used throughout the OpenAI API. | |||
package openai | |||
package gogpt | |||
// Usage Represents the total token usage per request to OpenAI. | |||
type Usage struct { | |||
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"context" | |||
@@ -1,121 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"strconv" | |||
"strings" | |||
"testing" | |||
"time" | |||
) | |||
func TestCompletionsWrongModel(t *testing.T) { | |||
config := DefaultConfig("whatever") | |||
config.BaseURL = "http://localhost/v1" | |||
client := NewClientWithConfig(config) | |||
_, err := client.CreateCompletion( | |||
context.Background(), | |||
CompletionRequest{ | |||
MaxTokens: 5, | |||
Model: GPT3Dot5Turbo, | |||
}, | |||
) | |||
if !errors.Is(err, ErrCompletionUnsupportedModel) { | |||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) | |||
} | |||
} | |||
// TestCompletions Tests the completions endpoint of the API using the mocked server. | |||
func TestCompletions(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/completions", handleCompletionEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := CompletionRequest{ | |||
MaxTokens: 5, | |||
Model: "ada", | |||
} | |||
req.Prompt = "Lorem ipsum" | |||
_, err = client.CreateCompletion(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateCompletion error: %v", err) | |||
} | |||
} | |||
// handleCompletionEndpoint Handles the completion endpoint by the test server. | |||
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// completions only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var completionReq CompletionRequest | |||
if completionReq, err = getCompletionBody(r); err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
res := CompletionResponse{ | |||
ID: strconv.Itoa(int(time.Now().Unix())), | |||
Object: "test-object", | |||
Created: time.Now().Unix(), | |||
// would be nice to validate Model during testing, but | |||
// this may not be possible with how much upkeep | |||
// would be required / wouldn't make much sense | |||
Model: completionReq.Model, | |||
} | |||
// create completions | |||
for i := 0; i < completionReq.N; i++ { | |||
// generate a random string of length completionReq.Length | |||
completionStr := strings.Repeat("a", completionReq.MaxTokens) | |||
if completionReq.Echo { | |||
completionStr = completionReq.Prompt + completionStr | |||
} | |||
res.Choices = append(res.Choices, CompletionChoice{ | |||
Text: completionStr, | |||
Index: i, | |||
}) | |||
} | |||
inputTokens := numTokens(completionReq.Prompt) * completionReq.N | |||
completionTokens := completionReq.MaxTokens * completionReq.N | |||
res.Usage = Usage{ | |||
PromptTokens: inputTokens, | |||
CompletionTokens: completionTokens, | |||
TotalTokens: inputTokens + completionTokens, | |||
} | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
// getCompletionBody Returns the body of the request to create a completion. | |||
func getCompletionBody(r *http.Request) (CompletionRequest, error) { | |||
completion := CompletionRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return CompletionRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &completion) | |||
if err != nil { | |||
return CompletionRequest{}, err | |||
} | |||
return completion, nil | |||
} |
@@ -1,11 +1,11 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"net/http" | |||
) | |||
const ( | |||
apiURLv1 = "https://api.openai.com/v1" | |||
apiURLv1 = "https://api.gogpt.com/v1" | |||
defaultEmptyMessagesLimit uint = 300 | |||
) | |||
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"context" | |||
@@ -1,104 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"testing" | |||
"time" | |||
) | |||
// TestEdits Tests the edits endpoint of the API using the mocked server. | |||
func TestEdits(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/edits", handleEditEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
// create an edit request | |||
model := "ada" | |||
editReq := EditsRequest{ | |||
Model: &model, | |||
Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + | |||
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + | |||
" ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + | |||
" ex ea commodo consequat. Duis aute irure dolor in reprehe", | |||
Instruction: "test instruction", | |||
N: 3, | |||
} | |||
response, err := client.Edits(ctx, editReq) | |||
if err != nil { | |||
t.Fatalf("Edits error: %v", err) | |||
} | |||
if len(response.Choices) != editReq.N { | |||
t.Fatalf("edits does not properly return the correct number of choices") | |||
} | |||
} | |||
// handleEditEndpoint Handles the edit endpoint by the test server. | |||
func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// edits only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var editReq EditsRequest | |||
editReq, err = getEditBody(r) | |||
if err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
// create a response | |||
res := EditsResponse{ | |||
Object: "test-object", | |||
Created: time.Now().Unix(), | |||
} | |||
// edit and calculate token usage | |||
editString := "edited by mocked OpenAI server :)" | |||
inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N | |||
completionTokens := int(float32(len(editString))/4) * editReq.N | |||
for i := 0; i < editReq.N; i++ { | |||
// instruction will be hidden and only seen by OpenAI | |||
res.Choices = append(res.Choices, EditsChoice{ | |||
Text: editReq.Input + editString, | |||
Index: i, | |||
}) | |||
} | |||
res.Usage = Usage{ | |||
PromptTokens: inputTokens, | |||
CompletionTokens: completionTokens, | |||
TotalTokens: inputTokens + completionTokens, | |||
} | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprint(w, string(resBytes)) | |||
} | |||
// getEditBody Returns the body of the request to create an edit. | |||
func getEditBody(r *http.Request) (EditsRequest, error) { | |||
edit := EditsRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return EditsRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &edit) | |||
if err != nil { | |||
return EditsRequest{}, err | |||
} | |||
return edit, nil | |||
} |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"context" | |||
@@ -130,7 +130,7 @@ type EmbeddingRequest struct { | |||
} | |||
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. | |||
// https://beta.openai.com/docs/api-reference/embeddings/create | |||
// https://beta.gogpt.com/docs/api-reference/embeddings/create | |||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { | |||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) | |||
if err != nil { | |||
@@ -1,48 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"bytes" | |||
"encoding/json" | |||
"testing" | |||
) | |||
func TestEmbedding(t *testing.T) { | |||
embeddedModels := []EmbeddingModel{ | |||
AdaSimilarity, | |||
BabbageSimilarity, | |||
CurieSimilarity, | |||
DavinciSimilarity, | |||
AdaSearchDocument, | |||
AdaSearchQuery, | |||
BabbageSearchDocument, | |||
BabbageSearchQuery, | |||
CurieSearchDocument, | |||
CurieSearchQuery, | |||
DavinciSearchDocument, | |||
DavinciSearchQuery, | |||
AdaCodeSearchCode, | |||
AdaCodeSearchText, | |||
BabbageCodeSearchCode, | |||
BabbageCodeSearchText, | |||
} | |||
for _, model := range embeddedModels { | |||
embeddingReq := EmbeddingRequest{ | |||
Input: []string{ | |||
"The food was delicious and the waiter", | |||
"Other examples of embedding request", | |||
}, | |||
Model: model, | |||
} | |||
// marshal embeddingReq to JSON and confirm that the model field equals | |||
// the AdaSearchQuery type | |||
marshaled, err := json.Marshal(embeddingReq) | |||
if err != nil { | |||
t.Fatalf("Could not marshal embedding request: %v", err) | |||
} | |||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { | |||
t.Fatalf("Expected embedding request to contain model field") | |||
} | |||
} | |||
} |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"context" | |||
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import "fmt" | |||
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"bytes" | |||
@@ -1,81 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"net/http" | |||
"strconv" | |||
"testing" | |||
"time" | |||
) | |||
func TestFileUpload(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/files", handleCreateFile) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := FileRequest{ | |||
FileName: "test.go", | |||
FilePath: "api.go", | |||
Purpose: "fine-tune", | |||
} | |||
_, err = client.CreateFile(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateFile error: %v", err) | |||
} | |||
} | |||
// handleCreateFile Handles the images endpoint by the test server. | |||
func handleCreateFile(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// edits only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
err = r.ParseMultipartForm(1024 * 1024 * 1024) | |||
if err != nil { | |||
http.Error(w, "file is more than 1GB", http.StatusInternalServerError) | |||
return | |||
} | |||
values := r.Form | |||
var purpose string | |||
for key, value := range values { | |||
if key == "purpose" { | |||
purpose = value[0] | |||
} | |||
} | |||
file, header, err := r.FormFile("file") | |||
if err != nil { | |||
return | |||
} | |||
defer file.Close() | |||
var fileReq = File{ | |||
Bytes: int(header.Size), | |||
ID: strconv.Itoa(int(time.Now().Unix())), | |||
FileName: header.Filename, | |||
Purpose: purpose, | |||
CreatedAt: time.Now().Unix(), | |||
Object: "test-objecct", | |||
Owner: "test-owner", | |||
} | |||
resBytes, _ = json.Marshal(fileReq) | |||
fmt.Fprint(w, string(resBytes)) | |||
} |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"context" | |||
@@ -1,101 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"net/http" | |||
"testing" | |||
) | |||
const testFineTuneID = "fine-tune-id" | |||
// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. | |||
func TestFineTunes(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler( | |||
"/v1/fine-tunes", | |||
func(w http.ResponseWriter, r *http.Request) { | |||
var resBytes []byte | |||
if r.Method == http.MethodGet { | |||
resBytes, _ = json.Marshal(FineTuneList{}) | |||
} else { | |||
resBytes, _ = json.Marshal(FineTune{}) | |||
} | |||
fmt.Fprintln(w, string(resBytes)) | |||
}, | |||
) | |||
server.RegisterHandler( | |||
"/v1/fine-tunes/"+testFineTuneID+"/cancel", | |||
func(w http.ResponseWriter, r *http.Request) { | |||
resBytes, _ := json.Marshal(FineTune{}) | |||
fmt.Fprintln(w, string(resBytes)) | |||
}, | |||
) | |||
server.RegisterHandler( | |||
"/v1/fine-tunes/"+testFineTuneID, | |||
func(w http.ResponseWriter, r *http.Request) { | |||
var resBytes []byte | |||
if r.Method == http.MethodDelete { | |||
resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) | |||
} else { | |||
resBytes, _ = json.Marshal(FineTune{}) | |||
} | |||
fmt.Fprintln(w, string(resBytes)) | |||
}, | |||
) | |||
server.RegisterHandler( | |||
"/v1/fine-tunes/"+testFineTuneID+"/events", | |||
func(w http.ResponseWriter, r *http.Request) { | |||
resBytes, _ := json.Marshal(FineTuneEventList{}) | |||
fmt.Fprintln(w, string(resBytes)) | |||
}, | |||
) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
_, err = client.ListFineTunes(ctx) | |||
if err != nil { | |||
t.Fatalf("ListFineTunes error: %v", err) | |||
} | |||
_, err = client.CreateFineTune(ctx, FineTuneRequest{}) | |||
if err != nil { | |||
t.Fatalf("CreateFineTune error: %v", err) | |||
} | |||
_, err = client.CancelFineTune(ctx, testFineTuneID) | |||
if err != nil { | |||
t.Fatalf("CancelFineTune error: %v", err) | |||
} | |||
_, err = client.GetFineTune(ctx, testFineTuneID) | |||
if err != nil { | |||
t.Fatalf("GetFineTune error: %v", err) | |||
} | |||
_, err = client.DeleteFineTune(ctx, testFineTuneID) | |||
if err != nil { | |||
t.Fatalf("DeleteFineTune error: %v", err) | |||
} | |||
_, err = client.ListFineTuneEvents(ctx, testFineTuneID) | |||
if err != nil { | |||
t.Fatalf("ListFineTuneEvents error: %v", err) | |||
} | |||
} |
@@ -1,3 +1,3 @@ | |||
module github.com/sashabaranov/go-openai | |||
module github.com/sashabaranov/go-gogpt | |||
go 1.18 |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"bytes" | |||
@@ -147,7 +147,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) | |||
return | |||
} | |||
writer.Close() | |||
//https://platform.openai.com/docs/api-reference/images/create-variation | |||
//https://platform.gogpt.com/docs/api-reference/images/create-variation | |||
urlSuffix := "/images/variations" | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) | |||
if err != nil { | |||
@@ -1,268 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"os" | |||
"testing" | |||
"time" | |||
) | |||
func TestImages(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/images/generations", handleImageEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := ImageRequest{} | |||
req.Prompt = "Lorem ipsum" | |||
_, err = client.CreateImage(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateImage error: %v", err) | |||
} | |||
} | |||
// handleImageEndpoint Handles the images endpoint by the test server. | |||
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// imagess only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var imageReq ImageRequest | |||
if imageReq, err = getImageBody(r); err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
res := ImageResponse{ | |||
Created: time.Now().Unix(), | |||
} | |||
for i := 0; i < imageReq.N; i++ { | |||
imageData := ImageResponseDataInner{} | |||
switch imageReq.ResponseFormat { | |||
case CreateImageResponseFormatURL, "": | |||
imageData.URL = "https://example.com/image.png" | |||
case CreateImageResponseFormatB64JSON: | |||
// This decodes to "{}" in base64. | |||
imageData.B64JSON = "e30K" | |||
default: | |||
http.Error(w, "invalid response format", http.StatusBadRequest) | |||
return | |||
} | |||
res.Data = append(res.Data, imageData) | |||
} | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
// getImageBody Returns the body of the request to create a image. | |||
func getImageBody(r *http.Request) (ImageRequest, error) { | |||
image := ImageRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return ImageRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &image) | |||
if err != nil { | |||
return ImageRequest{}, err | |||
} | |||
return image, nil | |||
} | |||
func TestImageEdit(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
origin, err := os.Create("image.png") | |||
if err != nil { | |||
t.Error("open origin file error") | |||
return | |||
} | |||
mask, err := os.Create("mask.png") | |||
if err != nil { | |||
t.Error("open mask file error") | |||
return | |||
} | |||
defer func() { | |||
mask.Close() | |||
origin.Close() | |||
os.Remove("mask.png") | |||
os.Remove("image.png") | |||
}() | |||
req := ImageEditRequest{ | |||
Image: origin, | |||
Mask: mask, | |||
Prompt: "There is a turtle in the pool", | |||
N: 3, | |||
Size: CreateImageSize1024x1024, | |||
} | |||
_, err = client.CreateEditImage(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateImage error: %v", err) | |||
} | |||
} | |||
func TestImageEditWithoutMask(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
origin, err := os.Create("image.png") | |||
if err != nil { | |||
t.Error("open origin file error") | |||
return | |||
} | |||
defer func() { | |||
origin.Close() | |||
os.Remove("image.png") | |||
}() | |||
req := ImageEditRequest{ | |||
Image: origin, | |||
Prompt: "There is a turtle in the pool", | |||
N: 3, | |||
Size: CreateImageSize1024x1024, | |||
} | |||
_, err = client.CreateEditImage(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateImage error: %v", err) | |||
} | |||
} | |||
// handleEditImageEndpoint Handles the images endpoint by the test server. | |||
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var resBytes []byte | |||
// imagess only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
responses := ImageResponse{ | |||
Created: time.Now().Unix(), | |||
Data: []ImageResponseDataInner{ | |||
{ | |||
URL: "test-url1", | |||
B64JSON: "", | |||
}, | |||
{ | |||
URL: "test-url2", | |||
B64JSON: "", | |||
}, | |||
{ | |||
URL: "test-url3", | |||
B64JSON: "", | |||
}, | |||
}, | |||
} | |||
resBytes, _ = json.Marshal(responses) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
func TestImageVariation(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
origin, err := os.Create("image.png") | |||
if err != nil { | |||
t.Error("open origin file error") | |||
return | |||
} | |||
defer func() { | |||
origin.Close() | |||
os.Remove("image.png") | |||
}() | |||
req := ImageVariRequest{ | |||
Image: origin, | |||
N: 3, | |||
Size: CreateImageSize1024x1024, | |||
} | |||
_, err = client.CreateVariImage(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateImage error: %v", err) | |||
} | |||
} | |||
// handleVariateImageEndpoint Handles the images endpoint by the test server. | |||
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var resBytes []byte | |||
// imagess only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
responses := ImageResponse{ | |||
Created: time.Now().Unix(), | |||
Data: []ImageResponseDataInner{ | |||
{ | |||
URL: "test-url1", | |||
B64JSON: "", | |||
}, | |||
{ | |||
URL: "test-url2", | |||
B64JSON: "", | |||
}, | |||
{ | |||
URL: "test-url3", | |||
B64JSON: "", | |||
}, | |||
}, | |||
} | |||
resBytes, _ = json.Marshal(responses) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"encoding/json" | |||
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"context" | |||
@@ -1,39 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"net/http" | |||
"testing" | |||
) | |||
// TestListModels Tests the models endpoint of the API using the mocked server. | |||
func TestListModels(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/models", handleModelsEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
_, err = client.ListModels(ctx) | |||
if err != nil { | |||
t.Fatalf("ListModels error: %v", err) | |||
} | |||
} | |||
// handleModelsEndpoint Handles the models endpoint by the test server. | |||
func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) { | |||
resBytes, _ := json.Marshal(ModelsList{}) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"context" | |||
@@ -1,102 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"strconv" | |||
"strings" | |||
"testing" | |||
"time" | |||
) | |||
// TestModeration Tests the moderations endpoint of the API using the mocked server. | |||
func TestModerations(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/moderations", handleModerationEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
// create an edit request | |||
model := "text-moderation-stable" | |||
moderationReq := ModerationRequest{ | |||
Model: &model, | |||
Input: "I want to kill them.", | |||
} | |||
_, err = client.Moderations(ctx, moderationReq) | |||
if err != nil { | |||
t.Fatalf("Moderation error: %v", err) | |||
} | |||
} | |||
// handleModerationEndpoint Handles the moderation endpoint by the test server. | |||
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// completions only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var moderationReq ModerationRequest | |||
if moderationReq, err = getModerationBody(r); err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
resCat := ResultCategories{} | |||
resCatScore := ResultCategoryScores{} | |||
switch { | |||
case strings.Contains(moderationReq.Input, "kill"): | |||
resCat = ResultCategories{Violence: true} | |||
resCatScore = ResultCategoryScores{Violence: 1} | |||
case strings.Contains(moderationReq.Input, "hate"): | |||
resCat = ResultCategories{Hate: true} | |||
resCatScore = ResultCategoryScores{Hate: 1} | |||
case strings.Contains(moderationReq.Input, "suicide"): | |||
resCat = ResultCategories{SelfHarm: true} | |||
resCatScore = ResultCategoryScores{SelfHarm: 1} | |||
case strings.Contains(moderationReq.Input, "porn"): | |||
resCat = ResultCategories{Sexual: true} | |||
resCatScore = ResultCategoryScores{Sexual: 1} | |||
} | |||
result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} | |||
res := ModerationResponse{ | |||
ID: strconv.Itoa(int(time.Now().Unix())), | |||
Model: *moderationReq.Model, | |||
} | |||
res.Results = append(res.Results, result) | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
// getModerationBody Returns the body of the request to do a moderation. | |||
func getModerationBody(r *http.Request) (ModerationRequest, error) { | |||
moderation := ModerationRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return ModerationRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &moderation) | |||
if err != nil { | |||
return ModerationRequest{}, err | |||
} | |||
return moderation, nil | |||
} |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"bytes" | |||
@@ -1,148 +0,0 @@ | |||
package openai //nolint:testpackage // testing private field | |||
import ( | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"errors" | |||
"net/http" | |||
"testing" | |||
) | |||
var ( | |||
errTestMarshallerFailed = errors.New("test marshaller failed") | |||
errTestRequestBuilderFailed = errors.New("test request builder failed") | |||
) | |||
type ( | |||
failingRequestBuilder struct{} | |||
failingMarshaller struct{} | |||
) | |||
func (*failingMarshaller) marshal(value any) ([]byte, error) { | |||
return []byte{}, errTestMarshallerFailed | |||
} | |||
func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) { | |||
return nil, errTestRequestBuilderFailed | |||
} | |||
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { | |||
builder := httpRequestBuilder{ | |||
marshaller: &failingMarshaller{}, | |||
} | |||
_, err := builder.build(context.Background(), "", "", struct{}{}) | |||
if !errors.Is(err, errTestMarshallerFailed) { | |||
t.Fatalf("Did not return error when marshaller failed: %v", err) | |||
} | |||
} | |||
func TestClientReturnsRequestBuilderErrors(t *testing.T) { | |||
var err error | |||
ts := test.NewTestServer().OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
client.requestBuilder = &failingRequestBuilder{} | |||
ctx := context.Background() | |||
_, err = client.CreateCompletion(ctx, CompletionRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateFineTune(ctx, FineTuneRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListFineTunes(ctx) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CancelFineTune(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.GetFineTune(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.DeleteFineTune(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListFineTuneEvents(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.Moderations(ctx, ModerationRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.Edits(ctx, EditsRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateImage(ctx, ImageRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
err = client.DeleteFile(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.GetFile(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListFiles(ctx) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListEngines(ctx) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.GetEngine(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListModels(ctx) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
} |
@@ -1,4 +1,4 @@ | |||
package openai | |||
package gogpt | |||
import ( | |||
"bufio" | |||
@@ -1,147 +0,0 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"errors" | |||
"io" | |||
"net/http" | |||
"net/http/httptest" | |||
"testing" | |||
) | |||
func TestCreateCompletionStream(t *testing.T) { | |||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
w.Header().Set("Content-Type", "text/event-stream") | |||
// Send test responses | |||
dataBytes := []byte{} | |||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||
//nolint:lll | |||
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` | |||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||
//nolint:lll | |||
data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}` | |||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||
dataBytes = append(dataBytes, []byte("event: done\n")...) | |||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) | |||
_, err := w.Write(dataBytes) | |||
if err != nil { | |||
t.Errorf("Write error: %s", err) | |||
} | |||
})) | |||
defer server.Close() | |||
// Client portion of the test | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = server.URL + "/v1" | |||
config.HTTPClient.Transport = &tokenRoundTripper{ | |||
test.GetTestToken(), | |||
http.DefaultTransport, | |||
} | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
request := CompletionRequest{ | |||
Prompt: "Ex falso quodlibet", | |||
Model: "text-davinci-002", | |||
MaxTokens: 10, | |||
Stream: true, | |||
} | |||
stream, err := client.CreateCompletionStream(ctx, request) | |||
if err != nil { | |||
t.Errorf("CreateCompletionStream returned error: %v", err) | |||
} | |||
defer stream.Close() | |||
expectedResponses := []CompletionResponse{ | |||
{ | |||
ID: "1", | |||
Object: "completion", | |||
Created: 1598069254, | |||
Model: "text-davinci-002", | |||
Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, | |||
}, | |||
{ | |||
ID: "2", | |||
Object: "completion", | |||
Created: 1598069255, | |||
Model: "text-davinci-002", | |||
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, | |||
}, | |||
} | |||
for ix, expectedResponse := range expectedResponses { | |||
receivedResponse, streamErr := stream.Recv() | |||
if streamErr != nil { | |||
t.Errorf("stream.Recv() failed: %v", streamErr) | |||
} | |||
if !compareResponses(expectedResponse, receivedResponse) { | |||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) | |||
} | |||
} | |||
_, streamErr := stream.Recv() | |||
if !errors.Is(streamErr, io.EOF) { | |||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) | |||
} | |||
_, streamErr = stream.Recv() | |||
if !errors.Is(streamErr, io.EOF) { | |||
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) | |||
} | |||
} | |||
// A "tokenRoundTripper" is a struct that implements the RoundTripper | |||
// interface, specifically to handle the authentication token by adding a token | |||
// to the request header. We need this because the API requires that each | |||
// request include a valid API token in the headers for authentication and | |||
// authorization. | |||
type tokenRoundTripper struct { | |||
token string | |||
fallback http.RoundTripper | |||
} | |||
// RoundTrip takes an *http.Request as input and returns an | |||
// *http.Response and an error. | |||
// | |||
// It is expected to use the provided request to create a connection to an HTTP | |||
// server and return the response, or an error if one occurred. The returned | |||
// Response should have its Body closed. If the RoundTrip method returns an | |||
// error, the Client's Get, Head, Post, and PostForm methods return the same | |||
// error. | |||
func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | |||
req.Header.Set("Authorization", "Bearer "+t.token) | |||
return t.fallback.RoundTrip(req) | |||
} | |||
// Helper funcs. | |||
func compareResponses(r1, r2 CompletionResponse) bool { | |||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { | |||
return false | |||
} | |||
if len(r1.Choices) != len(r2.Choices) { | |||
return false | |||
} | |||
for i := range r1.Choices { | |||
if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) { | |||
return false | |||
} | |||
} | |||
return true | |||
} | |||
func compareResponseChoices(c1, c2 CompletionChoice) bool { | |||
if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { | |||
return false | |||
} | |||
return true | |||
} |