Browse Source

add Reverse: for v0.0.6

tags/v0.0.6
DengBiao 1 year ago
parent
commit
3520cdfb5f
36 changed files with 70 additions and 1839 deletions
  1. +1
    -1
      go-gpt3/.gitignore
  2. +44
    -44
      go-gpt3/README.md
  3. +1
    -1
      go-gpt3/api.go
  4. +0
    -182
      go-gpt3/api_test.go
  5. +1
    -1
      go-gpt3/audio.go
  6. +0
    -143
      go-gpt3/audio_test.go
  7. +3
    -3
      go-gpt3/chat.go
  8. +1
    -1
      go-gpt3/chat_stream.go
  9. +0
    -153
      go-gpt3/chat_stream_test.go
  10. +0
    -132
      go-gpt3/chat_test.go
  11. +1
    -1
      go-gpt3/common.go
  12. +1
    -1
      go-gpt3/completion.go
  13. +0
    -121
      go-gpt3/completion_test.go
  14. +2
    -2
      go-gpt3/config.go
  15. +1
    -1
      go-gpt3/edits.go
  16. +0
    -104
      go-gpt3/edits_test.go
  17. +2
    -2
      go-gpt3/embeddings.go
  18. +0
    -48
      go-gpt3/embeddings_test.go
  19. +1
    -1
      go-gpt3/engines.go
  20. +1
    -1
      go-gpt3/error.go
  21. +1
    -1
      go-gpt3/files.go
  22. +0
    -81
      go-gpt3/files_test.go
  23. +1
    -1
      go-gpt3/fine_tunes.go
  24. +0
    -101
      go-gpt3/fine_tunes_test.go
  25. +1
    -1
      go-gpt3/go.mod
  26. +2
    -2
      go-gpt3/image.go
  27. +0
    -268
      go-gpt3/image_test.go
  28. +1
    -1
      go-gpt3/marshaller.go
  29. +1
    -1
      go-gpt3/models.go
  30. +0
    -39
      go-gpt3/models_test.go
  31. +1
    -1
      go-gpt3/moderation.go
  32. +0
    -102
      go-gpt3/moderation_test.go
  33. +1
    -1
      go-gpt3/request_builder.go
  34. +0
    -148
      go-gpt3/request_builder_test.go
  35. +1
    -1
      go-gpt3/stream.go
  36. +0
    -147
      go-gpt3/stream_test.go

+ 1
- 1
go-gpt3/.gitignore View File

@@ -15,5 +15,5 @@
# vendor/

# Auth token for tests
.openai-token
.gogpt-token
.idea

+ 44
- 44
go-gpt3/README.md View File

@@ -1,11 +1,11 @@
# Go OpenAI
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai)
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai)
[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai)
[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gogpt)
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gogpt)](https://goreportcard.com/report/github.com/sashabaranov/go-gogpt)
[![codecov](https://codecov.io/gh/sashabaranov/go-gogpt/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-gogpt)

> **Note**: the repository was recently renamed from `go-gpt3` to `go-openai`
> **Note**: the repository was recently renamed from `go-gpt3` to `go-gogpt`

This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support:
This library provides Go clients for [OpenAI API](https://platform.gogpt.com/). We support:

* ChatGPT
* GPT-3, GPT-4
@@ -14,7 +14,7 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/).

Installation:
```
go get github.com/sashabaranov/go-openai
go get github.com/sashabaranov/go-gogpt
```


@@ -26,18 +26,18 @@ package main
import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
gogpt "github.com/sashabaranov/go-gogpt"
)

func main() {
client := openai.NewClient("your token")
client := gogpt.NewClient("your token")
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Messages: []gogpt.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Role: gogpt.ChatMessageRoleUser,
Content: "Hello!",
},
},
@@ -67,15 +67,15 @@ package main
import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
gogpt "github.com/sashabaranov/go-gogpt"
)

