go-chatgpt
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

131 lines
4.3 KiB

  1. package gogpt
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. )
  7. type FineTuneRequest struct {
  8. TrainingFile string `json:"training_file"`
  9. ValidationFile string `json:"validation_file,omitempty"`
  10. Model string `json:"model,omitempty"`
  11. Epochs int `json:"n_epochs,omitempty"`
  12. BatchSize int `json:"batch_size,omitempty"`
  13. LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"`
  14. PromptLossRate float32 `json:"prompt_loss_rate,omitempty"`
  15. ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"`
  16. ClassificationClasses int `json:"classification_n_classes,omitempty"`
  17. ClassificationPositiveClass string `json:"classification_positive_class,omitempty"`
  18. ClassificationBetas []float32 `json:"classification_betas,omitempty"`
  19. Suffix string `json:"suffix,omitempty"`
  20. }
  21. type FineTune struct {
  22. ID string `json:"id"`
  23. Object string `json:"object"`
  24. Model string `json:"model"`
  25. CreatedAt int64 `json:"created_at"`
  26. FineTuneEventList []FineTuneEvent `json:"events,omitempty"`
  27. FineTunedModel string `json:"fine_tuned_model"`
  28. HyperParams FineTuneHyperParams `json:"hyperparams"`
  29. OrganizationID string `json:"organization_id"`
  30. ResultFiles []File `json:"result_files"`
  31. Status string `json:"status"`
  32. ValidationFiles []File `json:"validation_files"`
  33. TrainingFiles []File `json:"training_files"`
  34. UpdatedAt int64 `json:"updated_at"`
  35. }
  36. type FineTuneEvent struct {
  37. Object string `json:"object"`
  38. CreatedAt int64 `json:"created_at"`
  39. Level string `json:"level"`
  40. Message string `json:"message"`
  41. }
  42. type FineTuneHyperParams struct {
  43. BatchSize int `json:"batch_size"`
  44. LearningRateMultiplier float64 `json:"learning_rate_multiplier"`
  45. Epochs int `json:"n_epochs"`
  46. PromptLossWeight float64 `json:"prompt_loss_weight"`
  47. }
  48. type FineTuneList struct {
  49. Object string `json:"object"`
  50. Data []FineTune `json:"data"`
  51. }
  52. type FineTuneEventList struct {
  53. Object string `json:"object"`
  54. Data []FineTuneEvent `json:"data"`
  55. }
  56. type FineTuneDeleteResponse struct {
  57. ID string `json:"id"`
  58. Object string `json:"object"`
  59. Deleted bool `json:"deleted"`
  60. }
  61. func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
  62. urlSuffix := "/fine-tunes"
  63. req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
  64. if err != nil {
  65. return
  66. }
  67. err = c.sendRequest(req, &response)
  68. return
  69. }
  70. // CancelFineTune cancel a fine-tune job.
  71. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
  72. req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
  73. if err != nil {
  74. return
  75. }
  76. err = c.sendRequest(req, &response)
  77. return
  78. }
  79. func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
  80. req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
  81. if err != nil {
  82. return
  83. }
  84. err = c.sendRequest(req, &response)
  85. return
  86. }
  87. func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
  88. urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
  89. req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
  90. if err != nil {
  91. return
  92. }
  93. err = c.sendRequest(req, &response)
  94. return
  95. }
  96. func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
  97. req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
  98. if err != nil {
  99. return
  100. }
  101. err = c.sendRequest(req, &response)
  102. return
  103. }
  104. func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
  105. req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
  106. if err != nil {
  107. return
  108. }
  109. err = c.sendRequest(req, &response)
  110. return
  111. }