Browse Source

add Reverse: for v0.0.5

tags/v0.0.5
DengBiao 1 year ago
parent
commit
318faa37a5
37 changed files with 2488 additions and 263 deletions
  1. +27
    -0
      go-gpt3/.github/workflows/pr.yml
  2. +258
    -15
      go-gpt3/README.md
  3. +0
    -50
      go-gpt3/answers.go
  4. +30
    -5
      go-gpt3/api.go
  5. +182
    -0
      go-gpt3/api_test.go
  6. +1
    -1
      go-gpt3/audio.go
  7. +143
    -0
      go-gpt3/audio_test.go
  8. +19
    -73
      go-gpt3/chat.go
  9. +106
    -0
      go-gpt3/chat_stream.go
  10. +153
    -0
      go-gpt3/chat_stream_test.go
  11. +132
    -0
      go-gpt3/chat_test.go
  12. +1
    -1
      go-gpt3/common.go
  13. +13
    -7
      go-gpt3/completion.go
  14. +121
    -0
      go-gpt3/completion_test.go
  15. +2
    -2
      go-gpt3/config.go
  16. +2
    -10
      go-gpt3/edits.go
  17. +104
    -0
      go-gpt3/edits_test.go
  18. +3
    -12
      go-gpt3/embeddings.go
  19. +48
    -0
      go-gpt3/embeddings_test.go
  20. +3
    -3
      go-gpt3/engines.go
  21. +1
    -1
      go-gpt3/error.go
  22. +4
    -4
      go-gpt3/files.go
  23. +81
    -0
      go-gpt3/files_test.go
  24. +130
    -0
      go-gpt3/fine_tunes.go
  25. +101
    -0
      go-gpt3/fine_tunes_test.go
  26. +3
    -0
      go-gpt3/go.mod
  27. +53
    -14
      go-gpt3/image.go
  28. +268
    -0
      go-gpt3/image_test.go
  29. +15
    -0
      go-gpt3/marshaller.go
  30. +4
    -4
      go-gpt3/models.go
  31. +39
    -0
      go-gpt3/models_test.go
  32. +2
    -10
      go-gpt3/moderation.go
  33. +102
    -0
      go-gpt3/moderation_test.go
  34. +40
    -0
      go-gpt3/request_builder.go
  35. +148
    -0
      go-gpt3/request_builder_test.go
  36. +2
    -51
      go-gpt3/stream.go
  37. +147
    -0
      go-gpt3/stream_test.go

+ 27
- 0
go-gpt3/.github/workflows/pr.yml View File

@@ -0,0 +1,27 @@
name: PR sanity check

on:
- push
- pull_request

jobs:
prcheck:
name: PR sanity check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: '1.19'
- name: Run vet
run: |
go vet .
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v3
with:
version: latest
- name: Run tests
run: go test -race -covermode=atomic -coverprofile=coverage.out -v .
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3

+ 258
- 15
go-gpt3/README.md View File

@@ -1,17 +1,65 @@
# go-gpt3
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gpt3)
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gpt3)](https://goreportcard.com/report/github.com/sashabaranov/go-gpt3)
# Go OpenAI
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai)
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai)
[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai)

> **Note**: the repository was recently renamed from `go-gpt3` to `go-openai`