func main() {
c := openai.NewClient("your token")
c := gogpt.NewClient("your token")
ctx := context.Background()

req := openai.CompletionRequest{
Model: openai.GPT3Ada,
req := gogpt.CompletionRequest{
Model: gogpt.GPT3Ada,
MaxTokens: 5,
Prompt: "Lorem ipsum",
}
@@ -100,15 +100,15 @@ import (
"context"
"fmt"
"io"
openai "github.com/sashabaranov/go-openai"
gogpt "github.com/sashabaranov/go-gogpt"
)

func main() {
c := openai.NewClient("your token")
c := gogpt.NewClient("your token")
ctx := context.Background()

req := openai.CompletionRequest{
Model: openai.GPT3Ada,
req := gogpt.CompletionRequest{
Model: gogpt.GPT3Ada,
MaxTokens: 5,
Prompt: "Lorem ipsum",
Stream: true,
@@ -149,15 +149,15 @@ import (
"context"
"fmt"

openai "github.com/sashabaranov/go-openai"
gogpt "github.com/sashabaranov/go-gogpt"
)

func main() {
c := openai.NewClient("your token")
c := gogpt.NewClient("your token")
ctx := context.Background()

req := openai.AudioRequest{
Model: openai.Whisper1,
req := gogpt.AudioRequest{
Model: gogpt.Whisper1,
FilePath: "recording.mp3",
}
resp, err := c.CreateTranscription(ctx, req)
@@ -181,20 +181,20 @@ import (
"context"
"encoding/base64"
"fmt"
openai "github.com/sashabaranov/go-openai"
gogpt "github.com/sashabaranov/go-gogpt"
"image/png"
"os"
)

func main() {
c := openai.NewClient("your token")
c := gogpt.NewClient("your token")
ctx := context.Background()

// Sample image by link
reqUrl := openai.ImageRequest{
reqUrl := gogpt.ImageRequest{
Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail",
Size: openai.CreateImageSize256x256,
ResponseFormat: openai.CreateImageResponseFormatURL,
Size: gogpt.CreateImageSize256x256,
ResponseFormat: gogpt.CreateImageResponseFormatURL,
N: 1,
}

@@ -206,10 +206,10 @@ func main() {
fmt.Println(respUrl.Data[0].URL)

// Example image as base64
reqBase64 := openai.ImageRequest{
reqBase64 := gogpt.ImageRequest{
Prompt: "Portrait of a humanoid parrot in a classic costume, high detail, realistic light, unreal engine",
Size: openai.CreateImageSize256x256,
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
Size: gogpt.CreateImageSize256x256,
ResponseFormat: gogpt.CreateImageResponseFormatB64JSON,
N: 1,
}

@@ -254,7 +254,7 @@ func main() {
<summary>Configuring proxy</summary>

```go
config := openai.DefaultConfig("token")
config := gogpt.DefaultConfig("token")
proxyUrl, err := url.Parse("http://localhost:{port}")
if err != nil {
panic(err)
@@ -266,10 +266,10 @@ config.HTTPClient = &http.Client{
Transport: transport,
}

c := openai.NewClientWithConfig(config)
c := gogpt.NewClientWithConfig(config)
```

See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig
See also: https://pkg.go.dev/github.com/sashabaranov/go-gogpt#ClientConfig
</details>

<details>
@@ -285,12 +285,12 @@ import (
"os"
"strings"

"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-gogpt"
)

func main() {
client := openai.NewClient("your token")
messages := make([]openai.ChatCompletionMessage, 0)
client := gogpt.NewClient("your token")
messages := make([]gogpt.ChatCompletionMessage, 0)
reader := bufio.NewReader(os.Stdin)
fmt.Println("Conversation")
fmt.Println("---------------------")
@@ -300,15 +300,15 @@ func main() {
text, _ := reader.ReadString('\n')
// convert CRLF to LF
text = strings.Replace(text, "\n", "", -1)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
messages = append(messages, gogpt.ChatCompletionMessage{
Role: gogpt.ChatMessageRoleUser,
Content: text,
})

resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Messages: messages,
},
)
@@ -319,8 +319,8 @@ func main() {
}

content := resp.Choices[0].Message.Content
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
messages = append(messages, gogpt.ChatCompletionMessage{
Role: gogpt.ChatMessageRoleAssistant,
Content: content,
})
fmt.Println(content)


+ 1
- 1
go-gpt3/api.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"context"


+ 0
- 182
go-gpt3/api_test.go View File

@@ -1,182 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"

"context"
"errors"
"io"
"os"
"testing"
)

func TestAPI(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := NewClient(apiToken)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err != nil {
t.Fatalf("ListEngines error: %v", err)
}

_, err = c.GetEngine(ctx, "davinci")
if err != nil {
t.Fatalf("GetEngine error: %v", err)
}

fileRes, err := c.ListFiles(ctx)
if err != nil {
t.Fatalf("ListFiles error: %v", err)
}

if len(fileRes.Files) > 0 {
_, err = c.GetFile(ctx, fileRes.Files[0].ID)
if err != nil {
t.Fatalf("GetFile error: %v", err)
}
} // else skip

embeddingReq := EmbeddingRequest{
Input: []string{
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: AdaSearchQuery,
}
_, err = c.CreateEmbeddings(ctx, embeddingReq)
if err != nil {
t.Fatalf("Embedding error: %v", err)
}

_, err = c.CreateChatCompletion(
ctx,
ChatCompletionRequest{
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
},
)

if err != nil {
t.Errorf("CreateChatCompletion (without name) returned error: %v", err)
}

_, err = c.CreateChatCompletion(
ctx,
ChatCompletionRequest{
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Name: "John_Doe",
Content: "Hello!",
},
},
},
)

if err != nil {
t.Errorf("CreateChatCompletion (with name) returned error: %v", err)
}

stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: GPT3Ada,
MaxTokens: 5,
Stream: true,
})
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()

counter := 0
for {
_, err = stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
t.Errorf("Stream error: %v", err)
} else {
counter++
}
}
if counter == 0 {
t.Error("Stream did not return any responses")
}
}

func TestAPIError(t *testing.T) {
apiToken := os.Getenv("OPENAI_TOKEN")
if apiToken == "" {
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
}

var err error
c := NewClient(apiToken + "_invalid")
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines did not fail")
}

var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("Error is not an APIError: %+v", err)
}

if apiErr.StatusCode != 401 {
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
}
if *apiErr.Code != "invalid_api_key" {
t.Fatalf("Unexpected API error code: %s", *apiErr.Code)
}
if apiErr.Error() == "" {
t.Fatal("Empty error message occured")
}
}

func TestRequestError(t *testing.T) {
var err error

config := DefaultConfig("dummy")
config.BaseURL = "https://httpbin.org/status/418?"
c := NewClientWithConfig(config)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {
t.Fatal("ListEngines request did not fail")
}

var reqErr *RequestError
if !errors.As(err, &reqErr) {
t.Fatalf("Error is not a RequestError: %+v", err)
}

if reqErr.StatusCode != 418 {
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode)
}

if reqErr.Unwrap() == nil {
t.Fatalf("Empty request error occured")
}
}

// numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer
//
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available)
func numTokens(s string) int {
return int(float32(len(s)) / 4)
}

+ 1
- 1
go-gpt3/audio.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"bytes"


+ 0
- 143
go-gpt3/audio_test.go View File

@@ -1,143 +0,0 @@
package openai_test

import (
"bytes"
"errors"
"io"
"mime"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"testing"
)

// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
func TestAudio(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)

testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}

ctx := context.Background()

dir, cleanup := createTestDirectory(t)
defer cleanup()

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
createTestFile(t, path)

req := AudioRequest{
FilePath: path,
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
if err != nil {
t.Fatalf("audio API error: %v", err)
}
})
}
}

