@@ -15,5 +15,5 @@ | |||||
# vendor/ | # vendor/ | ||||
# Auth token for tests | # Auth token for tests | ||||
.openai-token | |||||
.gogpt-token | |||||
.idea | .idea |
@@ -1,11 +1,11 @@ | |||||
# Go OpenAI | # Go OpenAI | ||||
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai) | |||||
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) | |||||
[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) | |||||
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gogpt) | |||||
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gogpt)](https://goreportcard.com/report/github.com/sashabaranov/go-gogpt) | |||||
[![codecov](https://codecov.io/gh/sashabaranov/go-gogpt/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-gogpt) | |||||
> **Note**: the repository was recently renamed from `go-gpt3` to `go-openai` | |||||
> **Note**: the repository was recently renamed from `go-gpt3` to `go-gogpt` | |||||
This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support: | |||||
This library provides Go clients for [OpenAI API](https://platform.gogpt.com/). We support: | |||||
* ChatGPT | * ChatGPT | ||||
* GPT-3, GPT-4 | * GPT-3, GPT-4 | ||||
@@ -14,7 +14,7 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/). | |||||
Installation: | Installation: | ||||
``` | ``` | ||||
go get github.com/sashabaranov/go-openai | |||||
go get github.com/sashabaranov/go-gogpt | |||||
``` | ``` | ||||
@@ -26,18 +26,18 @@ package main | |||||
import ( | import ( | ||||
"context" | "context" | ||||
"fmt" | "fmt" | ||||
openai "github.com/sashabaranov/go-openai" | |||||
gogpt "github.com/sashabaranov/go-gogpt" | |||||
) | ) | ||||
func main() { | func main() { | ||||
client := openai.NewClient("your token") | |||||
client := gogpt.NewClient("your token") | |||||
resp, err := client.CreateChatCompletion( | resp, err := client.CreateChatCompletion( | ||||
context.Background(), | context.Background(), | ||||
openai.ChatCompletionRequest{ | |||||
Model: openai.GPT3Dot5Turbo, | |||||
Messages: []openai.ChatCompletionMessage{ | |||||
gogpt.ChatCompletionRequest{ | |||||
Model: gogpt.GPT3Dot5Turbo, | |||||
Messages: []gogpt.ChatCompletionMessage{ | |||||
{ | { | ||||
Role: openai.ChatMessageRoleUser, | |||||
Role: gogpt.ChatMessageRoleUser, | |||||
Content: "Hello!", | Content: "Hello!", | ||||
}, | }, | ||||
}, | }, | ||||
@@ -67,15 +67,15 @@ package main | |||||
import ( | import ( | ||||
"context" | "context" | ||||
"fmt" | "fmt" | ||||
openai "github.com/sashabaranov/go-openai" | |||||
gogpt "github.com/sashabaranov/go-gogpt" | |||||
) | ) | ||||
func main() { | func main() { | ||||
c := openai.NewClient("your token") | |||||
c := gogpt.NewClient("your token") | |||||
ctx := context.Background() | ctx := context.Background() | ||||
req := openai.CompletionRequest{ | |||||
Model: openai.GPT3Ada, | |||||
req := gogpt.CompletionRequest{ | |||||
Model: gogpt.GPT3Ada, | |||||
MaxTokens: 5, | MaxTokens: 5, | ||||
Prompt: "Lorem ipsum", | Prompt: "Lorem ipsum", | ||||
} | } | ||||
@@ -100,15 +100,15 @@ import ( | |||||
"context" | "context" | ||||
"fmt" | "fmt" | ||||
"io" | "io" | ||||
openai "github.com/sashabaranov/go-openai" | |||||
gogpt "github.com/sashabaranov/go-gogpt" | |||||
) | ) | ||||
func main() { | func main() { | ||||
c := openai.NewClient("your token") | |||||
c := gogpt.NewClient("your token") | |||||
ctx := context.Background() | ctx := context.Background() | ||||
req := openai.CompletionRequest{ | |||||
Model: openai.GPT3Ada, | |||||
req := gogpt.CompletionRequest{ | |||||
Model: gogpt.GPT3Ada, | |||||
MaxTokens: 5, | MaxTokens: 5, | ||||
Prompt: "Lorem ipsum", | Prompt: "Lorem ipsum", | ||||
Stream: true, | Stream: true, | ||||
@@ -149,15 +149,15 @@ import ( | |||||
"context" | "context" | ||||
"fmt" | "fmt" | ||||
openai "github.com/sashabaranov/go-openai" | |||||
gogpt "github.com/sashabaranov/go-gogpt" | |||||
) | ) | ||||
func main() { | func main() { | ||||
c := openai.NewClient("your token") | |||||
c := gogpt.NewClient("your token") | |||||
ctx := context.Background() | ctx := context.Background() | ||||
req := openai.AudioRequest{ | |||||
Model: openai.Whisper1, | |||||
req := gogpt.AudioRequest{ | |||||
Model: gogpt.Whisper1, | |||||
FilePath: "recording.mp3", | FilePath: "recording.mp3", | ||||
} | } | ||||
resp, err := c.CreateTranscription(ctx, req) | resp, err := c.CreateTranscription(ctx, req) | ||||
@@ -181,20 +181,20 @@ import ( | |||||
"context" | "context" | ||||
"encoding/base64" | "encoding/base64" | ||||
"fmt" | "fmt" | ||||
openai "github.com/sashabaranov/go-openai" | |||||
gogpt "github.com/sashabaranov/go-gogpt" | |||||
"image/png" | "image/png" | ||||
"os" | "os" | ||||
) | ) | ||||
func main() { | func main() { | ||||
c := openai.NewClient("your token") | |||||
c := gogpt.NewClient("your token") | |||||
ctx := context.Background() | ctx := context.Background() | ||||
// Sample image by link | // Sample image by link | ||||
reqUrl := openai.ImageRequest{ | |||||
reqUrl := gogpt.ImageRequest{ | |||||
Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail", | Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail", | ||||
Size: openai.CreateImageSize256x256, | |||||
ResponseFormat: openai.CreateImageResponseFormatURL, | |||||
Size: gogpt.CreateImageSize256x256, | |||||
ResponseFormat: gogpt.CreateImageResponseFormatURL, | |||||
N: 1, | N: 1, | ||||
} | } | ||||
@@ -206,10 +206,10 @@ func main() { | |||||
fmt.Println(respUrl.Data[0].URL) | fmt.Println(respUrl.Data[0].URL) | ||||
// Example image as base64 | // Example image as base64 | ||||
reqBase64 := openai.ImageRequest{ | |||||
reqBase64 := gogpt.ImageRequest{ | |||||
Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine", | Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine", | ||||
Size: openai.CreateImageSize256x256, | |||||
ResponseFormat: openai.CreateImageResponseFormatB64JSON, | |||||
Size: gogpt.CreateImageSize256x256, | |||||
ResponseFormat: gogpt.CreateImageResponseFormatB64JSON, | |||||
N: 1, | N: 1, | ||||
} | } | ||||
@@ -254,7 +254,7 @@ func main() { | |||||
<summary>Configuring proxy</summary> | <summary>Configuring proxy</summary> | ||||
```go | ```go | ||||
config := openai.DefaultConfig("token") | |||||
config := gogpt.DefaultConfig("token") | |||||
proxyUrl, err := url.Parse("http://localhost:{port}") | proxyUrl, err := url.Parse("http://localhost:{port}") | ||||
if err != nil { | if err != nil { | ||||
panic(err) | panic(err) | ||||
@@ -266,10 +266,10 @@ config.HTTPClient = &http.Client{ | |||||
Transport: transport, | Transport: transport, | ||||
} | } | ||||
c := openai.NewClientWithConfig(config) | |||||
c := gogpt.NewClientWithConfig(config) | |||||
``` | ``` | ||||
See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig | |||||
See also: https://pkg.go.dev/github.com/sashabaranov/go-gogpt#ClientConfig | |||||
</details> | </details> | ||||
<details> | <details> | ||||
@@ -285,12 +285,12 @@ import ( | |||||
"os" | "os" | ||||
"strings" | "strings" | ||||
"github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-gogpt" | |||||
) | ) | ||||
func main() { | func main() { | ||||
client := openai.NewClient("your token") | |||||
messages := make([]openai.ChatCompletionMessage, 0) | |||||
client := gogpt.NewClient("your token") | |||||
messages := make([]gogpt.ChatCompletionMessage, 0) | |||||
reader := bufio.NewReader(os.Stdin) | reader := bufio.NewReader(os.Stdin) | ||||
fmt.Println("Conversation") | fmt.Println("Conversation") | ||||
fmt.Println("---------------------") | fmt.Println("---------------------") | ||||
@@ -300,15 +300,15 @@ func main() { | |||||
text, _ := reader.ReadString('\n') | text, _ := reader.ReadString('\n') | ||||
// convert CRLF to LF | // convert CRLF to LF | ||||
text = strings.Replace(text, "\n", "", -1) | text = strings.Replace(text, "\n", "", -1) | ||||
messages = append(messages, openai.ChatCompletionMessage{ | |||||
Role: openai.ChatMessageRoleUser, | |||||
messages = append(messages, gogpt.ChatCompletionMessage{ | |||||
Role: gogpt.ChatMessageRoleUser, | |||||
Content: text, | Content: text, | ||||
}) | }) | ||||
resp, err := client.CreateChatCompletion( | resp, err := client.CreateChatCompletion( | ||||
context.Background(), | context.Background(), | ||||
openai.ChatCompletionRequest{ | |||||
Model: openai.GPT3Dot5Turbo, | |||||
gogpt.ChatCompletionRequest{ | |||||
Model: gogpt.GPT3Dot5Turbo, | |||||
Messages: messages, | Messages: messages, | ||||
}, | }, | ||||
) | ) | ||||
@@ -319,8 +319,8 @@ func main() { | |||||
} | } | ||||
content := resp.Choices[0].Message.Content | content := resp.Choices[0].Message.Content | ||||
messages = append(messages, openai.ChatCompletionMessage{ | |||||
Role: openai.ChatMessageRoleAssistant, | |||||
messages = append(messages, gogpt.ChatCompletionMessage{ | |||||
Role: gogpt.ChatMessageRoleAssistant, | |||||
Content: content, | Content: content, | ||||
}) | }) | ||||
fmt.Println(content) | fmt.Println(content) | ||||
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -1,182 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"context" | |||||
"errors" | |||||
"io" | |||||
"os" | |||||
"testing" | |||||
) | |||||
func TestAPI(t *testing.T) { | |||||
apiToken := os.Getenv("OPENAI_TOKEN") | |||||
if apiToken == "" { | |||||
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") | |||||
} | |||||
var err error | |||||
c := NewClient(apiToken) | |||||
ctx := context.Background() | |||||
_, err = c.ListEngines(ctx) | |||||
if err != nil { | |||||
t.Fatalf("ListEngines error: %v", err) | |||||
} | |||||
_, err = c.GetEngine(ctx, "davinci") | |||||
if err != nil { | |||||
t.Fatalf("GetEngine error: %v", err) | |||||
} | |||||
fileRes, err := c.ListFiles(ctx) | |||||
if err != nil { | |||||
t.Fatalf("ListFiles error: %v", err) | |||||
} | |||||
if len(fileRes.Files) > 0 { | |||||
_, err = c.GetFile(ctx, fileRes.Files[0].ID) | |||||
if err != nil { | |||||
t.Fatalf("GetFile error: %v", err) | |||||
} | |||||
} // else skip | |||||
embeddingReq := EmbeddingRequest{ | |||||
Input: []string{ | |||||
"The food was delicious and the waiter", | |||||
"Other examples of embedding request", | |||||
}, | |||||
Model: AdaSearchQuery, | |||||
} | |||||
_, err = c.CreateEmbeddings(ctx, embeddingReq) | |||||
if err != nil { | |||||
t.Fatalf("Embedding error: %v", err) | |||||
} | |||||
_, err = c.CreateChatCompletion( | |||||
ctx, | |||||
ChatCompletionRequest{ | |||||
Model: GPT3Dot5Turbo, | |||||
Messages: []ChatCompletionMessage{ | |||||
{ | |||||
Role: ChatMessageRoleUser, | |||||
Content: "Hello!", | |||||
}, | |||||
}, | |||||
}, | |||||
) | |||||
if err != nil { | |||||
t.Errorf("CreateChatCompletion (without name) returned error: %v", err) | |||||
} | |||||
_, err = c.CreateChatCompletion( | |||||
ctx, | |||||
ChatCompletionRequest{ | |||||
Model: GPT3Dot5Turbo, | |||||
Messages: []ChatCompletionMessage{ | |||||
{ | |||||
Role: ChatMessageRoleUser, | |||||
Name: "John_Doe", | |||||
Content: "Hello!", | |||||
}, | |||||
}, | |||||
}, | |||||
) | |||||
if err != nil { | |||||
t.Errorf("CreateChatCompletion (with name) returned error: %v", err) | |||||
} | |||||
stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ | |||||
Prompt: "Ex falso quodlibet", | |||||
Model: GPT3Ada, | |||||
MaxTokens: 5, | |||||
Stream: true, | |||||
}) | |||||
if err != nil { | |||||
t.Errorf("CreateCompletionStream returned error: %v", err) | |||||
} | |||||
defer stream.Close() | |||||
counter := 0 | |||||
for { | |||||
_, err = stream.Recv() | |||||
if err != nil { | |||||
if errors.Is(err, io.EOF) { | |||||
break | |||||
} | |||||
t.Errorf("Stream error: %v", err) | |||||
} else { | |||||
counter++ | |||||
} | |||||
} | |||||
if counter == 0 { | |||||
t.Error("Stream did not return any responses") | |||||
} | |||||
} | |||||
func TestAPIError(t *testing.T) { | |||||
apiToken := os.Getenv("OPENAI_TOKEN") | |||||
if apiToken == "" { | |||||
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") | |||||
} | |||||
var err error | |||||
c := NewClient(apiToken + "_invalid") | |||||
ctx := context.Background() | |||||
_, err = c.ListEngines(ctx) | |||||
if err == nil { | |||||
t.Fatal("ListEngines did not fail") | |||||
} | |||||
var apiErr *APIError | |||||
if !errors.As(err, &apiErr) { | |||||
t.Fatalf("Error is not an APIError: %+v", err) | |||||
} | |||||
if apiErr.StatusCode != 401 { | |||||
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode) | |||||
} | |||||
if *apiErr.Code != "invalid_api_key" { | |||||
t.Fatalf("Unexpected API error code: %s", *apiErr.Code) | |||||
} | |||||
if apiErr.Error() == "" { | |||||
t.Fatal("Empty error message occured") | |||||
} | |||||
} | |||||
func TestRequestError(t *testing.T) { | |||||
var err error | |||||
config := DefaultConfig("dummy") | |||||
config.BaseURL = "https://httpbin.org/status/418?" | |||||
c := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
_, err = c.ListEngines(ctx) | |||||
if err == nil { | |||||
t.Fatal("ListEngines request did not fail") | |||||
} | |||||
var reqErr *RequestError | |||||
if !errors.As(err, &reqErr) { | |||||
t.Fatalf("Error is not a RequestError: %+v", err) | |||||
} | |||||
if reqErr.StatusCode != 418 { | |||||
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode) | |||||
} | |||||
if reqErr.Unwrap() == nil { | |||||
t.Fatalf("Empty request error occured") | |||||
} | |||||
} | |||||
// numTokens Returns the number of GPT-3 encoded tokens in the given text. | |||||
// This function approximates based on the rule of thumb stated by OpenAI: | |||||
// https://beta.openai.com/tokenizer | |||||
// | |||||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) | |||||
func numTokens(s string) int { | |||||
return int(float32(len(s)) / 4) | |||||
} |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"bytes" | "bytes" | ||||
@@ -1,143 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
"bytes" | |||||
"errors" | |||||
"io" | |||||
"mime" | |||||
"mime/multipart" | |||||
"net/http" | |||||
"os" | |||||
"path/filepath" | |||||
"strings" | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"testing" | |||||
) | |||||
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. | |||||
func TestAudio(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) | |||||
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
testcases := []struct { | |||||
name string | |||||
createFn func(context.Context, AudioRequest) (AudioResponse, error) | |||||
}{ | |||||
{ | |||||
"transcribe", | |||||
client.CreateTranscription, | |||||
}, | |||||
{ | |||||
"translate", | |||||
client.CreateTranslation, | |||||
}, | |||||
} | |||||
ctx := context.Background() | |||||
dir, cleanup := createTestDirectory(t) | |||||
defer cleanup() | |||||
for _, tc := range testcases { | |||||
t.Run(tc.name, func(t *testing.T) { | |||||
path := filepath.Join(dir, "fake.mp3") | |||||
createTestFile(t, path) | |||||
req := AudioRequest{ | |||||
FilePath: path, | |||||
Model: "whisper-3", | |||||
} | |||||
_, err = tc.createFn(ctx, req) | |||||
if err != nil { | |||||
t.Fatalf("audio API error: %v", err) | |||||
} | |||||
}) | |||||
} | |||||
} | |||||
// createTestFile creates a fake file with "hello" as the content. | |||||
func createTestFile(t *testing.T, path string) { | |||||
file, err := os.Create(path) | |||||
if err != nil { | |||||
t.Fatalf("failed to create file %v", err) | |||||
} | |||||
if _, err = file.WriteString("hello"); err != nil { | |||||
t.Fatalf("failed to write to file %v", err) | |||||
} | |||||
file.Close() | |||||
} | |||||
// createTestDirectory creates a temporary folder which will be deleted when cleanup is called. | |||||
func createTestDirectory(t *testing.T) (path string, cleanup func()) { | |||||
t.Helper() | |||||
path, err := os.MkdirTemp(os.TempDir(), "") | |||||
if err != nil { | |||||
t.Fatal(err) | |||||
} | |||||
return path, func() { os.RemoveAll(path) } | |||||
} | |||||
// handleAudioEndpoint Handles the completion endpoint by the test server. | |||||
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { | |||||
var err error | |||||
// audio endpoints only accept POST requests | |||||
if r.Method != "POST" { | |||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) | |||||
} | |||||
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) | |||||
if err != nil { | |||||
http.Error(w, "failed to parse media type", http.StatusBadRequest) | |||||
return | |||||
} | |||||
if !strings.HasPrefix(mediaType, "multipart") { | |||||
http.Error(w, "request is not multipart", http.StatusBadRequest) | |||||
} | |||||
boundary, ok := params["boundary"] | |||||
if !ok { | |||||
http.Error(w, "no boundary in params", http.StatusBadRequest) | |||||
return | |||||
} | |||||
fileData := &bytes.Buffer{} | |||||
mr := multipart.NewReader(r.Body, boundary) | |||||
part, err := mr.NextPart() | |||||
if err != nil && errors.Is(err, io.EOF) { | |||||
http.Error(w, "error accessing file", http.StatusBadRequest) | |||||
return | |||||
} | |||||
if _, err = io.Copy(fileData, part); err != nil { | |||||
http.Error(w, "failed to copy file", http.StatusInternalServerError) | |||||
return | |||||
} | |||||
if len(fileData.Bytes()) == 0 { | |||||
w.WriteHeader(http.StatusInternalServerError) | |||||
http.Error(w, "received empty file data", http.StatusBadRequest) | |||||
return | |||||
} | |||||
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { | |||||
http.Error(w, "failed to write body", http.StatusInternalServerError) | |||||
return | |||||
} | |||||
} |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -23,8 +23,8 @@ type ChatCompletionMessage struct { | |||||
// This property isn't in the official documentation, but it's in | // This property isn't in the official documentation, but it's in | ||||
// the documentation for the official library for python: | // the documentation for the official library for python: | ||||
// - https://github.com/openai/openai-python/blob/main/chatml.md | |||||
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||||
// - https://github.com/gogpt/gogpt-python/blob/main/chatml.md | |||||
// - https://github.com/gogpt/gogpt-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||||
Name string `json:"name,omitempty"` | Name string `json:"name,omitempty"` | ||||
} | } | ||||
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"bufio" | "bufio" | ||||
@@ -1,153 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"encoding/json" | |||||
"errors" | |||||
"io" | |||||
"net/http" | |||||
"net/http/httptest" | |||||
"testing" | |||||
) | |||||
func TestCreateChatCompletionStream(t *testing.T) { | |||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
w.Header().Set("Content-Type", "text/event-stream") | |||||
// Send test responses | |||||
dataBytes := []byte{} | |||||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||||
//nolint:lll | |||||
data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` | |||||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||||
//nolint:lll | |||||
data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` | |||||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||||
dataBytes = append(dataBytes, []byte("event: done\n")...) | |||||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) | |||||
_, err := w.Write(dataBytes) | |||||
if err != nil { | |||||
t.Errorf("Write error: %s", err) | |||||
} | |||||
})) | |||||
defer server.Close() | |||||
// Client portion of the test | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = server.URL + "/v1" | |||||
config.HTTPClient.Transport = &tokenRoundTripper{ | |||||
test.GetTestToken(), | |||||
http.DefaultTransport, | |||||
} | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
request := ChatCompletionRequest{ | |||||
MaxTokens: 5, | |||||
Model: GPT3Dot5Turbo, | |||||
Messages: []ChatCompletionMessage{ | |||||
{ | |||||
Role: ChatMessageRoleUser, | |||||
Content: "Hello!", | |||||
}, | |||||
}, | |||||
Stream: true, | |||||
} | |||||
stream, err := client.CreateChatCompletionStream(ctx, request) | |||||
if err != nil { | |||||
t.Errorf("CreateCompletionStream returned error: %v", err) | |||||
} | |||||
defer stream.Close() | |||||
expectedResponses := []ChatCompletionStreamResponse{ | |||||
{ | |||||
ID: "1", | |||||
Object: "completion", | |||||
Created: 1598069254, | |||||
Model: GPT3Dot5Turbo, | |||||
Choices: []ChatCompletionStreamChoice{ | |||||
{ | |||||
Delta: ChatCompletionStreamChoiceDelta{ | |||||
Content: "response1", | |||||
}, | |||||
FinishReason: "max_tokens", | |||||
}, | |||||
}, | |||||
}, | |||||
{ | |||||
ID: "2", | |||||
Object: "completion", | |||||
Created: 1598069255, | |||||
Model: GPT3Dot5Turbo, | |||||
Choices: []ChatCompletionStreamChoice{ | |||||
{ | |||||
Delta: ChatCompletionStreamChoiceDelta{ | |||||
Content: "response2", | |||||
}, | |||||
FinishReason: "max_tokens", | |||||
}, | |||||
}, | |||||
}, | |||||
} | |||||
for ix, expectedResponse := range expectedResponses { | |||||
b, _ := json.Marshal(expectedResponse) | |||||
t.Logf("%d: %s", ix, string(b)) | |||||
receivedResponse, streamErr := stream.Recv() | |||||
if streamErr != nil { | |||||
t.Errorf("stream.Recv() failed: %v", streamErr) | |||||
} | |||||
if !compareChatResponses(expectedResponse, receivedResponse) { | |||||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) | |||||
} | |||||
} | |||||
_, streamErr := stream.Recv() | |||||
if !errors.Is(streamErr, io.EOF) { | |||||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) | |||||
} | |||||
_, streamErr = stream.Recv() | |||||
if !errors.Is(streamErr, io.EOF) { | |||||
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) | |||||
} | |||||
} | |||||
// Helper funcs. | |||||
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { | |||||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { | |||||
return false | |||||
} | |||||
if len(r1.Choices) != len(r2.Choices) { | |||||
return false | |||||
} | |||||
for i := range r1.Choices { | |||||
if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) { | |||||
return false | |||||
} | |||||
} | |||||
return true | |||||
} | |||||
func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { | |||||
if c1.Index != c2.Index { | |||||
return false | |||||
} | |||||
if c1.Delta.Content != c2.Delta.Content { | |||||
return false | |||||
} | |||||
if c1.FinishReason != c2.FinishReason { | |||||
return false | |||||
} | |||||
return true | |||||
} |
@@ -1,132 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"encoding/json" | |||||
"errors" | |||||
"fmt" | |||||
"io" | |||||
"net/http" | |||||
"strconv" | |||||
"strings" | |||||
"testing" | |||||
"time" | |||||
) | |||||
func TestChatCompletionsWrongModel(t *testing.T) { | |||||
config := DefaultConfig("whatever") | |||||
config.BaseURL = "http://localhost/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
req := ChatCompletionRequest{ | |||||
MaxTokens: 5, | |||||
Model: "ada", | |||||
Messages: []ChatCompletionMessage{ | |||||
{ | |||||
Role: ChatMessageRoleUser, | |||||
Content: "Hello!", | |||||
}, | |||||
}, | |||||
} | |||||
_, err := client.CreateChatCompletion(ctx, req) | |||||
if !errors.Is(err, ErrChatCompletionInvalidModel) { | |||||
t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err) | |||||
} | |||||
} | |||||
// TestCompletions Tests the completions endpoint of the API using the mocked server. | |||||
func TestChatCompletions(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
req := ChatCompletionRequest{ | |||||
MaxTokens: 5, | |||||
Model: GPT3Dot5Turbo, | |||||
Messages: []ChatCompletionMessage{ | |||||
{ | |||||
Role: ChatMessageRoleUser, | |||||
Content: "Hello!", | |||||
}, | |||||
}, | |||||
} | |||||
_, err = client.CreateChatCompletion(ctx, req) | |||||
if err != nil { | |||||
t.Fatalf("CreateChatCompletion error: %v", err) | |||||
} | |||||
} | |||||
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. | |||||
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { | |||||
var err error | |||||
var resBytes []byte | |||||
// completions only accepts POST requests | |||||
if r.Method != "POST" { | |||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||||
} | |||||
var completionReq ChatCompletionRequest | |||||
if completionReq, err = getChatCompletionBody(r); err != nil { | |||||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||||
return | |||||
} | |||||
res := ChatCompletionResponse{ | |||||
ID: strconv.Itoa(int(time.Now().Unix())), | |||||
Object: "test-object", | |||||
Created: time.Now().Unix(), | |||||
// would be nice to validate Model during testing, but | |||||
// this may not be possible with how much upkeep | |||||
// would be required / wouldn't make much sense | |||||
Model: completionReq.Model, | |||||
} | |||||
// create completions | |||||
for i := 0; i < completionReq.N; i++ { | |||||
// generate a random string of length completionReq.Length | |||||
completionStr := strings.Repeat("a", completionReq.MaxTokens) | |||||
res.Choices = append(res.Choices, ChatCompletionChoice{ | |||||
Message: ChatCompletionMessage{ | |||||
Role: ChatMessageRoleAssistant, | |||||
Content: completionStr, | |||||
}, | |||||
Index: i, | |||||
}) | |||||
} | |||||
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N | |||||
completionTokens := completionReq.MaxTokens * completionReq.N | |||||
res.Usage = Usage{ | |||||
PromptTokens: inputTokens, | |||||
CompletionTokens: completionTokens, | |||||
TotalTokens: inputTokens + completionTokens, | |||||
} | |||||
resBytes, _ = json.Marshal(res) | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
} | |||||
// getChatCompletionBody Returns the body of the request to create a completion. | |||||
func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { | |||||
completion := ChatCompletionRequest{} | |||||
// read the request body | |||||
reqBody, err := io.ReadAll(r.Body) | |||||
if err != nil { | |||||
return ChatCompletionRequest{}, err | |||||
} | |||||
err = json.Unmarshal(reqBody, &completion) | |||||
if err != nil { | |||||
return ChatCompletionRequest{}, err | |||||
} | |||||
return completion, nil | |||||
} |
@@ -1,5 +1,5 @@ | |||||
// common.go defines common types used throughout the OpenAI API. | // common.go defines common types used throughout the OpenAI API. | ||||
package openai | |||||
package gogpt | |||||
// Usage Represents the total token usage per request to OpenAI. | // Usage Represents the total token usage per request to OpenAI. | ||||
type Usage struct { | type Usage struct { | ||||
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -1,121 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"encoding/json" | |||||
"errors" | |||||
"fmt" | |||||
"io" | |||||
"net/http" | |||||
"strconv" | |||||
"strings" | |||||
"testing" | |||||
"time" | |||||
) | |||||
func TestCompletionsWrongModel(t *testing.T) { | |||||
config := DefaultConfig("whatever") | |||||
config.BaseURL = "http://localhost/v1" | |||||
client := NewClientWithConfig(config) | |||||
_, err := client.CreateCompletion( | |||||
context.Background(), | |||||
CompletionRequest{ | |||||
MaxTokens: 5, | |||||
Model: GPT3Dot5Turbo, | |||||
}, | |||||
) | |||||
if !errors.Is(err, ErrCompletionUnsupportedModel) { | |||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) | |||||
} | |||||
} | |||||
// TestCompletions Tests the completions endpoint of the API using the mocked server. | |||||
func TestCompletions(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/completions", handleCompletionEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
req := CompletionRequest{ | |||||
MaxTokens: 5, | |||||
Model: "ada", | |||||
} | |||||
req.Prompt = "Lorem ipsum" | |||||
_, err = client.CreateCompletion(ctx, req) | |||||
if err != nil { | |||||
t.Fatalf("CreateCompletion error: %v", err) | |||||
} | |||||
} | |||||
// handleCompletionEndpoint Handles the completion endpoint by the test server. | |||||
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { | |||||
var err error | |||||
var resBytes []byte | |||||
// completions only accepts POST requests | |||||
if r.Method != "POST" { | |||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||||
} | |||||
var completionReq CompletionRequest | |||||
if completionReq, err = getCompletionBody(r); err != nil { | |||||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||||
return | |||||
} | |||||
res := CompletionResponse{ | |||||
ID: strconv.Itoa(int(time.Now().Unix())), | |||||
Object: "test-object", | |||||
Created: time.Now().Unix(), | |||||
// would be nice to validate Model during testing, but | |||||
// this may not be possible with how much upkeep | |||||
// would be required / wouldn't make much sense | |||||
Model: completionReq.Model, | |||||
} | |||||
// create completions | |||||
for i := 0; i < completionReq.N; i++ { | |||||
// generate a random string of length completionReq.Length | |||||
completionStr := strings.Repeat("a", completionReq.MaxTokens) | |||||
if completionReq.Echo { | |||||
completionStr = completionReq.Prompt + completionStr | |||||
} | |||||
res.Choices = append(res.Choices, CompletionChoice{ | |||||
Text: completionStr, | |||||
Index: i, | |||||
}) | |||||
} | |||||
inputTokens := numTokens(completionReq.Prompt) * completionReq.N | |||||
completionTokens := completionReq.MaxTokens * completionReq.N | |||||
res.Usage = Usage{ | |||||
PromptTokens: inputTokens, | |||||
CompletionTokens: completionTokens, | |||||
TotalTokens: inputTokens + completionTokens, | |||||
} | |||||
resBytes, _ = json.Marshal(res) | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
} | |||||
// getCompletionBody Returns the body of the request to create a completion. | |||||
func getCompletionBody(r *http.Request) (CompletionRequest, error) { | |||||
completion := CompletionRequest{} | |||||
// read the request body | |||||
reqBody, err := io.ReadAll(r.Body) | |||||
if err != nil { | |||||
return CompletionRequest{}, err | |||||
} | |||||
err = json.Unmarshal(reqBody, &completion) | |||||
if err != nil { | |||||
return CompletionRequest{}, err | |||||
} | |||||
return completion, nil | |||||
} |
@@ -1,11 +1,11 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"net/http" | "net/http" | ||||
) | ) | ||||
const ( | const ( | ||||
apiURLv1 = "https://api.openai.com/v1" | |||||
apiURLv1 = "https://api.gogpt.com/v1" | |||||
defaultEmptyMessagesLimit uint = 300 | defaultEmptyMessagesLimit uint = 300 | ||||
) | ) | ||||
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -1,104 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"encoding/json" | |||||
"fmt" | |||||
"io" | |||||
"net/http" | |||||
"testing" | |||||
"time" | |||||
) | |||||
// TestEdits Tests the edits endpoint of the API using the mocked server. | |||||
func TestEdits(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/edits", handleEditEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
// create an edit request | |||||
model := "ada" | |||||
editReq := EditsRequest{ | |||||
Model: &model, | |||||
Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + | |||||
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + | |||||
" ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + | |||||
" ex ea commodo consequat. Duis aute irure dolor in reprehe", | |||||
Instruction: "test instruction", | |||||
N: 3, | |||||
} | |||||
response, err := client.Edits(ctx, editReq) | |||||
if err != nil { | |||||
t.Fatalf("Edits error: %v", err) | |||||
} | |||||
if len(response.Choices) != editReq.N { | |||||
t.Fatalf("edits does not properly return the correct number of choices") | |||||
} | |||||
} | |||||
// handleEditEndpoint Handles the edit endpoint by the test server. | |||||
func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { | |||||
var err error | |||||
var resBytes []byte | |||||
// edits only accepts POST requests | |||||
if r.Method != "POST" { | |||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||||
} | |||||
var editReq EditsRequest | |||||
editReq, err = getEditBody(r) | |||||
if err != nil { | |||||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||||
return | |||||
} | |||||
// create a response | |||||
res := EditsResponse{ | |||||
Object: "test-object", | |||||
Created: time.Now().Unix(), | |||||
} | |||||
// edit and calculate token usage | |||||
editString := "edited by mocked OpenAI server :)" | |||||
inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N | |||||
completionTokens := int(float32(len(editString))/4) * editReq.N | |||||
for i := 0; i < editReq.N; i++ { | |||||
// instruction will be hidden and only seen by OpenAI | |||||
res.Choices = append(res.Choices, EditsChoice{ | |||||
Text: editReq.Input + editString, | |||||
Index: i, | |||||
}) | |||||
} | |||||
res.Usage = Usage{ | |||||
PromptTokens: inputTokens, | |||||
CompletionTokens: completionTokens, | |||||
TotalTokens: inputTokens + completionTokens, | |||||
} | |||||
resBytes, _ = json.Marshal(res) | |||||
fmt.Fprint(w, string(resBytes)) | |||||
} | |||||
// getEditBody Returns the body of the request to create an edit. | |||||
func getEditBody(r *http.Request) (EditsRequest, error) { | |||||
edit := EditsRequest{} | |||||
// read the request body | |||||
reqBody, err := io.ReadAll(r.Body) | |||||
if err != nil { | |||||
return EditsRequest{}, err | |||||
} | |||||
err = json.Unmarshal(reqBody, &edit) | |||||
if err != nil { | |||||
return EditsRequest{}, err | |||||
} | |||||
return edit, nil | |||||
} |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -130,7 +130,7 @@ type EmbeddingRequest struct { | |||||
} | } | ||||
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. | // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. | ||||
// https://beta.openai.com/docs/api-reference/embeddings/create | |||||
// https://beta.gogpt.com/docs/api-reference/embeddings/create | |||||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { | func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { | ||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) | req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) | ||||
if err != nil { | if err != nil { | ||||
@@ -1,48 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"bytes" | |||||
"encoding/json" | |||||
"testing" | |||||
) | |||||
func TestEmbedding(t *testing.T) { | |||||
embeddedModels := []EmbeddingModel{ | |||||
AdaSimilarity, | |||||
BabbageSimilarity, | |||||
CurieSimilarity, | |||||
DavinciSimilarity, | |||||
AdaSearchDocument, | |||||
AdaSearchQuery, | |||||
BabbageSearchDocument, | |||||
BabbageSearchQuery, | |||||
CurieSearchDocument, | |||||
CurieSearchQuery, | |||||
DavinciSearchDocument, | |||||
DavinciSearchQuery, | |||||
AdaCodeSearchCode, | |||||
AdaCodeSearchText, | |||||
BabbageCodeSearchCode, | |||||
BabbageCodeSearchText, | |||||
} | |||||
for _, model := range embeddedModels { | |||||
embeddingReq := EmbeddingRequest{ | |||||
Input: []string{ | |||||
"The food was delicious and the waiter", | |||||
"Other examples of embedding request", | |||||
}, | |||||
Model: model, | |||||
} | |||||
// marshal embeddingReq to JSON and confirm that the model field equals | |||||
// the AdaSearchQuery type | |||||
marshaled, err := json.Marshal(embeddingReq) | |||||
if err != nil { | |||||
t.Fatalf("Could not marshal embedding request: %v", err) | |||||
} | |||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { | |||||
t.Fatalf("Expected embedding request to contain model field") | |||||
} | |||||
} | |||||
} |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import "fmt" | import "fmt" | ||||
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"bytes" | "bytes" | ||||
@@ -1,81 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"encoding/json" | |||||
"fmt" | |||||
"net/http" | |||||
"strconv" | |||||
"testing" | |||||
"time" | |||||
) | |||||
func TestFileUpload(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/files", handleCreateFile) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
req := FileRequest{ | |||||
FileName: "test.go", | |||||
FilePath: "api.go", | |||||
Purpose: "fine-tune", | |||||
} | |||||
_, err = client.CreateFile(ctx, req) | |||||
if err != nil { | |||||
t.Fatalf("CreateFile error: %v", err) | |||||
} | |||||
} | |||||
// handleCreateFile Handles the images endpoint by the test server. | |||||
func handleCreateFile(w http.ResponseWriter, r *http.Request) { | |||||
var err error | |||||
var resBytes []byte | |||||
// edits only accepts POST requests | |||||
if r.Method != "POST" { | |||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||||
} | |||||
err = r.ParseMultipartForm(1024 * 1024 * 1024) | |||||
if err != nil { | |||||
http.Error(w, "file is more than 1GB", http.StatusInternalServerError) | |||||
return | |||||
} | |||||
values := r.Form | |||||
var purpose string | |||||
for key, value := range values { | |||||
if key == "purpose" { | |||||
purpose = value[0] | |||||
} | |||||
} | |||||
file, header, err := r.FormFile("file") | |||||
if err != nil { | |||||
return | |||||
} | |||||
defer file.Close() | |||||
var fileReq = File{ | |||||
Bytes: int(header.Size), | |||||
ID: strconv.Itoa(int(time.Now().Unix())), | |||||
FileName: header.Filename, | |||||
Purpose: purpose, | |||||
CreatedAt: time.Now().Unix(), | |||||
Object: "test-objecct", | |||||
Owner: "test-owner", | |||||
} | |||||
resBytes, _ = json.Marshal(fileReq) | |||||
fmt.Fprint(w, string(resBytes)) | |||||
} |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -1,101 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"encoding/json" | |||||
"fmt" | |||||
"net/http" | |||||
"testing" | |||||
) | |||||
const testFineTuneID = "fine-tune-id" | |||||
// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. | |||||
func TestFineTunes(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler( | |||||
"/v1/fine-tunes", | |||||
func(w http.ResponseWriter, r *http.Request) { | |||||
var resBytes []byte | |||||
if r.Method == http.MethodGet { | |||||
resBytes, _ = json.Marshal(FineTuneList{}) | |||||
} else { | |||||
resBytes, _ = json.Marshal(FineTune{}) | |||||
} | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
}, | |||||
) | |||||
server.RegisterHandler( | |||||
"/v1/fine-tunes/"+testFineTuneID+"/cancel", | |||||
func(w http.ResponseWriter, r *http.Request) { | |||||
resBytes, _ := json.Marshal(FineTune{}) | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
}, | |||||
) | |||||
server.RegisterHandler( | |||||
"/v1/fine-tunes/"+testFineTuneID, | |||||
func(w http.ResponseWriter, r *http.Request) { | |||||
var resBytes []byte | |||||
if r.Method == http.MethodDelete { | |||||
resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) | |||||
} else { | |||||
resBytes, _ = json.Marshal(FineTune{}) | |||||
} | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
}, | |||||
) | |||||
server.RegisterHandler( | |||||
"/v1/fine-tunes/"+testFineTuneID+"/events", | |||||
func(w http.ResponseWriter, r *http.Request) { | |||||
resBytes, _ := json.Marshal(FineTuneEventList{}) | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
}, | |||||
) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
_, err = client.ListFineTunes(ctx) | |||||
if err != nil { | |||||
t.Fatalf("ListFineTunes error: %v", err) | |||||
} | |||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{}) | |||||
if err != nil { | |||||
t.Fatalf("CreateFineTune error: %v", err) | |||||
} | |||||
_, err = client.CancelFineTune(ctx, testFineTuneID) | |||||
if err != nil { | |||||
t.Fatalf("CancelFineTune error: %v", err) | |||||
} | |||||
_, err = client.GetFineTune(ctx, testFineTuneID) | |||||
if err != nil { | |||||
t.Fatalf("GetFineTune error: %v", err) | |||||
} | |||||
_, err = client.DeleteFineTune(ctx, testFineTuneID) | |||||
if err != nil { | |||||
t.Fatalf("DeleteFineTune error: %v", err) | |||||
} | |||||
_, err = client.ListFineTuneEvents(ctx, testFineTuneID) | |||||
if err != nil { | |||||
t.Fatalf("ListFineTuneEvents error: %v", err) | |||||
} | |||||
} |
@@ -1,3 +1,3 @@ | |||||
module github.com/sashabaranov/go-openai | |||||
module github.com/sashabaranov/go-gogpt | |||||
go 1.18 | go 1.18 |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"bytes" | "bytes" | ||||
@@ -147,7 +147,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) | |||||
return | return | ||||
} | } | ||||
writer.Close() | writer.Close() | ||||
//https://platform.openai.com/docs/api-reference/images/create-variation | |||||
//https://platform.gogpt.com/docs/api-reference/images/create-variation | |||||
urlSuffix := "/images/variations" | urlSuffix := "/images/variations" | ||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) | req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) | ||||
if err != nil { | if err != nil { | ||||
@@ -1,268 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"encoding/json" | |||||
"fmt" | |||||
"io" | |||||
"net/http" | |||||
"os" | |||||
"testing" | |||||
"time" | |||||
) | |||||
func TestImages(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/images/generations", handleImageEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
req := ImageRequest{} | |||||
req.Prompt = "Lorem ipsum" | |||||
_, err = client.CreateImage(ctx, req) | |||||
if err != nil { | |||||
t.Fatalf("CreateImage error: %v", err) | |||||
} | |||||
} | |||||
// handleImageEndpoint Handles the images endpoint by the test server. | |||||
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { | |||||
var err error | |||||
var resBytes []byte | |||||
// imagess only accepts POST requests | |||||
if r.Method != "POST" { | |||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||||
} | |||||
var imageReq ImageRequest | |||||
if imageReq, err = getImageBody(r); err != nil { | |||||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||||
return | |||||
} | |||||
res := ImageResponse{ | |||||
Created: time.Now().Unix(), | |||||
} | |||||
for i := 0; i < imageReq.N; i++ { | |||||
imageData := ImageResponseDataInner{} | |||||
switch imageReq.ResponseFormat { | |||||
case CreateImageResponseFormatURL, "": | |||||
imageData.URL = "https://example.com/image.png" | |||||
case CreateImageResponseFormatB64JSON: | |||||
// This decodes to "{}" in base64. | |||||
imageData.B64JSON = "e30K" | |||||
default: | |||||
http.Error(w, "invalid response format", http.StatusBadRequest) | |||||
return | |||||
} | |||||
res.Data = append(res.Data, imageData) | |||||
} | |||||
resBytes, _ = json.Marshal(res) | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
} | |||||
// getImageBody Returns the body of the request to create a image. | |||||
func getImageBody(r *http.Request) (ImageRequest, error) { | |||||
image := ImageRequest{} | |||||
// read the request body | |||||
reqBody, err := io.ReadAll(r.Body) | |||||
if err != nil { | |||||
return ImageRequest{}, err | |||||
} | |||||
err = json.Unmarshal(reqBody, &image) | |||||
if err != nil { | |||||
return ImageRequest{}, err | |||||
} | |||||
return image, nil | |||||
} | |||||
func TestImageEdit(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
origin, err := os.Create("image.png") | |||||
if err != nil { | |||||
t.Error("open origin file error") | |||||
return | |||||
} | |||||
mask, err := os.Create("mask.png") | |||||
if err != nil { | |||||
t.Error("open mask file error") | |||||
return | |||||
} | |||||
defer func() { | |||||
mask.Close() | |||||
origin.Close() | |||||
os.Remove("mask.png") | |||||
os.Remove("image.png") | |||||
}() | |||||
req := ImageEditRequest{ | |||||
Image: origin, | |||||
Mask: mask, | |||||
Prompt: "There is a turtle in the pool", | |||||
N: 3, | |||||
Size: CreateImageSize1024x1024, | |||||
} | |||||
_, err = client.CreateEditImage(ctx, req) | |||||
if err != nil { | |||||
t.Fatalf("CreateImage error: %v", err) | |||||
} | |||||
} | |||||
func TestImageEditWithoutMask(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
origin, err := os.Create("image.png") | |||||
if err != nil { | |||||
t.Error("open origin file error") | |||||
return | |||||
} | |||||
defer func() { | |||||
origin.Close() | |||||
os.Remove("image.png") | |||||
}() | |||||
req := ImageEditRequest{ | |||||
Image: origin, | |||||
Prompt: "There is a turtle in the pool", | |||||
N: 3, | |||||
Size: CreateImageSize1024x1024, | |||||
} | |||||
_, err = client.CreateEditImage(ctx, req) | |||||
if err != nil { | |||||
t.Fatalf("CreateImage error: %v", err) | |||||
} | |||||
} | |||||
// handleEditImageEndpoint Handles the images endpoint by the test server. | |||||
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { | |||||
var resBytes []byte | |||||
// imagess only accepts POST requests | |||||
if r.Method != "POST" { | |||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||||
} | |||||
responses := ImageResponse{ | |||||
Created: time.Now().Unix(), | |||||
Data: []ImageResponseDataInner{ | |||||
{ | |||||
URL: "test-url1", | |||||
B64JSON: "", | |||||
}, | |||||
{ | |||||
URL: "test-url2", | |||||
B64JSON: "", | |||||
}, | |||||
{ | |||||
URL: "test-url3", | |||||
B64JSON: "", | |||||
}, | |||||
}, | |||||
} | |||||
resBytes, _ = json.Marshal(responses) | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
} | |||||
func TestImageVariation(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
origin, err := os.Create("image.png") | |||||
if err != nil { | |||||
t.Error("open origin file error") | |||||
return | |||||
} | |||||
defer func() { | |||||
origin.Close() | |||||
os.Remove("image.png") | |||||
}() | |||||
req := ImageVariRequest{ | |||||
Image: origin, | |||||
N: 3, | |||||
Size: CreateImageSize1024x1024, | |||||
} | |||||
_, err = client.CreateVariImage(ctx, req) | |||||
if err != nil { | |||||
t.Fatalf("CreateImage error: %v", err) | |||||
} | |||||
} | |||||
// handleVariateImageEndpoint Handles the images endpoint by the test server. | |||||
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { | |||||
var resBytes []byte | |||||
// imagess only accepts POST requests | |||||
if r.Method != "POST" { | |||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||||
} | |||||
responses := ImageResponse{ | |||||
Created: time.Now().Unix(), | |||||
Data: []ImageResponseDataInner{ | |||||
{ | |||||
URL: "test-url1", | |||||
B64JSON: "", | |||||
}, | |||||
{ | |||||
URL: "test-url2", | |||||
B64JSON: "", | |||||
}, | |||||
{ | |||||
URL: "test-url3", | |||||
B64JSON: "", | |||||
}, | |||||
}, | |||||
} | |||||
resBytes, _ = json.Marshal(responses) | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
} |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"encoding/json" | "encoding/json" | ||||
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -1,39 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"encoding/json" | |||||
"fmt" | |||||
"net/http" | |||||
"testing" | |||||
) | |||||
// TestListModels Tests the models endpoint of the API using the mocked server. | |||||
func TestListModels(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/models", handleModelsEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
_, err = client.ListModels(ctx) | |||||
if err != nil { | |||||
t.Fatalf("ListModels error: %v", err) | |||||
} | |||||
} | |||||
// handleModelsEndpoint Handles the models endpoint by the test server. | |||||
func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) { | |||||
resBytes, _ := json.Marshal(ModelsList{}) | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
} |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"context" | "context" | ||||
@@ -1,102 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"encoding/json" | |||||
"fmt" | |||||
"io" | |||||
"net/http" | |||||
"strconv" | |||||
"strings" | |||||
"testing" | |||||
"time" | |||||
) | |||||
// TestModeration Tests the moderations endpoint of the API using the mocked server. | |||||
func TestModerations(t *testing.T) { | |||||
server := test.NewTestServer() | |||||
server.RegisterHandler("/v1/moderations", handleModerationEndpoint) | |||||
// create the test server | |||||
var err error | |||||
ts := server.OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
// create an edit request | |||||
model := "text-moderation-stable" | |||||
moderationReq := ModerationRequest{ | |||||
Model: &model, | |||||
Input: "I want to kill them.", | |||||
} | |||||
_, err = client.Moderations(ctx, moderationReq) | |||||
if err != nil { | |||||
t.Fatalf("Moderation error: %v", err) | |||||
} | |||||
} | |||||
// handleModerationEndpoint Handles the moderation endpoint by the test server. | |||||
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { | |||||
var err error | |||||
var resBytes []byte | |||||
// completions only accepts POST requests | |||||
if r.Method != "POST" { | |||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||||
} | |||||
var moderationReq ModerationRequest | |||||
if moderationReq, err = getModerationBody(r); err != nil { | |||||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||||
return | |||||
} | |||||
resCat := ResultCategories{} | |||||
resCatScore := ResultCategoryScores{} | |||||
switch { | |||||
case strings.Contains(moderationReq.Input, "kill"): | |||||
resCat = ResultCategories{Violence: true} | |||||
resCatScore = ResultCategoryScores{Violence: 1} | |||||
case strings.Contains(moderationReq.Input, "hate"): | |||||
resCat = ResultCategories{Hate: true} | |||||
resCatScore = ResultCategoryScores{Hate: 1} | |||||
case strings.Contains(moderationReq.Input, "suicide"): | |||||
resCat = ResultCategories{SelfHarm: true} | |||||
resCatScore = ResultCategoryScores{SelfHarm: 1} | |||||
case strings.Contains(moderationReq.Input, "porn"): | |||||
resCat = ResultCategories{Sexual: true} | |||||
resCatScore = ResultCategoryScores{Sexual: 1} | |||||
} | |||||
result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} | |||||
res := ModerationResponse{ | |||||
ID: strconv.Itoa(int(time.Now().Unix())), | |||||
Model: *moderationReq.Model, | |||||
} | |||||
res.Results = append(res.Results, result) | |||||
resBytes, _ = json.Marshal(res) | |||||
fmt.Fprintln(w, string(resBytes)) | |||||
} | |||||
// getModerationBody Returns the body of the request to do a moderation. | |||||
func getModerationBody(r *http.Request) (ModerationRequest, error) { | |||||
moderation := ModerationRequest{} | |||||
// read the request body | |||||
reqBody, err := io.ReadAll(r.Body) | |||||
if err != nil { | |||||
return ModerationRequest{}, err | |||||
} | |||||
err = json.Unmarshal(reqBody, &moderation) | |||||
if err != nil { | |||||
return ModerationRequest{}, err | |||||
} | |||||
return moderation, nil | |||||
} |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"bytes" | "bytes" | ||||
@@ -1,148 +0,0 @@ | |||||
package openai //nolint:testpackage // testing private field | |||||
import ( | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"errors" | |||||
"net/http" | |||||
"testing" | |||||
) | |||||
var ( | |||||
errTestMarshallerFailed = errors.New("test marshaller failed") | |||||
errTestRequestBuilderFailed = errors.New("test request builder failed") | |||||
) | |||||
type ( | |||||
failingRequestBuilder struct{} | |||||
failingMarshaller struct{} | |||||
) | |||||
func (*failingMarshaller) marshal(value any) ([]byte, error) { | |||||
return []byte{}, errTestMarshallerFailed | |||||
} | |||||
func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) { | |||||
return nil, errTestRequestBuilderFailed | |||||
} | |||||
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { | |||||
builder := httpRequestBuilder{ | |||||
marshaller: &failingMarshaller{}, | |||||
} | |||||
_, err := builder.build(context.Background(), "", "", struct{}{}) | |||||
if !errors.Is(err, errTestMarshallerFailed) { | |||||
t.Fatalf("Did not return error when marshaller failed: %v", err) | |||||
} | |||||
} | |||||
func TestClientReturnsRequestBuilderErrors(t *testing.T) { | |||||
var err error | |||||
ts := test.NewTestServer().OpenAITestServer() | |||||
ts.Start() | |||||
defer ts.Close() | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = ts.URL + "/v1" | |||||
client := NewClientWithConfig(config) | |||||
client.requestBuilder = &failingRequestBuilder{} | |||||
ctx := context.Background() | |||||
_, err = client.CreateCompletion(ctx, CompletionRequest{}) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{}) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.ListFineTunes(ctx) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.CancelFineTune(ctx, "") | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.GetFineTune(ctx, "") | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.DeleteFineTune(ctx, "") | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.ListFineTuneEvents(ctx, "") | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.Moderations(ctx, ModerationRequest{}) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.Edits(ctx, EditsRequest{}) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.CreateImage(ctx, ImageRequest{}) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
err = client.DeleteFile(ctx, "") | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.GetFile(ctx, "") | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.ListFiles(ctx) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.ListEngines(ctx) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.GetEngine(ctx, "") | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
_, err = client.ListModels(ctx) | |||||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||||
} | |||||
} |
@@ -1,4 +1,4 @@ | |||||
package openai | |||||
package gogpt | |||||
import ( | import ( | ||||
"bufio" | "bufio" | ||||
@@ -1,147 +0,0 @@ | |||||
package openai_test | |||||
import ( | |||||
. "github.com/sashabaranov/go-openai" | |||||
"github.com/sashabaranov/go-openai/internal/test" | |||||
"context" | |||||
"errors" | |||||
"io" | |||||
"net/http" | |||||
"net/http/httptest" | |||||
"testing" | |||||
) | |||||
func TestCreateCompletionStream(t *testing.T) { | |||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
w.Header().Set("Content-Type", "text/event-stream") | |||||
// Send test responses | |||||
dataBytes := []byte{} | |||||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||||
//nolint:lll | |||||
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` | |||||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||||
//nolint:lll | |||||
data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}` | |||||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||||
dataBytes = append(dataBytes, []byte("event: done\n")...) | |||||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) | |||||
_, err := w.Write(dataBytes) | |||||
if err != nil { | |||||
t.Errorf("Write error: %s", err) | |||||
} | |||||
})) | |||||
defer server.Close() | |||||
// Client portion of the test | |||||
config := DefaultConfig(test.GetTestToken()) | |||||
config.BaseURL = server.URL + "/v1" | |||||
config.HTTPClient.Transport = &tokenRoundTripper{ | |||||
test.GetTestToken(), | |||||
http.DefaultTransport, | |||||
} | |||||
client := NewClientWithConfig(config) | |||||
ctx := context.Background() | |||||
request := CompletionRequest{ | |||||
Prompt: "Ex falso quodlibet", | |||||
Model: "text-davinci-002", | |||||
MaxTokens: 10, | |||||
Stream: true, | |||||
} | |||||
stream, err := client.CreateCompletionStream(ctx, request) | |||||
if err != nil { | |||||
t.Errorf("CreateCompletionStream returned error: %v", err) | |||||
} | |||||
defer stream.Close() | |||||
expectedResponses := []CompletionResponse{ | |||||
{ | |||||
ID: "1", | |||||
Object: "completion", | |||||
Created: 1598069254, | |||||
Model: "text-davinci-002", | |||||
Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, | |||||
}, | |||||
{ | |||||
ID: "2", | |||||
Object: "completion", | |||||
Created: 1598069255, | |||||
Model: "text-davinci-002", | |||||
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, | |||||
}, | |||||
} | |||||
for ix, expectedResponse := range expectedResponses { | |||||
receivedResponse, streamErr := stream.Recv() | |||||
if streamErr != nil { | |||||
t.Errorf("stream.Recv() failed: %v", streamErr) | |||||
} | |||||
if !compareResponses(expectedResponse, receivedResponse) { | |||||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) | |||||
} | |||||
} | |||||
_, streamErr := stream.Recv() | |||||
if !errors.Is(streamErr, io.EOF) { | |||||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) | |||||
} | |||||
_, streamErr = stream.Recv() | |||||
if !errors.Is(streamErr, io.EOF) { | |||||
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) | |||||
} | |||||
} | |||||
// A "tokenRoundTripper" is a struct that implements the RoundTripper | |||||
// interface, specifically to handle the authentication token by adding a token | |||||
// to the request header. We need this because the API requires that each | |||||
// request include a valid API token in the headers for authentication and | |||||
// authorization. | |||||
type tokenRoundTripper struct { | |||||
token string | |||||
fallback http.RoundTripper | |||||
} | |||||
// RoundTrip takes an *http.Request as input and returns an | |||||
// *http.Response and an error. | |||||
// | |||||
// It is expected to use the provided request to create a connection to an HTTP | |||||
// server and return the response, or an error if one occurred. The returned | |||||
// Response should have its Body closed. If the RoundTrip method returns an | |||||
// error, the Client's Get, Head, Post, and PostForm methods return the same | |||||
// error. | |||||
func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | |||||
req.Header.Set("Authorization", "Bearer "+t.token) | |||||
return t.fallback.RoundTrip(req) | |||||
} | |||||
// Helper funcs. | |||||
func compareResponses(r1, r2 CompletionResponse) bool { | |||||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { | |||||
return false | |||||
} | |||||
if len(r1.Choices) != len(r2.Choices) { | |||||
return false | |||||
} | |||||
for i := range r1.Choices { | |||||
if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) { | |||||
return false | |||||
} | |||||
} | |||||
return true | |||||
} | |||||
func compareResponseChoices(c1, c2 CompletionChoice) bool { | |||||
if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { | |||||
return false | |||||
} | |||||
return true | |||||
} |