[OpenAI ChatGPT and GPT-3](https://platform.openai.com/) API client for Go
This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support:

* ChatGPT
* GPT-3, GPT-4
* DALL·E 2
* Whisper

Installation:
```
go get github.com/sashabaranov/go-gpt3
go get github.com/sashabaranov/go-openai
```


ChatGPT example usage:

```go
package main

import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
)

func main() {
client := openai.NewClient("your token")
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
},
)

if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
return
}

fmt.Println(resp.Choices[0].Message.Content)
}

```


Example usage:

Other examples:

<details>
<summary>GPT-3 completion</summary>

```go
package main
@@ -19,27 +67,30 @@ package main
import (
"context"
"fmt"
gogpt "github.com/sashabaranov/go-gpt3"
openai "github.com/sashabaranov/go-openai"
)

func main() {
c := gogpt.NewClient("your token")
c := openai.NewClient("your token")
ctx := context.Background()

req := gogpt.CompletionRequest{
Model: gogpt.GPT3Ada,
req := openai.CompletionRequest{
Model: openai.GPT3Ada,
MaxTokens: 5,
Prompt: "Lorem ipsum",
}
resp, err := c.CreateCompletion(ctx, req)
if err != nil {
fmt.Printf("Completion error: %v\n", err)
return
}
fmt.Println(resp.Choices[0].Text)
}
```
</details>

Streaming response example:
<details>
<summary>GPT-3 streaming completion</summary>

```go
package main
@@ -49,21 +100,22 @@ import (
"context"
"fmt"
"io"
gogpt "github.com/sashabaranov/go-gpt3"
openai "github.com/sashabaranov/go-openai"
)

func main() {
c := gogpt.NewClient("your token")
c := openai.NewClient("your token")
ctx := context.Background()

req := gogpt.CompletionRequest{
Model: gogpt.GPT3Ada,
req := openai.CompletionRequest{
Model: openai.GPT3Ada,
MaxTokens: 5,
Prompt: "Lorem ipsum",
Stream: true,
}
stream, err := c.CreateCompletionStream(ctx, req)
if err != nil {
fmt.Printf("CompletionStream error: %v\n", err)
return
}
defer stream.Close()
@@ -85,3 +137,194 @@ func main() {
}
}
```
</details>

<details>
<summary>Audio Speech-To-Text</summary>

```go
package main

import (
"context"
"fmt"

openai "github.com/sashabaranov/go-openai"
)

func main() {
c := openai.NewClient("your token")
ctx := context.Background()

req := openai.AudioRequest{
Model: openai.Whisper1,
FilePath: "recording.mp3",
}
resp, err := c.CreateTranscription(ctx, req)
if err != nil {
fmt.Printf("Transcription error: %v\n", err)
return
}
fmt.Println(resp.Text)
}
```
</details>

<details>
<summary>DALL-E 2 image generation</summary>

```go
package main

import (
"bytes"
"context"
"encoding/base64"
"fmt"
openai "github.com/sashabaranov/go-openai"
"image/png"
"os"
)

func main() {
c := openai.NewClient("your token")
ctx := context.Background()

// Sample image by link
reqUrl := openai.ImageRequest{
Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail",
Size: openai.CreateImageSize256x256,
ResponseFormat: openai.CreateImageResponseFormatURL,
N: 1,
}

respUrl, err := c.CreateImage(ctx, reqUrl)
if err != nil {
fmt.Printf("Image creation error: %v\n", err)
return
}
fmt.Println(respUrl.Data[0].URL)

// Example image as base64
reqBase64 := openai.ImageRequest{
Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine",
Size: openai.CreateImageSize256x256,
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
N: 1,
}

respBase64, err := c.CreateImage(ctx, reqBase64)
if err != nil {
fmt.Printf("Image creation error: %v\n", err)
return
}

imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON)
if err != nil {
fmt.Printf("Base64 decode error: %v\n", err)
return
}

r := bytes.NewReader(imgBytes)
imgData, err := png.Decode(r)
if err != nil {
fmt.Printf("PNG decode error: %v\n", err)
return
}

file, err := os.Create("image.png")
if err != nil {
fmt.Printf("File creation error: %v\n", err)
return
}
defer file.Close()

if err := png.Encode(file, imgData); err != nil {
fmt.Printf("PNG encode error: %v\n", err)
return
}

fmt.Println("The image was saved as example.png")
}

```
</details>

<details>
<summary>Configuring proxy</summary>

```go
config := openai.DefaultConfig("token")
proxyUrl, err := url.Parse("http://localhost:{port}")
if err != nil {
panic(err)
}
transport := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}
config.HTTPClient = &http.Client{
Transport: transport,
}

c := openai.NewClientWithConfig(config)
```

See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig
</details>

<details>
<summary>ChatGPT support context</summary>

```go
package main

import (
"bufio"
"context"
"fmt"
"os"
"strings"

"github.com/sashabaranov/go-openai"
)

func main() {
client := openai.NewClient("your token")
messages := make([]openai.ChatCompletionMessage, 0)
reader := bufio.NewReader(os.Stdin)
fmt.Println("Conversation")
fmt.Println("---------------------")

for {
fmt.Print("-> ")
text, _ := reader.ReadString('\n')
// convert CRLF to LF
text = strings.Replace(text, "\n", "", -1)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: text,
})

resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: messages,
},
)

if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
continue
}

content := resp.Choices[0].Message.Content
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: content,
})
fmt.Println(content)
}
}
```
</details>

+ 0
- 50
go-gpt3/answers.go View File

@@ -1,50 +0,0 @@
package gogpt

import (
"bytes"
"context"
"encoding/json"
"net/http"
)

type AnswerRequest struct {
Documents []string `json:"documents,omitempty"`
File string `json:"file,omitempty"`
Question string `json:"question"`
SearchModel string `json:"search_model,omitempty"`
Model string `json:"model"`
ExamplesContext string `json:"examples_context"`
Examples [][]string `json:"examples"`
MaxTokens int `json:"max_tokens,omitempty"`
Stop []string `json:"stop,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
}

type AnswerResponse struct {
Answers []string `json:"answers"`
Completion string `json:"completion"`
Model string `json:"model"`
Object string `json:"object"`
SearchModel string `json:"search_model"`
SelectedDocuments []struct {
Document int `json:"document"`
Text string `json:"text"`
} `json:"selected_documents"`
}

// Search — perform a semantic search api call over a list of documents.
func (c *Client) Answers(ctx context.Context, request AnswerRequest) (response AnswerResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/answers"), bytes.NewBuffer(reqBytes))
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

+ 30
- 5
go-gpt3/api.go View File

@@ -1,6 +1,7 @@
package gogpt
package openai

import (
"context"
"encoding/json"
"fmt"
"net/http"
@@ -9,17 +10,22 @@ import (
// Client is OpenAI GPT-3 API client.
type Client struct {
config ClientConfig

requestBuilder requestBuilder
}

// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
config := DefaultConfig(authToken)
return &Client{config}
return NewClientWithConfig(config)
}

// NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{config}
return &Client{
config: config,
requestBuilder: newRequestBuilder(),
}
}

// NewOrgClient creates new OpenAI API client for specified Organization ID.
@@ -28,7 +34,7 @@ func NewClientWithConfig(config ClientConfig) *Client {
func NewOrgClient(authToken, org string) *Client {
config := DefaultConfig(authToken)
config.OrgID = org
return &Client{config}
return NewClientWithConfig(config)
}

func (c *Client) sendRequest(req *http.Request, v interface{}) error {
@@ -68,7 +74,7 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
}

if v != nil {
if err = json.NewDecoder(res.Body).Decode(&v); err != nil {
if err = json.NewDecoder(res.Body).Decode(v); err != nil {
return err
}
}
@@ -79,3 +85,22 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
func (c *Client) fullURL(suffix string) string {
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}

func (c *Client) newStreamRequest(
ctx context.Context,
method string,
urlSuffix string,
body any) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body)
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))

return req, nil
}

+ 182
- 0
go-gpt3/api_test.go View File

@@ -0,0 +1,182 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"

"context"
"errors"
"io"
"os"
"testing"
)

func TestAPI(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := NewClient(apiToken)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err != nil {
t.Fatalf("ListEngines error: %v", err)
}

_, err = c.GetEngine(ctx, "davinci")
if err != nil {
t.Fatalf("GetEngine error: %v", err)
}

fileRes, err := c.ListFiles(ctx)
if err != nil {
t.Fatalf("ListFiles error: %v", err)
}

if len(fileRes.Files) > 0 {
_, err = c.GetFile(ctx, fileRes.Files[0].ID)
if err != nil {
t.Fatalf("GetFile error: %v", err)
}
} // else skip

embeddingReq := EmbeddingRequest{
Input: []string{
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: AdaSearchQuery,
}
_, err = c.CreateEmbeddings(ctx, embeddingReq)
if err != nil {
t.Fatalf("Embedding error: %v", err)
}

_, err = c.CreateChatCompletion(
ctx,
ChatCompletionRequest{
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
},
)

if err != nil {
t.Errorf("CreateChatCompletion (without name) returned error: %v", err)
}

_, err = c.CreateChatCompletion(
ctx,
ChatCompletionRequest{
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Name: "John_Doe",
Content: "Hello!",
},
},
},
)

if err != nil {
t.Errorf("CreateChatCompletion (with name) returned error: %v", err)
}

stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: GPT3Ada,
MaxTokens: 5,
Stream: true,
})
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()

counter := 0
for {
_, err = stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
t.Errorf("Stream error: %v", err)
} else {
counter++
}
}
if counter == 0 {
t.Error("Stream did not return any responses")
}
}

func TestAPIError(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := NewClient(apiToken + "_invalid")
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines did not fail")
}

var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("Error is not an APIError: %+v", err)
}

if apiErr.StatusCode != 401 {
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
}
if *apiErr.Code != "invalid_api_key" {
t.Fatalf("Unexpected API error code: %s", *apiErr.Code)
}
if apiErr.Error() == "" {
t.Fatal("Empty error message occured")
}
}

func TestRequestError(t *testing.T) {
var err error

config := DefaultConfig("dummy")
config.BaseURL = "https://httpbin.org/status/418?"
c := NewClientWithConfig(config)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines request did not fail")
}

var reqErr *RequestError
if !errors.As(err, &reqErr) {
t.Fatalf("Error is not a RequestError: %+v", err)
}

if reqErr.StatusCode != 418 {
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode)
}

if reqErr.Unwrap() == nil {
t.Fatalf("Empty request error occured")
}
}

// numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer
//
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available)
func numTokens(s string) int {
return int(float32(len(s)) / 4)
}

+ 1
- 1
go-gpt3/audio.go View File

@@ -1,4 +1,4 @@
package gogpt
package openai

import (
"bytes"


+ 143
- 0
go-gpt3/audio_test.go View File

@@ -0,0 +1,143 @@
package openai_test

import (
"bytes"
"errors"
"io"
"mime"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"testing"
)

// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
func TestAudio(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)

testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}

ctx := context.Background()

dir, cleanup := createTestDirectory(t)
defer cleanup()

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
createTestFile(t, path)

req := AudioRequest{
FilePath: path,
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
if err != nil {
t.Fatalf("audio API error: %v", err)
}
})
}
}

// createTestFile creates a fake file with "hello" as the content.
func createTestFile(t *testing.T, path string) {
file, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file %v", err)
}
if _, err = file.WriteString("hello"); err != nil {
t.Fatalf("failed to write to file %v", err)
}
file.Close()
}

// createTestDirectory creates a temporary folder which will be deleted when cleanup is called.
func createTestDirectory(t *testing.T) (path string, cleanup func()) {
t.Helper()

path, err := os.MkdirTemp(os.TempDir(), "")
if err != nil {
t.Fatal(err)
}

return path, func() { os.RemoveAll(path) }
}

// handleAudioEndpoint Handles the completion endpoint by the test server.
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
var err error

// audio endpoints only accept POST requests
if r.Method != "POST" {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}

mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
http.Error(w, "failed to parse media type", http.StatusBadRequest)
return
}

if !strings.HasPrefix(mediaType, "multipart") {
http.Error(w, "request is not multipart", http.StatusBadRequest)
}

boundary, ok := params["boundary"]
if !ok {
http.Error(w, "no boundary in params", http.StatusBadRequest)
return
}

fileData := &bytes.Buffer{}
mr := multipart.NewReader(r.Body, boundary)
part, err := mr.NextPart()
if err != nil && errors.Is(err, io.EOF) {
http.Error(w, "error accessing file", http.StatusBadRequest)
return
}
if _, err = io.Copy(fileData, part); err != nil {
http.Error(w, "failed to copy file", http.StatusInternalServerError)
return
}

if len(fileData.Bytes()) == 0 {
w.WriteHeader(http.StatusInternalServerError)
http.Error(w, "received empty file data", http.StatusBadRequest)
return
}

if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
http.Error(w, "failed to write body", http.StatusInternalServerError)
return
}
}

+ 19
- 73
go-gpt3/chat.go View File

@@ -1,15 +1,18 @@
package gogpt
package openai

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
)

// Chat message role defined by the OpenAI API.
const (
ChatMessageRoleSystem = "system"
ChatMessageRoleUser = "user"
ChatMessageRoleAssistant = "assistant"
)

var (
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported")
)
@@ -17,6 +20,12 @@ var (
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`

// This property isn't in the official documentation, but it's in
// the documentation for the official library for python:
// - https://github.com/openai/openai-python/blob/main/chatml.md
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Name string `json:"name,omitempty"`
}

// ChatCompletionRequest represents a request structure for chat completion API.
@@ -41,24 +50,6 @@ type ChatCompletionChoice struct {
FinishReason string `json:"finish_reason"`
}

type ChatCompletionChoiceForStream struct {
Index int `json:"index"`
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
}

// ChatCompletionResponseForStream represents a response structure for chat completion API.
type ChatCompletionResponseForStream struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionChoiceForStream `json:"choices"`
Usage Usage `json:"usage"`
}