// createTestFile creates a fake file with "hello" as the content.
func createTestFile(t *testing.T, path string) {
file, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file %v", err)
}
if _, err = file.WriteString("hello"); err != nil {
t.Fatalf("failed to write to file %v", err)
}
file.Close()
}

// createTestDirectory creates a temporary folder which will be deleted when cleanup is called.
func createTestDirectory(t *testing.T) (path string, cleanup func()) {
t.Helper()

path, err := os.MkdirTemp(os.TempDir(), "")
if err != nil {
t.Fatal(err)
}

return path, func() { os.RemoveAll(path) }
}

// handleAudioEndpoint Handles the completion endpoint by the test server.
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
var err error

// audio endpoints only accept POST requests
if r.Method != "POST" {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}

mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
http.Error(w, "failed to parse media type", http.StatusBadRequest)
return
}

if !strings.HasPrefix(mediaType, "multipart") {
http.Error(w, "request is not multipart", http.StatusBadRequest)
}

boundary, ok := params["boundary"]
if !ok {
http.Error(w, "no boundary in params", http.StatusBadRequest)
return
}

fileData := &bytes.Buffer{}
mr := multipart.NewReader(r.Body, boundary)
part, err := mr.NextPart()
if err != nil && errors.Is(err, io.EOF) {
http.Error(w, "error accessing file", http.StatusBadRequest)
return
}
if _, err = io.Copy(fileData, part); err != nil {
http.Error(w, "failed to copy file", http.StatusInternalServerError)
return
}

