diff --git a/go-gpt3/.github/workflows/pr.yml b/go-gpt3/.github/workflows/pr.yml
new file mode 100644
index 0000000..9a1b3ac
--- /dev/null
+++ b/go-gpt3/.github/workflows/pr.yml
@@ -0,0 +1,27 @@
+name: PR sanity check
+
+on:
+ - push
+ - pull_request
+
+jobs:
+ prcheck:
+ name: PR sanity check
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - name: Setup Go
+ uses: actions/setup-go@v2
+ with:
+ go-version: '1.19'
+ - name: Run vet
+ run: |
+ go vet .
+ - name: Run golangci-lint
+ uses: golangci/golangci-lint-action@v3
+ with:
+ version: latest
+ - name: Run tests
+ run: go test -race -covermode=atomic -coverprofile=coverage.out -v .
+ - name: Upload coverage reports to Codecov
+ uses: codecov/codecov-action@v3
diff --git a/go-gpt3/README.md b/go-gpt3/README.md
index 7f8b016..e6e352e 100644
--- a/go-gpt3/README.md
+++ b/go-gpt3/README.md
@@ -1,17 +1,65 @@
-# go-gpt3
-[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gpt3)
-[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gpt3)](https://goreportcard.com/report/github.com/sashabaranov/go-gpt3)
+# Go OpenAI
+[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai)
+[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai)
+[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai)
+> **Note**: the repository was recently renamed from `go-gpt3` to `go-openai`
-[OpenAI ChatGPT and GPT-3](https://platform.openai.com/) API client for Go
+This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support:
+
+* ChatGPT
+* GPT-3, GPT-4
+* DALL·E 2
+* Whisper
Installation:
```
-go get github.com/sashabaranov/go-gpt3
+go get github.com/sashabaranov/go-openai
+```
+
+
+ChatGPT example usage:
+
+```go
+package main
+
+import (
+ "context"
+ "fmt"
+ openai "github.com/sashabaranov/go-openai"
+)
+
+func main() {
+ client := openai.NewClient("your token")
+ resp, err := client.CreateChatCompletion(
+ context.Background(),
+ openai.ChatCompletionRequest{
+ Model: openai.GPT3Dot5Turbo,
+ Messages: []openai.ChatCompletionMessage{
+ {
+ Role: openai.ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ },
+ )
+
+ if err != nil {
+ fmt.Printf("ChatCompletion error: %v\n", err)
+ return
+ }
+
+ fmt.Println(resp.Choices[0].Message.Content)
+}
+
```
-Example usage:
+
+Other examples:
+
+
+GPT-3 completion
```go
package main
@@ -19,27 +67,30 @@ package main
import (
"context"
"fmt"
- gogpt "github.com/sashabaranov/go-gpt3"
+ openai "github.com/sashabaranov/go-openai"
)
func main() {
- c := gogpt.NewClient("your token")
+ c := openai.NewClient("your token")
ctx := context.Background()
- req := gogpt.CompletionRequest{
- Model: gogpt.GPT3Ada,
+ req := openai.CompletionRequest{
+ Model: openai.GPT3Ada,
MaxTokens: 5,
Prompt: "Lorem ipsum",
}
resp, err := c.CreateCompletion(ctx, req)
if err != nil {
+ fmt.Printf("Completion error: %v\n", err)
return
}
fmt.Println(resp.Choices[0].Text)
}
```
+
-Streaming response example:
+
+GPT-3 streaming completion
```go
package main
@@ -49,21 +100,22 @@ import (
"context"
"fmt"
"io"
- gogpt "github.com/sashabaranov/go-gpt3"
+ openai "github.com/sashabaranov/go-openai"
)
func main() {
- c := gogpt.NewClient("your token")
+ c := openai.NewClient("your token")
ctx := context.Background()
- req := gogpt.CompletionRequest{
- Model: gogpt.GPT3Ada,
+ req := openai.CompletionRequest{
+ Model: openai.GPT3Ada,
MaxTokens: 5,
Prompt: "Lorem ipsum",
Stream: true,
}
stream, err := c.CreateCompletionStream(ctx, req)
if err != nil {
+ fmt.Printf("CompletionStream error: %v\n", err)
return
}
defer stream.Close()
@@ -85,3 +137,194 @@ func main() {
}
}
```
+
+
+
+Audio Speech-To-Text
+
+```go
+package main
+
+import (
+ "context"
+ "fmt"
+
+ openai "github.com/sashabaranov/go-openai"
+)
+
+func main() {
+ c := openai.NewClient("your token")
+ ctx := context.Background()
+
+ req := openai.AudioRequest{
+ Model: openai.Whisper1,
+ FilePath: "recording.mp3",
+ }
+ resp, err := c.CreateTranscription(ctx, req)
+ if err != nil {
+ fmt.Printf("Transcription error: %v\n", err)
+ return
+ }
+ fmt.Println(resp.Text)
+}
+```
+
+
+
+DALL-E 2 image generation
+
+```go
+package main
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "fmt"
+ openai "github.com/sashabaranov/go-openai"
+ "image/png"
+ "os"
+)
+
+func main() {
+ c := openai.NewClient("your token")
+ ctx := context.Background()
+
+ // Sample image by link
+ reqUrl := openai.ImageRequest{
+ Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail",
+ Size: openai.CreateImageSize256x256,
+ ResponseFormat: openai.CreateImageResponseFormatURL,
+ N: 1,
+ }
+
+ respUrl, err := c.CreateImage(ctx, reqUrl)
+ if err != nil {
+ fmt.Printf("Image creation error: %v\n", err)
+ return
+ }
+ fmt.Println(respUrl.Data[0].URL)
+
+ // Example image as base64
+ reqBase64 := openai.ImageRequest{
+ Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine",
+ Size: openai.CreateImageSize256x256,
+ ResponseFormat: openai.CreateImageResponseFormatB64JSON,
+ N: 1,
+ }
+
+ respBase64, err := c.CreateImage(ctx, reqBase64)
+ if err != nil {
+ fmt.Printf("Image creation error: %v\n", err)
+ return
+ }
+
+ imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON)
+ if err != nil {
+ fmt.Printf("Base64 decode error: %v\n", err)
+ return
+ }
+
+ r := bytes.NewReader(imgBytes)
+ imgData, err := png.Decode(r)
+ if err != nil {
+ fmt.Printf("PNG decode error: %v\n", err)
+ return
+ }
+
+ file, err := os.Create("image.png")
+ if err != nil {
+ fmt.Printf("File creation error: %v\n", err)
+ return
+ }
+ defer file.Close()
+
+ if err := png.Encode(file, imgData); err != nil {
+ fmt.Printf("PNG encode error: %v\n", err)
+ return
+ }
+
+ fmt.Println("The image was saved as example.png")
+}
+
+```
+
+
+
+Configuring proxy
+
+```go
+config := openai.DefaultConfig("token")
+proxyUrl, err := url.Parse("http://localhost:{port}")
+if err != nil {
+ panic(err)
+}
+transport := &http.Transport{
+ Proxy: http.ProxyURL(proxyUrl),
+}
+config.HTTPClient = &http.Client{
+ Transport: transport,
+}
+
+c := openai.NewClientWithConfig(config)
+```
+
+See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig
+
+
+
+ChatGPT support context
+
+```go
+package main
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/sashabaranov/go-openai"
+)
+
+func main() {
+ client := openai.NewClient("your token")
+ messages := make([]openai.ChatCompletionMessage, 0)
+ reader := bufio.NewReader(os.Stdin)
+ fmt.Println("Conversation")
+ fmt.Println("---------------------")
+
+ for {
+ fmt.Print("-> ")
+ text, _ := reader.ReadString('\n')
+ // convert CRLF to LF
+ text = strings.Replace(text, "\n", "", -1)
+ messages = append(messages, openai.ChatCompletionMessage{
+ Role: openai.ChatMessageRoleUser,
+ Content: text,
+ })
+
+ resp, err := client.CreateChatCompletion(
+ context.Background(),
+ openai.ChatCompletionRequest{
+ Model: openai.GPT3Dot5Turbo,
+ Messages: messages,
+ },
+ )
+
+ if err != nil {
+ fmt.Printf("ChatCompletion error: %v\n", err)
+ continue
+ }
+
+ content := resp.Choices[0].Message.Content
+ messages = append(messages, openai.ChatCompletionMessage{
+ Role: openai.ChatMessageRoleAssistant,
+ Content: content,
+ })
+ fmt.Println(content)
+ }
+}
+```
+
\ No newline at end of file
diff --git a/go-gpt3/answers.go b/go-gpt3/answers.go
deleted file mode 100644
index 5a99078..0000000
--- a/go-gpt3/answers.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package gogpt
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "net/http"
-)
-
-type AnswerRequest struct {
- Documents []string `json:"documents,omitempty"`
- File string `json:"file,omitempty"`
- Question string `json:"question"`
- SearchModel string `json:"search_model,omitempty"`
- Model string `json:"model"`
- ExamplesContext string `json:"examples_context"`
- Examples [][]string `json:"examples"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Stop []string `json:"stop,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
-}
-
-type AnswerResponse struct {
- Answers []string `json:"answers"`
- Completion string `json:"completion"`
- Model string `json:"model"`
- Object string `json:"object"`
- SearchModel string `json:"search_model"`
- SelectedDocuments []struct {
- Document int `json:"document"`
- Text string `json:"text"`
- } `json:"selected_documents"`
-}
-
-// Search — perform a semantic search api call over a list of documents.
-func (c *Client) Answers(ctx context.Context, request AnswerRequest) (response AnswerResponse, err error) {
- var reqBytes []byte
- reqBytes, err = json.Marshal(request)
- if err != nil {
- return
- }
-
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/answers"), bytes.NewBuffer(reqBytes))
- if err != nil {
- return
- }
-
- err = c.sendRequest(req, &response)
- return
-}
diff --git a/go-gpt3/api.go b/go-gpt3/api.go
index 715c9dd..00d6d35 100644
--- a/go-gpt3/api.go
+++ b/go-gpt3/api.go
@@ -1,6 +1,7 @@
-package gogpt
+package openai
import (
+ "context"
"encoding/json"
"fmt"
"net/http"
@@ -9,17 +10,22 @@ import (
// Client is OpenAI GPT-3 API client.
type Client struct {
config ClientConfig
+
+ requestBuilder requestBuilder
}
// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
config := DefaultConfig(authToken)
- return &Client{config}
+ return NewClientWithConfig(config)
}
// NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client {
- return &Client{config}
+ return &Client{
+ config: config,
+ requestBuilder: newRequestBuilder(),
+ }
}
// NewOrgClient creates new OpenAI API client for specified Organization ID.
@@ -28,7 +34,7 @@ func NewClientWithConfig(config ClientConfig) *Client {
func NewOrgClient(authToken, org string) *Client {
config := DefaultConfig(authToken)
config.OrgID = org
- return &Client{config}
+ return NewClientWithConfig(config)
}
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
@@ -68,7 +74,7 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
}
if v != nil {
- if err = json.NewDecoder(res.Body).Decode(&v); err != nil {
+ if err = json.NewDecoder(res.Body).Decode(v); err != nil {
return err
}
}
@@ -79,3 +85,22 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
func (c *Client) fullURL(suffix string) string {
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}
+
+func (c *Client) newStreamRequest(
+ ctx context.Context,
+ method string,
+ urlSuffix string,
+ body any) (*http.Request, error) {
+ req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body)
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "text/event-stream")
+ req.Header.Set("Cache-Control", "no-cache")
+ req.Header.Set("Connection", "keep-alive")
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
+
+ return req, nil
+}
diff --git a/go-gpt3/api_test.go b/go-gpt3/api_test.go
new file mode 100644
index 0000000..202ec94
--- /dev/null
+++ b/go-gpt3/api_test.go
@@ -0,0 +1,182 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+
+ "context"
+ "errors"
+ "io"
+ "os"
+ "testing"
+)
+
+func TestAPI(t *testing.T) {
+ apiToken := os.Getenv("OPENAI_TOKEN")
+ if apiToken == "" {
+ t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
+ }
+
+ var err error
+ c := NewClient(apiToken)
+ ctx := context.Background()
+ _, err = c.ListEngines(ctx)
+ if err != nil {
+ t.Fatalf("ListEngines error: %v", err)
+ }
+
+ _, err = c.GetEngine(ctx, "davinci")
+ if err != nil {
+ t.Fatalf("GetEngine error: %v", err)
+ }
+
+ fileRes, err := c.ListFiles(ctx)
+ if err != nil {
+ t.Fatalf("ListFiles error: %v", err)
+ }
+
+ if len(fileRes.Files) > 0 {
+ _, err = c.GetFile(ctx, fileRes.Files[0].ID)
+ if err != nil {
+ t.Fatalf("GetFile error: %v", err)
+ }
+ } // else skip
+
+ embeddingReq := EmbeddingRequest{
+ Input: []string{
+ "The food was delicious and the waiter",
+ "Other examples of embedding request",
+ },
+ Model: AdaSearchQuery,
+ }
+ _, err = c.CreateEmbeddings(ctx, embeddingReq)
+ if err != nil {
+ t.Fatalf("Embedding error: %v", err)
+ }
+
+ _, err = c.CreateChatCompletion(
+ ctx,
+ ChatCompletionRequest{
+ Model: GPT3Dot5Turbo,
+ Messages: []ChatCompletionMessage{
+ {
+ Role: ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ },
+ )
+
+ if err != nil {
+ t.Errorf("CreateChatCompletion (without name) returned error: %v", err)
+ }
+
+ _, err = c.CreateChatCompletion(
+ ctx,
+ ChatCompletionRequest{
+ Model: GPT3Dot5Turbo,
+ Messages: []ChatCompletionMessage{
+ {
+ Role: ChatMessageRoleUser,
+ Name: "John_Doe",
+ Content: "Hello!",
+ },
+ },
+ },
+ )
+
+ if err != nil {
+ t.Errorf("CreateChatCompletion (with name) returned error: %v", err)
+ }
+
+ stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
+ Prompt: "Ex falso quodlibet",
+ Model: GPT3Ada,
+ MaxTokens: 5,
+ Stream: true,
+ })
+ if err != nil {
+ t.Errorf("CreateCompletionStream returned error: %v", err)
+ }
+ defer stream.Close()
+
+ counter := 0
+ for {
+ _, err = stream.Recv()
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ t.Errorf("Stream error: %v", err)
+ } else {
+ counter++
+ }
+ }
+ if counter == 0 {
+ t.Error("Stream did not return any responses")
+ }
+}
+
+func TestAPIError(t *testing.T) {
+ apiToken := os.Getenv("OPENAI_TOKEN")
+ if apiToken == "" {
+ t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
+ }
+
+ var err error
+ c := NewClient(apiToken + "_invalid")
+ ctx := context.Background()
+ _, err = c.ListEngines(ctx)
+ if err == nil {
+ t.Fatal("ListEngines did not fail")
+ }
+
+ var apiErr *APIError
+ if !errors.As(err, &apiErr) {
+ t.Fatalf("Error is not an APIError: %+v", err)
+ }
+
+ if apiErr.StatusCode != 401 {
+ t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
+ }
+ if *apiErr.Code != "invalid_api_key" {
+ t.Fatalf("Unexpected API error code: %s", *apiErr.Code)
+ }
+ if apiErr.Error() == "" {
+ t.Fatal("Empty error message occured")
+ }
+}
+
+func TestRequestError(t *testing.T) {
+ var err error
+
+ config := DefaultConfig("dummy")
+ config.BaseURL = "https://httpbin.org/status/418?"
+ c := NewClientWithConfig(config)
+ ctx := context.Background()
+ _, err = c.ListEngines(ctx)
+ if err == nil {
+ t.Fatal("ListEngines request did not fail")
+ }
+
+ var reqErr *RequestError
+ if !errors.As(err, &reqErr) {
+ t.Fatalf("Error is not a RequestError: %+v", err)
+ }
+
+ if reqErr.StatusCode != 418 {
+ t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode)
+ }
+
+ if reqErr.Unwrap() == nil {
+ t.Fatalf("Empty request error occured")
+ }
+}
+
+// numTokens Returns the number of GPT-3 encoded tokens in the given text.
+// This function approximates based on the rule of thumb stated by OpenAI:
+// https://beta.openai.com/tokenizer
+//
+// TODO: implement an actual tokenizer for GPT-3 and Codex (once available)
+func numTokens(s string) int {
+ return int(float32(len(s)) / 4)
+}
diff --git a/go-gpt3/audio.go b/go-gpt3/audio.go
index 0dc611e..54bd32f 100644
--- a/go-gpt3/audio.go
+++ b/go-gpt3/audio.go
@@ -1,4 +1,4 @@
-package gogpt
+package openai
import (
"bytes"
diff --git a/go-gpt3/audio_test.go b/go-gpt3/audio_test.go
new file mode 100644
index 0000000..2a035c9
--- /dev/null
+++ b/go-gpt3/audio_test.go
@@ -0,0 +1,143 @@
+package openai_test
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "mime"
+ "mime/multipart"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "testing"
+)
+
+// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
+func TestAudio(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
+ server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+
+ testcases := []struct {
+ name string
+ createFn func(context.Context, AudioRequest) (AudioResponse, error)
+ }{
+ {
+ "transcribe",
+ client.CreateTranscription,
+ },
+ {
+ "translate",
+ client.CreateTranslation,
+ },
+ }
+
+ ctx := context.Background()
+
+ dir, cleanup := createTestDirectory(t)
+ defer cleanup()
+
+ for _, tc := range testcases {
+ t.Run(tc.name, func(t *testing.T) {
+ path := filepath.Join(dir, "fake.mp3")
+ createTestFile(t, path)
+
+ req := AudioRequest{
+ FilePath: path,
+ Model: "whisper-3",
+ }
+ _, err = tc.createFn(ctx, req)
+ if err != nil {
+ t.Fatalf("audio API error: %v", err)
+ }
+ })
+ }
+}
+
+// createTestFile creates a fake file with "hello" as the content.
+func createTestFile(t *testing.T, path string) {
+ file, err := os.Create(path)
+ if err != nil {
+ t.Fatalf("failed to create file %v", err)
+ }
+ if _, err = file.WriteString("hello"); err != nil {
+ t.Fatalf("failed to write to file %v", err)
+ }
+ file.Close()
+}
+
+// createTestDirectory creates a temporary folder which will be deleted when cleanup is called.
+func createTestDirectory(t *testing.T) (path string, cleanup func()) {
+ t.Helper()
+
+ path, err := os.MkdirTemp(os.TempDir(), "")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return path, func() { os.RemoveAll(path) }
+}
+
+// handleAudioEndpoint Handles the completion endpoint by the test server.
+func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
+ var err error
+
+ // audio endpoints only accept POST requests
+ if r.Method != "POST" {
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ }
+
+ mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
+ if err != nil {
+ http.Error(w, "failed to parse media type", http.StatusBadRequest)
+ return
+ }
+
+ if !strings.HasPrefix(mediaType, "multipart") {
+ http.Error(w, "request is not multipart", http.StatusBadRequest)
+ }
+
+ boundary, ok := params["boundary"]
+ if !ok {
+ http.Error(w, "no boundary in params", http.StatusBadRequest)
+ return
+ }
+
+ fileData := &bytes.Buffer{}
+ mr := multipart.NewReader(r.Body, boundary)
+ part, err := mr.NextPart()
+ if err != nil && errors.Is(err, io.EOF) {
+ http.Error(w, "error accessing file", http.StatusBadRequest)
+ return
+ }
+ if _, err = io.Copy(fileData, part); err != nil {
+ http.Error(w, "failed to copy file", http.StatusInternalServerError)
+ return
+ }
+
+ if len(fileData.Bytes()) == 0 {
+ w.WriteHeader(http.StatusInternalServerError)
+ http.Error(w, "received empty file data", http.StatusBadRequest)
+ return
+ }
+
+ if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
+ http.Error(w, "failed to write body", http.StatusInternalServerError)
+ return
+ }
+}
diff --git a/go-gpt3/chat.go b/go-gpt3/chat.go
index 7ac1121..99edfe8 100644
--- a/go-gpt3/chat.go
+++ b/go-gpt3/chat.go
@@ -1,15 +1,18 @@
-package gogpt
+package openai
import (
- "bufio"
- "bytes"
"context"
- "encoding/json"
"errors"
- "fmt"
"net/http"
)
+// Chat message role defined by the OpenAI API.
+const (
+ ChatMessageRoleSystem = "system"
+ ChatMessageRoleUser = "user"
+ ChatMessageRoleAssistant = "assistant"
+)
+
var (
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported")
)
@@ -17,6 +20,12 @@ var (
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
+
+ // This property isn't in the official documentation, but it's in
+ // the documentation for the official library for python:
+ // - https://github.com/openai/openai-python/blob/main/chatml.md
+ // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ Name string `json:"name,omitempty"`
}
// ChatCompletionRequest represents a request structure for chat completion API.
@@ -41,24 +50,6 @@ type ChatCompletionChoice struct {
FinishReason string `json:"finish_reason"`
}
-type ChatCompletionChoiceForStream struct {
- Index int `json:"index"`
- Delta struct {
- Content string `json:"content"`
- } `json:"delta"`
- FinishReason string `json:"finish_reason"`
-}
-
-// ChatCompletionResponseForStream represents a response structure for chat completion API.
-type ChatCompletionResponseForStream struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []ChatCompletionChoiceForStream `json:"choices"`
- Usage Usage `json:"usage"`
-}
-
// ChatCompletionResponse represents a response structure for chat completion API.
type ChatCompletionResponse struct {
ID string `json:"id"`
@@ -69,25 +60,21 @@ type ChatCompletionResponse struct {
Usage Usage `json:"usage"`
}
-// CreateChatCompletion — API call to Creates a completion for the chat message.
+// CreateChatCompletion — API call to Create a completion for the chat message.
func (c *Client) CreateChatCompletion(
ctx context.Context,
request ChatCompletionRequest,
) (response ChatCompletionResponse, err error) {
model := request.Model
- if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo {
+ switch model {
+ case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K:
+ default:
err = ErrChatCompletionInvalidModel
return
}
- var reqBytes []byte
- reqBytes, err = json.Marshal(request)
- if err != nil {
- return
- }
-
urlSuffix := "/chat/completions"
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
+ req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}
@@ -95,44 +82,3 @@ func (c *Client) CreateChatCompletion(
err = c.sendRequest(req, &response)
return
}
-
-// CreateChatCompletionStream — API call to create a completion w/ streaming
-func (c *Client) CreateChatCompletionStream(
- ctx context.Context,
- request ChatCompletionRequest,
-) (stream *CompletionStream, err error) {
- model := request.Model
- if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo {
- err = ErrChatCompletionInvalidModel
- return
- }
- request.Stream = true
- reqBytes, err := json.Marshal(request)
- if err != nil {
- return
- }
-
- urlSuffix := "/chat/completions"
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "text/event-stream")
- req.Header.Set("Cache-Control", "no-cache")
- req.Header.Set("Connection", "keep-alive")
- req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
- if err != nil {
- return
- }
-
- resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
- if err != nil {
- return
- }
-
- stream = &CompletionStream{
- emptyMessagesLimit: c.config.EmptyMessagesLimit,
-
- reader: bufio.NewReader(resp.Body),
- response: resp,
- }
- return
-}
diff --git a/go-gpt3/chat_stream.go b/go-gpt3/chat_stream.go
new file mode 100644
index 0000000..26e964c
--- /dev/null
+++ b/go-gpt3/chat_stream.go
@@ -0,0 +1,106 @@
+package openai
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+)
+
+type ChatCompletionStreamChoiceDelta struct {
+ Content string `json:"content"`
+}
+
+type ChatCompletionStreamChoice struct {
+ Index int `json:"index"`
+ Delta ChatCompletionStreamChoiceDelta `json:"delta"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type ChatCompletionStreamResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ Choices []ChatCompletionStreamChoice `json:"choices"`
+}
+
+// ChatCompletionStream
+// Note: Perhaps it is more elegant to abstract Stream using generics.
+type ChatCompletionStream struct {
+ emptyMessagesLimit uint
+ isFinished bool
+
+ reader *bufio.Reader
+ response *http.Response
+}
+
+func (stream *ChatCompletionStream) Recv() (response ChatCompletionStreamResponse, err error) {
+ if stream.isFinished {
+ err = io.EOF
+ return
+ }
+
+ var emptyMessagesCount uint
+
+waitForData:
+ line, err := stream.reader.ReadBytes('\n')
+ if err != nil {
+ return
+ }
+
+ var headerData = []byte("data: ")
+ line = bytes.TrimSpace(line)
+ if !bytes.HasPrefix(line, headerData) {
+ emptyMessagesCount++
+ if emptyMessagesCount > stream.emptyMessagesLimit {
+ err = ErrTooManyEmptyStreamMessages
+ return
+ }
+
+ goto waitForData
+ }
+
+ line = bytes.TrimPrefix(line, headerData)
+ if string(line) == "[DONE]" {
+ stream.isFinished = true
+ err = io.EOF
+ return
+ }
+
+ err = json.Unmarshal(line, &response)
+ return
+}
+
+func (stream *ChatCompletionStream) Close() {
+ stream.response.Body.Close()
+}
+
+// CreateChatCompletionStream — API call to create a chat completion w/ streaming
+// support. It sets whether to stream back partial progress. If set, tokens will be
+// sent as data-only server-sent events as they become available, with the
+// stream terminated by a data: [DONE] message.
+func (c *Client) CreateChatCompletionStream(
+ ctx context.Context,
+ request ChatCompletionRequest,
+) (stream *ChatCompletionStream, err error) {
+ request.Stream = true
+ req, err := c.newStreamRequest(ctx, "POST", "/chat/completions", request)
+ if err != nil {
+ return
+ }
+
+ resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
+ if err != nil {
+ return
+ }
+
+ stream = &ChatCompletionStream{
+ emptyMessagesLimit: c.config.EmptyMessagesLimit,
+ reader: bufio.NewReader(resp.Body),
+ response: resp,
+ }
+ return
+}
diff --git a/go-gpt3/chat_stream_test.go b/go-gpt3/chat_stream_test.go
new file mode 100644
index 0000000..e3da2da
--- /dev/null
+++ b/go-gpt3/chat_stream_test.go
@@ -0,0 +1,153 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestCreateChatCompletionStream(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ // Send test responses
+ dataBytes := []byte{}
+ dataBytes = append(dataBytes, []byte("event: message\n")...)
+ //nolint:lll
+ data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}`
+ dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
+
+ dataBytes = append(dataBytes, []byte("event: message\n")...)
+ //nolint:lll
+ data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}`
+ dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
+
+ dataBytes = append(dataBytes, []byte("event: done\n")...)
+ dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
+
+ _, err := w.Write(dataBytes)
+ if err != nil {
+ t.Errorf("Write error: %s", err)
+ }
+ }))
+ defer server.Close()
+
+ // Client portion of the test
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = server.URL + "/v1"
+ config.HTTPClient.Transport = &tokenRoundTripper{
+ test.GetTestToken(),
+ http.DefaultTransport,
+ }
+
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ request := ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: GPT3Dot5Turbo,
+ Messages: []ChatCompletionMessage{
+ {
+ Role: ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ Stream: true,
+ }
+
+ stream, err := client.CreateChatCompletionStream(ctx, request)
+ if err != nil {
+ t.Errorf("CreateCompletionStream returned error: %v", err)
+ }
+ defer stream.Close()
+
+ expectedResponses := []ChatCompletionStreamResponse{
+ {
+ ID: "1",
+ Object: "completion",
+ Created: 1598069254,
+ Model: GPT3Dot5Turbo,
+ Choices: []ChatCompletionStreamChoice{
+ {
+ Delta: ChatCompletionStreamChoiceDelta{
+ Content: "response1",
+ },
+ FinishReason: "max_tokens",
+ },
+ },
+ },
+ {
+ ID: "2",
+ Object: "completion",
+ Created: 1598069255,
+ Model: GPT3Dot5Turbo,
+ Choices: []ChatCompletionStreamChoice{
+ {
+ Delta: ChatCompletionStreamChoiceDelta{
+ Content: "response2",
+ },
+ FinishReason: "max_tokens",
+ },
+ },
+ },
+ }
+
+ for ix, expectedResponse := range expectedResponses {
+ b, _ := json.Marshal(expectedResponse)
+ t.Logf("%d: %s", ix, string(b))
+
+ receivedResponse, streamErr := stream.Recv()
+ if streamErr != nil {
+ t.Errorf("stream.Recv() failed: %v", streamErr)
+ }
+ if !compareChatResponses(expectedResponse, receivedResponse) {
+ t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
+ }
+ }
+
+ _, streamErr := stream.Recv()
+ if !errors.Is(streamErr, io.EOF) {
+ t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
+ }
+
+ _, streamErr = stream.Recv()
+ if !errors.Is(streamErr, io.EOF) {
+ t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
+ }
+}
+
+// Helper funcs.
+func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
+ if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
+ return false
+ }
+ if len(r1.Choices) != len(r2.Choices) {
+ return false
+ }
+ for i := range r1.Choices {
+ if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool {
+ if c1.Index != c2.Index {
+ return false
+ }
+ if c1.Delta.Content != c2.Delta.Content {
+ return false
+ }
+ if c1.FinishReason != c2.FinishReason {
+ return false
+ }
+ return true
+}
diff --git a/go-gpt3/chat_test.go b/go-gpt3/chat_test.go
new file mode 100644
index 0000000..5c03ebf
--- /dev/null
+++ b/go-gpt3/chat_test.go
@@ -0,0 +1,132 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestChatCompletionsWrongModel(t *testing.T) {
+ config := DefaultConfig("whatever")
+ config.BaseURL = "http://localhost/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ req := ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: "ada",
+ Messages: []ChatCompletionMessage{
+ {
+ Role: ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ }
+ _, err := client.CreateChatCompletion(ctx, req)
+ if !errors.Is(err, ErrChatCompletionInvalidModel) {
+ t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err)
+ }
+}
+
+// TestCompletions Tests the completions endpoint of the API using the mocked server.
+func TestChatCompletions(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ req := ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: GPT3Dot5Turbo,
+ Messages: []ChatCompletionMessage{
+ {
+ Role: ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ }
+ _, err = client.CreateChatCompletion(ctx, req)
+ if err != nil {
+ t.Fatalf("CreateChatCompletion error: %v", err)
+ }
+}
+
+// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
+func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
+ var err error
+ var resBytes []byte
+
+ // completions only accepts POST requests
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ var completionReq ChatCompletionRequest
+ if completionReq, err = getChatCompletionBody(r); err != nil {
+ http.Error(w, "could not read request", http.StatusInternalServerError)
+ return
+ }
+ res := ChatCompletionResponse{
+ ID: strconv.Itoa(int(time.Now().Unix())),
+ Object: "test-object",
+ Created: time.Now().Unix(),
+ // would be nice to validate Model during testing, but
+ // this may not be possible with how much upkeep
+ // would be required / wouldn't make much sense
+ Model: completionReq.Model,
+ }
+ // create completions
+ for i := 0; i < completionReq.N; i++ {
+ // generate a random string of length completionReq.Length
+ completionStr := strings.Repeat("a", completionReq.MaxTokens)
+
+ res.Choices = append(res.Choices, ChatCompletionChoice{
+ Message: ChatCompletionMessage{
+ Role: ChatMessageRoleAssistant,
+ Content: completionStr,
+ },
+ Index: i,
+ })
+ }
+ inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N
+ completionTokens := completionReq.MaxTokens * completionReq.N
+ res.Usage = Usage{
+ PromptTokens: inputTokens,
+ CompletionTokens: completionTokens,
+ TotalTokens: inputTokens + completionTokens,
+ }
+ resBytes, _ = json.Marshal(res)
+ fmt.Fprintln(w, string(resBytes))
+}
+
+// getChatCompletionBody Returns the body of the request to create a completion.
+func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) {
+ completion := ChatCompletionRequest{}
+ // read the request body
+ reqBody, err := io.ReadAll(r.Body)
+ if err != nil {
+ return ChatCompletionRequest{}, err
+ }
+ err = json.Unmarshal(reqBody, &completion)
+ if err != nil {
+ return ChatCompletionRequest{}, err
+ }
+ return completion, nil
+}
diff --git a/go-gpt3/common.go b/go-gpt3/common.go
index 9fb0178..3b555a7 100644
--- a/go-gpt3/common.go
+++ b/go-gpt3/common.go
@@ -1,5 +1,5 @@
// common.go defines common types used throughout the OpenAI API.
-package gogpt
+package openai
// Usage Represents the total token usage per request to OpenAI.
type Usage struct {
diff --git a/go-gpt3/completion.go b/go-gpt3/completion.go
index 74762c9..66b4866 100644
--- a/go-gpt3/completion.go
+++ b/go-gpt3/completion.go
@@ -1,17 +1,24 @@
-package gogpt
+package openai
import (
- "bytes"
"context"
- "encoding/json"
+ "errors"
"net/http"
)
+var (
+ ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
+)
+
// GPT3 Defines the models provided by OpenAI to use when generating
// completions from OpenAI.
// GPT3 Models are designed for text-based tasks. For code-specific
// tasks, please refer to the Codex series of models.
const (
+ GPT432K0314 = "gpt-4-32k-0314"
+ GPT432K = "gpt-4-32k"
+ GPT40314 = "gpt-4-0314"
+ GPT4 = "gpt-4"
GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301"
GPT3Dot5Turbo = "gpt-3.5-turbo"
GPT3TextDavinci003 = "text-davinci-003"
@@ -92,14 +99,13 @@ func (c *Client) CreateCompletion(
ctx context.Context,
request CompletionRequest,
) (response CompletionResponse, err error) {
- var reqBytes []byte
- reqBytes, err = json.Marshal(request)
- if err != nil {
+ if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo {
+ err = ErrCompletionUnsupportedModel
return
}
urlSuffix := "/completions"
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
+ req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}
diff --git a/go-gpt3/completion_test.go b/go-gpt3/completion_test.go
new file mode 100644
index 0000000..9868eb2
--- /dev/null
+++ b/go-gpt3/completion_test.go
@@ -0,0 +1,121 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestCompletionsWrongModel(t *testing.T) {
+ config := DefaultConfig("whatever")
+ config.BaseURL = "http://localhost/v1"
+ client := NewClientWithConfig(config)
+
+ _, err := client.CreateCompletion(
+ context.Background(),
+ CompletionRequest{
+ MaxTokens: 5,
+ Model: GPT3Dot5Turbo,
+ },
+ )
+ if !errors.Is(err, ErrCompletionUnsupportedModel) {
+ t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
+ }
+}
+
+// TestCompletions Tests the completions endpoint of the API using the mocked server.
+func TestCompletions(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ req := CompletionRequest{
+ MaxTokens: 5,
+ Model: "ada",
+ }
+ req.Prompt = "Lorem ipsum"
+ _, err = client.CreateCompletion(ctx, req)
+ if err != nil {
+ t.Fatalf("CreateCompletion error: %v", err)
+ }
+}
+
+// handleCompletionEndpoint Handles the completion endpoint by the test server.
+func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
+ var err error
+ var resBytes []byte
+
+ // completions only accepts POST requests
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ var completionReq CompletionRequest
+ if completionReq, err = getCompletionBody(r); err != nil {
+ http.Error(w, "could not read request", http.StatusInternalServerError)
+ return
+ }
+ res := CompletionResponse{
+ ID: strconv.Itoa(int(time.Now().Unix())),
+ Object: "test-object",
+ Created: time.Now().Unix(),
+ // would be nice to validate Model during testing, but
+ // this may not be possible with how much upkeep
+ // would be required / wouldn't make much sense
+ Model: completionReq.Model,
+ }
+ // create completions
+ for i := 0; i < completionReq.N; i++ {
+ // generate a random string of length completionReq.Length
+ completionStr := strings.Repeat("a", completionReq.MaxTokens)
+ if completionReq.Echo {
+ completionStr = completionReq.Prompt + completionStr
+ }
+ res.Choices = append(res.Choices, CompletionChoice{
+ Text: completionStr,
+ Index: i,
+ })
+ }
+ inputTokens := numTokens(completionReq.Prompt) * completionReq.N
+ completionTokens := completionReq.MaxTokens * completionReq.N
+ res.Usage = Usage{
+ PromptTokens: inputTokens,
+ CompletionTokens: completionTokens,
+ TotalTokens: inputTokens + completionTokens,
+ }
+ resBytes, _ = json.Marshal(res)
+ fmt.Fprintln(w, string(resBytes))
+}
+
+// getCompletionBody Returns the body of the request to create a completion.
+func getCompletionBody(r *http.Request) (CompletionRequest, error) {
+ completion := CompletionRequest{}
+ // read the request body
+ reqBody, err := io.ReadAll(r.Body)
+ if err != nil {
+ return CompletionRequest{}, err
+ }
+ err = json.Unmarshal(reqBody, &completion)
+ if err != nil {
+ return CompletionRequest{}, err
+ }
+ return completion, nil
+}
diff --git a/go-gpt3/config.go b/go-gpt3/config.go
index 8e5f211..e09c256 100644
--- a/go-gpt3/config.go
+++ b/go-gpt3/config.go
@@ -1,11 +1,11 @@
-package gogpt
+package openai
import (
"net/http"
)
const (
- apiURLv1 = "http://chatgpt.zhiyinos.cn/"
+ apiURLv1 = "https://api.openai.com/v1"
defaultEmptyMessagesLimit uint = 300
)
diff --git a/go-gpt3/edits.go b/go-gpt3/edits.go
index 8cfc21c..858a8e5 100644
--- a/go-gpt3/edits.go
+++ b/go-gpt3/edits.go
@@ -1,9 +1,7 @@
-package gogpt
+package openai
import (
- "bytes"
"context"
- "encoding/json"
"net/http"
)
@@ -33,13 +31,7 @@ type EditsResponse struct {
// Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
- var reqBytes []byte
- reqBytes, err = json.Marshal(request)
- if err != nil {
- return
- }
-
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/edits"), bytes.NewBuffer(reqBytes))
+ req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request)
if err != nil {
return
}
diff --git a/go-gpt3/edits_test.go b/go-gpt3/edits_test.go
new file mode 100644
index 0000000..6a16f7c
--- /dev/null
+++ b/go-gpt3/edits_test.go
@@ -0,0 +1,104 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "testing"
+ "time"
+)
+
+// TestEdits Tests the edits endpoint of the API using the mocked server.
+func TestEdits(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/edits", handleEditEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ // create an edit request
+ model := "ada"
+ editReq := EditsRequest{
+ Model: &model,
+ Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " +
+ "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" +
+ " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" +
+ " ex ea commodo consequat. Duis aute irure dolor in reprehe",
+ Instruction: "test instruction",
+ N: 3,
+ }
+ response, err := client.Edits(ctx, editReq)
+ if err != nil {
+ t.Fatalf("Edits error: %v", err)
+ }
+ if len(response.Choices) != editReq.N {
+ t.Fatalf("edits does not properly return the correct number of choices")
+ }
+}
+
+// handleEditEndpoint Handles the edit endpoint by the test server.
+func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
+ var err error
+ var resBytes []byte
+
+ // edits only accepts POST requests
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ var editReq EditsRequest
+ editReq, err = getEditBody(r)
+ if err != nil {
+ http.Error(w, "could not read request", http.StatusInternalServerError)
+ return
+ }
+ // create a response
+ res := EditsResponse{
+ Object: "test-object",
+ Created: time.Now().Unix(),
+ }
+ // edit and calculate token usage
+ editString := "edited by mocked OpenAI server :)"
+ inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N
+ completionTokens := int(float32(len(editString))/4) * editReq.N
+ for i := 0; i < editReq.N; i++ {
+ // instruction will be hidden and only seen by OpenAI
+ res.Choices = append(res.Choices, EditsChoice{
+ Text: editReq.Input + editString,
+ Index: i,
+ })
+ }
+ res.Usage = Usage{
+ PromptTokens: inputTokens,
+ CompletionTokens: completionTokens,
+ TotalTokens: inputTokens + completionTokens,
+ }
+ resBytes, _ = json.Marshal(res)
+ fmt.Fprint(w, string(resBytes))
+}
+
+// getEditBody Returns the body of the request to create an edit.
+func getEditBody(r *http.Request) (EditsRequest, error) {
+ edit := EditsRequest{}
+ // read the request body
+ reqBody, err := io.ReadAll(r.Body)
+ if err != nil {
+ return EditsRequest{}, err
+ }
+ err = json.Unmarshal(reqBody, &edit)
+ if err != nil {
+ return EditsRequest{}, err
+ }
+ return edit, nil
+}
diff --git a/go-gpt3/embeddings.go b/go-gpt3/embeddings.go
index bfdb802..2deaccc 100644
--- a/go-gpt3/embeddings.go
+++ b/go-gpt3/embeddings.go
@@ -1,9 +1,7 @@
-package gogpt
+package openai
import (
- "bytes"
"context"
- "encoding/json"
"net/http"
)
@@ -103,7 +101,7 @@ var stringToEnum = map[string]EmbeddingModel{
// then their vector representations should also be similar.
type Embedding struct {
Object string `json:"object"`
- Embedding []float64 `json:"embedding"`
+ Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
@@ -134,14 +132,7 @@ type EmbeddingRequest struct {
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
- var reqBytes []byte
- reqBytes, err = json.Marshal(request)
- if err != nil {
- return
- }
-
- urlSuffix := "/embeddings"
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
+ req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request)
if err != nil {
return
}
diff --git a/go-gpt3/embeddings_test.go b/go-gpt3/embeddings_test.go
new file mode 100644
index 0000000..2aa48c5
--- /dev/null
+++ b/go-gpt3/embeddings_test.go
@@ -0,0 +1,48 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+
+ "bytes"
+ "encoding/json"
+ "testing"
+)
+
+func TestEmbedding(t *testing.T) {
+ embeddedModels := []EmbeddingModel{
+ AdaSimilarity,
+ BabbageSimilarity,
+ CurieSimilarity,
+ DavinciSimilarity,
+ AdaSearchDocument,
+ AdaSearchQuery,
+ BabbageSearchDocument,
+ BabbageSearchQuery,
+ CurieSearchDocument,
+ CurieSearchQuery,
+ DavinciSearchDocument,
+ DavinciSearchQuery,
+ AdaCodeSearchCode,
+ AdaCodeSearchText,
+ BabbageCodeSearchCode,
+ BabbageCodeSearchText,
+ }
+ for _, model := range embeddedModels {
+ embeddingReq := EmbeddingRequest{
+ Input: []string{
+ "The food was delicious and the waiter",
+ "Other examples of embedding request",
+ },
+ Model: model,
+ }
+ // marshal embeddingReq to JSON and confirm that the model field equals
+ // the AdaSearchQuery type
+ marshaled, err := json.Marshal(embeddingReq)
+ if err != nil {
+ t.Fatalf("Could not marshal embedding request: %v", err)
+ }
+ if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
+ t.Fatalf("Expected embedding request to contain model field")
+ }
+ }
+}
diff --git a/go-gpt3/engines.go b/go-gpt3/engines.go
index 2805f15..bb6a66c 100644
--- a/go-gpt3/engines.go
+++ b/go-gpt3/engines.go
@@ -1,4 +1,4 @@
-package gogpt
+package openai
import (
"context"
@@ -22,7 +22,7 @@ type EnginesList struct {
// ListEngines Lists the currently available engines, and provides basic
// information about each option such as the owner and availability.
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/engines"), nil)
+ req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
if err != nil {
return
}
@@ -38,7 +38,7 @@ func (c *Client) GetEngine(
engineID string,
) (engine Engine, err error) {
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
+ req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}
diff --git a/go-gpt3/error.go b/go-gpt3/error.go
index 927fafd..d041da2 100644
--- a/go-gpt3/error.go
+++ b/go-gpt3/error.go
@@ -1,4 +1,4 @@
-package gogpt
+package openai
import "fmt"
diff --git a/go-gpt3/files.go b/go-gpt3/files.go
index 385bb52..ec441c3 100644
--- a/go-gpt3/files.go
+++ b/go-gpt3/files.go
@@ -1,4 +1,4 @@
-package gogpt
+package openai
import (
"bytes"
@@ -112,7 +112,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
// DeleteFile deletes an existing file.
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
+ req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
if err != nil {
return
}
@@ -124,7 +124,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
// ListFiles Lists the currently available files,
// and provides basic information about each file such as the file name and purpose.
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/files"), nil)
+ req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil)
if err != nil {
return
}
@@ -137,7 +137,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
// such as the file name and purpose.
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
urlSuffix := fmt.Sprintf("/files/%s", fileID)
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
+ req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}
diff --git a/go-gpt3/files_test.go b/go-gpt3/files_test.go
new file mode 100644
index 0000000..6a78ce1
--- /dev/null
+++ b/go-gpt3/files_test.go
@@ -0,0 +1,81 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strconv"
+ "testing"
+ "time"
+)
+
+func TestFileUpload(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/files", handleCreateFile)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ req := FileRequest{
+ FileName: "test.go",
+ FilePath: "api.go",
+ Purpose: "fine-tune",
+ }
+ _, err = client.CreateFile(ctx, req)
+ if err != nil {
+ t.Fatalf("CreateFile error: %v", err)
+ }
+}
+
+// handleCreateFile Handles the images endpoint by the test server.
+func handleCreateFile(w http.ResponseWriter, r *http.Request) {
+ var err error
+ var resBytes []byte
+
+ // edits only accepts POST requests
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ err = r.ParseMultipartForm(1024 * 1024 * 1024)
+ if err != nil {
+ http.Error(w, "file is more than 1GB", http.StatusInternalServerError)
+ return
+ }
+
+ values := r.Form
+ var purpose string
+ for key, value := range values {
+ if key == "purpose" {
+ purpose = value[0]
+ }
+ }
+ file, header, err := r.FormFile("file")
+ if err != nil {
+ return
+ }
+ defer file.Close()
+
+ var fileReq = File{
+ Bytes: int(header.Size),
+ ID: strconv.Itoa(int(time.Now().Unix())),
+ FileName: header.Filename,
+ Purpose: purpose,
+ CreatedAt: time.Now().Unix(),
+ Object: "test-objecct",
+ Owner: "test-owner",
+ }
+
+ resBytes, _ = json.Marshal(fileReq)
+ fmt.Fprint(w, string(resBytes))
+}
diff --git a/go-gpt3/fine_tunes.go b/go-gpt3/fine_tunes.go
new file mode 100644
index 0000000..a121867
--- /dev/null
+++ b/go-gpt3/fine_tunes.go
@@ -0,0 +1,130 @@
+package openai
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+)
+
+type FineTuneRequest struct {
+ TrainingFile string `json:"training_file"`
+ ValidationFile string `json:"validation_file,omitempty"`
+ Model string `json:"model,omitempty"`
+ Epochs int `json:"n_epochs,omitempty"`
+ BatchSize int `json:"batch_size,omitempty"`
+ LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"`
+ PromptLossRate float32 `json:"prompt_loss_rate,omitempty"`
+ ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"`
+ ClassificationClasses int `json:"classification_n_classes,omitempty"`
+ ClassificationPositiveClass string `json:"classification_positive_class,omitempty"`
+ ClassificationBetas []float32 `json:"classification_betas,omitempty"`
+ Suffix string `json:"suffix,omitempty"`
+}
+
+type FineTune struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Model string `json:"model"`
+ CreatedAt int64 `json:"created_at"`
+ FineTuneEventList []FineTuneEvent `json:"events,omitempty"`
+ FineTunedModel string `json:"fine_tuned_model"`
+ HyperParams FineTuneHyperParams `json:"hyperparams"`
+ OrganizationID string `json:"organization_id"`
+ ResultFiles []File `json:"result_files"`
+ Status string `json:"status"`
+ ValidationFiles []File `json:"validation_files"`
+ TrainingFiles []File `json:"training_files"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+type FineTuneEvent struct {
+ Object string `json:"object"`
+ CreatedAt int64 `json:"created_at"`
+ Level string `json:"level"`
+ Message string `json:"message"`
+}
+
+type FineTuneHyperParams struct {
+ BatchSize int `json:"batch_size"`
+ LearningRateMultiplier float64 `json:"learning_rate_multiplier"`
+ Epochs int `json:"n_epochs"`
+ PromptLossWeight float64 `json:"prompt_loss_weight"`
+}
+
+type FineTuneList struct {
+ Object string `json:"object"`
+ Data []FineTune `json:"data"`
+}
+type FineTuneEventList struct {
+ Object string `json:"object"`
+ Data []FineTuneEvent `json:"data"`
+}
+
+type FineTuneDeleteResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Deleted bool `json:"deleted"`
+}
+
+func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
+ urlSuffix := "/fine-tunes"
+ req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
+ if err != nil {
+ return
+ }
+
+ err = c.sendRequest(req, &response)
+ return
+}
+
+// CancelFineTune cancel a fine-tune job.
+func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
+ req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
+ if err != nil {
+ return
+ }
+
+ err = c.sendRequest(req, &response)
+ return
+}
+
+func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
+ req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
+ if err != nil {
+ return
+ }
+
+ err = c.sendRequest(req, &response)
+ return
+}
+
+func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
+ urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
+ req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
+ if err != nil {
+ return
+ }
+
+ err = c.sendRequest(req, &response)
+ return
+}
+
+func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
+ req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
+ if err != nil {
+ return
+ }
+
+ err = c.sendRequest(req, &response)
+ return
+}
+
+func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
+ req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
+ if err != nil {
+ return
+ }
+
+ err = c.sendRequest(req, &response)
+ return
+}
diff --git a/go-gpt3/fine_tunes_test.go b/go-gpt3/fine_tunes_test.go
new file mode 100644
index 0000000..1f6f967
--- /dev/null
+++ b/go-gpt3/fine_tunes_test.go
@@ -0,0 +1,101 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+)
+
+const testFineTuneID = "fine-tune-id"
+
+// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server.
+func TestFineTunes(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler(
+ "/v1/fine-tunes",
+ func(w http.ResponseWriter, r *http.Request) {
+ var resBytes []byte
+ if r.Method == http.MethodGet {
+ resBytes, _ = json.Marshal(FineTuneList{})
+ } else {
+ resBytes, _ = json.Marshal(FineTune{})
+ }
+ fmt.Fprintln(w, string(resBytes))
+ },
+ )
+
+ server.RegisterHandler(
+ "/v1/fine-tunes/"+testFineTuneID+"/cancel",
+ func(w http.ResponseWriter, r *http.Request) {
+ resBytes, _ := json.Marshal(FineTune{})
+ fmt.Fprintln(w, string(resBytes))
+ },
+ )
+
+ server.RegisterHandler(
+ "/v1/fine-tunes/"+testFineTuneID,
+ func(w http.ResponseWriter, r *http.Request) {
+ var resBytes []byte
+ if r.Method == http.MethodDelete {
+ resBytes, _ = json.Marshal(FineTuneDeleteResponse{})
+ } else {
+ resBytes, _ = json.Marshal(FineTune{})
+ }
+ fmt.Fprintln(w, string(resBytes))
+ },
+ )
+
+ server.RegisterHandler(
+ "/v1/fine-tunes/"+testFineTuneID+"/events",
+ func(w http.ResponseWriter, r *http.Request) {
+ resBytes, _ := json.Marshal(FineTuneEventList{})
+ fmt.Fprintln(w, string(resBytes))
+ },
+ )
+
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ _, err = client.ListFineTunes(ctx)
+ if err != nil {
+ t.Fatalf("ListFineTunes error: %v", err)
+ }
+
+ _, err = client.CreateFineTune(ctx, FineTuneRequest{})
+ if err != nil {
+ t.Fatalf("CreateFineTune error: %v", err)
+ }
+
+ _, err = client.CancelFineTune(ctx, testFineTuneID)
+ if err != nil {
+ t.Fatalf("CancelFineTune error: %v", err)
+ }
+
+ _, err = client.GetFineTune(ctx, testFineTuneID)
+ if err != nil {
+ t.Fatalf("GetFineTune error: %v", err)
+ }
+
+ _, err = client.DeleteFineTune(ctx, testFineTuneID)
+ if err != nil {
+ t.Fatalf("DeleteFineTune error: %v", err)
+ }
+
+ _, err = client.ListFineTuneEvents(ctx, testFineTuneID)
+ if err != nil {
+ t.Fatalf("ListFineTuneEvents error: %v", err)
+ }
+}
diff --git a/go-gpt3/go.mod b/go-gpt3/go.mod
new file mode 100644
index 0000000..42cc7b3
--- /dev/null
+++ b/go-gpt3/go.mod
@@ -0,0 +1,3 @@
+module github.com/sashabaranov/go-openai
+
+go 1.18
diff --git a/go-gpt3/image.go b/go-gpt3/image.go
index 07ecaa7..c0dfa64 100644
--- a/go-gpt3/image.go
+++ b/go-gpt3/image.go
@@ -1,9 +1,8 @@
-package gogpt
+package openai
import (
"bytes"
"context"
- "encoding/json"
"io"
"mime/multipart"
"net/http"
@@ -46,14 +45,8 @@ type ImageResponseDataInner struct {
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
- var reqBytes []byte
- reqBytes, err = json.Marshal(request)
- if err != nil {
- return
- }
-
urlSuffix := "/images/generations"
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
+ req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}
@@ -86,20 +79,65 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
return
}
- // mask
- mask, err := writer.CreateFormFile("mask", request.Mask.Name())
+ // mask, it is optional
+ if request.Mask != nil {
+ mask, err2 := writer.CreateFormFile("mask", request.Mask.Name())
+ if err2 != nil {
+ return
+ }
+ _, err = io.Copy(mask, request.Mask)
+ if err != nil {
+ return
+ }
+ }
+
+ err = writer.WriteField("prompt", request.Prompt)
if err != nil {
return
}
- _, err = io.Copy(mask, request.Mask)
+ err = writer.WriteField("n", strconv.Itoa(request.N))
+ if err != nil {
+ return
+ }
+ err = writer.WriteField("size", request.Size)
+ if err != nil {
+ return
+ }
+ writer.Close()
+ urlSuffix := "/images/edits"
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
if err != nil {
return
}
- err = writer.WriteField("prompt", request.Prompt)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ err = c.sendRequest(req, &response)
+ return
+}
+
+// ImageVariRequest represents the request structure for the image API.
+type ImageVariRequest struct {
+ Image *os.File `json:"image,omitempty"`
+ N int `json:"n,omitempty"`
+ Size string `json:"size,omitempty"`
+}
+
+// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
+// Use abbreviations(vari for variation) because ci-lint has a single-line length limit ...
+func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) (response ImageResponse, err error) {
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+
+ // image
+ image, err := writer.CreateFormFile("image", request.Image.Name())
if err != nil {
return
}
+ _, err = io.Copy(image, request.Image)
+ if err != nil {
+ return
+ }
+
err = writer.WriteField("n", strconv.Itoa(request.N))
if err != nil {
return
@@ -109,7 +147,8 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
return
}
writer.Close()
- urlSuffix := "/images/edits"
+ //https://platform.openai.com/docs/api-reference/images/create-variation
+ urlSuffix := "/images/variations"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
if err != nil {
return
diff --git a/go-gpt3/image_test.go b/go-gpt3/image_test.go
new file mode 100644
index 0000000..b7949c8
--- /dev/null
+++ b/go-gpt3/image_test.go
@@ -0,0 +1,268 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "testing"
+ "time"
+)
+
+func TestImages(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ req := ImageRequest{}
+ req.Prompt = "Lorem ipsum"
+ _, err = client.CreateImage(ctx, req)
+ if err != nil {
+ t.Fatalf("CreateImage error: %v", err)
+ }
+}
+
+// handleImageEndpoint Handles the images endpoint by the test server.
+func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
+ var err error
+ var resBytes []byte
+
+ // imagess only accepts POST requests
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ var imageReq ImageRequest
+ if imageReq, err = getImageBody(r); err != nil {
+ http.Error(w, "could not read request", http.StatusInternalServerError)
+ return
+ }
+ res := ImageResponse{
+ Created: time.Now().Unix(),
+ }
+ for i := 0; i < imageReq.N; i++ {
+ imageData := ImageResponseDataInner{}
+ switch imageReq.ResponseFormat {
+ case CreateImageResponseFormatURL, "":
+ imageData.URL = "https://example.com/image.png"
+ case CreateImageResponseFormatB64JSON:
+ // This decodes to "{}" in base64.
+ imageData.B64JSON = "e30K"
+ default:
+ http.Error(w, "invalid response format", http.StatusBadRequest)
+ return
+ }
+ res.Data = append(res.Data, imageData)
+ }
+ resBytes, _ = json.Marshal(res)
+ fmt.Fprintln(w, string(resBytes))
+}
+
+// getImageBody Returns the body of the request to create a image.
+func getImageBody(r *http.Request) (ImageRequest, error) {
+ image := ImageRequest{}
+ // read the request body
+ reqBody, err := io.ReadAll(r.Body)
+ if err != nil {
+ return ImageRequest{}, err
+ }
+ err = json.Unmarshal(reqBody, &image)
+ if err != nil {
+ return ImageRequest{}, err
+ }
+ return image, nil
+}
+
+func TestImageEdit(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ origin, err := os.Create("image.png")
+ if err != nil {
+ t.Error("open origin file error")
+ return
+ }
+
+ mask, err := os.Create("mask.png")
+ if err != nil {
+ t.Error("open mask file error")
+ return
+ }
+
+ defer func() {
+ mask.Close()
+ origin.Close()
+ os.Remove("mask.png")
+ os.Remove("image.png")
+ }()
+
+ req := ImageEditRequest{
+ Image: origin,
+ Mask: mask,
+ Prompt: "There is a turtle in the pool",
+ N: 3,
+ Size: CreateImageSize1024x1024,
+ }
+ _, err = client.CreateEditImage(ctx, req)
+ if err != nil {
+ t.Fatalf("CreateImage error: %v", err)
+ }
+}
+
+func TestImageEditWithoutMask(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ origin, err := os.Create("image.png")
+ if err != nil {
+ t.Error("open origin file error")
+ return
+ }
+
+ defer func() {
+ origin.Close()
+ os.Remove("image.png")
+ }()
+
+ req := ImageEditRequest{
+ Image: origin,
+ Prompt: "There is a turtle in the pool",
+ N: 3,
+ Size: CreateImageSize1024x1024,
+ }
+ _, err = client.CreateEditImage(ctx, req)
+ if err != nil {
+ t.Fatalf("CreateImage error: %v", err)
+ }
+}
+
+// handleEditImageEndpoint Handles the images endpoint by the test server.
+func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
+ var resBytes []byte
+
+ // imagess only accepts POST requests
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+
+ responses := ImageResponse{
+ Created: time.Now().Unix(),
+ Data: []ImageResponseDataInner{
+ {
+ URL: "test-url1",
+ B64JSON: "",
+ },
+ {
+ URL: "test-url2",
+ B64JSON: "",
+ },
+ {
+ URL: "test-url3",
+ B64JSON: "",
+ },
+ },
+ }
+
+ resBytes, _ = json.Marshal(responses)
+ fmt.Fprintln(w, string(resBytes))
+}
+
+func TestImageVariation(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ origin, err := os.Create("image.png")
+ if err != nil {
+ t.Error("open origin file error")
+ return
+ }
+
+ defer func() {
+ origin.Close()
+ os.Remove("image.png")
+ }()
+
+ req := ImageVariRequest{
+ Image: origin,
+ N: 3,
+ Size: CreateImageSize1024x1024,
+ }
+ _, err = client.CreateVariImage(ctx, req)
+ if err != nil {
+ t.Fatalf("CreateImage error: %v", err)
+ }
+}
+
+// handleVariateImageEndpoint Handles the images endpoint by the test server.
+func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
+ var resBytes []byte
+
+ // imagess only accepts POST requests
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+
+ responses := ImageResponse{
+ Created: time.Now().Unix(),
+ Data: []ImageResponseDataInner{
+ {
+ URL: "test-url1",
+ B64JSON: "",
+ },
+ {
+ URL: "test-url2",
+ B64JSON: "",
+ },
+ {
+ URL: "test-url3",
+ B64JSON: "",
+ },
+ },
+ }
+
+ resBytes, _ = json.Marshal(responses)
+ fmt.Fprintln(w, string(resBytes))
+}
diff --git a/go-gpt3/marshaller.go b/go-gpt3/marshaller.go
new file mode 100644
index 0000000..308ccd1
--- /dev/null
+++ b/go-gpt3/marshaller.go
@@ -0,0 +1,15 @@
+package openai
+
+import (
+ "encoding/json"
+)
+
+type marshaller interface {
+ marshal(value any) ([]byte, error)
+}
+
+type jsonMarshaller struct{}
+
+func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
+ return json.Marshal(value)
+}
diff --git a/go-gpt3/models.go b/go-gpt3/models.go
index c18e502..2be91aa 100644
--- a/go-gpt3/models.go
+++ b/go-gpt3/models.go
@@ -1,4 +1,4 @@
-package gogpt
+package openai
import (
"context"
@@ -7,7 +7,7 @@ import (
// Model struct represents an OpenAPI model.
type Model struct {
- CreatedAt int64 `json:"created_at"`
+ CreatedAt int64 `json:"created"`
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
@@ -18,7 +18,7 @@ type Model struct {
// Permission struct represents an OpenAPI permission.
type Permission struct {
- CreatedAt int64 `json:"created_at"`
+ CreatedAt int64 `json:"created"`
ID string `json:"id"`
Object string `json:"object"`
AllowCreateEngine bool `json:"allow_create_engine"`
@@ -40,7 +40,7 @@ type ModelsList struct {
// ListModels Lists the currently available models,
// and provides basic information about each model such as the model id and parent.
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/models"), nil)
+ req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil)
if err != nil {
return
}
diff --git a/go-gpt3/models_test.go b/go-gpt3/models_test.go
new file mode 100644
index 0000000..c96ece8
--- /dev/null
+++ b/go-gpt3/models_test.go
@@ -0,0 +1,39 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+)
+
+// TestListModels Tests the models endpoint of the API using the mocked server.
+func TestListModels(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/models", handleModelsEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ _, err = client.ListModels(ctx)
+ if err != nil {
+ t.Fatalf("ListModels error: %v", err)
+ }
+}
+
+// handleModelsEndpoint Handles the models endpoint by the test server.
+func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) {
+ resBytes, _ := json.Marshal(ModelsList{})
+ fmt.Fprintln(w, string(resBytes))
+}
diff --git a/go-gpt3/moderation.go b/go-gpt3/moderation.go
index 1849e10..ff789a6 100644
--- a/go-gpt3/moderation.go
+++ b/go-gpt3/moderation.go
@@ -1,9 +1,7 @@
-package gogpt
+package openai
import (
- "bytes"
"context"
- "encoding/json"
"net/http"
)
@@ -52,13 +50,7 @@ type ModerationResponse struct {
// Moderations — perform a moderation api call over a string.
// Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
- var reqBytes []byte
- reqBytes, err = json.Marshal(request)
- if err != nil {
- return
- }
-
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/moderations"), bytes.NewBuffer(reqBytes))
+ req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request)
if err != nil {
return
}
diff --git a/go-gpt3/moderation_test.go b/go-gpt3/moderation_test.go
new file mode 100644
index 0000000..f501245
--- /dev/null
+++ b/go-gpt3/moderation_test.go
@@ -0,0 +1,102 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+// TestModeration Tests the moderations endpoint of the API using the mocked server.
+func TestModerations(t *testing.T) {
+ server := test.NewTestServer()
+ server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
+ // create the test server
+ var err error
+ ts := server.OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ // create an edit request
+ model := "text-moderation-stable"
+ moderationReq := ModerationRequest{
+ Model: &model,
+ Input: "I want to kill them.",
+ }
+ _, err = client.Moderations(ctx, moderationReq)
+ if err != nil {
+ t.Fatalf("Moderation error: %v", err)
+ }
+}
+
+// handleModerationEndpoint Handles the moderation endpoint by the test server.
+func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
+ var err error
+ var resBytes []byte
+
+ // completions only accepts POST requests
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ var moderationReq ModerationRequest
+ if moderationReq, err = getModerationBody(r); err != nil {
+ http.Error(w, "could not read request", http.StatusInternalServerError)
+ return
+ }
+
+ resCat := ResultCategories{}
+ resCatScore := ResultCategoryScores{}
+ switch {
+ case strings.Contains(moderationReq.Input, "kill"):
+ resCat = ResultCategories{Violence: true}
+ resCatScore = ResultCategoryScores{Violence: 1}
+ case strings.Contains(moderationReq.Input, "hate"):
+ resCat = ResultCategories{Hate: true}
+ resCatScore = ResultCategoryScores{Hate: 1}
+ case strings.Contains(moderationReq.Input, "suicide"):
+ resCat = ResultCategories{SelfHarm: true}
+ resCatScore = ResultCategoryScores{SelfHarm: 1}
+ case strings.Contains(moderationReq.Input, "porn"):
+ resCat = ResultCategories{Sexual: true}
+ resCatScore = ResultCategoryScores{Sexual: 1}
+ }
+
+ result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true}
+
+ res := ModerationResponse{
+ ID: strconv.Itoa(int(time.Now().Unix())),
+ Model: *moderationReq.Model,
+ }
+ res.Results = append(res.Results, result)
+
+ resBytes, _ = json.Marshal(res)
+ fmt.Fprintln(w, string(resBytes))
+}
+
+// getModerationBody Returns the body of the request to do a moderation.
+func getModerationBody(r *http.Request) (ModerationRequest, error) {
+ moderation := ModerationRequest{}
+ // read the request body
+ reqBody, err := io.ReadAll(r.Body)
+ if err != nil {
+ return ModerationRequest{}, err
+ }
+ err = json.Unmarshal(reqBody, &moderation)
+ if err != nil {
+ return ModerationRequest{}, err
+ }
+ return moderation, nil
+}
diff --git a/go-gpt3/request_builder.go b/go-gpt3/request_builder.go
new file mode 100644
index 0000000..f0cef10
--- /dev/null
+++ b/go-gpt3/request_builder.go
@@ -0,0 +1,40 @@
+package openai
+
+import (
+ "bytes"
+ "context"
+ "net/http"
+)
+
+type requestBuilder interface {
+ build(ctx context.Context, method, url string, request any) (*http.Request, error)
+}
+
+type httpRequestBuilder struct {
+ marshaller marshaller
+}
+
+func newRequestBuilder() *httpRequestBuilder {
+ return &httpRequestBuilder{
+ marshaller: &jsonMarshaller{},
+ }
+}
+
+func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) {
+ if request == nil {
+ return http.NewRequestWithContext(ctx, method, url, nil)
+ }
+
+ var reqBytes []byte
+ reqBytes, err := b.marshaller.marshal(request)
+ if err != nil {
+ return nil, err
+ }
+
+ return http.NewRequestWithContext(
+ ctx,
+ method,
+ url,
+ bytes.NewBuffer(reqBytes),
+ )
+}
diff --git a/go-gpt3/request_builder_test.go b/go-gpt3/request_builder_test.go
new file mode 100644
index 0000000..533977a
--- /dev/null
+++ b/go-gpt3/request_builder_test.go
@@ -0,0 +1,148 @@
+package openai //nolint:testpackage // testing private field
+
+import (
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "errors"
+ "net/http"
+ "testing"
+)
+
+var (
+ errTestMarshallerFailed = errors.New("test marshaller failed")
+ errTestRequestBuilderFailed = errors.New("test request builder failed")
+)
+
+type (
+ failingRequestBuilder struct{}
+ failingMarshaller struct{}
+)
+
+func (*failingMarshaller) marshal(value any) ([]byte, error) {
+ return []byte{}, errTestMarshallerFailed
+}
+
+func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) {
+ return nil, errTestRequestBuilderFailed
+}
+
+func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
+ builder := httpRequestBuilder{
+ marshaller: &failingMarshaller{},
+ }
+
+ _, err := builder.build(context.Background(), "", "", struct{}{})
+ if !errors.Is(err, errTestMarshallerFailed) {
+ t.Fatalf("Did not return error when marshaller failed: %v", err)
+ }
+}
+
+func TestClientReturnsRequestBuilderErrors(t *testing.T) {
+ var err error
+ ts := test.NewTestServer().OpenAITestServer()
+ ts.Start()
+ defer ts.Close()
+
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = ts.URL + "/v1"
+ client := NewClientWithConfig(config)
+ client.requestBuilder = &failingRequestBuilder{}
+
+ ctx := context.Background()
+
+ _, err = client.CreateCompletion(ctx, CompletionRequest{})
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.CreateFineTune(ctx, FineTuneRequest{})
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.ListFineTunes(ctx)
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.CancelFineTune(ctx, "")
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.GetFineTune(ctx, "")
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.DeleteFineTune(ctx, "")
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.ListFineTuneEvents(ctx, "")
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.Moderations(ctx, ModerationRequest{})
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.Edits(ctx, EditsRequest{})
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.CreateImage(ctx, ImageRequest{})
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ err = client.DeleteFile(ctx, "")
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.GetFile(ctx, "")
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.ListFiles(ctx)
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.ListEngines(ctx)
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.GetEngine(ctx, "")
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+
+ _, err = client.ListModels(ctx)
+ if !errors.Is(err, errTestRequestBuilderFailed) {
+ t.Fatalf("Did not return error when request builder failed: %v", err)
+ }
+}
diff --git a/go-gpt3/stream.go b/go-gpt3/stream.go
index 8f35d2e..0eed4aa 100644
--- a/go-gpt3/stream.go
+++ b/go-gpt3/stream.go
@@ -1,4 +1,4 @@
-package gogpt
+package openai
import (
"bufio"
@@ -6,7 +6,6 @@ import (
"context"
"encoding/json"
"errors"
- "fmt"
"io"
"net/http"
)
@@ -60,43 +59,6 @@ waitForData:
return
}
-func (stream *CompletionStream) ChatRecv() (response ChatCompletionResponseForStream, err error) {
- if stream.isFinished {
- err = io.EOF
- return
- }
-
- var emptyMessagesCount uint
-
-waitForData:
- line, err := stream.reader.ReadBytes('\n')
- if err != nil {
- return
- }
-
- var headerData = []byte("data: ")
- line = bytes.TrimSpace(line)
- if !bytes.HasPrefix(line, headerData) {
- emptyMessagesCount++
- if emptyMessagesCount > stream.emptyMessagesLimit {
- err = ErrTooManyEmptyStreamMessages
- return
- }
-
- goto waitForData
- }
-
- line = bytes.TrimPrefix(line, headerData)
- if string(line) == "[DONE]" {
- stream.isFinished = true
- err = io.EOF
- return
- }
-
- err = json.Unmarshal(line, &response)
- return
-}
-
func (stream *CompletionStream) Close() {
stream.response.Body.Close()
}
@@ -110,18 +72,7 @@ func (c *Client) CreateCompletionStream(
request CompletionRequest,
) (stream *CompletionStream, err error) {
request.Stream = true
- reqBytes, err := json.Marshal(request)
- if err != nil {
- return
- }
-
- urlSuffix := "/completions"
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "text/event-stream")
- req.Header.Set("Cache-Control", "no-cache")
- req.Header.Set("Connection", "keep-alive")
- req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
+ req, err := c.newStreamRequest(ctx, "POST", "/completions", request)
if err != nil {
return
}
diff --git a/go-gpt3/stream_test.go b/go-gpt3/stream_test.go
new file mode 100644
index 0000000..8f89e6b
--- /dev/null
+++ b/go-gpt3/stream_test.go
@@ -0,0 +1,147 @@
+package openai_test
+
+import (
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test"
+
+ "context"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestCreateCompletionStream(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ // Send test responses
+ dataBytes := []byte{}
+ dataBytes = append(dataBytes, []byte("event: message\n")...)
+ //nolint:lll
+ data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
+ dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
+
+ dataBytes = append(dataBytes, []byte("event: message\n")...)
+ //nolint:lll
+ data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}`
+ dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
+
+ dataBytes = append(dataBytes, []byte("event: done\n")...)
+ dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
+
+ _, err := w.Write(dataBytes)
+ if err != nil {
+ t.Errorf("Write error: %s", err)
+ }
+ }))
+ defer server.Close()
+
+ // Client portion of the test
+ config := DefaultConfig(test.GetTestToken())
+ config.BaseURL = server.URL + "/v1"
+ config.HTTPClient.Transport = &tokenRoundTripper{
+ test.GetTestToken(),
+ http.DefaultTransport,
+ }
+
+ client := NewClientWithConfig(config)
+ ctx := context.Background()
+
+ request := CompletionRequest{
+ Prompt: "Ex falso quodlibet",
+ Model: "text-davinci-002",
+ MaxTokens: 10,
+ Stream: true,
+ }
+
+ stream, err := client.CreateCompletionStream(ctx, request)
+ if err != nil {
+ t.Errorf("CreateCompletionStream returned error: %v", err)
+ }
+ defer stream.Close()
+
+ expectedResponses := []CompletionResponse{
+ {
+ ID: "1",
+ Object: "completion",
+ Created: 1598069254,
+ Model: "text-davinci-002",
+ Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}},
+ },
+ {
+ ID: "2",
+ Object: "completion",
+ Created: 1598069255,
+ Model: "text-davinci-002",
+ Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
+ },
+ }
+
+ for ix, expectedResponse := range expectedResponses {
+ receivedResponse, streamErr := stream.Recv()
+ if streamErr != nil {
+ t.Errorf("stream.Recv() failed: %v", streamErr)
+ }
+ if !compareResponses(expectedResponse, receivedResponse) {
+ t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
+ }
+ }
+
+ _, streamErr := stream.Recv()
+ if !errors.Is(streamErr, io.EOF) {
+ t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
+ }
+
+ _, streamErr = stream.Recv()
+ if !errors.Is(streamErr, io.EOF) {
+ t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
+ }
+}
+
+// A "tokenRoundTripper" is a struct that implements the RoundTripper
+// interface, specifically to handle the authentication token by adding a token
+// to the request header. We need this because the API requires that each
+// request include a valid API token in the headers for authentication and
+// authorization.
+type tokenRoundTripper struct {
+ token string
+ fallback http.RoundTripper
+}
+
+// RoundTrip takes an *http.Request as input and returns an
+// *http.Response and an error.
+//
+// It is expected to use the provided request to create a connection to an HTTP
+// server and return the response, or an error if one occurred. The returned
+// Response should have its Body closed. If the RoundTrip method returns an
+// error, the Client's Get, Head, Post, and PostForm methods return the same
+// error.
+func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ req.Header.Set("Authorization", "Bearer "+t.token)
+ return t.fallback.RoundTrip(req)
+}
+
+// Helper funcs.
+func compareResponses(r1, r2 CompletionResponse) bool {
+ if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
+ return false
+ }
+ if len(r1.Choices) != len(r2.Choices) {
+ return false
+ }
+ for i := range r1.Choices {
+ if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+func compareResponseChoices(c1, c2 CompletionChoice) bool {
+ if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason {
+ return false
+ }
+ return true
+}