// ChatCompletionResponse represents a response structure for chat completion API.
type ChatCompletionResponse struct {
ID string `json:"id"`
@@ -69,25 +60,21 @@ type ChatCompletionResponse struct {
Usage Usage `json:"usage"`
}

// CreateChatCompletion — API call to Creates a completion for the chat message.
// CreateChatCompletion — API call to Create a completion for the chat message.
func (c *Client) CreateChatCompletion(
ctx context.Context,
request ChatCompletionRequest,
) (response ChatCompletionResponse, err error) {
model := request.Model
if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo {
switch model {
case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K:
default:
err = ErrChatCompletionInvalidModel
return
}

var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}

urlSuffix := "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}
@@ -95,44 +82,3 @@ func (c *Client) CreateChatCompletion(
err = c.sendRequest(req, &response)
return
}

// CreateChatCompletionStream — API call to create a completion w/ streaming
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
request ChatCompletionRequest,
) (stream *CompletionStream, err error) {
model := request.Model
if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo {
err = ErrChatCompletionInvalidModel
return
}
request.Stream = true
reqBytes, err := json.Marshal(request)
if err != nil {
return
}

urlSuffix := "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
if err != nil {
return
}

resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
if err != nil {
return
}

stream = &CompletionStream{
emptyMessagesLimit: c.config.EmptyMessagesLimit,

reader: bufio.NewReader(resp.Body),
response: resp,
}
return
}

+ 106
- 0
go-gpt3/chat_stream.go View File

@@ -0,0 +1,106 @@
package openai

import (
"bufio"
"bytes"
"context"
"encoding/json"
"io"
"net/http"
)

type ChatCompletionStreamChoiceDelta struct {
Content string `json:"content"`
}

type ChatCompletionStreamChoice struct {
Index int `json:"index"`
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
FinishReason string `json:"finish_reason"`
}

type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionStreamChoice `json:"choices"`
}