if len(fileData.Bytes()) == 0 {
w.WriteHeader(http.StatusInternalServerError)
http.Error(w, "received empty file data", http.StatusBadRequest)
return
}

if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
http.Error(w, "failed to write body", http.StatusInternalServerError)
return
}
}

+ 3
- 3
go-gpt3/chat.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"context"
@@ -23,8 +23,8 @@ type ChatCompletionMessage struct {

// This property isn't in the official documentation, but it's in
// the documentation for the official library for python:
// - https://github.com/openai/openai-python/blob/main/chatml.md
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// - https://github.com/gogpt/gogpt-python/blob/main/chatml.md
// - https://github.com/gogpt/gogpt-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Name string `json:"name,omitempty"`
}



+ 1
- 1
go-gpt3/chat_stream.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"bufio"


+ 0
- 153
go-gpt3/chat_stream_test.go View File

@@ -1,153 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
)

func TestCreateChatCompletionStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
}))
defer server.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}

client := NewClientWithConfig(config)
ctx := context.Background()

request := ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
}

stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()

expectedResponses := []ChatCompletionStreamResponse{
{
ID: "1",
Object: "completion",
Created: 1598069254,
Model: GPT3Dot5Turbo,
Choices: []ChatCompletionStreamChoice{
{
Delta: ChatCompletionStreamChoiceDelta{
Content: "response1",
},
FinishReason: "max_tokens",
},
},
},
{
ID: "2",
Object: "completion",
Created: 1598069255,
Model: GPT3Dot5Turbo,
Choices: []ChatCompletionStreamChoice{
{
Delta: ChatCompletionStreamChoiceDelta{
Content: "response2",
},
FinishReason: "max_tokens",
},
},
},
}

for ix, expectedResponse := range expectedResponses {
b, _ := json.Marshal(expectedResponse)
t.Logf("%d: %s", ix, string(b))

receivedResponse, streamErr := stream.Recv()
if streamErr != nil {
t.Errorf("stream.Recv() failed: %v", streamErr)
}
if !compareChatResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
}

_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
}

_, streamErr = stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
}
}

// Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
return false
}
if len(r1.Choices) != len(r2.Choices) {
return false
}
for i := range r1.Choices {
if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) {
return false
}
}
return true
}

func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool {
if c1.Index != c2.Index {
return false
}
if c1.Delta.Content != c2.Delta.Content {
return false
}
if c1.FinishReason != c2.FinishReason {
return false
}
return true
}

+ 0
- 132
go-gpt3/chat_test.go View File

@@ -1,132 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"
)

func TestChatCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := ChatCompletionRequest{
MaxTokens: 5,
Model: "ada",
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
}
_, err := client.CreateChatCompletion(ctx, req)
if !errors.Is(err, ErrChatCompletionInvalidModel) {
t.Fatalf("CreateChatCompletion should return wrong model error, but returned: %v", err)
}
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletions(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
}
_, err = client.CreateChatCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateChatCompletion error: %v", err)
}
}

// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq ChatCompletionRequest
if completionReq, err = getChatCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := ChatCompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
for i := 0; i < completionReq.N; i++ {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)

