diff --git a/go-gpt3/chat.go b/go-gpt3/chat.go index 9200665..7ac1121 100644 --- a/go-gpt3/chat.go +++ b/go-gpt3/chat.go @@ -41,6 +41,24 @@ type ChatCompletionChoice struct { FinishReason string `json:"finish_reason"` } +type ChatCompletionChoiceForStream struct { + Index int `json:"index"` + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +// ChatCompletionResponseForStream represents a response structure for chat completion API. +type ChatCompletionResponseForStream struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoiceForStream `json:"choices"` + Usage Usage `json:"usage"` +} + // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { ID string `json:"id"` @@ -81,8 +99,13 @@ func (c *Client) CreateChatCompletion( // CreateChatCompletionStream — API call to create a completion w/ streaming func (c *Client) CreateChatCompletionStream( ctx context.Context, - request CompletionRequest, + request ChatCompletionRequest, ) (stream *CompletionStream, err error) { + model := request.Model + if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo { + err = ErrChatCompletionInvalidModel + return + } request.Stream = true reqBytes, err := json.Marshal(request) if err != nil { diff --git a/go-gpt3/stream.go b/go-gpt3/stream.go index d1bdf48..8f35d2e 100644 --- a/go-gpt3/stream.go +++ b/go-gpt3/stream.go @@ -60,6 +60,43 @@ waitForData: return } +func (stream *CompletionStream) ChatRecv() (response ChatCompletionResponseForStream, err error) { + if stream.isFinished { + err = io.EOF + return + } + + var emptyMessagesCount uint + +waitForData: + line, err := stream.reader.ReadBytes('\n') + if err != nil { + return + } + + var headerData = []byte("data: ") + line = bytes.TrimSpace(line) + if !bytes.HasPrefix(line, headerData) { + emptyMessagesCount++ + if emptyMessagesCount > stream.emptyMessagesLimit { + err = ErrTooManyEmptyStreamMessages + return + } + + goto waitForData + } + + line = bytes.TrimPrefix(line, headerData) + if string(line) == "[DONE]" { + stream.isFinished = true + err = io.EOF + return + } + + err = json.Unmarshal(line, &response) + return +} + func (stream *CompletionStream) Close() { stream.response.Body.Close() }