// ChatCompletionStream
// Note: Perhaps it is more elegant to abstract Stream using generics.
type ChatCompletionStream struct {
emptyMessagesLimit uint
isFinished bool

reader *bufio.Reader
response *http.Response
}

func (stream *ChatCompletionStream) Recv() (response ChatCompletionStreamResponse, err error) {
if stream.isFinished {
err = io.EOF
return
}

var emptyMessagesCount uint

waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
return
}

var headerData = []byte("data: ")
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return
}

goto waitForData
}

line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return
}

err = json.Unmarshal(line, &response)
return
}

func (stream *ChatCompletionStream) Close() {
stream.response.Body.Close()
}

// CreateChatCompletionStream — API call to create a chat completion w/ streaming
// support. It sets whether to stream back partial progress. If set, tokens will be
// sent as data-only server-sent events as they become available, with the
// stream terminated by a data: [DONE] message.
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
request ChatCompletionRequest,
) (stream *ChatCompletionStream, err error) {
request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", "/chat/completions", request)
if err != nil {
return
}

resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
if err != nil {
return
}

stream = &ChatCompletionStream{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
}
return
}

+ 153
- 0
go-gpt3/chat_stream_test.go View File

@@ -0,0 +1,153 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
)

func TestCreateChatCompletionStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
}))
defer server.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}

client := NewClientWithConfig(config)
ctx := context.Background()

request := ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
}

stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()

expectedResponses := []ChatCompletionStreamResponse{
{
ID: "1",
Object: "completion",
Created: 1598069254,
Model: GPT3Dot5Turbo,
Choices: []ChatCompletionStreamChoice{
{
Delta: ChatCompletionStreamChoiceDelta{
Content: "response1",
},
FinishReason: "max_tokens",
},
},
},
{
ID: "2",
Object: "completion",
Created: 1598069255,
Model: GPT3Dot5Turbo,
Choices: []ChatCompletionStreamChoice{
{
Delta: ChatCompletionStreamChoiceDelta{
Content: "response2",
},
FinishReason: "max_tokens",
},
},
},
}

for ix, expectedResponse := range expectedResponses {
b, _ := json.Marshal(expectedResponse)
t.Logf("%d: %s", ix, string(b))

receivedResponse, streamErr := stream.Recv()
if streamErr != nil {
t.Errorf("stream.Recv() failed: %v", streamErr)
}
if !compareChatResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
}

_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
}

_, streamErr = stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
}
}

// Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
return false
}
if len(r1.Choices) != len(r2.Choices) {
return false
}
for i := range r1.Choices {
if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) {
return false
}
}
return true
}

func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool {
if c1.Index != c2.Index {
return false
}
if c1.Delta.Content != c2.Delta.Content {
return false
}
if c1.FinishReason != c2.FinishReason {
return false
}
return true
}

+ 132
- 0
go-gpt3/chat_test.go View File

@@ -0,0 +1,132 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"
)

func TestChatCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := ChatCompletionRequest{
MaxTokens: 5,
Model: "ada",
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
}
_, err := client.CreateChatCompletion(ctx, req)
if !errors.Is(err, ErrChatCompletionInvalidModel) {
t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err)
}
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletions(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
}
_, err = client.CreateChatCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateChatCompletion error: %v", err)
}
}

// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq ChatCompletionRequest
if completionReq, err = getChatCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := ChatCompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
for i := 0; i < completionReq.N; i++ {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)

res.Choices = append(res.Choices, ChatCompletionChoice{
Message: ChatCompletionMessage{
Role: ChatMessageRoleAssistant,
Content: completionStr,
},
Index: i,
})
}
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// getChatCompletionBody Returns the body of the request to create a completion.
func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) {
completion := ChatCompletionRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return ChatCompletionRequest{}, err
}
err = json.Unmarshal(reqBody, &completion)
if err != nil {
return ChatCompletionRequest{}, err
}
return completion, nil
}

+ 1
- 1
go-gpt3/common.go View File

@@ -1,5 +1,5 @@
// common.go defines common types used throughout the OpenAI API.
package gogpt
package openai

// Usage Represents the total token usage per request to OpenAI.
type Usage struct {


+ 13
- 7
go-gpt3/completion.go View File

@@ -1,17 +1,24 @@
package gogpt
package openai

import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
)

var (
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
)

// GPT3 Defines the models provided by OpenAI to use when generating
// completions from OpenAI.
// GPT3 Models are designed for text-based tasks. For code-specific
// tasks, please refer to the Codex series of models.
const (
GPT432K0314 = "gpt-4-32k-0314"
GPT432K = "gpt-4-32k"
GPT40314 = "gpt-4-0314"
GPT4 = "gpt-4"
GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301"
GPT3Dot5Turbo = "gpt-3.5-turbo"
GPT3TextDavinci003 = "text-davinci-003"
@@ -92,14 +99,13 @@ func (c *Client) CreateCompletion(
ctx context.Context,
request CompletionRequest,
) (response CompletionResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo {
err = ErrCompletionUnsupportedModel
return
}

urlSuffix := "/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}


+ 121
- 0
go-gpt3/completion_test.go View File

@@ -0,0 +1,121 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"
)

func TestCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)