res.Choices = append(res.Choices, ChatCompletionChoice{
Message: ChatCompletionMessage{
Role: ChatMessageRoleAssistant,
Content: completionStr,
},
Index: i,
})
}
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// getChatCompletionBody Returns the body of the request to create a completion.
func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) {
completion := ChatCompletionRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return ChatCompletionRequest{}, err
}
err = json.Unmarshal(reqBody, &completion)
if err != nil {
return ChatCompletionRequest{}, err
}
return completion, nil
}

+ 1
- 1
go-gpt3/common.go View File

@@ -1,5 +1,5 @@
// common.go defines common types used throughout the OpenAI API.
package openai
package gogpt

// Usage Represents the total token usage per request to OpenAI.
type Usage struct {


+ 1
- 1
go-gpt3/completion.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"context"


+ 0
- 121
go-gpt3/completion_test.go View File

@@ -1,121 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"
)

func TestCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)

_, err := client.CreateCompletion(
context.Background(),
CompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
},
)
if !errors.Is(err, ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
}
}

// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestCompletions(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := CompletionRequest{
MaxTokens: 5,
Model: "ada",
}
req.Prompt = "Lorem ipsum"
_, err = client.CreateCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateCompletion error: %v", err)
}
}

// handleCompletionEndpoint Handles the completion endpoint by the test server.
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq CompletionRequest
if completionReq, err = getCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := CompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
for i := 0; i < completionReq.N; i++ {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)
if completionReq.Echo {
completionStr = completionReq.Prompt + completionStr
}
res.Choices = append(res.Choices, CompletionChoice{
Text: completionStr,
Index: i,
})
}
inputTokens := numTokens(completionReq.Prompt) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// getCompletionBody Returns the body of the request to create a completion.
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
completion := CompletionRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return CompletionRequest{}, err
}
err = json.Unmarshal(reqBody, &completion)
if err != nil {
return CompletionRequest{}, err
}
return completion, nil
}

+ 2
- 2
go-gpt3/config.go View File

@@ -1,11 +1,11 @@
package openai
package gogpt

import (
"net/http"
)

const (
apiURLv1 = "https://api.openai.com/v1"
apiURLv1 = "https://api.gogpt.com/v1"
defaultEmptyMessagesLimit uint = 300
)



+ 1
- 1
go-gpt3/edits.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"context"


+ 0
- 104
go-gpt3/edits_test.go View File

@@ -1,104 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"time"
)

// TestEdits Tests the edits endpoint of the API using the mocked server.
func TestEdits(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/edits", handleEditEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

// create an edit request
model := "ada"
editReq := EditsRequest{
Model: &model,
Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " +
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" +
" ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" +
" ex ea commodo consequat. Duis aute irure dolor in reprehe",
Instruction: "test instruction",
N: 3,
}
response, err := client.Edits(ctx, editReq)
if err != nil {
t.Fatalf("Edits error: %v", err)
}
if len(response.Choices) != editReq.N {
t.Fatalf("edits does not properly return the correct number of choices")
}
}

// handleEditEndpoint Handles the edit endpoint by the test server.
func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// edits only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var editReq EditsRequest
editReq, err = getEditBody(r)
if err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
// create a response
res := EditsResponse{
Object: "test-object",
Created: time.Now().Unix(),
}
// edit and calculate token usage
editString := "edited by mocked OpenAI server :)"
inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N
completionTokens := int(float32(len(editString))/4) * editReq.N
for i := 0; i < editReq.N; i++ {
// instruction will be hidden and only seen by OpenAI
res.Choices = append(res.Choices, EditsChoice{
Text: editReq.Input + editString,
Index: i,
})
}
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprint(w, string(resBytes))
}

// getEditBody Returns the body of the request to create an edit.
func getEditBody(r *http.Request) (EditsRequest, error) {
edit := EditsRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return EditsRequest{}, err
}
err = json.Unmarshal(reqBody, &edit)
if err != nil {
return EditsRequest{}, err
}
return edit, nil
}

+ 2
- 2
go-gpt3/embeddings.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"context"
@@ -130,7 +130,7 @@ type EmbeddingRequest struct {
}

// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create
// https://beta.gogpt.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request)
if err != nil {


+ 0
- 48
go-gpt3/embeddings_test.go View File

@@ -1,48 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"

"bytes"
"encoding/json"
"testing"
)

