From 87200fff34f8a842d5af6aedb0046939e55a5776 Mon Sep 17 00:00:00 2001 From: DengBiao <2319963317@qq.com> Date: Thu, 9 Mar 2023 15:55:48 +0800 Subject: [PATCH] add Reverse: for v0.0.4 --- go-gpt3/audio.go | 100 ++++++++++++++++++++++++++++++++ go-gpt3/internal/test/server.go | 46 +++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 go-gpt3/audio.go create mode 100644 go-gpt3/internal/test/server.go diff --git a/go-gpt3/audio.go b/go-gpt3/audio.go new file mode 100644 index 0000000..0dc611e --- /dev/null +++ b/go-gpt3/audio.go @@ -0,0 +1,100 @@ +package gogpt + +import ( + "bytes" + "context" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" +) + +// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. +const ( + Whisper1 = "whisper-1" +) + +// AudioRequest represents a request structure for audio API. +type AudioRequest struct { + Model string + FilePath string +} + +// AudioResponse represents a response structure for audio API. +type AudioResponse struct { + Text string `json:"text"` +} + +// CreateTranscription — API call to create a transcription. Returns transcribed text. +func (c *Client) CreateTranscription( + ctx context.Context, + request AudioRequest, +) (response AudioResponse, err error) { + response, err = c.callAudioAPI(ctx, request, "transcriptions") + return +} + +// CreateTranslation — API call to translate audio into English. +func (c *Client) CreateTranslation( + ctx context.Context, + request AudioRequest, +) (response AudioResponse, err error) { + response, err = c.callAudioAPI(ctx, request, "translations") + return +} + +// callAudioAPI — API call to an audio endpoint. +func (c *Client) callAudioAPI( + ctx context.Context, + request AudioRequest, + endpointSuffix string, +) (response AudioResponse, err error) { + var formBody bytes.Buffer + w := multipart.NewWriter(&formBody) + + if err = audioMultipartForm(request, w); err != nil { + return + } + + urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody) + if err != nil { + return + } + req.Header.Add("Content-Type", w.FormDataContentType()) + + err = c.sendRequest(req, &response) + return +} + +// audioMultipartForm creates a form with audio file contents and the name of the model to use for +// audio processing. +func audioMultipartForm(request AudioRequest, w *multipart.Writer) error { + f, err := os.Open(request.FilePath) + if err != nil { + return fmt.Errorf("opening audio file: %w", err) + } + + fw, err := w.CreateFormFile("file", f.Name()) + if err != nil { + return fmt.Errorf("creating form file: %w", err) + } + + if _, err = io.Copy(fw, f); err != nil { + return fmt.Errorf("reading from opened audio file: %w", err) + } + + fw, err = w.CreateFormField("model") + if err != nil { + return fmt.Errorf("creating form field: %w", err) + } + + modelName := bytes.NewReader([]byte(request.Model)) + if _, err = io.Copy(fw, modelName); err != nil { + return fmt.Errorf("writing model name: %w", err) + } + w.Close() + + return nil +} diff --git a/go-gpt3/internal/test/server.go b/go-gpt3/internal/test/server.go new file mode 100644 index 0000000..0c6f67d --- /dev/null +++ b/go-gpt3/internal/test/server.go @@ -0,0 +1,46 @@ +package test + +import ( + "log" + "net/http" + "net/http/httptest" +) + +const testAPI = "this-is-my-secure-token-do-not-steal!!" + +func GetTestToken() string { + return testAPI +} + +type ServerTest struct { + handlers map[string]handler +} +type handler func(w http.ResponseWriter, r *http.Request) + +func NewTestServer() *ServerTest { + return &ServerTest{handlers: make(map[string]handler)} +} + +func (ts *ServerTest) RegisterHandler(path string, handler handler) { + ts.handlers[path] = handler +} + +// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. +func (ts *ServerTest) OpenAITestServer() *httptest.Server { + return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("received request at path %q\n", r.URL.Path) + + // check auth + if r.Header.Get("Authorization") != "Bearer "+GetTestToken() { + w.WriteHeader(http.StatusUnauthorized) + return + } + + handlerCall, ok := ts.handlers[r.URL.Path] + if !ok { + http.Error(w, "the resource path doesn't exist", http.StatusNotFound) + return + } + handlerCall(w, r) + })) +}