_, err := client.CreateCompletion(
context.Background(),
CompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
},
)
if !errors.Is(err, ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
}
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestCompletions(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := CompletionRequest{
MaxTokens: 5,
Model: "ada",
}
req.Prompt = "Lorem ipsum"
_, err = client.CreateCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateCompletion error: %v", err)
}
}

// handleCompletionEndpoint Handles the completion endpoint by the test server.
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq CompletionRequest
if completionReq, err = getCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := CompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
for i := 0; i < completionReq.N; i++ {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)
if completionReq.Echo {
completionStr = completionReq.Prompt + completionStr
}
res.Choices = append(res.Choices, CompletionChoice{
Text: completionStr,
Index: i,
})
}
inputTokens := numTokens(completionReq.Prompt) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// getCompletionBody Returns the body of the request to create a completion.
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
completion := CompletionRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return CompletionRequest{}, err
}
err = json.Unmarshal(reqBody, &completion)
if err != nil {
return CompletionRequest{}, err
}
return completion, nil
}

+ 2
- 2
go-gpt3/config.go View File

@@ -1,11 +1,11 @@
package gogpt
package openai

import (
"net/http"
)

const (
apiURLv1 = "http://chatgpt.zhiyinos.cn/"
apiURLv1 = "https://api.openai.com/v1"
defaultEmptyMessagesLimit uint = 300
)



+ 2
- 10
go-gpt3/edits.go View File

@@ -1,9 +1,7 @@
package gogpt
package openai

import (
"bytes"
"context"
"encoding/json"
"net/http"
)

@@ -33,13 +31,7 @@ type EditsResponse struct {

// Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/edits"), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request)
if err != nil {
return
}


+ 104
- 0
go-gpt3/edits_test.go View File

@@ -0,0 +1,104 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"time"
)

// TestEdits Tests the edits endpoint of the API using the mocked server.
func TestEdits(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/edits", handleEditEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

// create an edit request
model := "ada"
editReq := EditsRequest{
Model: &model,
Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " +
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" +
" ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" +
" ex ea commodo consequat. Duis aute irure dolor in reprehe",
Instruction: "test instruction",
N: 3,
}
response, err := client.Edits(ctx, editReq)
if err != nil {
t.Fatalf("Edits error: %v", err)
}
if len(response.Choices) != editReq.N {
t.Fatalf("edits does not properly return the correct number of choices")
}
}

// handleEditEndpoint Handles the edit endpoint by the test server.
func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// edits only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var editReq EditsRequest
editReq, err = getEditBody(r)
if err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
// create a response
res := EditsResponse{
Object: "test-object",
Created: time.Now().Unix(),
}
// edit and calculate token usage
editString := "edited by mocked OpenAI server :)"
inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N
completionTokens := int(float32(len(editString))/4) * editReq.N
for i := 0; i < editReq.N; i++ {
// instruction will be hidden and only seen by OpenAI
res.Choices = append(res.Choices, EditsChoice{
Text: editReq.Input + editString,
Index: i,
})
}
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprint(w, string(resBytes))
}

// getEditBody Returns the body of the request to create an edit.
func getEditBody(r *http.Request) (EditsRequest, error) {
edit := EditsRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return EditsRequest{}, err
}
err = json.Unmarshal(reqBody, &edit)
if err != nil {
return EditsRequest{}, err
}
return edit, nil
}

+ 3
- 12
go-gpt3/embeddings.go View File

@@ -1,9 +1,7 @@
package gogpt
package openai

import (
"bytes"
"context"
"encoding/json"
"net/http"
)

@@ -103,7 +101,7 @@ var stringToEnum = map[string]EmbeddingModel{
// then their vector representations should also be similar.
type Embedding struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}

@@ -134,14 +132,7 @@ type EmbeddingRequest struct {
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}

urlSuffix := "/embeddings"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request)
if err != nil {
return
}


+ 48
- 0
go-gpt3/embeddings_test.go View File

@@ -0,0 +1,48 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"

"bytes"
"encoding/json"
"testing"
)