func TestEmbedding(t *testing.T) {
embeddedModels := []EmbeddingModel{
AdaSimilarity,
BabbageSimilarity,
CurieSimilarity,
DavinciSimilarity,
AdaSearchDocument,
AdaSearchQuery,
BabbageSearchDocument,
BabbageSearchQuery,
CurieSearchDocument,
CurieSearchQuery,
DavinciSearchDocument,
DavinciSearchQuery,
AdaCodeSearchCode,
AdaCodeSearchText,
BabbageCodeSearchCode,
BabbageCodeSearchText,
}
for _, model := range embeddedModels {
embeddingReq := EmbeddingRequest{
Input: []string{
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: model,
}
// marshal embeddingReq to JSON and confirm that the model field equals
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
if err != nil {
t.Fatalf("Could not marshal embedding request: %v", err)
}
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
}
}

+ 1
- 1
go-gpt3/engines.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"context"


+ 1
- 1
go-gpt3/error.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import "fmt"



+ 1
- 1
go-gpt3/files.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"bytes"


+ 0
- 81
go-gpt3/files_test.go View File

@@ -1,81 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"testing"
"time"
)

func TestFileUpload(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files", handleCreateFile)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := FileRequest{
FileName: "test.go",
FilePath: "api.go",
Purpose: "fine-tune",
}
_, err = client.CreateFile(ctx, req)
if err != nil {
t.Fatalf("CreateFile error: %v", err)
}
}

// handleCreateFile Handles the images endpoint by the test server.
func handleCreateFile(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// edits only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
err = r.ParseMultipartForm(1024 * 1024 * 1024)
if err != nil {
http.Error(w, "file is more than 1GB", http.StatusInternalServerError)
return
}

values := r.Form
var purpose string
for key, value := range values {
if key == "purpose" {
purpose = value[0]
}
}
file, header, err := r.FormFile("file")
if err != nil {
return
}
defer file.Close()

var fileReq = File{
Bytes: int(header.Size),
ID: strconv.Itoa(int(time.Now().Unix())),
FileName: header.Filename,
Purpose: purpose,
CreatedAt: time.Now().Unix(),
Object: "test-objecct",
Owner: "test-owner",
}

resBytes, _ = json.Marshal(fileReq)
fmt.Fprint(w, string(resBytes))
}

+ 1
- 1
go-gpt3/fine_tunes.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"context"


+ 0
- 101
go-gpt3/fine_tunes_test.go View File

@@ -1,101 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"net/http"
"testing"
)

const testFineTuneID = "fine-tune-id"

// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server.
func TestFineTunes(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler(
"/v1/fine-tunes",
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
if r.Method == http.MethodGet {
resBytes, _ = json.Marshal(FineTuneList{})
} else {
resBytes, _ = json.Marshal(FineTune{})
}
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine-tunes/"+testFineTuneID+"/cancel",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTune{})
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine-tunes/"+testFineTuneID,
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
if r.Method == http.MethodDelete {
resBytes, _ = json.Marshal(FineTuneDeleteResponse{})
} else {
resBytes, _ = json.Marshal(FineTune{})
}
fmt.Fprintln(w, string(resBytes))
},
)

server.RegisterHandler(
"/v1/fine-tunes/"+testFineTuneID+"/events",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTuneEventList{})
fmt.Fprintln(w, string(resBytes))
},
)

// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

_, err = client.ListFineTunes(ctx)
if err != nil {
t.Fatalf("ListFineTunes error: %v", err)
}

_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if err != nil {
t.Fatalf("CreateFineTune error: %v", err)
}

_, err = client.CancelFineTune(ctx, testFineTuneID)
if err != nil {
t.Fatalf("CancelFineTune error: %v", err)
}

_, err = client.GetFineTune(ctx, testFineTuneID)
if err != nil {
t.Fatalf("GetFineTune error: %v", err)
}

_, err = client.DeleteFineTune(ctx, testFineTuneID)
if err != nil {
t.Fatalf("DeleteFineTune error: %v", err)
}

_, err = client.ListFineTuneEvents(ctx, testFineTuneID)
if err != nil {
t.Fatalf("ListFineTuneEvents error: %v", err)
}
}

