@@ -0,0 +1,27 @@ | |||
name: PR sanity check | |||
on: | |||
- push | |||
- pull_request | |||
jobs: | |||
prcheck: | |||
name: PR sanity check | |||
runs-on: ubuntu-latest | |||
steps: | |||
- uses: actions/checkout@v2 | |||
- name: Setup Go | |||
uses: actions/setup-go@v2 | |||
with: | |||
go-version: '1.19' | |||
- name: Run vet | |||
run: | | |||
go vet . | |||
- name: Run golangci-lint | |||
uses: golangci/golangci-lint-action@v3 | |||
with: | |||
version: latest | |||
- name: Run tests | |||
run: go test -race -covermode=atomic -coverprofile=coverage.out -v . | |||
- name: Upload coverage reports to Codecov | |||
uses: codecov/codecov-action@v3 |
@@ -1,17 +1,65 @@ | |||
# go-gpt3 | |||
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gpt3) | |||
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gpt3)](https://goreportcard.com/report/github.com/sashabaranov/go-gpt3) | |||
# Go OpenAI | |||
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai) | |||
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) | |||
[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) | |||
> **Note**: the repository was recently renamed from `go-gpt3` to `go-openai` | |||
[OpenAI ChatGPT and GPT-3](https://platform.openai.com/) API client for Go | |||
This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support: | |||
* ChatGPT | |||
* GPT-3, GPT-4 | |||
* DALL·E 2 | |||
* Whisper | |||
Installation: | |||
``` | |||
go get github.com/sashabaranov/go-gpt3 | |||
go get github.com/sashabaranov/go-openai | |||
``` | |||
ChatGPT example usage: | |||
```go | |||
package main | |||
import ( | |||
"context" | |||
"fmt" | |||
openai "github.com/sashabaranov/go-openai" | |||
) | |||
func main() { | |||
client := openai.NewClient("your token") | |||
resp, err := client.CreateChatCompletion( | |||
context.Background(), | |||
openai.ChatCompletionRequest{ | |||
Model: openai.GPT3Dot5Turbo, | |||
Messages: []openai.ChatCompletionMessage{ | |||
{ | |||
Role: openai.ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
}, | |||
) | |||
if err != nil { | |||
fmt.Printf("ChatCompletion error: %v\n", err) | |||
return | |||
} | |||
fmt.Println(resp.Choices[0].Message.Content) | |||
} | |||
``` | |||
Example usage: | |||
Other examples: | |||
<details> | |||
<summary>GPT-3 completion</summary> | |||
```go | |||
package main | |||
@@ -19,27 +67,30 @@ package main | |||
import ( | |||
"context" | |||
"fmt" | |||
gogpt "github.com/sashabaranov/go-gpt3" | |||
openai "github.com/sashabaranov/go-openai" | |||
) | |||
func main() { | |||
c := gogpt.NewClient("your token") | |||
c := openai.NewClient("your token") | |||
ctx := context.Background() | |||
req := gogpt.CompletionRequest{ | |||
Model: gogpt.GPT3Ada, | |||
req := openai.CompletionRequest{ | |||
Model: openai.GPT3Ada, | |||
MaxTokens: 5, | |||
Prompt: "Lorem ipsum", | |||
} | |||
resp, err := c.CreateCompletion(ctx, req) | |||
if err != nil { | |||
fmt.Printf("Completion error: %v\n", err) | |||
return | |||
} | |||
fmt.Println(resp.Choices[0].Text) | |||
} | |||
``` | |||
</details> | |||
Streaming response example: | |||
<details> | |||
<summary>GPT-3 streaming completion</summary> | |||
```go | |||
package main | |||
@@ -49,21 +100,22 @@ import ( | |||
"context" | |||
"fmt" | |||
"io" | |||
gogpt "github.com/sashabaranov/go-gpt3" | |||
openai "github.com/sashabaranov/go-openai" | |||
) | |||
func main() { | |||
c := gogpt.NewClient("your token") | |||
c := openai.NewClient("your token") | |||
ctx := context.Background() | |||
req := gogpt.CompletionRequest{ | |||
Model: gogpt.GPT3Ada, | |||
req := openai.CompletionRequest{ | |||
Model: openai.GPT3Ada, | |||
MaxTokens: 5, | |||
Prompt: "Lorem ipsum", | |||
Stream: true, | |||
} | |||
stream, err := c.CreateCompletionStream(ctx, req) | |||
if err != nil { | |||
fmt.Printf("CompletionStream error: %v\n", err) | |||
return | |||
} | |||
defer stream.Close() | |||
@@ -85,3 +137,194 @@ func main() { | |||
} | |||
} | |||
``` | |||
</details> | |||
<details> | |||
<summary>Audio Speech-To-Text</summary> | |||
```go | |||
package main | |||
import ( | |||
"context" | |||
"fmt" | |||
openai "github.com/sashabaranov/go-openai" | |||
) | |||
func main() { | |||
c := openai.NewClient("your token") | |||
ctx := context.Background() | |||
req := openai.AudioRequest{ | |||
Model: openai.Whisper1, | |||
FilePath: "recording.mp3", | |||
} | |||
resp, err := c.CreateTranscription(ctx, req) | |||
if err != nil { | |||
fmt.Printf("Transcription error: %v\n", err) | |||
return | |||
} | |||
fmt.Println(resp.Text) | |||
} | |||
``` | |||
</details> | |||
<details> | |||
<summary>DALL-E 2 image generation</summary> | |||
```go | |||
package main | |||
import ( | |||
"bytes" | |||
"context" | |||
"encoding/base64" | |||
"fmt" | |||
openai "github.com/sashabaranov/go-openai" | |||
"image/png" | |||
"os" | |||
) | |||
func main() { | |||
c := openai.NewClient("your token") | |||
ctx := context.Background() | |||
// Sample image by link | |||
reqUrl := openai.ImageRequest{ | |||
Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail", | |||
Size: openai.CreateImageSize256x256, | |||
ResponseFormat: openai.CreateImageResponseFormatURL, | |||
N: 1, | |||
} | |||
respUrl, err := c.CreateImage(ctx, reqUrl) | |||
if err != nil { | |||
fmt.Printf("Image creation error: %v\n", err) | |||
return | |||
} | |||
fmt.Println(respUrl.Data[0].URL) | |||
// Example image as base64 | |||
reqBase64 := openai.ImageRequest{ | |||
Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine", | |||
Size: openai.CreateImageSize256x256, | |||
ResponseFormat: openai.CreateImageResponseFormatB64JSON, | |||
N: 1, | |||
} | |||
respBase64, err := c.CreateImage(ctx, reqBase64) | |||
if err != nil { | |||
fmt.Printf("Image creation error: %v\n", err) | |||
return | |||
} | |||
imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON) | |||
if err != nil { | |||
fmt.Printf("Base64 decode error: %v\n", err) | |||
return | |||
} | |||
r := bytes.NewReader(imgBytes) | |||
imgData, err := png.Decode(r) | |||
if err != nil { | |||
fmt.Printf("PNG decode error: %v\n", err) | |||
return | |||
} | |||
file, err := os.Create("image.png") | |||
if err != nil { | |||
fmt.Printf("File creation error: %v\n", err) | |||
return | |||
} | |||
defer file.Close() | |||
if err := png.Encode(file, imgData); err != nil { | |||
fmt.Printf("PNG encode error: %v\n", err) | |||
return | |||
} | |||
fmt.Println("The image was saved as example.png") | |||
} | |||
``` | |||
</details> | |||
<details> | |||
<summary>Configuring proxy</summary> | |||
```go | |||
config := openai.DefaultConfig("token") | |||
proxyUrl, err := url.Parse("http://localhost:{port}") | |||
if err != nil { | |||
panic(err) | |||
} | |||
transport := &http.Transport{ | |||
Proxy: http.ProxyURL(proxyUrl), | |||
} | |||
config.HTTPClient = &http.Client{ | |||
Transport: transport, | |||
} | |||
c := openai.NewClientWithConfig(config) | |||
``` | |||
See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig | |||
</details> | |||
<details> | |||
<summary>ChatGPT support context</summary> | |||
```go | |||
package main | |||
import ( | |||
"bufio" | |||
"context" | |||
"fmt" | |||
"os" | |||
"strings" | |||
"github.com/sashabaranov/go-openai" | |||
) | |||
func main() { | |||
client := openai.NewClient("your token") | |||
messages := make([]openai.ChatCompletionMessage, 0) | |||
reader := bufio.NewReader(os.Stdin) | |||
fmt.Println("Conversation") | |||
fmt.Println("---------------------") | |||
for { | |||
fmt.Print("-> ") | |||
text, _ := reader.ReadString('\n') | |||
// convert CRLF to LF | |||
text = strings.Replace(text, "\n", "", -1) | |||
messages = append(messages, openai.ChatCompletionMessage{ | |||
Role: openai.ChatMessageRoleUser, | |||
Content: text, | |||
}) | |||
resp, err := client.CreateChatCompletion( | |||
context.Background(), | |||
openai.ChatCompletionRequest{ | |||
Model: openai.GPT3Dot5Turbo, | |||
Messages: messages, | |||
}, | |||
) | |||
if err != nil { | |||
fmt.Printf("ChatCompletion error: %v\n", err) | |||
continue | |||
} | |||
content := resp.Choices[0].Message.Content | |||
messages = append(messages, openai.ChatCompletionMessage{ | |||
Role: openai.ChatMessageRoleAssistant, | |||
Content: content, | |||
}) | |||
fmt.Println(content) | |||
} | |||
} | |||
``` | |||
</details> |
@@ -1,50 +0,0 @@ | |||
package gogpt | |||
import ( | |||
"bytes" | |||
"context" | |||
"encoding/json" | |||
"net/http" | |||
) | |||
type AnswerRequest struct { | |||
Documents []string `json:"documents,omitempty"` | |||
File string `json:"file,omitempty"` | |||
Question string `json:"question"` | |||
SearchModel string `json:"search_model,omitempty"` | |||
Model string `json:"model"` | |||
ExamplesContext string `json:"examples_context"` | |||
Examples [][]string `json:"examples"` | |||
MaxTokens int `json:"max_tokens,omitempty"` | |||
Stop []string `json:"stop,omitempty"` | |||
Temperature *float64 `json:"temperature,omitempty"` | |||
} | |||
type AnswerResponse struct { | |||
Answers []string `json:"answers"` | |||
Completion string `json:"completion"` | |||
Model string `json:"model"` | |||
Object string `json:"object"` | |||
SearchModel string `json:"search_model"` | |||
SelectedDocuments []struct { | |||
Document int `json:"document"` | |||
Text string `json:"text"` | |||
} `json:"selected_documents"` | |||
} | |||
// Search — perform a semantic search api call over a list of documents. | |||
func (c *Client) Answers(ctx context.Context, request AnswerRequest) (response AnswerResponse, err error) { | |||
var reqBytes []byte | |||
reqBytes, err = json.Marshal(request) | |||
if err != nil { | |||
return | |||
} | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/answers"), bytes.NewBuffer(reqBytes)) | |||
if err != nil { | |||
return | |||
} | |||
err = c.sendRequest(req, &response) | |||
return | |||
} |
@@ -1,6 +1,7 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"net/http" | |||
@@ -9,17 +10,22 @@ import ( | |||
// Client is OpenAI GPT-3 API client. | |||
type Client struct { | |||
config ClientConfig | |||
requestBuilder requestBuilder | |||
} | |||
// NewClient creates new OpenAI API client. | |||
func NewClient(authToken string) *Client { | |||
config := DefaultConfig(authToken) | |||
return &Client{config} | |||
return NewClientWithConfig(config) | |||
} | |||
// NewClientWithConfig creates new OpenAI API client for specified config. | |||
func NewClientWithConfig(config ClientConfig) *Client { | |||
return &Client{config} | |||
return &Client{ | |||
config: config, | |||
requestBuilder: newRequestBuilder(), | |||
} | |||
} | |||
// NewOrgClient creates new OpenAI API client for specified Organization ID. | |||
@@ -28,7 +34,7 @@ func NewClientWithConfig(config ClientConfig) *Client { | |||
func NewOrgClient(authToken, org string) *Client { | |||
config := DefaultConfig(authToken) | |||
config.OrgID = org | |||
return &Client{config} | |||
return NewClientWithConfig(config) | |||
} | |||
func (c *Client) sendRequest(req *http.Request, v interface{}) error { | |||
@@ -68,7 +74,7 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { | |||
} | |||
if v != nil { | |||
if err = json.NewDecoder(res.Body).Decode(&v); err != nil { | |||
if err = json.NewDecoder(res.Body).Decode(v); err != nil { | |||
return err | |||
} | |||
} | |||
@@ -79,3 +85,22 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { | |||
func (c *Client) fullURL(suffix string) string { | |||
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) | |||
} | |||
func (c *Client) newStreamRequest( | |||
ctx context.Context, | |||
method string, | |||
urlSuffix string, | |||
body any) (*http.Request, error) { | |||
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body) | |||
if err != nil { | |||
return nil, err | |||
} | |||
req.Header.Set("Content-Type", "application/json") | |||
req.Header.Set("Accept", "text/event-stream") | |||
req.Header.Set("Cache-Control", "no-cache") | |||
req.Header.Set("Connection", "keep-alive") | |||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) | |||
return req, nil | |||
} |
@@ -0,0 +1,182 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"context" | |||
"errors" | |||
"io" | |||
"os" | |||
"testing" | |||
) | |||
func TestAPI(t *testing.T) { | |||
apiToken := os.Getenv("OPENAI_TOKEN") | |||
if apiToken == "" { | |||
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") | |||
} | |||
var err error | |||
c := NewClient(apiToken) | |||
ctx := context.Background() | |||
_, err = c.ListEngines(ctx) | |||
if err != nil { | |||
t.Fatalf("ListEngines error: %v", err) | |||
} | |||
_, err = c.GetEngine(ctx, "davinci") | |||
if err != nil { | |||
t.Fatalf("GetEngine error: %v", err) | |||
} | |||
fileRes, err := c.ListFiles(ctx) | |||
if err != nil { | |||
t.Fatalf("ListFiles error: %v", err) | |||
} | |||
if len(fileRes.Files) > 0 { | |||
_, err = c.GetFile(ctx, fileRes.Files[0].ID) | |||
if err != nil { | |||
t.Fatalf("GetFile error: %v", err) | |||
} | |||
} // else skip | |||
embeddingReq := EmbeddingRequest{ | |||
Input: []string{ | |||
"The food was delicious and the waiter", | |||
"Other examples of embedding request", | |||
}, | |||
Model: AdaSearchQuery, | |||
} | |||
_, err = c.CreateEmbeddings(ctx, embeddingReq) | |||
if err != nil { | |||
t.Fatalf("Embedding error: %v", err) | |||
} | |||
_, err = c.CreateChatCompletion( | |||
ctx, | |||
ChatCompletionRequest{ | |||
Model: GPT3Dot5Turbo, | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
}, | |||
) | |||
if err != nil { | |||
t.Errorf("CreateChatCompletion (without name) returned error: %v", err) | |||
} | |||
_, err = c.CreateChatCompletion( | |||
ctx, | |||
ChatCompletionRequest{ | |||
Model: GPT3Dot5Turbo, | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Name: "John_Doe", | |||
Content: "Hello!", | |||
}, | |||
}, | |||
}, | |||
) | |||
if err != nil { | |||
t.Errorf("CreateChatCompletion (with name) returned error: %v", err) | |||
} | |||
stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ | |||
Prompt: "Ex falso quodlibet", | |||
Model: GPT3Ada, | |||
MaxTokens: 5, | |||
Stream: true, | |||
}) | |||
if err != nil { | |||
t.Errorf("CreateCompletionStream returned error: %v", err) | |||
} | |||
defer stream.Close() | |||
counter := 0 | |||
for { | |||
_, err = stream.Recv() | |||
if err != nil { | |||
if errors.Is(err, io.EOF) { | |||
break | |||
} | |||
t.Errorf("Stream error: %v", err) | |||
} else { | |||
counter++ | |||
} | |||
} | |||
if counter == 0 { | |||
t.Error("Stream did not return any responses") | |||
} | |||
} | |||
func TestAPIError(t *testing.T) { | |||
apiToken := os.Getenv("OPENAI_TOKEN") | |||
if apiToken == "" { | |||
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") | |||
} | |||
var err error | |||
c := NewClient(apiToken + "_invalid") | |||
ctx := context.Background() | |||
_, err = c.ListEngines(ctx) | |||
if err == nil { | |||
t.Fatal("ListEngines did not fail") | |||
} | |||
var apiErr *APIError | |||
if !errors.As(err, &apiErr) { | |||
t.Fatalf("Error is not an APIError: %+v", err) | |||
} | |||
if apiErr.StatusCode != 401 { | |||
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode) | |||
} | |||
if *apiErr.Code != "invalid_api_key" { | |||
t.Fatalf("Unexpected API error code: %s", *apiErr.Code) | |||
} | |||
if apiErr.Error() == "" { | |||
t.Fatal("Empty error message occured") | |||
} | |||
} | |||
func TestRequestError(t *testing.T) { | |||
var err error | |||
config := DefaultConfig("dummy") | |||
config.BaseURL = "https://httpbin.org/status/418?" | |||
c := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
_, err = c.ListEngines(ctx) | |||
if err == nil { | |||
t.Fatal("ListEngines request did not fail") | |||
} | |||
var reqErr *RequestError | |||
if !errors.As(err, &reqErr) { | |||
t.Fatalf("Error is not a RequestError: %+v", err) | |||
} | |||
if reqErr.StatusCode != 418 { | |||
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode) | |||
} | |||
if reqErr.Unwrap() == nil { | |||
t.Fatalf("Empty request error occured") | |||
} | |||
} | |||
// numTokens Returns the number of GPT-3 encoded tokens in the given text. | |||
// This function approximates based on the rule of thumb stated by OpenAI: | |||
// https://beta.openai.com/tokenizer | |||
// | |||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) | |||
func numTokens(s string) int { | |||
return int(float32(len(s)) / 4) | |||
} |
@@ -1,4 +1,4 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"bytes" | |||
@@ -0,0 +1,143 @@ | |||
package openai_test | |||
import ( | |||
"bytes" | |||
"errors" | |||
"io" | |||
"mime" | |||
"mime/multipart" | |||
"net/http" | |||
"os" | |||
"path/filepath" | |||
"strings" | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"testing" | |||
) | |||
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. | |||
func TestAudio(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) | |||
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
testcases := []struct { | |||
name string | |||
createFn func(context.Context, AudioRequest) (AudioResponse, error) | |||
}{ | |||
{ | |||
"transcribe", | |||
client.CreateTranscription, | |||
}, | |||
{ | |||
"translate", | |||
client.CreateTranslation, | |||
}, | |||
} | |||
ctx := context.Background() | |||
dir, cleanup := createTestDirectory(t) | |||
defer cleanup() | |||
for _, tc := range testcases { | |||
t.Run(tc.name, func(t *testing.T) { | |||
path := filepath.Join(dir, "fake.mp3") | |||
createTestFile(t, path) | |||
req := AudioRequest{ | |||
FilePath: path, | |||
Model: "whisper-3", | |||
} | |||
_, err = tc.createFn(ctx, req) | |||
if err != nil { | |||
t.Fatalf("audio API error: %v", err) | |||
} | |||
}) | |||
} | |||
} | |||
// createTestFile creates a fake file with "hello" as the content. | |||
func createTestFile(t *testing.T, path string) { | |||
file, err := os.Create(path) | |||
if err != nil { | |||
t.Fatalf("failed to create file %v", err) | |||
} | |||
if _, err = file.WriteString("hello"); err != nil { | |||
t.Fatalf("failed to write to file %v", err) | |||
} | |||
file.Close() | |||
} | |||
// createTestDirectory creates a temporary folder which will be deleted when cleanup is called. | |||
func createTestDirectory(t *testing.T) (path string, cleanup func()) { | |||
t.Helper() | |||
path, err := os.MkdirTemp(os.TempDir(), "") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
return path, func() { os.RemoveAll(path) } | |||
} | |||
// handleAudioEndpoint Handles the completion endpoint by the test server. | |||
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
// audio endpoints only accept POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) | |||
if err != nil { | |||
http.Error(w, "failed to parse media type", http.StatusBadRequest) | |||
return | |||
} | |||
if !strings.HasPrefix(mediaType, "multipart") { | |||
http.Error(w, "request is not multipart", http.StatusBadRequest) | |||
} | |||
boundary, ok := params["boundary"] | |||
if !ok { | |||
http.Error(w, "no boundary in params", http.StatusBadRequest) | |||
return | |||
} | |||
fileData := &bytes.Buffer{} | |||
mr := multipart.NewReader(r.Body, boundary) | |||
part, err := mr.NextPart() | |||
if err != nil && errors.Is(err, io.EOF) { | |||
http.Error(w, "error accessing file", http.StatusBadRequest) | |||
return | |||
} | |||
if _, err = io.Copy(fileData, part); err != nil { | |||
http.Error(w, "failed to copy file", http.StatusInternalServerError) | |||
return | |||
} | |||
if len(fileData.Bytes()) == 0 { | |||
w.WriteHeader(http.StatusInternalServerError) | |||
http.Error(w, "received empty file data", http.StatusBadRequest) | |||
return | |||
} | |||
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { | |||
http.Error(w, "failed to write body", http.StatusInternalServerError) | |||
return | |||
} | |||
} |
@@ -1,15 +1,18 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"bufio" | |||
"bytes" | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"fmt" | |||
"net/http" | |||
) | |||
// Chat message role defined by the OpenAI API. | |||
const ( | |||
ChatMessageRoleSystem = "system" | |||
ChatMessageRoleUser = "user" | |||
ChatMessageRoleAssistant = "assistant" | |||
) | |||
var ( | |||
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") | |||
) | |||
@@ -17,6 +20,12 @@ var ( | |||
type ChatCompletionMessage struct { | |||
Role string `json:"role"` | |||
Content string `json:"content"` | |||
// This property isn't in the official documentation, but it's in | |||
// the documentation for the official library for python: | |||
// - https://github.com/openai/openai-python/blob/main/chatml.md | |||
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||
Name string `json:"name,omitempty"` | |||
} | |||
// ChatCompletionRequest represents a request structure for chat completion API. | |||
@@ -41,24 +50,6 @@ type ChatCompletionChoice struct { | |||
FinishReason string `json:"finish_reason"` | |||
} | |||
type ChatCompletionChoiceForStream struct { | |||
Index int `json:"index"` | |||
Delta struct { | |||
Content string `json:"content"` | |||
} `json:"delta"` | |||
FinishReason string `json:"finish_reason"` | |||
} | |||
// ChatCompletionResponseForStream represents a response structure for chat completion API. | |||
type ChatCompletionResponseForStream struct { | |||
ID string `json:"id"` | |||
Object string `json:"object"` | |||
Created int64 `json:"created"` | |||
Model string `json:"model"` | |||
Choices []ChatCompletionChoiceForStream `json:"choices"` | |||
Usage Usage `json:"usage"` | |||
} | |||
// ChatCompletionResponse represents a response structure for chat completion API. | |||
type ChatCompletionResponse struct { | |||
ID string `json:"id"` | |||
@@ -69,25 +60,21 @@ type ChatCompletionResponse struct { | |||
Usage Usage `json:"usage"` | |||
} | |||
// CreateChatCompletion — API call to Creates a completion for the chat message. | |||
// CreateChatCompletion — API call to Create a completion for the chat message. | |||
func (c *Client) CreateChatCompletion( | |||
ctx context.Context, | |||
request ChatCompletionRequest, | |||
) (response ChatCompletionResponse, err error) { | |||
model := request.Model | |||
if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo { | |||
switch model { | |||
case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K: | |||
default: | |||
err = ErrChatCompletionInvalidModel | |||
return | |||
} | |||
var reqBytes []byte | |||
reqBytes, err = json.Marshal(request) | |||
if err != nil { | |||
return | |||
} | |||
urlSuffix := "/chat/completions" | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) | |||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) | |||
if err != nil { | |||
return | |||
} | |||
@@ -95,44 +82,3 @@ func (c *Client) CreateChatCompletion( | |||
err = c.sendRequest(req, &response) | |||
return | |||
} | |||
// CreateChatCompletionStream — API call to create a completion w/ streaming | |||
func (c *Client) CreateChatCompletionStream( | |||
ctx context.Context, | |||
request ChatCompletionRequest, | |||
) (stream *CompletionStream, err error) { | |||
model := request.Model | |||
if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo { | |||
err = ErrChatCompletionInvalidModel | |||
return | |||
} | |||
request.Stream = true | |||
reqBytes, err := json.Marshal(request) | |||
if err != nil { | |||
return | |||
} | |||
urlSuffix := "/chat/completions" | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) | |||
req.Header.Set("Content-Type", "application/json") | |||
req.Header.Set("Accept", "text/event-stream") | |||
req.Header.Set("Cache-Control", "no-cache") | |||
req.Header.Set("Connection", "keep-alive") | |||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) | |||
if err != nil { | |||
return | |||
} | |||
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() | |||
if err != nil { | |||
return | |||
} | |||
stream = &CompletionStream{ | |||
emptyMessagesLimit: c.config.EmptyMessagesLimit, | |||
reader: bufio.NewReader(resp.Body), | |||
response: resp, | |||
} | |||
return | |||
} |
@@ -0,0 +1,106 @@ | |||
package openai | |||
import ( | |||
"bufio" | |||
"bytes" | |||
"context" | |||
"encoding/json" | |||
"io" | |||
"net/http" | |||
) | |||
type ChatCompletionStreamChoiceDelta struct { | |||
Content string `json:"content"` | |||
} | |||
type ChatCompletionStreamChoice struct { | |||
Index int `json:"index"` | |||
Delta ChatCompletionStreamChoiceDelta `json:"delta"` | |||
FinishReason string `json:"finish_reason"` | |||
} | |||
type ChatCompletionStreamResponse struct { | |||
ID string `json:"id"` | |||
Object string `json:"object"` | |||
Created int64 `json:"created"` | |||
Model string `json:"model"` | |||
Choices []ChatCompletionStreamChoice `json:"choices"` | |||
} | |||
// ChatCompletionStream | |||
// Note: Perhaps it is more elegant to abstract Stream using generics. | |||
type ChatCompletionStream struct { | |||
emptyMessagesLimit uint | |||
isFinished bool | |||
reader *bufio.Reader | |||
response *http.Response | |||
} | |||
func (stream *ChatCompletionStream) Recv() (response ChatCompletionStreamResponse, err error) { | |||
if stream.isFinished { | |||
err = io.EOF | |||
return | |||
} | |||
var emptyMessagesCount uint | |||
waitForData: | |||
line, err := stream.reader.ReadBytes('\n') | |||
if err != nil { | |||
return | |||
} | |||
var headerData = []byte("data: ") | |||
line = bytes.TrimSpace(line) | |||
if !bytes.HasPrefix(line, headerData) { | |||
emptyMessagesCount++ | |||
if emptyMessagesCount > stream.emptyMessagesLimit { | |||
err = ErrTooManyEmptyStreamMessages | |||
return | |||
} | |||
goto waitForData | |||
} | |||
line = bytes.TrimPrefix(line, headerData) | |||
if string(line) == "[DONE]" { | |||
stream.isFinished = true | |||
err = io.EOF | |||
return | |||
} | |||
err = json.Unmarshal(line, &response) | |||
return | |||
} | |||
func (stream *ChatCompletionStream) Close() { | |||
stream.response.Body.Close() | |||
} | |||
// CreateChatCompletionStream — API call to create a chat completion w/ streaming | |||
// support. It sets whether to stream back partial progress. If set, tokens will be | |||
// sent as data-only server-sent events as they become available, with the | |||
// stream terminated by a data: [DONE] message. | |||
func (c *Client) CreateChatCompletionStream( | |||
ctx context.Context, | |||
request ChatCompletionRequest, | |||
) (stream *ChatCompletionStream, err error) { | |||
request.Stream = true | |||
req, err := c.newStreamRequest(ctx, "POST", "/chat/completions", request) | |||
if err != nil { | |||
return | |||
} | |||
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() | |||
if err != nil { | |||
return | |||
} | |||
stream = &ChatCompletionStream{ | |||
emptyMessagesLimit: c.config.EmptyMessagesLimit, | |||
reader: bufio.NewReader(resp.Body), | |||
response: resp, | |||
} | |||
return | |||
} |
@@ -0,0 +1,153 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"io" | |||
"net/http" | |||
"net/http/httptest" | |||
"testing" | |||
) | |||
func TestCreateChatCompletionStream(t *testing.T) { | |||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
w.Header().Set("Content-Type", "text/event-stream") | |||
// Send test responses | |||
dataBytes := []byte{} | |||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||
//nolint:lll | |||
data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` | |||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||
//nolint:lll | |||
data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` | |||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||
dataBytes = append(dataBytes, []byte("event: done\n")...) | |||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) | |||
_, err := w.Write(dataBytes) | |||
if err != nil { | |||
t.Errorf("Write error: %s", err) | |||
} | |||
})) | |||
defer server.Close() | |||
// Client portion of the test | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = server.URL + "/v1" | |||
config.HTTPClient.Transport = &tokenRoundTripper{ | |||
test.GetTestToken(), | |||
http.DefaultTransport, | |||
} | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
request := ChatCompletionRequest{ | |||
MaxTokens: 5, | |||
Model: GPT3Dot5Turbo, | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
Stream: true, | |||
} | |||
stream, err := client.CreateChatCompletionStream(ctx, request) | |||
if err != nil { | |||
t.Errorf("CreateCompletionStream returned error: %v", err) | |||
} | |||
defer stream.Close() | |||
expectedResponses := []ChatCompletionStreamResponse{ | |||
{ | |||
ID: "1", | |||
Object: "completion", | |||
Created: 1598069254, | |||
Model: GPT3Dot5Turbo, | |||
Choices: []ChatCompletionStreamChoice{ | |||
{ | |||
Delta: ChatCompletionStreamChoiceDelta{ | |||
Content: "response1", | |||
}, | |||
FinishReason: "max_tokens", | |||
}, | |||
}, | |||
}, | |||
{ | |||
ID: "2", | |||
Object: "completion", | |||
Created: 1598069255, | |||
Model: GPT3Dot5Turbo, | |||
Choices: []ChatCompletionStreamChoice{ | |||
{ | |||
Delta: ChatCompletionStreamChoiceDelta{ | |||
Content: "response2", | |||
}, | |||
FinishReason: "max_tokens", | |||
}, | |||
}, | |||
}, | |||
} | |||
for ix, expectedResponse := range expectedResponses { | |||
b, _ := json.Marshal(expectedResponse) | |||
t.Logf("%d: %s", ix, string(b)) | |||
receivedResponse, streamErr := stream.Recv() | |||
if streamErr != nil { | |||
t.Errorf("stream.Recv() failed: %v", streamErr) | |||
} | |||
if !compareChatResponses(expectedResponse, receivedResponse) { | |||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) | |||
} | |||
} | |||
_, streamErr := stream.Recv() | |||
if !errors.Is(streamErr, io.EOF) { | |||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) | |||
} | |||
_, streamErr = stream.Recv() | |||
if !errors.Is(streamErr, io.EOF) { | |||
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) | |||
} | |||
} | |||
// Helper funcs. | |||
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { | |||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { | |||
return false | |||
} | |||
if len(r1.Choices) != len(r2.Choices) { | |||
return false | |||
} | |||
for i := range r1.Choices { | |||
if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) { | |||
return false | |||
} | |||
} | |||
return true | |||
} | |||
func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { | |||
if c1.Index != c2.Index { | |||
return false | |||
} | |||
if c1.Delta.Content != c2.Delta.Content { | |||
return false | |||
} | |||
if c1.FinishReason != c2.FinishReason { | |||
return false | |||
} | |||
return true | |||
} |
@@ -0,0 +1,132 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"strconv" | |||
"strings" | |||
"testing" | |||
"time" | |||
) | |||
func TestChatCompletionsWrongModel(t *testing.T) { | |||
config := DefaultConfig("whatever") | |||
config.BaseURL = "http://localhost/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := ChatCompletionRequest{ | |||
MaxTokens: 5, | |||
Model: "ada", | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
} | |||
_, err := client.CreateChatCompletion(ctx, req) | |||
if !errors.Is(err, ErrChatCompletionInvalidModel) { | |||
t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err) | |||
} | |||
} | |||
// TestCompletions Tests the completions endpoint of the API using the mocked server. | |||
func TestChatCompletions(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := ChatCompletionRequest{ | |||
MaxTokens: 5, | |||
Model: GPT3Dot5Turbo, | |||
Messages: []ChatCompletionMessage{ | |||
{ | |||
Role: ChatMessageRoleUser, | |||
Content: "Hello!", | |||
}, | |||
}, | |||
} | |||
_, err = client.CreateChatCompletion(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateChatCompletion error: %v", err) | |||
} | |||
} | |||
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. | |||
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// completions only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var completionReq ChatCompletionRequest | |||
if completionReq, err = getChatCompletionBody(r); err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
res := ChatCompletionResponse{ | |||
ID: strconv.Itoa(int(time.Now().Unix())), | |||
Object: "test-object", | |||
Created: time.Now().Unix(), | |||
// would be nice to validate Model during testing, but | |||
// this may not be possible with how much upkeep | |||
// would be required / wouldn't make much sense | |||
Model: completionReq.Model, | |||
} | |||
// create completions | |||
for i := 0; i < completionReq.N; i++ { | |||
// generate a random string of length completionReq.Length | |||
completionStr := strings.Repeat("a", completionReq.MaxTokens) | |||
res.Choices = append(res.Choices, ChatCompletionChoice{ | |||
Message: ChatCompletionMessage{ | |||
Role: ChatMessageRoleAssistant, | |||
Content: completionStr, | |||
}, | |||
Index: i, | |||
}) | |||
} | |||
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N | |||
completionTokens := completionReq.MaxTokens * completionReq.N | |||
res.Usage = Usage{ | |||
PromptTokens: inputTokens, | |||
CompletionTokens: completionTokens, | |||
TotalTokens: inputTokens + completionTokens, | |||
} | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
// getChatCompletionBody Returns the body of the request to create a completion. | |||
func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { | |||
completion := ChatCompletionRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return ChatCompletionRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &completion) | |||
if err != nil { | |||
return ChatCompletionRequest{}, err | |||
} | |||
return completion, nil | |||
} |
@@ -1,5 +1,5 @@ | |||
// common.go defines common types used throughout the OpenAI API. | |||
package gogpt | |||
package openai | |||
// Usage Represents the total token usage per request to OpenAI. | |||
type Usage struct { | |||
@@ -1,17 +1,24 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"bytes" | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"net/http" | |||
) | |||
var ( | |||
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll | |||
) | |||
// GPT3 Defines the models provided by OpenAI to use when generating | |||
// completions from OpenAI. | |||
// GPT3 Models are designed for text-based tasks. For code-specific | |||
// tasks, please refer to the Codex series of models. | |||
const ( | |||
GPT432K0314 = "gpt-4-32k-0314" | |||
GPT432K = "gpt-4-32k" | |||
GPT40314 = "gpt-4-0314" | |||
GPT4 = "gpt-4" | |||
GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" | |||
GPT3Dot5Turbo = "gpt-3.5-turbo" | |||
GPT3TextDavinci003 = "text-davinci-003" | |||
@@ -92,14 +99,13 @@ func (c *Client) CreateCompletion( | |||
ctx context.Context, | |||
request CompletionRequest, | |||
) (response CompletionResponse, err error) { | |||
var reqBytes []byte | |||
reqBytes, err = json.Marshal(request) | |||
if err != nil { | |||
if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo { | |||
err = ErrCompletionUnsupportedModel | |||
return | |||
} | |||
urlSuffix := "/completions" | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) | |||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) | |||
if err != nil { | |||
return | |||
} | |||
@@ -0,0 +1,121 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"strconv" | |||
"strings" | |||
"testing" | |||
"time" | |||
) | |||
func TestCompletionsWrongModel(t *testing.T) { | |||
config := DefaultConfig("whatever") | |||
config.BaseURL = "http://localhost/v1" | |||
client := NewClientWithConfig(config) | |||
_, err := client.CreateCompletion( | |||
context.Background(), | |||
CompletionRequest{ | |||
MaxTokens: 5, | |||
Model: GPT3Dot5Turbo, | |||
}, | |||
) | |||
if !errors.Is(err, ErrCompletionUnsupportedModel) { | |||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) | |||
} | |||
} | |||
// TestCompletions Tests the completions endpoint of the API using the mocked server. | |||
func TestCompletions(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/completions", handleCompletionEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := CompletionRequest{ | |||
MaxTokens: 5, | |||
Model: "ada", | |||
} | |||
req.Prompt = "Lorem ipsum" | |||
_, err = client.CreateCompletion(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateCompletion error: %v", err) | |||
} | |||
} | |||
// handleCompletionEndpoint Handles the completion endpoint by the test server. | |||
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// completions only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var completionReq CompletionRequest | |||
if completionReq, err = getCompletionBody(r); err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
res := CompletionResponse{ | |||
ID: strconv.Itoa(int(time.Now().Unix())), | |||
Object: "test-object", | |||
Created: time.Now().Unix(), | |||
// would be nice to validate Model during testing, but | |||
// this may not be possible with how much upkeep | |||
// would be required / wouldn't make much sense | |||
Model: completionReq.Model, | |||
} | |||
// create completions | |||
for i := 0; i < completionReq.N; i++ { | |||
// generate a random string of length completionReq.Length | |||
completionStr := strings.Repeat("a", completionReq.MaxTokens) | |||
if completionReq.Echo { | |||
completionStr = completionReq.Prompt + completionStr | |||
} | |||
res.Choices = append(res.Choices, CompletionChoice{ | |||
Text: completionStr, | |||
Index: i, | |||
}) | |||
} | |||
inputTokens := numTokens(completionReq.Prompt) * completionReq.N | |||
completionTokens := completionReq.MaxTokens * completionReq.N | |||
res.Usage = Usage{ | |||
PromptTokens: inputTokens, | |||
CompletionTokens: completionTokens, | |||
TotalTokens: inputTokens + completionTokens, | |||
} | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
// getCompletionBody Returns the body of the request to create a completion. | |||
func getCompletionBody(r *http.Request) (CompletionRequest, error) { | |||
completion := CompletionRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return CompletionRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &completion) | |||
if err != nil { | |||
return CompletionRequest{}, err | |||
} | |||
return completion, nil | |||
} |
@@ -1,11 +1,11 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"net/http" | |||
) | |||
const ( | |||
apiURLv1 = "http://chatgpt.zhiyinos.cn/" | |||
apiURLv1 = "https://api.openai.com/v1" | |||
defaultEmptyMessagesLimit uint = 300 | |||
) | |||
@@ -1,9 +1,7 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"bytes" | |||
"context" | |||
"encoding/json" | |||
"net/http" | |||
) | |||
@@ -33,13 +31,7 @@ type EditsResponse struct { | |||
// Perform an API call to the Edits endpoint. | |||
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { | |||
var reqBytes []byte | |||
reqBytes, err = json.Marshal(request) | |||
if err != nil { | |||
return | |||
} | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/edits"), bytes.NewBuffer(reqBytes)) | |||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request) | |||
if err != nil { | |||
return | |||
} | |||
@@ -0,0 +1,104 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"testing" | |||
"time" | |||
) | |||
// TestEdits Tests the edits endpoint of the API using the mocked server. | |||
func TestEdits(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/edits", handleEditEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
// create an edit request | |||
model := "ada" | |||
editReq := EditsRequest{ | |||
Model: &model, | |||
Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + | |||
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + | |||
" ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + | |||
" ex ea commodo consequat. Duis aute irure dolor in reprehe", | |||
Instruction: "test instruction", | |||
N: 3, | |||
} | |||
response, err := client.Edits(ctx, editReq) | |||
if err != nil { | |||
t.Fatalf("Edits error: %v", err) | |||
} | |||
if len(response.Choices) != editReq.N { | |||
t.Fatalf("edits does not properly return the correct number of choices") | |||
} | |||
} | |||
// handleEditEndpoint Handles the edit endpoint by the test server. | |||
func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// edits only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var editReq EditsRequest | |||
editReq, err = getEditBody(r) | |||
if err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
// create a response | |||
res := EditsResponse{ | |||
Object: "test-object", | |||
Created: time.Now().Unix(), | |||
} | |||
// edit and calculate token usage | |||
editString := "edited by mocked OpenAI server :)" | |||
inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N | |||
completionTokens := int(float32(len(editString))/4) * editReq.N | |||
for i := 0; i < editReq.N; i++ { | |||
// instruction will be hidden and only seen by OpenAI | |||
res.Choices = append(res.Choices, EditsChoice{ | |||
Text: editReq.Input + editString, | |||
Index: i, | |||
}) | |||
} | |||
res.Usage = Usage{ | |||
PromptTokens: inputTokens, | |||
CompletionTokens: completionTokens, | |||
TotalTokens: inputTokens + completionTokens, | |||
} | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprint(w, string(resBytes)) | |||
} | |||
// getEditBody Returns the body of the request to create an edit. | |||
func getEditBody(r *http.Request) (EditsRequest, error) { | |||
edit := EditsRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return EditsRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &edit) | |||
if err != nil { | |||
return EditsRequest{}, err | |||
} | |||
return edit, nil | |||
} |
@@ -1,9 +1,7 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"bytes" | |||
"context" | |||
"encoding/json" | |||
"net/http" | |||
) | |||
@@ -103,7 +101,7 @@ var stringToEnum = map[string]EmbeddingModel{ | |||
// then their vector representations should also be similar. | |||
type Embedding struct { | |||
Object string `json:"object"` | |||
Embedding []float64 `json:"embedding"` | |||
Embedding []float32 `json:"embedding"` | |||
Index int `json:"index"` | |||
} | |||
@@ -134,14 +132,7 @@ type EmbeddingRequest struct { | |||
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. | |||
// https://beta.openai.com/docs/api-reference/embeddings/create | |||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { | |||
var reqBytes []byte | |||
reqBytes, err = json.Marshal(request) | |||
if err != nil { | |||
return | |||
} | |||
urlSuffix := "/embeddings" | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) | |||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) | |||
if err != nil { | |||
return | |||
} | |||
@@ -0,0 +1,48 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"bytes" | |||
"encoding/json" | |||
"testing" | |||
) | |||
func TestEmbedding(t *testing.T) { | |||
embeddedModels := []EmbeddingModel{ | |||
AdaSimilarity, | |||
BabbageSimilarity, | |||
CurieSimilarity, | |||
DavinciSimilarity, | |||
AdaSearchDocument, | |||
AdaSearchQuery, | |||
BabbageSearchDocument, | |||
BabbageSearchQuery, | |||
CurieSearchDocument, | |||
CurieSearchQuery, | |||
DavinciSearchDocument, | |||
DavinciSearchQuery, | |||
AdaCodeSearchCode, | |||
AdaCodeSearchText, | |||
BabbageCodeSearchCode, | |||
BabbageCodeSearchText, | |||
} | |||
for _, model := range embeddedModels { | |||
embeddingReq := EmbeddingRequest{ | |||
Input: []string{ | |||
"The food was delicious and the waiter", | |||
"Other examples of embedding request", | |||
}, | |||
Model: model, | |||
} | |||
// marshal embeddingReq to JSON and confirm that the model field equals | |||
// the AdaSearchQuery type | |||
marshaled, err := json.Marshal(embeddingReq) | |||
if err != nil { | |||
t.Fatalf("Could not marshal embedding request: %v", err) | |||
} | |||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { | |||
t.Fatalf("Expected embedding request to contain model field") | |||
} | |||
} | |||
} |
@@ -1,4 +1,4 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"context" | |||
@@ -22,7 +22,7 @@ type EnginesList struct { | |||
// ListEngines Lists the currently available engines, and provides basic | |||
// information about each option such as the owner and availability. | |||
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { | |||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/engines"), nil) | |||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil) | |||
if err != nil { | |||
return | |||
} | |||
@@ -38,7 +38,7 @@ func (c *Client) GetEngine( | |||
engineID string, | |||
) (engine Engine, err error) { | |||
urlSuffix := fmt.Sprintf("/engines/%s", engineID) | |||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) | |||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) | |||
if err != nil { | |||
return | |||
} | |||
@@ -1,4 +1,4 @@ | |||
package gogpt | |||
package openai | |||
import "fmt" | |||
@@ -1,4 +1,4 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"bytes" | |||
@@ -112,7 +112,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File | |||
// DeleteFile deletes an existing file. | |||
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { | |||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) | |||
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) | |||
if err != nil { | |||
return | |||
} | |||
@@ -124,7 +124,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { | |||
// ListFiles Lists the currently available files, | |||
// and provides basic information about each file such as the file name and purpose. | |||
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { | |||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/files"), nil) | |||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil) | |||
if err != nil { | |||
return | |||
} | |||
@@ -137,7 +137,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { | |||
// such as the file name and purpose. | |||
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { | |||
urlSuffix := fmt.Sprintf("/files/%s", fileID) | |||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) | |||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) | |||
if err != nil { | |||
return | |||
} | |||
@@ -0,0 +1,81 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"net/http" | |||
"strconv" | |||
"testing" | |||
"time" | |||
) | |||
func TestFileUpload(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/files", handleCreateFile) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := FileRequest{ | |||
FileName: "test.go", | |||
FilePath: "api.go", | |||
Purpose: "fine-tune", | |||
} | |||
_, err = client.CreateFile(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateFile error: %v", err) | |||
} | |||
} | |||
// handleCreateFile Handles the images endpoint by the test server. | |||
func handleCreateFile(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// edits only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
err = r.ParseMultipartForm(1024 * 1024 * 1024) | |||
if err != nil { | |||
http.Error(w, "file is more than 1GB", http.StatusInternalServerError) | |||
return | |||
} | |||
values := r.Form | |||
var purpose string | |||
for key, value := range values { | |||
if key == "purpose" { | |||
purpose = value[0] | |||
} | |||
} | |||
file, header, err := r.FormFile("file") | |||
if err != nil { | |||
return | |||
} | |||
defer file.Close() | |||
var fileReq = File{ | |||
Bytes: int(header.Size), | |||
ID: strconv.Itoa(int(time.Now().Unix())), | |||
FileName: header.Filename, | |||
Purpose: purpose, | |||
CreatedAt: time.Now().Unix(), | |||
Object: "test-objecct", | |||
Owner: "test-owner", | |||
} | |||
resBytes, _ = json.Marshal(fileReq) | |||
fmt.Fprint(w, string(resBytes)) | |||
} |
@@ -0,0 +1,130 @@ | |||
package openai | |||
import ( | |||
"context" | |||
"fmt" | |||
"net/http" | |||
) | |||
type FineTuneRequest struct { | |||
TrainingFile string `json:"training_file"` | |||
ValidationFile string `json:"validation_file,omitempty"` | |||
Model string `json:"model,omitempty"` | |||
Epochs int `json:"n_epochs,omitempty"` | |||
BatchSize int `json:"batch_size,omitempty"` | |||
LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"` | |||
PromptLossRate float32 `json:"prompt_loss_rate,omitempty"` | |||
ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"` | |||
ClassificationClasses int `json:"classification_n_classes,omitempty"` | |||
ClassificationPositiveClass string `json:"classification_positive_class,omitempty"` | |||
ClassificationBetas []float32 `json:"classification_betas,omitempty"` | |||
Suffix string `json:"suffix,omitempty"` | |||
} | |||
type FineTune struct { | |||
ID string `json:"id"` | |||
Object string `json:"object"` | |||
Model string `json:"model"` | |||
CreatedAt int64 `json:"created_at"` | |||
FineTuneEventList []FineTuneEvent `json:"events,omitempty"` | |||
FineTunedModel string `json:"fine_tuned_model"` | |||
HyperParams FineTuneHyperParams `json:"hyperparams"` | |||
OrganizationID string `json:"organization_id"` | |||
ResultFiles []File `json:"result_files"` | |||
Status string `json:"status"` | |||
ValidationFiles []File `json:"validation_files"` | |||
TrainingFiles []File `json:"training_files"` | |||
UpdatedAt int64 `json:"updated_at"` | |||
} | |||
type FineTuneEvent struct { | |||
Object string `json:"object"` | |||
CreatedAt int64 `json:"created_at"` | |||
Level string `json:"level"` | |||
Message string `json:"message"` | |||
} | |||
type FineTuneHyperParams struct { | |||
BatchSize int `json:"batch_size"` | |||
LearningRateMultiplier float64 `json:"learning_rate_multiplier"` | |||
Epochs int `json:"n_epochs"` | |||
PromptLossWeight float64 `json:"prompt_loss_weight"` | |||
} | |||
type FineTuneList struct { | |||
Object string `json:"object"` | |||
Data []FineTune `json:"data"` | |||
} | |||
type FineTuneEventList struct { | |||
Object string `json:"object"` | |||
Data []FineTuneEvent `json:"data"` | |||
} | |||
type FineTuneDeleteResponse struct { | |||
ID string `json:"id"` | |||
Object string `json:"object"` | |||
Deleted bool `json:"deleted"` | |||
} | |||
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { | |||
urlSuffix := "/fine-tunes" | |||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) | |||
if err != nil { | |||
return | |||
} | |||
err = c.sendRequest(req, &response) | |||
return | |||
} | |||
// CancelFineTune cancel a fine-tune job. | |||
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { | |||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) | |||
if err != nil { | |||
return | |||
} | |||
err = c.sendRequest(req, &response) | |||
return | |||
} | |||
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { | |||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) | |||
if err != nil { | |||
return | |||
} | |||
err = c.sendRequest(req, &response) | |||
return | |||
} | |||
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { | |||
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) | |||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) | |||
if err != nil { | |||
return | |||
} | |||
err = c.sendRequest(req, &response) | |||
return | |||
} | |||
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { | |||
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) | |||
if err != nil { | |||
return | |||
} | |||
err = c.sendRequest(req, &response) | |||
return | |||
} | |||
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { | |||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) | |||
if err != nil { | |||
return | |||
} | |||
err = c.sendRequest(req, &response) | |||
return | |||
} |
@@ -0,0 +1,101 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"net/http" | |||
"testing" | |||
) | |||
const testFineTuneID = "fine-tune-id" | |||
// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. | |||
func TestFineTunes(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler( | |||
"/v1/fine-tunes", | |||
func(w http.ResponseWriter, r *http.Request) { | |||
var resBytes []byte | |||
if r.Method == http.MethodGet { | |||
resBytes, _ = json.Marshal(FineTuneList{}) | |||
} else { | |||
resBytes, _ = json.Marshal(FineTune{}) | |||
} | |||
fmt.Fprintln(w, string(resBytes)) | |||
}, | |||
) | |||
server.RegisterHandler( | |||
"/v1/fine-tunes/"+testFineTuneID+"/cancel", | |||
func(w http.ResponseWriter, r *http.Request) { | |||
resBytes, _ := json.Marshal(FineTune{}) | |||
fmt.Fprintln(w, string(resBytes)) | |||
}, | |||
) | |||
server.RegisterHandler( | |||
"/v1/fine-tunes/"+testFineTuneID, | |||
func(w http.ResponseWriter, r *http.Request) { | |||
var resBytes []byte | |||
if r.Method == http.MethodDelete { | |||
resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) | |||
} else { | |||
resBytes, _ = json.Marshal(FineTune{}) | |||
} | |||
fmt.Fprintln(w, string(resBytes)) | |||
}, | |||
) | |||
server.RegisterHandler( | |||
"/v1/fine-tunes/"+testFineTuneID+"/events", | |||
func(w http.ResponseWriter, r *http.Request) { | |||
resBytes, _ := json.Marshal(FineTuneEventList{}) | |||
fmt.Fprintln(w, string(resBytes)) | |||
}, | |||
) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
_, err = client.ListFineTunes(ctx) | |||
if err != nil { | |||
t.Fatalf("ListFineTunes error: %v", err) | |||
} | |||
_, err = client.CreateFineTune(ctx, FineTuneRequest{}) | |||
if err != nil { | |||
t.Fatalf("CreateFineTune error: %v", err) | |||
} | |||
_, err = client.CancelFineTune(ctx, testFineTuneID) | |||
if err != nil { | |||
t.Fatalf("CancelFineTune error: %v", err) | |||
} | |||
_, err = client.GetFineTune(ctx, testFineTuneID) | |||
if err != nil { | |||
t.Fatalf("GetFineTune error: %v", err) | |||
} | |||
_, err = client.DeleteFineTune(ctx, testFineTuneID) | |||
if err != nil { | |||
t.Fatalf("DeleteFineTune error: %v", err) | |||
} | |||
_, err = client.ListFineTuneEvents(ctx, testFineTuneID) | |||
if err != nil { | |||
t.Fatalf("ListFineTuneEvents error: %v", err) | |||
} | |||
} |
@@ -0,0 +1,3 @@ | |||
module github.com/sashabaranov/go-openai | |||
go 1.18 |
@@ -1,9 +1,8 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"bytes" | |||
"context" | |||
"encoding/json" | |||
"io" | |||
"mime/multipart" | |||
"net/http" | |||
@@ -46,14 +45,8 @@ type ImageResponseDataInner struct { | |||
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. | |||
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { | |||
var reqBytes []byte | |||
reqBytes, err = json.Marshal(request) | |||
if err != nil { | |||
return | |||
} | |||
urlSuffix := "/images/generations" | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) | |||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) | |||
if err != nil { | |||
return | |||
} | |||
@@ -86,20 +79,65 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) | |||
return | |||
} | |||
// mask | |||
mask, err := writer.CreateFormFile("mask", request.Mask.Name()) | |||
// mask, it is optional | |||
if request.Mask != nil { | |||
mask, err2 := writer.CreateFormFile("mask", request.Mask.Name()) | |||
if err2 != nil { | |||
return | |||
} | |||
_, err = io.Copy(mask, request.Mask) | |||
if err != nil { | |||
return | |||
} | |||
} | |||
err = writer.WriteField("prompt", request.Prompt) | |||
if err != nil { | |||
return | |||
} | |||
_, err = io.Copy(mask, request.Mask) | |||
err = writer.WriteField("n", strconv.Itoa(request.N)) | |||
if err != nil { | |||
return | |||
} | |||
err = writer.WriteField("size", request.Size) | |||
if err != nil { | |||
return | |||
} | |||
writer.Close() | |||
urlSuffix := "/images/edits" | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) | |||
if err != nil { | |||
return | |||
} | |||
err = writer.WriteField("prompt", request.Prompt) | |||
req.Header.Set("Content-Type", writer.FormDataContentType()) | |||
err = c.sendRequest(req, &response) | |||
return | |||
} | |||
// ImageVariRequest represents the request structure for the image API. | |||
type ImageVariRequest struct { | |||
Image *os.File `json:"image,omitempty"` | |||
N int `json:"n,omitempty"` | |||
Size string `json:"size,omitempty"` | |||
} | |||
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. | |||
// Use abbreviations(vari for variation) because ci-lint has a single-line length limit ... | |||
func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) (response ImageResponse, err error) { | |||
body := &bytes.Buffer{} | |||
writer := multipart.NewWriter(body) | |||
// image | |||
image, err := writer.CreateFormFile("image", request.Image.Name()) | |||
if err != nil { | |||
return | |||
} | |||
_, err = io.Copy(image, request.Image) | |||
if err != nil { | |||
return | |||
} | |||
err = writer.WriteField("n", strconv.Itoa(request.N)) | |||
if err != nil { | |||
return | |||
@@ -109,7 +147,8 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) | |||
return | |||
} | |||
writer.Close() | |||
urlSuffix := "/images/edits" | |||
//https://platform.openai.com/docs/api-reference/images/create-variation | |||
urlSuffix := "/images/variations" | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) | |||
if err != nil { | |||
return | |||
@@ -0,0 +1,268 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"os" | |||
"testing" | |||
"time" | |||
) | |||
func TestImages(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/images/generations", handleImageEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
req := ImageRequest{} | |||
req.Prompt = "Lorem ipsum" | |||
_, err = client.CreateImage(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateImage error: %v", err) | |||
} | |||
} | |||
// handleImageEndpoint Handles the images endpoint by the test server. | |||
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// imagess only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var imageReq ImageRequest | |||
if imageReq, err = getImageBody(r); err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
res := ImageResponse{ | |||
Created: time.Now().Unix(), | |||
} | |||
for i := 0; i < imageReq.N; i++ { | |||
imageData := ImageResponseDataInner{} | |||
switch imageReq.ResponseFormat { | |||
case CreateImageResponseFormatURL, "": | |||
imageData.URL = "https://example.com/image.png" | |||
case CreateImageResponseFormatB64JSON: | |||
// This decodes to "{}" in base64. | |||
imageData.B64JSON = "e30K" | |||
default: | |||
http.Error(w, "invalid response format", http.StatusBadRequest) | |||
return | |||
} | |||
res.Data = append(res.Data, imageData) | |||
} | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
// getImageBody Returns the body of the request to create a image. | |||
func getImageBody(r *http.Request) (ImageRequest, error) { | |||
image := ImageRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return ImageRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &image) | |||
if err != nil { | |||
return ImageRequest{}, err | |||
} | |||
return image, nil | |||
} | |||
func TestImageEdit(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
origin, err := os.Create("image.png") | |||
if err != nil { | |||
t.Error("open origin file error") | |||
return | |||
} | |||
mask, err := os.Create("mask.png") | |||
if err != nil { | |||
t.Error("open mask file error") | |||
return | |||
} | |||
defer func() { | |||
mask.Close() | |||
origin.Close() | |||
os.Remove("mask.png") | |||
os.Remove("image.png") | |||
}() | |||
req := ImageEditRequest{ | |||
Image: origin, | |||
Mask: mask, | |||
Prompt: "There is a turtle in the pool", | |||
N: 3, | |||
Size: CreateImageSize1024x1024, | |||
} | |||
_, err = client.CreateEditImage(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateImage error: %v", err) | |||
} | |||
} | |||
func TestImageEditWithoutMask(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
origin, err := os.Create("image.png") | |||
if err != nil { | |||
t.Error("open origin file error") | |||
return | |||
} | |||
defer func() { | |||
origin.Close() | |||
os.Remove("image.png") | |||
}() | |||
req := ImageEditRequest{ | |||
Image: origin, | |||
Prompt: "There is a turtle in the pool", | |||
N: 3, | |||
Size: CreateImageSize1024x1024, | |||
} | |||
_, err = client.CreateEditImage(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateImage error: %v", err) | |||
} | |||
} | |||
// handleEditImageEndpoint Handles the images endpoint by the test server. | |||
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var resBytes []byte | |||
// imagess only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
responses := ImageResponse{ | |||
Created: time.Now().Unix(), | |||
Data: []ImageResponseDataInner{ | |||
{ | |||
URL: "test-url1", | |||
B64JSON: "", | |||
}, | |||
{ | |||
URL: "test-url2", | |||
B64JSON: "", | |||
}, | |||
{ | |||
URL: "test-url3", | |||
B64JSON: "", | |||
}, | |||
}, | |||
} | |||
resBytes, _ = json.Marshal(responses) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
func TestImageVariation(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
origin, err := os.Create("image.png") | |||
if err != nil { | |||
t.Error("open origin file error") | |||
return | |||
} | |||
defer func() { | |||
origin.Close() | |||
os.Remove("image.png") | |||
}() | |||
req := ImageVariRequest{ | |||
Image: origin, | |||
N: 3, | |||
Size: CreateImageSize1024x1024, | |||
} | |||
_, err = client.CreateVariImage(ctx, req) | |||
if err != nil { | |||
t.Fatalf("CreateImage error: %v", err) | |||
} | |||
} | |||
// handleVariateImageEndpoint Handles the images endpoint by the test server. | |||
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var resBytes []byte | |||
// imagess only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
responses := ImageResponse{ | |||
Created: time.Now().Unix(), | |||
Data: []ImageResponseDataInner{ | |||
{ | |||
URL: "test-url1", | |||
B64JSON: "", | |||
}, | |||
{ | |||
URL: "test-url2", | |||
B64JSON: "", | |||
}, | |||
{ | |||
URL: "test-url3", | |||
B64JSON: "", | |||
}, | |||
}, | |||
} | |||
resBytes, _ = json.Marshal(responses) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} |
@@ -0,0 +1,15 @@ | |||
package openai | |||
import ( | |||
"encoding/json" | |||
) | |||
type marshaller interface { | |||
marshal(value any) ([]byte, error) | |||
} | |||
type jsonMarshaller struct{} | |||
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) { | |||
return json.Marshal(value) | |||
} |
@@ -1,4 +1,4 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"context" | |||
@@ -7,7 +7,7 @@ import ( | |||
// Model struct represents an OpenAPI model. | |||
type Model struct { | |||
CreatedAt int64 `json:"created_at"` | |||
CreatedAt int64 `json:"created"` | |||
ID string `json:"id"` | |||
Object string `json:"object"` | |||
OwnedBy string `json:"owned_by"` | |||
@@ -18,7 +18,7 @@ type Model struct { | |||
// Permission struct represents an OpenAPI permission. | |||
type Permission struct { | |||
CreatedAt int64 `json:"created_at"` | |||
CreatedAt int64 `json:"created"` | |||
ID string `json:"id"` | |||
Object string `json:"object"` | |||
AllowCreateEngine bool `json:"allow_create_engine"` | |||
@@ -40,7 +40,7 @@ type ModelsList struct { | |||
// ListModels Lists the currently available models, | |||
// and provides basic information about each model such as the model id and parent. | |||
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { | |||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/models"), nil) | |||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil) | |||
if err != nil { | |||
return | |||
} | |||
@@ -0,0 +1,39 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"net/http" | |||
"testing" | |||
) | |||
// TestListModels Tests the models endpoint of the API using the mocked server. | |||
func TestListModels(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/models", handleModelsEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
_, err = client.ListModels(ctx) | |||
if err != nil { | |||
t.Fatalf("ListModels error: %v", err) | |||
} | |||
} | |||
// handleModelsEndpoint Handles the models endpoint by the test server. | |||
func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) { | |||
resBytes, _ := json.Marshal(ModelsList{}) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} |
@@ -1,9 +1,7 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"bytes" | |||
"context" | |||
"encoding/json" | |||
"net/http" | |||
) | |||
@@ -52,13 +50,7 @@ type ModerationResponse struct { | |||
// Moderations — perform a moderation api call over a string. | |||
// Input can be an array or slice but a string will reduce the complexity. | |||
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { | |||
var reqBytes []byte | |||
reqBytes, err = json.Marshal(request) | |||
if err != nil { | |||
return | |||
} | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/moderations"), bytes.NewBuffer(reqBytes)) | |||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request) | |||
if err != nil { | |||
return | |||
} | |||
@@ -0,0 +1,102 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"encoding/json" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
"strconv" | |||
"strings" | |||
"testing" | |||
"time" | |||
) | |||
// TestModeration Tests the moderations endpoint of the API using the mocked server. | |||
func TestModerations(t *testing.T) { | |||
server := test.NewTestServer() | |||
server.RegisterHandler("/v1/moderations", handleModerationEndpoint) | |||
// create the test server | |||
var err error | |||
ts := server.OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
// create an edit request | |||
model := "text-moderation-stable" | |||
moderationReq := ModerationRequest{ | |||
Model: &model, | |||
Input: "I want to kill them.", | |||
} | |||
_, err = client.Moderations(ctx, moderationReq) | |||
if err != nil { | |||
t.Fatalf("Moderation error: %v", err) | |||
} | |||
} | |||
// handleModerationEndpoint Handles the moderation endpoint by the test server. | |||
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { | |||
var err error | |||
var resBytes []byte | |||
// completions only accepts POST requests | |||
if r.Method != "POST" { | |||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |||
} | |||
var moderationReq ModerationRequest | |||
if moderationReq, err = getModerationBody(r); err != nil { | |||
http.Error(w, "could not read request", http.StatusInternalServerError) | |||
return | |||
} | |||
resCat := ResultCategories{} | |||
resCatScore := ResultCategoryScores{} | |||
switch { | |||
case strings.Contains(moderationReq.Input, "kill"): | |||
resCat = ResultCategories{Violence: true} | |||
resCatScore = ResultCategoryScores{Violence: 1} | |||
case strings.Contains(moderationReq.Input, "hate"): | |||
resCat = ResultCategories{Hate: true} | |||
resCatScore = ResultCategoryScores{Hate: 1} | |||
case strings.Contains(moderationReq.Input, "suicide"): | |||
resCat = ResultCategories{SelfHarm: true} | |||
resCatScore = ResultCategoryScores{SelfHarm: 1} | |||
case strings.Contains(moderationReq.Input, "porn"): | |||
resCat = ResultCategories{Sexual: true} | |||
resCatScore = ResultCategoryScores{Sexual: 1} | |||
} | |||
result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} | |||
res := ModerationResponse{ | |||
ID: strconv.Itoa(int(time.Now().Unix())), | |||
Model: *moderationReq.Model, | |||
} | |||
res.Results = append(res.Results, result) | |||
resBytes, _ = json.Marshal(res) | |||
fmt.Fprintln(w, string(resBytes)) | |||
} | |||
// getModerationBody Returns the body of the request to do a moderation. | |||
func getModerationBody(r *http.Request) (ModerationRequest, error) { | |||
moderation := ModerationRequest{} | |||
// read the request body | |||
reqBody, err := io.ReadAll(r.Body) | |||
if err != nil { | |||
return ModerationRequest{}, err | |||
} | |||
err = json.Unmarshal(reqBody, &moderation) | |||
if err != nil { | |||
return ModerationRequest{}, err | |||
} | |||
return moderation, nil | |||
} |
@@ -0,0 +1,40 @@ | |||
package openai | |||
import ( | |||
"bytes" | |||
"context" | |||
"net/http" | |||
) | |||
type requestBuilder interface { | |||
build(ctx context.Context, method, url string, request any) (*http.Request, error) | |||
} | |||
type httpRequestBuilder struct { | |||
marshaller marshaller | |||
} | |||
func newRequestBuilder() *httpRequestBuilder { | |||
return &httpRequestBuilder{ | |||
marshaller: &jsonMarshaller{}, | |||
} | |||
} | |||
func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) { | |||
if request == nil { | |||
return http.NewRequestWithContext(ctx, method, url, nil) | |||
} | |||
var reqBytes []byte | |||
reqBytes, err := b.marshaller.marshal(request) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return http.NewRequestWithContext( | |||
ctx, | |||
method, | |||
url, | |||
bytes.NewBuffer(reqBytes), | |||
) | |||
} |
@@ -0,0 +1,148 @@ | |||
package openai //nolint:testpackage // testing private field | |||
import ( | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"errors" | |||
"net/http" | |||
"testing" | |||
) | |||
var ( | |||
errTestMarshallerFailed = errors.New("test marshaller failed") | |||
errTestRequestBuilderFailed = errors.New("test request builder failed") | |||
) | |||
type ( | |||
failingRequestBuilder struct{} | |||
failingMarshaller struct{} | |||
) | |||
func (*failingMarshaller) marshal(value any) ([]byte, error) { | |||
return []byte{}, errTestMarshallerFailed | |||
} | |||
func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) { | |||
return nil, errTestRequestBuilderFailed | |||
} | |||
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { | |||
builder := httpRequestBuilder{ | |||
marshaller: &failingMarshaller{}, | |||
} | |||
_, err := builder.build(context.Background(), "", "", struct{}{}) | |||
if !errors.Is(err, errTestMarshallerFailed) { | |||
t.Fatalf("Did not return error when marshaller failed: %v", err) | |||
} | |||
} | |||
func TestClientReturnsRequestBuilderErrors(t *testing.T) { | |||
var err error | |||
ts := test.NewTestServer().OpenAITestServer() | |||
ts.Start() | |||
defer ts.Close() | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = ts.URL + "/v1" | |||
client := NewClientWithConfig(config) | |||
client.requestBuilder = &failingRequestBuilder{} | |||
ctx := context.Background() | |||
_, err = client.CreateCompletion(ctx, CompletionRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateFineTune(ctx, FineTuneRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListFineTunes(ctx) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CancelFineTune(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.GetFineTune(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.DeleteFineTune(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListFineTuneEvents(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.Moderations(ctx, ModerationRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.Edits(ctx, EditsRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.CreateImage(ctx, ImageRequest{}) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
err = client.DeleteFile(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.GetFile(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListFiles(ctx) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListEngines(ctx) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.GetEngine(ctx, "") | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
_, err = client.ListModels(ctx) | |||
if !errors.Is(err, errTestRequestBuilderFailed) { | |||
t.Fatalf("Did not return error when request builder failed: %v", err) | |||
} | |||
} |
@@ -1,4 +1,4 @@ | |||
package gogpt | |||
package openai | |||
import ( | |||
"bufio" | |||
@@ -6,7 +6,6 @@ import ( | |||
"context" | |||
"encoding/json" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"net/http" | |||
) | |||
@@ -60,43 +59,6 @@ waitForData: | |||
return | |||
} | |||
func (stream *CompletionStream) ChatRecv() (response ChatCompletionResponseForStream, err error) { | |||
if stream.isFinished { | |||
err = io.EOF | |||
return | |||
} | |||
var emptyMessagesCount uint | |||
waitForData: | |||
line, err := stream.reader.ReadBytes('\n') | |||
if err != nil { | |||
return | |||
} | |||
var headerData = []byte("data: ") | |||
line = bytes.TrimSpace(line) | |||
if !bytes.HasPrefix(line, headerData) { | |||
emptyMessagesCount++ | |||
if emptyMessagesCount > stream.emptyMessagesLimit { | |||
err = ErrTooManyEmptyStreamMessages | |||
return | |||
} | |||
goto waitForData | |||
} | |||
line = bytes.TrimPrefix(line, headerData) | |||
if string(line) == "[DONE]" { | |||
stream.isFinished = true | |||
err = io.EOF | |||
return | |||
} | |||
err = json.Unmarshal(line, &response) | |||
return | |||
} | |||
func (stream *CompletionStream) Close() { | |||
stream.response.Body.Close() | |||
} | |||
@@ -110,18 +72,7 @@ func (c *Client) CreateCompletionStream( | |||
request CompletionRequest, | |||
) (stream *CompletionStream, err error) { | |||
request.Stream = true | |||
reqBytes, err := json.Marshal(request) | |||
if err != nil { | |||
return | |||
} | |||
urlSuffix := "/completions" | |||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) | |||
req.Header.Set("Content-Type", "application/json") | |||
req.Header.Set("Accept", "text/event-stream") | |||
req.Header.Set("Cache-Control", "no-cache") | |||
req.Header.Set("Connection", "keep-alive") | |||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) | |||
req, err := c.newStreamRequest(ctx, "POST", "/completions", request) | |||
if err != nil { | |||
return | |||
} | |||
@@ -0,0 +1,147 @@ | |||
package openai_test | |||
import ( | |||
. "github.com/sashabaranov/go-openai" | |||
"github.com/sashabaranov/go-openai/internal/test" | |||
"context" | |||
"errors" | |||
"io" | |||
"net/http" | |||
"net/http/httptest" | |||
"testing" | |||
) | |||
func TestCreateCompletionStream(t *testing.T) { | |||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
w.Header().Set("Content-Type", "text/event-stream") | |||
// Send test responses | |||
dataBytes := []byte{} | |||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||
//nolint:lll | |||
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` | |||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||
dataBytes = append(dataBytes, []byte("event: message\n")...) | |||
//nolint:lll | |||
data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}` | |||
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) | |||
dataBytes = append(dataBytes, []byte("event: done\n")...) | |||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) | |||
_, err := w.Write(dataBytes) | |||
if err != nil { | |||
t.Errorf("Write error: %s", err) | |||
} | |||
})) | |||
defer server.Close() | |||
// Client portion of the test | |||
config := DefaultConfig(test.GetTestToken()) | |||
config.BaseURL = server.URL + "/v1" | |||
config.HTTPClient.Transport = &tokenRoundTripper{ | |||
test.GetTestToken(), | |||
http.DefaultTransport, | |||
} | |||
client := NewClientWithConfig(config) | |||
ctx := context.Background() | |||
request := CompletionRequest{ | |||
Prompt: "Ex falso quodlibet", | |||
Model: "text-davinci-002", | |||
MaxTokens: 10, | |||
Stream: true, | |||
} | |||
stream, err := client.CreateCompletionStream(ctx, request) | |||
if err != nil { | |||
t.Errorf("CreateCompletionStream returned error: %v", err) | |||
} | |||
defer stream.Close() | |||
expectedResponses := []CompletionResponse{ | |||
{ | |||
ID: "1", | |||
Object: "completion", | |||
Created: 1598069254, | |||
Model: "text-davinci-002", | |||
Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, | |||
}, | |||
{ | |||
ID: "2", | |||
Object: "completion", | |||
Created: 1598069255, | |||
Model: "text-davinci-002", | |||
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, | |||
}, | |||
} | |||
for ix, expectedResponse := range expectedResponses { | |||
receivedResponse, streamErr := stream.Recv() | |||
if streamErr != nil { | |||
t.Errorf("stream.Recv() failed: %v", streamErr) | |||
} | |||
if !compareResponses(expectedResponse, receivedResponse) { | |||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) | |||
} | |||
} | |||
_, streamErr := stream.Recv() | |||
if !errors.Is(streamErr, io.EOF) { | |||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) | |||
} | |||
_, streamErr = stream.Recv() | |||
if !errors.Is(streamErr, io.EOF) { | |||
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) | |||
} | |||
} | |||
// A "tokenRoundTripper" is a struct that implements the RoundTripper | |||
// interface, specifically to handle the authentication token by adding a token | |||
// to the request header. We need this because the API requires that each | |||
// request include a valid API token in the headers for authentication and | |||
// authorization. | |||
type tokenRoundTripper struct { | |||
token string | |||
fallback http.RoundTripper | |||
} | |||
// RoundTrip takes an *http.Request as input and returns an | |||
// *http.Response and an error. | |||
// | |||
// It is expected to use the provided request to create a connection to an HTTP | |||
// server and return the response, or an error if one occurred. The returned | |||
// Response should have its Body closed. If the RoundTrip method returns an | |||
// error, the Client's Get, Head, Post, and PostForm methods return the same | |||
// error. | |||
func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | |||
req.Header.Set("Authorization", "Bearer "+t.token) | |||
return t.fallback.RoundTrip(req) | |||
} | |||
// Helper funcs. | |||
func compareResponses(r1, r2 CompletionResponse) bool { | |||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { | |||
return false | |||
} | |||
if len(r1.Choices) != len(r2.Choices) { | |||
return false | |||
} | |||
for i := range r1.Choices { | |||
if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) { | |||
return false | |||
} | |||
} | |||
return true | |||
} | |||
func compareResponseChoices(c1, c2 CompletionChoice) bool { | |||
if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { | |||
return false | |||
} | |||
return true | |||
} |