func TestEmbedding(t *testing.T) {
embeddedModels := []EmbeddingModel{
AdaSimilarity,
BabbageSimilarity,
CurieSimilarity,
DavinciSimilarity,
AdaSearchDocument,
AdaSearchQuery,
BabbageSearchDocument,
BabbageSearchQuery,
CurieSearchDocument,
CurieSearchQuery,
DavinciSearchDocument,
DavinciSearchQuery,
AdaCodeSearchCode,
AdaCodeSearchText,
BabbageCodeSearchCode,
BabbageCodeSearchText,
}
for _, model := range embeddedModels {
embeddingReq := EmbeddingRequest{
Input: []string{
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: model,
}
// marshal embeddingReq to JSON and confirm that the model field equals
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
if err != nil {
t.Fatalf("Could not marshal embedding request: %v", err)
}
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
}
}

+ 3
- 3
go-gpt3/engines.go View File

@@ -1,4 +1,4 @@
package gogpt
package openai

import (
"context"
@@ -22,7 +22,7 @@ type EnginesList struct {
// ListEngines Lists the currently available engines, and provides basic
// information about each option such as the owner and availability.
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/engines"), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
if err != nil {
return
}
@@ -38,7 +38,7 @@ func (c *Client) GetEngine(
engineID string,
) (engine Engine, err error) {
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}


+ 1
- 1
go-gpt3/error.go View File

@@ -1,4 +1,4 @@
package gogpt
package openai

import "fmt"



+ 4
- 4
go-gpt3/files.go View File

@@ -1,4 +1,4 @@
package gogpt
package openai

import (
"bytes"
@@ -112,7 +112,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File

// DeleteFile deletes an existing file.
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
if err != nil {
return
}
@@ -124,7 +124,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
// ListFiles Lists the currently available files,
// and provides basic information about each file such as the file name and purpose.
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/files"), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil)
if err != nil {
return
}
@@ -137,7 +137,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
// such as the file name and purpose.
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
urlSuffix := fmt.Sprintf("/files/%s", fileID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}


+ 81
- 0
go-gpt3/files_test.go View File

@@ -0,0 +1,81 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"testing"
"time"
)

func TestFileUpload(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files", handleCreateFile)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := FileRequest{
FileName: "test.go",
FilePath: "api.go",
Purpose: "fine-tune",
}
_, err = client.CreateFile(ctx, req)
if err != nil {
t.Fatalf("CreateFile error: %v", err)
}
}

// handleCreateFile Handles the images endpoint by the test server.
func handleCreateFile(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// edits only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
err = r.ParseMultipartForm(1024 * 1024 * 1024)
if err != nil {
http.Error(w, "file is more than 1GB", http.StatusInternalServerError)
return
}

values := r.Form
var purpose string
for key, value := range values {
if key == "purpose" {
purpose = value[0]
}
}
file, header, err := r.FormFile("file")
if err != nil {
return
}
defer file.Close()

var fileReq = File{
Bytes: int(header.Size),
ID: strconv.Itoa(int(time.Now().Unix())),
FileName: header.Filename,
Purpose: purpose,
CreatedAt: time.Now().Unix(),
Object: "test-objecct",
Owner: "test-owner",
}

resBytes, _ = json.Marshal(fileReq)
fmt.Fprint(w, string(resBytes))
}

+ 130
- 0
go-gpt3/fine_tunes.go View File

@@ -0,0 +1,130 @@
package openai

import (
"context"
"fmt"
"net/http"
)

type FineTuneRequest struct {
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
Model string `json:"model,omitempty"`
Epochs int `json:"n_epochs,omitempty"`
BatchSize int `json:"batch_size,omitempty"`
LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"`
PromptLossRate float32 `json:"prompt_loss_rate,omitempty"`
ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"`
ClassificationClasses int `json:"classification_n_classes,omitempty"`
ClassificationPositiveClass string `json:"classification_positive_class,omitempty"`
ClassificationBetas []float32 `json:"classification_betas,omitempty"`
Suffix string `json:"suffix,omitempty"`
}

type FineTune struct {
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model"`
CreatedAt int64 `json:"created_at"`
FineTuneEventList []FineTuneEvent `json:"events,omitempty"`
FineTunedModel string `json:"fine_tuned_model"`
HyperParams FineTuneHyperParams `json:"hyperparams"`
OrganizationID string `json:"organization_id"`
ResultFiles []File `json:"result_files"`
Status string `json:"status"`
ValidationFiles []File `json:"validation_files"`
TrainingFiles []File `json:"training_files"`
UpdatedAt int64 `json:"updated_at"`
}

type FineTuneEvent struct {
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Level string `json:"level"`
Message string `json:"message"`
}

type FineTuneHyperParams struct {
BatchSize int `json:"batch_size"`
LearningRateMultiplier float64 `json:"learning_rate_multiplier"`
Epochs int `json:"n_epochs"`
PromptLossWeight float64 `json:"prompt_loss_weight"`
}

type FineTuneList struct {
Object string `json:"object"`
Data []FineTune `json:"data"`
}
type FineTuneEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`
}

type FineTuneDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`
}

func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
urlSuffix := "/fine-tunes"
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

// CancelFineTune cancel a fine-tune job.
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
if err != nil {
return
}

err = c.sendRequest(req, &response)
return
}

+ 101
- 0
go-gpt3/fine_tunes_test.go View File

@@ -0,0 +1,101 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"net/http"
"testing"
)

const testFineTuneID = "fine-tune-id"

// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server.
func TestFineTunes(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler(
"/v1/fine-tunes",
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
if r.Method == http.MethodGet {
resBytes, _ = json.Marshal(FineTuneList{})
} else {
resBytes, _ = json.Marshal(FineTune{})
}
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine-tunes/"+testFineTuneID+"/cancel",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTune{})
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine-tunes/"+testFineTuneID,
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
if r.Method == http.MethodDelete {
resBytes, _ = json.Marshal(FineTuneDeleteResponse{})
} else {
resBytes, _ = json.Marshal(FineTune{})
}
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine-tunes/"+testFineTuneID+"/events",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTuneEventList{})
fmt.Fprintln(w, string(resBytes))
},
)

// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

_, err = client.ListFineTunes(ctx)
if err != nil {
t.Fatalf("ListFineTunes error: %v", err)
}

_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if err != nil {
t.Fatalf("CreateFineTune error: %v", err)
}

_, err = client.CancelFineTune(ctx, testFineTuneID)
if err != nil {
t.Fatalf("CancelFineTune error: %v", err)
}

_, err = client.GetFineTune(ctx, testFineTuneID)
if err != nil {
t.Fatalf("GetFineTune error: %v", err)
}

_, err = client.DeleteFineTune(ctx, testFineTuneID)
if err != nil {
t.Fatalf("DeleteFineTune error: %v", err)
}

_, err = client.ListFineTuneEvents(ctx, testFineTuneID)
if err != nil {
t.Fatalf("ListFineTuneEvents error: %v", err)
}
}

+ 3
- 0
go-gpt3/go.mod View File

@@ -0,0 +1,3 @@
module github.com/sashabaranov/go-openai

go 1.18

+ 53
- 14
go-gpt3/image.go View File

@@ -1,9 +1,8 @@
package gogpt
package openai

import (
"bytes"
"context"
"encoding/json"
"io"
"mime/multipart"
"net/http"
@@ -46,14 +45,8 @@ type ImageResponseDataInner struct {

// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}

urlSuffix := "/images/generations"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}
@@ -86,20 +79,65 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
return
}

// mask
mask, err := writer.CreateFormFile("mask", request.Mask.Name())
// mask, it is optional
if request.Mask != nil {
mask, err2 := writer.CreateFormFile("mask", request.Mask.Name())
if err2 != nil {
return
}
_, err = io.Copy(mask, request.Mask)
if err != nil {
return
}
}