+ 1
- 1
go-gpt3/go.mod View File

@@ -1,3 +1,3 @@
module github.com/sashabaranov/go-openai
module github.com/sashabaranov/go-gogpt

go 1.18

+ 2
- 2
go-gpt3/image.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"bytes"
@@ -147,7 +147,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
return
}
writer.Close()
//https://platform.openai.com/docs/api-reference/images/create-variation
//https://platform.gogpt.com/docs/api-reference/images/create-variation
urlSuffix := "/images/variations"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
if err != nil {


+ 0
- 268
go-gpt3/image_test.go View File

@@ -1,268 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"testing"
"time"
)

func TestImages(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

req := ImageRequest{}
req.Prompt = "Lorem ipsum"
_, err = client.CreateImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

// handleImageEndpoint Handles the images endpoint by the test server.
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var imageReq ImageRequest
if imageReq, err = getImageBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := ImageResponse{
Created: time.Now().Unix(),
}
for i := 0; i < imageReq.N; i++ {
imageData := ImageResponseDataInner{}
switch imageReq.ResponseFormat {
case CreateImageResponseFormatURL, "":
imageData.URL = "https://example.com/image.png"
case CreateImageResponseFormatB64JSON:
// This decodes to "{}" in base64.
imageData.B64JSON = "e30K"
default:
http.Error(w, "invalid response format", http.StatusBadRequest)
return
}
res.Data = append(res.Data, imageData)
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (ImageRequest, error) {
image := ImageRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return ImageRequest{}, err
}
err = json.Unmarshal(reqBody, &image)
if err != nil {
return ImageRequest{}, err
}
return image, nil
}

func TestImageEdit(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}

mask, err := os.Create("mask.png")
if err != nil {
t.Error("open mask file error")
return
}

defer func() {
mask.Close()
origin.Close()
os.Remove("mask.png")
os.Remove("image.png")
}()

req := ImageEditRequest{
Image: origin,
Mask: mask,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
}
_, err = client.CreateEditImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

func TestImageEditWithoutMask(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}

defer func() {
origin.Close()
os.Remove("image.png")
}()

req := ImageEditRequest{
Image: origin,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
}
_, err = client.CreateEditImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

// handleEditImageEndpoint Handles the images endpoint by the test server.
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte

// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}

responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}

resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}

func TestImageVariation(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}

defer func() {
origin.Close()
os.Remove("image.png")
}()

req := ImageVariRequest{
Image: origin,
N: 3,
Size: CreateImageSize1024x1024,
}
_, err = client.CreateVariImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}

// handleVariateImageEndpoint Handles the images endpoint by the test server.
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte

// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}

responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}

resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}

+ 1
- 1
go-gpt3/marshaller.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"encoding/json"


+ 1
- 1
go-gpt3/models.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"context"


+ 0
- 39
go-gpt3/models_test.go View File

@@ -1,39 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"net/http"
"testing"
)

// TestListModels Tests the models endpoint of the API using the mocked server.
func TestListModels(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/models", handleModelsEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

_, err = client.ListModels(ctx)
if err != nil {
t.Fatalf("ListModels error: %v", err)
}
}

// handleModelsEndpoint Handles the models endpoint by the test server.
func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(ModelsList{})
fmt.Fprintln(w, string(resBytes))
}

+ 1
- 1
go-gpt3/moderation.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"context"


+ 0
- 102
go-gpt3/moderation_test.go View File

@@ -1,102 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"
)

// TestModeration Tests the moderations endpoint of the API using the mocked server.
func TestModerations(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

// create an edit request
model := "text-moderation-stable"
moderationReq := ModerationRequest{
Model: &model,
Input: "I want to kill them.",
}
_, err = client.Moderations(ctx, moderationReq)
if err != nil {
t.Fatalf("Moderation error: %v", err)
}
}

// handleModerationEndpoint Handles the moderation endpoint by the test server.
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte

// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var moderationReq ModerationRequest
if moderationReq, err = getModerationBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}