err = writer.WriteField("prompt", request.Prompt)
if err != nil {
return
}
_, err = io.Copy(mask, request.Mask)
err = writer.WriteField("n", strconv.Itoa(request.N))
if err != nil {
return
}
err = writer.WriteField("size", request.Size)
if err != nil {
return
}
writer.Close()
urlSuffix := "/images/edits"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
if err != nil {
return
}

err = writer.WriteField("prompt", request.Prompt)
req.Header.Set("Content-Type", writer.FormDataContentType())
err = c.sendRequest(req, &response)
return
}

// ImageVariRequest represents the request structure for the image API.
type ImageVariRequest struct {
Image *os.File `json:"image,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
}

// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
// Use abbreviations(vari for variation) because ci-lint has a single-line length limit ...
func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) (response ImageResponse, err error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)

// image
image, err := writer.CreateFormFile("image", request.Image.Name())
if err != nil {
return
}
_, err = io.Copy(image, request.Image)
if err != nil {
return
}

err = writer.WriteField("n", strconv.Itoa(request.N))
if err != nil {
return
@@ -109,7 +147,8 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
return
}
writer.Close()
urlSuffix := "/images/edits"
//https://platform.openai.com/docs/api-reference/images/create-variation
urlSuffix := "/images/variations"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
if err != nil {
return


+ 268
- 0
go-gpt3/image_test.go View File

@@ -0,0 +1,268 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"testing"
"time"
)

func TestImages(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := ImageRequest{}
req.Prompt = "Lorem ipsum"
_, err = client.CreateImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

// handleImageEndpoint Handles the images endpoint by the test server.
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var imageReq ImageRequest
if imageReq, err = getImageBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := ImageResponse{
Created: time.Now().Unix(),
}
for i := 0; i < imageReq.N; i++ {
imageData := ImageResponseDataInner{}
switch imageReq.ResponseFormat {
case CreateImageResponseFormatURL, "":
imageData.URL = "https://example.com/image.png"
case CreateImageResponseFormatB64JSON:
// This decodes to "{}" in base64.
imageData.B64JSON = "e30K"
default:
http.Error(w, "invalid response format", http.StatusBadRequest)
return
}
res.Data = append(res.Data, imageData)
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (ImageRequest, error) {
image := ImageRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return ImageRequest{}, err
}
err = json.Unmarshal(reqBody, &image)
if err != nil {
return ImageRequest{}, err
}
return image, nil
}

func TestImageEdit(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}

mask, err := os.Create("mask.png")
if err != nil {
t.Error("open mask file error")
return
}

defer func() {
mask.Close()
origin.Close()
os.Remove("mask.png")
os.Remove("image.png")
}()

req := ImageEditRequest{
Image: origin,
Mask: mask,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
}
_, err = client.CreateEditImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

func TestImageEditWithoutMask(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}

defer func() {
origin.Close()
os.Remove("image.png")
}()

req := ImageEditRequest{
Image: origin,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
}
_, err = client.CreateEditImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

// handleEditImageEndpoint Handles the images endpoint by the test server.
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte

// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}

responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}

resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}

func TestImageVariation(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}

defer func() {
origin.Close()
os.Remove("image.png")
}()

req := ImageVariRequest{
Image: origin,
N: 3,
Size: CreateImageSize1024x1024,
}
_, err = client.CreateVariImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

// handleVariateImageEndpoint Handles the images endpoint by the test server.
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte

// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}

responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}

resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}

+ 15
- 0
go-gpt3/marshaller.go View File

@@ -0,0 +1,15 @@
package openai

import (
"encoding/json"
)

type marshaller interface {
marshal(value any) ([]byte, error)
}

type jsonMarshaller struct{}

func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
return json.Marshal(value)
}

+ 4
- 4
go-gpt3/models.go View File

@@ -1,4 +1,4 @@
package gogpt
package openai

import (
"context"
@@ -7,7 +7,7 @@ import (

// Model struct represents an OpenAPI model.
type Model struct {
CreatedAt int64 `json:"created_at"`
CreatedAt int64 `json:"created"`
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
@@ -18,7 +18,7 @@ type Model struct {

// Permission struct represents an OpenAPI permission.
type Permission struct {
CreatedAt int64 `json:"created_at"`
CreatedAt int64 `json:"created"`
ID string `json:"id"`
Object string `json:"object"`
AllowCreateEngine bool `json:"allow_create_engine"`
@@ -40,7 +40,7 @@ type ModelsList struct {
// ListModels Lists the currently available models,
// and provides basic information about each model such as the model id and parent.
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/models"), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil)
if err != nil {
return
}


+ 39
- 0
go-gpt3/models_test.go View File

@@ -0,0 +1,39 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"net/http"
"testing"
)

// TestListModels Tests the models endpoint of the API using the mocked server.
func TestListModels(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/models", handleModelsEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

_, err = client.ListModels(ctx)
if err != nil {
t.Fatalf("ListModels error: %v", err)
}
}

// handleModelsEndpoint Handles the models endpoint by the test server.
func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(ModelsList{})
fmt.Fprintln(w, string(resBytes))
}

+ 2
- 10
go-gpt3/moderation.go View File

@@ -1,9 +1,7 @@
package gogpt
package openai

import (
"bytes"
"context"
"encoding/json"
"net/http"
)

@@ -52,13 +50,7 @@ type ModerationResponse struct {
// Moderations — perform a moderation api call over a string.
// Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/moderations"), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request)
if err != nil {
return
}


+ 102
- 0
go-gpt3/moderation_test.go View File

@@ -0,0 +1,102 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"
)

// TestModeration Tests the moderations endpoint of the API using the mocked server.
func TestModerations(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

// create an edit request
model := "text-moderation-stable"
moderationReq := ModerationRequest{
Model: &model,
Input: "I want to kill them.",
}
_, err = client.Moderations(ctx, moderationReq)
if err != nil {
t.Fatalf("Moderation error: %v", err)
}
}

// handleModerationEndpoint Handles the moderation endpoint by the test server.
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var moderationReq ModerationRequest
if moderationReq, err = getModerationBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}

resCat := ResultCategories{}
resCatScore := ResultCategoryScores{}
switch {
case strings.Contains(moderationReq.Input, "kill"):
resCat = ResultCategories{Violence: true}
resCatScore = ResultCategoryScores{Violence: 1}
case strings.Contains(moderationReq.Input, "hate"):
resCat = ResultCategories{Hate: true}
resCatScore = ResultCategoryScores{Hate: 1}
case strings.Contains(moderationReq.Input, "suicide"):
resCat = ResultCategories{SelfHarm: true}
resCatScore = ResultCategoryScores{SelfHarm: 1}
case strings.Contains(moderationReq.Input, "porn"):
resCat = ResultCategories{Sexual: true}
resCatScore = ResultCategoryScores{Sexual: 1}
}

result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true}

res := ModerationResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Model: *moderationReq.Model,
}
res.Results = append(res.Results, result)

resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// getModerationBody Returns the body of the request to do a moderation.
func getModerationBody(r *http.Request) (ModerationRequest, error) {
moderation := ModerationRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return ModerationRequest{}, err
}
err = json.Unmarshal(reqBody, &moderation)
if err != nil {
return ModerationRequest{}, err
}
return moderation, nil
}

+ 40
- 0
go-gpt3/request_builder.go View File

@@ -0,0 +1,40 @@
package openai

import (
"bytes"
"context"
"net/http"
)

type requestBuilder interface {
build(ctx context.Context, method, url string, request any) (*http.Request, error)
}

type httpRequestBuilder struct {
marshaller marshaller
}

func newRequestBuilder() *httpRequestBuilder {
return &httpRequestBuilder{
marshaller: &jsonMarshaller{},
}
}

func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) {
if request == nil {
return http.NewRequestWithContext(ctx, method, url, nil)
}

var reqBytes []byte
reqBytes, err := b.marshaller.marshal(request)
if err != nil {
return nil, err
}

return http.NewRequestWithContext(
ctx,
method,
url,
bytes.NewBuffer(reqBytes),
)
}

+ 148
- 0
go-gpt3/request_builder_test.go View File

@@ -0,0 +1,148 @@
package openai //nolint:testpackage // testing private field

import (
"github.com/sashabaranov/go-openai/internal/test"

"context"
"errors"
"net/http"
"testing"
)

var (
errTestMarshallerFailed = errors.New("test marshaller failed")
errTestRequestBuilderFailed = errors.New("test request builder failed")
)

type (
failingRequestBuilder struct{}
failingMarshaller struct{}
)

func (*failingMarshaller) marshal(value any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}

func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) {
return nil, errTestRequestBuilderFailed
}

func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
builder := httpRequestBuilder{
marshaller: &failingMarshaller{},
}

_, err := builder.build(context.Background(), "", "", struct{}{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
}

func TestClientReturnsRequestBuilderErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}

ctx := context.Background()

_, err = client.CreateCompletion(ctx, CompletionRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListFineTunes(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CancelFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.GetFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.DeleteFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListFineTuneEvents(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

err = client.DeleteFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.GetFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListFiles(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListEngines(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.GetEngine(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListModels(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

+ 2
- 51
go-gpt3/stream.go View File

@@ -1,4 +1,4 @@
package gogpt
package openai

import (
"bufio"
@@ -6,7 +6,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
)
@@ -60,43 +59,6 @@ waitForData:
return
}

func (stream *CompletionStream) ChatRecv() (response ChatCompletionResponseForStream, err error) {
if stream.isFinished {
err = io.EOF
return
}

var emptyMessagesCount uint

waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
return
}

var headerData = []byte("data: ")
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return
}

goto waitForData
}

line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return
}

err = json.Unmarshal(line, &response)
return
}

func (stream *CompletionStream) Close() {
stream.response.Body.Close()
}
@@ -110,18 +72,7 @@ func (c *Client) CreateCompletionStream(
request CompletionRequest,
) (stream *CompletionStream, err error) {
request.Stream = true
reqBytes, err := json.Marshal(request)
if err != nil {
return
}

urlSuffix := "/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
req, err := c.newStreamRequest(ctx, "POST", "/completions", request)
if err != nil {
return
}


+ 147
- 0
go-gpt3/stream_test.go View File

@@ -0,0 +1,147 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
)

func TestCreateCompletionStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
}))
defer server.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}

client := NewClientWithConfig(config)
ctx := context.Background()

request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}

stream, err := client.CreateCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()

expectedResponses := []CompletionResponse{
{
ID: "1",
Object: "completion",
Created: 1598069254,
Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}},
},
{
ID: "2",
Object: "completion",
Created: 1598069255,
Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
},
}

for ix, expectedResponse := range expectedResponses {
receivedResponse, streamErr := stream.Recv()
if streamErr != nil {
t.Errorf("stream.Recv() failed: %v", streamErr)
}
if !compareResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
}

_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
}

_, streamErr = stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
}
}

// A "tokenRoundTripper" is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each
// request include a valid API token in the headers for authentication and
// authorization.
type tokenRoundTripper struct {
token string
fallback http.RoundTripper
}

// RoundTrip takes an *http.Request as input and returns an
// *http.Response and an error.
//
// It is expected to use the provided request to create a connection to an HTTP
// server and return the response, or an error if one occurred. The returned
// Response should have its Body closed. If the RoundTrip method returns an
// error, the Client's Get, Head, Post, and PostForm methods return the same
// error.
func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", "Bearer "+t.token)
return t.fallback.RoundTrip(req)
}

// Helper funcs.
func compareResponses(r1, r2 CompletionResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
return false
}
if len(r1.Choices) != len(r2.Choices) {
return false
}
for i := range r1.Choices {
if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) {
return false
}
}
return true
}

func compareResponseChoices(c1, c2 CompletionChoice) bool {
if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason {
return false
}
return true
}

Loading…
Cancel
Save