resCat := ResultCategories{}
resCatScore := ResultCategoryScores{}
switch {
case strings.Contains(moderationReq.Input, "kill"):
resCat = ResultCategories{Violence: true}
resCatScore = ResultCategoryScores{Violence: 1}
case strings.Contains(moderationReq.Input, "hate"):
resCat = ResultCategories{Hate: true}
resCatScore = ResultCategoryScores{Hate: 1}
case strings.Contains(moderationReq.Input, "suicide"):
resCat = ResultCategories{SelfHarm: true}
resCatScore = ResultCategoryScores{SelfHarm: 1}
case strings.Contains(moderationReq.Input, "porn"):
resCat = ResultCategories{Sexual: true}
resCatScore = ResultCategoryScores{Sexual: 1}
}

result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true}

res := ModerationResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Model: *moderationReq.Model,
}
res.Results = append(res.Results, result)

resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}

// getModerationBody Returns the body of the request to do a moderation.
func getModerationBody(r *http.Request) (ModerationRequest, error) {
moderation := ModerationRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return ModerationRequest{}, err
}
err = json.Unmarshal(reqBody, &moderation)
if err != nil {
return ModerationRequest{}, err
}
return moderation, nil
}

+ 1
- 1
go-gpt3/request_builder.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"bytes"


+ 0
- 148
go-gpt3/request_builder_test.go View File

@@ -1,148 +0,0 @@
package openai //nolint:testpackage // testing private field

import (
"github.com/sashabaranov/go-openai/internal/test"

"context"
"errors"
"net/http"
"testing"
)

var (
errTestMarshallerFailed = errors.New("test marshaller failed")
errTestRequestBuilderFailed = errors.New("test request builder failed")
)

type (
failingRequestBuilder struct{}
failingMarshaller struct{}
)

func (*failingMarshaller) marshal(value any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}

func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) {
return nil, errTestRequestBuilderFailed
}

func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
builder := httpRequestBuilder{
marshaller: &failingMarshaller{},
}

_, err := builder.build(context.Background(), "", "", struct{}{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
}

func TestClientReturnsRequestBuilderErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}

ctx := context.Background()

_, err = client.CreateCompletion(ctx, CompletionRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListFineTunes(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CancelFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.GetFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.DeleteFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListFineTuneEvents(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

err = client.DeleteFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.GetFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListFiles(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListEngines(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.GetEngine(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}

_, err = client.ListModels(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

+ 1
- 1
go-gpt3/stream.go View File

@@ -1,4 +1,4 @@
package openai
package gogpt

import (
"bufio"


+ 0
- 147
go-gpt3/stream_test.go View File

@@ -1,147 +0,0 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"

"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
)

func TestCreateCompletionStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
}))
defer server.Close()

// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}

client := NewClientWithConfig(config)
ctx := context.Background()

request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}

stream, err := client.CreateCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()

expectedResponses := []CompletionResponse{
{
ID: "1",
Object: "completion",
Created: 1598069254,
Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}},
},
{
ID: "2",
Object: "completion",
Created: 1598069255,
Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
},
}

for ix, expectedResponse := range expectedResponses {
receivedResponse, streamErr := stream.Recv()
if streamErr != nil {
t.Errorf("stream.Recv() failed: %v", streamErr)
}
if !compareResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
}

_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
}

_, streamErr = stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
}
}

// A "tokenRoundTripper" is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each
// request include a valid API token in the headers for authentication and
// authorization.
type tokenRoundTripper struct {
token string
fallback http.RoundTripper
}

// RoundTrip takes an *http.Request as input and returns an
// *http.Response and an error.
//
// It is expected to use the provided request to create a connection to an HTTP
// server and return the response, or an error if one occurred. The returned
// Response should have its Body closed. If the RoundTrip method returns an
// error, the Client's Get, Head, Post, and PostForm methods return the same
// error.
func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", "Bearer "+t.token)
return t.fallback.RoundTrip(req)
}

// Helper funcs.
func compareResponses(r1, r2 CompletionResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
return false
}
if len(r1.Choices) != len(r2.Choices) {
return false
}
for i := range r1.Choices {
if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) {
return false
}
}
return true
}

func compareResponseChoices(c1, c2 CompletionChoice) bool {
if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason {
return false
}
return true
}

Loading…
Cancel
Save