diff --git a/cmd/gpt3-sdk/main.go b/cmd/gpt3-sdk/main.go new file mode 100644 index 0000000..7f600c5 --- /dev/null +++ b/cmd/gpt3-sdk/main.go @@ -0,0 +1,45 @@ +//Package gpt3 provides access to the GPT3 completions Api +//along with new beta APIs for classification, enhanced search, and question answering. +// +//The underlying structure is defined along a request / response interface pattern with a +//singular call to the client. +//The request is initialised as per required parameters an example being: +// +// req := gpt3.CompletionRequest{ +// Prompt: string(query), +// MaxTokens: 60, +// TopP: 1, +// Temperature: 0.3, +// FrequencyPenalty: 0.5, +// PresencePenalty: 0, +// Stop: []string{"You:"}, +// } +// +//The content filter endpoint is used to validate a prompt in order to safeguard responses ushered back to the enduser. +//The request object should always have the following parameters: +// +// reformattedPrompt := fmt.Sprintf("<|endoftext|>[%s]\n--\nLabel:", string(query)) +// +// req := gpt3.ContentFilterRequest{ +// Prompt: reformattedPrompt, +// MaxTokens: 1, +// TopP: 0, +// Temperature: 0, +// Logprobs: 10, +// } +// +// The Response is the same format as that of the Completions request with the following entries: +// +// 0 => text is safe +// 1 => This text is sensitive. This means that the text could be talking about a sensitive topic, something political, +// religious, or talking about a protected class such as race or nationality. +// 2 => This text is unsafe. This means that the text contains profane language, prejudiced or hateful language, +// something that could be NSFW, or text that portrays certain groups/people in a harmful manner. +// +// Code Generation: +// +// Added to the completions API are the codex engines for code generation. +// The Codex model series is a descendant of our base GPT-3 series that’s been trained on both +// natural language and billions of lines of code. + +package gpt3_sdk diff --git a/gpt3.go b/gpt3.go deleted file mode 100644 index 7636de7..0000000 --- a/gpt3.go +++ /dev/null @@ -1,142 +0,0 @@ -//Package gpt3 provides access to the the GPT3 completions Api -//along with new beta APIs for classification, enhanced search, and question answering. -// -//The underlying structure is defined along a request / response interface pattern with a -//singular call to the client. -//The request is initialised as per required parameters an example being: -// -// req := gpt3.CompletionRequest{ -// Prompt: string(query), -// MaxTokens: 60, -// TopP: 1, -// Temperature: 0.3, -// FrequencyPenalty: 0.5, -// PresencePenalty: 0, -// Stop: []string{"You:"}, -// } -// -//The content filter endpoint is used to validate a prompt in order to safeguard responses ushered back to the enduser. -//The request object should always have the following parameters: -// -// reformattedPrompt := fmt.Sprintf("<|endoftext|>[%s]\n--\nLabel:", string(query)) -// -// req := gpt3.ContentFilterRequest{ -// Prompt: reformattedPrompt, -// MaxTokens: 1, -// TopP: 0, -// Temperature: 0, -// Logprobs: 10, -// } -// -// The Response is the same format as that of the Completions request with the following entries: -// -// 0 => text is safe -// 1 => This text is sensitive. This means that the text could be talking about a sensitive topic, something political, -// religious, or talking about a protected class such as race or nationality. -// 2 => This text is unsafe. This means that the text contains profane language, prejudiced or hateful language, -// something that could be NSFW, or text that portrays certain groups/people in a harmful manner. -// -// Code Generation: -// -// Added to the completions API are the codex engines for code generation. -// The Codex model series is a descendant of our base GPT-3 series that’s been trained on both -// natural language and billions of lines of code. - -package gpt3 - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "sync" -) - -var once sync.Once - -const ( - baseUrl = "https://api.openai.com" - defaultVersion = "v1" - apiKeyName = "OPENAI_API_KEY" - apiKeyMissingError = "Api key required. Please ensure env variable %s is set." -) - -type Client interface { - Call(request Request) (*Response, error) - Setup(...string) -} - -type ApiClient struct { - apiKey string - engines []string -} - -func (a ApiClient) Call(request Request) (*Response, error) { - var err error - var req *http.Request - - config := RequestConfig{ - endpointVersion: defaultVersion, - baseUrl: baseUrl, - engine: a.engines[0], - } - - jsonStr, err := json.Marshal(request) - if err != nil { - _ = fmt.Errorf("Request marshalling error: %s\n", err) - return nil, err - } - - requestMethod, requestUrl := request.getRequestMeta(config) - - if requestMethod == getRequest{ - req, err = http.NewRequest(requestMethod, requestUrl, nil) - } else { - req, err = http.NewRequest(requestMethod, requestUrl, bytes.NewBuffer(jsonStr)) - } - - if err != nil { - _ = fmt.Errorf("Http Request creation error: %s\n", err) - return nil, err - } - - authHeader := fmt.Sprintf("Bearer %s", a.apiKey) - req.Header.Set("Authorization", authHeader) - req.Header.Set("Content-Type", "application/json") - client := &http.Client{} - resp, err := client.Do(req) - - if err != nil { - _ = fmt.Errorf("Http request error: %s\n", err) - return nil, err - } - defer resp.Body.Close() - - respObj := request.attachResponse() - data, err := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK{ - _ = fmt.Errorf("Http response user error %d\n", resp.StatusCode) - errObj := ErrorBag{} - json.Unmarshal(data, &errObj) - return nil, errObj - } - - if err := json.Unmarshal(data, respObj); err != nil { - _ = fmt.Errorf("Http response unmarshal error: %s\n", err) - return nil, err - } - return &respObj, nil -} - -func (a *ApiClient) Setup(engines ...string) *ApiClient { - once.Do(func() { - //Perform single action initialisations - }) - - a.apiKey = os.Getenv(apiKeyName) - a.engines = append(a.engines, engines...) - return a -} diff --git a/internal/config.go b/internal/config.go new file mode 100644 index 0000000..48e4eaf --- /dev/null +++ b/internal/config.go @@ -0,0 +1,37 @@ +package internal + +import ( + "log" + "os" +) + +const ( + baseUrl = "OPENAI_API_BASE_URL" + defaultVersion = "OPENAI_API_VERSION" + apiKeyName = "OPENAI_API_KEY" +) + +var Config config + +type config struct { + Gpt3BaseUrl string + Gpt3ApiVersion string + Gpt3ApiKey string +} + +func init() { + Config.Gpt3BaseUrl = os.Getenv(baseUrl) + if Config.Gpt3BaseUrl == "" { + Config.Gpt3BaseUrl = "https://api.openai.com" + } + + Config.Gpt3ApiVersion = os.Getenv(defaultVersion) + if Config.Gpt3ApiVersion == "" { + Config.Gpt3ApiVersion = "v1" + } + + Config.Gpt3ApiKey = os.Getenv(apiKeyName) + if Config.Gpt3ApiKey == "" { + log.Panicf("Api key required. Please ensure env variable %s is set.", apiKeyName) + } +} diff --git a/internal/examples/main.go b/internal/examples/main.go index 90a8e2a..96161ef 100644 --- a/internal/examples/main.go +++ b/internal/examples/main.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "github.com/Godzab/go-gpt3" + "github.com/Godzab/go-gpt3/cmd" "io/ioutil" "log" ) @@ -18,14 +19,13 @@ func main() { //EnginesCall() } - -func answersCall(){ +func answersCall() { examples := make([][]string, 1) - data1 := []string{"What is human life expectancy in the United States?","78 years."} + data1 := []string{"What is human life expectancy in the United States?", "78 years."} examples[0] = data1 req := gpt3.AnswerRequest{ - Documents: []string{"Puppy A is happy.","Puppy B is sad."}, + Documents: []string{"Puppy A is happy.", "Puppy B is sad."}, Question: "which puppy is happy?", SearchModel: gpt3.ADA, Model: gpt3.CURIE, @@ -37,7 +37,7 @@ func answersCall(){ N: 1, } - cl := gpt3.ApiClient{} + cl := cmd.ApiClient{} cl.Setup(gpt3.ADA, gpt3.DAVINCI) response, err := cl.Call(&req) @@ -50,22 +50,22 @@ func answersCall(){ fmt.Println(results) } -func completionCall(){ +func completionCall() { query, err := ioutil.ReadFile("prompts.txt") if err != nil { panic(err) } req := gpt3.CompletionRequest{ - Prompt: string(query), - MaxTokens: 60, - TopP: 1, - Temperature: 0.3, + Prompt: string(query), + MaxTokens: 60, + TopP: 1, + Temperature: 0.3, FrequencyPenalty: 0.5, - PresencePenalty: 0, - Stop: []string{"You:"}, + PresencePenalty: 0, + Stop: []string{"You:"}, } - cl := gpt3.ApiClient{} + cl := cmd.ApiClient{} cl.Setup(gpt3.DAVINCI_INSTRUCT_BETA, gpt3.DAVINCI) response, err := cl.Call(&req) @@ -76,26 +76,26 @@ func completionCall(){ data := *response results, _ := data.(*gpt3.CompletionResponse) - for _,t := range results.Choices{ + for _, t := range results.Choices { fmt.Println(t) } } -func completionCodexCall(){ +func completionCodexCall() { query, err := ioutil.ReadFile("prompts.txt") if err != nil { panic(err) } req := gpt3.CompletionRequest{ - Prompt: string(query), - MaxTokens: 300, - TopP: 1, - Temperature: 0.5, + Prompt: string(query), + MaxTokens: 300, + TopP: 1, + Temperature: 0.5, FrequencyPenalty: 0.5, - PresencePenalty: 0, + PresencePenalty: 0, } - cl := gpt3.ApiClient{} + cl := cmd.ApiClient{} cl.Setup(gpt3.DAVINCI_CODEX) response, err := cl.Call(&req) @@ -106,18 +106,18 @@ func completionCodexCall(){ data := *response results, _ := data.(*gpt3.CompletionResponse) - for _,t := range results.Choices{ + for _, t := range results.Choices { fmt.Println(t) } } -func SearchCall(){ +func SearchCall() { req := gpt3.SearchRequest{ - Documents: []string{"White House","hospital","school","City"}, - Query: "the headmaster", + Documents: []string{"White House", "hospital", "school", "City"}, + Query: "the headmaster", } - cl := gpt3.ApiClient{} + cl := cmd.ApiClient{} cl.Setup(gpt3.DAVINCI, gpt3.DAVINCI_INSTRUCT_BETA) response, err := cl.Call(&req) @@ -128,14 +128,14 @@ func SearchCall(){ data := *response results, _ := data.(*gpt3.SearchResponse) - for _,t := range results.Data{ + for _, t := range results.Data { fmt.Println(t) } } -func EnginesCall(){ +func EnginesCall() { req := gpt3.EnginesRequest{} - cl := gpt3.ApiClient{} + cl := cmd.ApiClient{} cl.Setup(gpt3.DAVINCI) response, err := cl.Call(&req) @@ -146,14 +146,14 @@ func EnginesCall(){ data := *response results, _ := data.(*gpt3.EnginesResponse) - for _,t := range results.Data{ + for _, t := range results.Data { fmt.Println(t) } } -func FilesCall(){ +func FilesCall() { req := gpt3.FilesRequest{} - cl := gpt3.ApiClient{} + cl := cmd.ApiClient{} cl.Setup(gpt3.CURIE) response, err := cl.Call(&req) @@ -164,13 +164,12 @@ func FilesCall(){ data := *response results, _ := data.(*gpt3.FilesResponse) - for _,t := range results.Data{ + for _, t := range results.Data { fmt.Println(t) } } - -func contentFilterCall(){ +func contentFilterCall() { query, err := ioutil.ReadFile("prompts.txt") if err != nil { panic(err) @@ -181,10 +180,10 @@ func contentFilterCall(){ MaxTokens: 1, TopP: 0, Temperature: 0, - Logprobs: 10, + Logprobs: 10, } - cl := gpt3.ApiClient{} + cl := cmd.ApiClient{} cl.Setup(gpt3.DAVINCI_INSTRUCT_BETA, gpt3.DAVINCI) response, err := cl.Call(&req) @@ -195,8 +194,8 @@ func contentFilterCall(){ data := *response results, _ := data.(*gpt3.CompletionResponse) jsn, err := json.MarshalIndent(results, "", " ") - if err != nil{ + if err != nil { log.Fatalln(err) } fmt.Print(string(jsn), "\n") -} \ No newline at end of file +} diff --git a/pkg/client/client.go b/pkg/client/client.go new file mode 100644 index 0000000..23c34a4 --- /dev/null +++ b/pkg/client/client.go @@ -0,0 +1,100 @@ +package client + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/Godzab/go-gpt3/internal" + "github.com/Godzab/go-gpt3/pkg/models" + "io" + "net/http" +) + +type Client interface { + Call(request models.Request) (*models.Response, error) + Setup(...string) +} + +type Gpt3Client struct { + apiKey string + apiBaseUrl string + apiVersion string + engines []string +} + +func (a *Gpt3Client) Call(request models.Request) (*models.Response, error) { + var err error + var req *http.Request + + config := models.RequestConfig{ + EndpointVersion: a.apiVersion, + BaseUrl: a.apiBaseUrl, + Engine: a.engines[0], + } + + jsonStr, err := json.Marshal(request) + + if err != nil { + _ = fmt.Errorf("Request marshalling error: %s\n", err) + return nil, err + } + + req, err = a.instantiateRequestObject(request, config, req, err, jsonStr) + + if err != nil { + _ = fmt.Errorf("Http Request creation error: %s\n", err) + return nil, err + } + + prepareRequestHeaders(a.apiKey, req) + + client := &http.Client{} + resp, err := client.Do(req) + + if err != nil { + _ = fmt.Errorf("Http request error: %s\n", err) + return nil, err + } + + defer resp.Body.Close() + + respObj := request.AttachResponse() + data, err := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + _ = fmt.Errorf("Http response user error %d\n", resp.StatusCode) + errObj := models.ErrorBag{} + json.Unmarshal(data, &errObj) + return nil, errObj + } + + if err := json.Unmarshal(data, respObj); err != nil { + _ = fmt.Errorf("Http response unmarshal error: %s\n", err) + return nil, err + } + return &respObj, nil +} + +func (a *Gpt3Client) instantiateRequestObject(request models.Request, config models.RequestConfig, req *http.Request, err error, jsonStr []byte) (*http.Request, error) { + requestMethod, requestUrl := request.GetRequestMeta(config) + + if requestMethod == "GET" { + req, err = http.NewRequest(requestMethod, requestUrl, nil) + } else { + req, err = http.NewRequest(requestMethod, requestUrl, bytes.NewBuffer(jsonStr)) + } + return req, err +} + +func prepareRequestHeaders(apiKey string, req *http.Request) { + authHeader := fmt.Sprintf("Bearer %s", apiKey) + req.Header.Set("Authorization", authHeader) + req.Header.Set("Content-Type", "application/json") +} + +func (a *Gpt3Client) Setup(engines ...string) { + a.apiKey = internal.Config.Gpt3ApiKey + a.apiVersion = internal.Config.Gpt3ApiVersion + a.apiBaseUrl = internal.Config.Gpt3ApiKey + a.engines = append(a.engines, engines...) +} diff --git a/models.go b/pkg/models/models.go similarity index 92% rename from models.go rename to pkg/models/models.go index e0a5b8b..f5b7a0e 100644 --- a/models.go +++ b/pkg/models/models.go @@ -1,4 +1,4 @@ -package gpt3 +package models import ( "fmt" @@ -6,7 +6,7 @@ import ( ) const ( - DAVINCI = "davinci" + DAVINCI = "text-davinci-002" CURIE = "curie" BABBAGE = "babbage" BABBAGE_INSTRUCT_BETA = "text-babbage-001" @@ -30,7 +30,6 @@ const ( CUSHMAN_CODEX = "cushman-codex" ) -// const ( getRequest = "GET" postRequest = "POST" @@ -43,12 +42,12 @@ const ( ) type RequestConfig struct { - endpointVersion, baseUrl, engine string + EndpointVersion, BaseUrl, Engine string } type Request interface { - attachResponse() Response - getRequestMeta(config RequestConfig) (string, string) + AttachResponse() Response + GetRequestMeta(config RequestConfig) (string, string) } type Response interface { @@ -121,7 +120,7 @@ func (r *FilesResponse) GetBody() Response { } func (r *FilesRequest) getRequestMeta(config RequestConfig) (string, string) { - return getRequest, fmt.Sprintf("%s/%s/files", config.baseUrl, config.endpointVersion) + return getRequest, fmt.Sprintf("%s/%s/files", config.BaseUrl, config.EndpointVersion) } // File models @@ -144,7 +143,7 @@ func (r *FileResponse) GetBody() Response { } func (r *FileRequest) getRequestMeta(config RequestConfig) (string, string) { - return postRequest, fmt.Sprintf("%s/%s/files", config.baseUrl, config.endpointVersion) + return postRequest, fmt.Sprintf("%s/%s/files", config.BaseUrl, config.EndpointVersion) } // CompletionRequest Completion model structures @@ -178,14 +177,14 @@ func (r *CompletionRequest) attachResponse() Response { } func (r *CompletionRequest) getRequestMeta(config RequestConfig) (string, string) { - return postRequest, fmt.Sprintf("%s/%s/engines/%s/completions", config.baseUrl, config.endpointVersion, config.engine) + return postRequest, fmt.Sprintf("%s/%s/engines/%s/completions", config.BaseUrl, config.EndpointVersion, config.Engine) } func (r *CompletionResponse) GetBody() Response { return r } -//ContentFilterRequest Content filter model structures +// ContentFilterRequest Content filter model structures type ContentFilterRequest struct { Prompt string `json:"prompt"` MaxTokens int `json:"max_tokens"` @@ -203,7 +202,7 @@ func (r *ContentFilterRequest) attachResponse() Response { } func (r *ContentFilterRequest) getRequestMeta(config RequestConfig) (string, string) { - return postRequest, fmt.Sprintf("%s/%s/engines/content-filter-alpha-c4/completions", config.baseUrl, config.endpointVersion) + return postRequest, fmt.Sprintf("%s/%s/engines/content-filter-alpha-c4/completions", config.BaseUrl, config.EndpointVersion) } // SearchRequest Search Model structures @@ -227,7 +226,7 @@ func (r *SearchRequest) attachResponse() Response { } func (r *SearchRequest) getRequestMeta(config RequestConfig) (string, string) { - return postRequest, fmt.Sprintf("%s/%s/engines/%s/search", config.baseUrl, config.endpointVersion, config.engine) + return postRequest, fmt.Sprintf("%s/%s/engines/%s/search", config.BaseUrl, config.EndpointVersion, config.Engine) } func (r *SearchResponse) GetBody() Response { @@ -251,7 +250,7 @@ func (r *EnginesRequest) attachResponse() Response { } func (r *EnginesRequest) getRequestMeta(config RequestConfig) (string, string) { - return getRequest, fmt.Sprintf("%s/%s/engines", config.baseUrl, config.endpointVersion) + return getRequest, fmt.Sprintf("%s/%s/engines", config.BaseUrl, config.EndpointVersion) } // ClassificationRequest Classification Model structures @@ -286,7 +285,7 @@ func (r *ClassificationRequest) attachResponse() Response { } func (r *ClassificationRequest) getRequestMeta(config RequestConfig) (string, string) { - return postRequest, fmt.Sprintf("%s/%s/classifications", config.baseUrl, config.endpointVersion) + return postRequest, fmt.Sprintf("%s/%s/classifications", config.BaseUrl, config.EndpointVersion) } func (r *ClassificationResponse) GetBody() Response { @@ -329,14 +328,14 @@ func (r *AnswerRequest) attachResponse() Response { } func (r *AnswerRequest) getRequestMeta(config RequestConfig) (string, string) { - return postRequest, fmt.Sprintf("%s/%s/answers", config.baseUrl, config.endpointVersion) + return postRequest, fmt.Sprintf("%s/%s/answers", config.BaseUrl, config.EndpointVersion) } func (r *AnswerResponse) GetBody() Response { return r } -//GptErrorResponse Error handling for client calls +// GptErrorResponse Error handling for client calls type GptErrorResponse struct { Code interface{} `json:"code"` Message string `json:"message"` diff --git a/pkg/models/models_test.go b/pkg/models/models_test.go new file mode 100644 index 0000000..b9b6625 --- /dev/null +++ b/pkg/models/models_test.go @@ -0,0 +1,899 @@ +package models + +import ( + "os" + "reflect" + "testing" +) + +func TestAnswerRequest_attachResponse(t *testing.T) { + type fields struct { + Documents []string + Question string + SearchModel string + Model string + ExamplesContext string + Examples [][]string + MaxTokens int + Stop []string + File string + MaxRerank int32 + Temperature float32 + Logprobs interface{} + N int + LogitBias map[string]int8 + ReturnPrompt bool + ReturnMetadata bool + Expand []string + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &AnswerRequest{ + Documents: tt.fields.Documents, + Question: tt.fields.Question, + SearchModel: tt.fields.SearchModel, + Model: tt.fields.Model, + ExamplesContext: tt.fields.ExamplesContext, + Examples: tt.fields.Examples, + MaxTokens: tt.fields.MaxTokens, + Stop: tt.fields.Stop, + File: tt.fields.File, + MaxRerank: tt.fields.MaxRerank, + Temperature: tt.fields.Temperature, + Logprobs: tt.fields.Logprobs, + N: tt.fields.N, + LogitBias: tt.fields.LogitBias, + ReturnPrompt: tt.fields.ReturnPrompt, + ReturnMetadata: tt.fields.ReturnMetadata, + Expand: tt.fields.Expand, + } + if got := r.attachResponse(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("attachResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAnswerRequest_getRequestMeta(t *testing.T) { + type fields struct { + Documents []string + Question string + SearchModel string + Model string + ExamplesContext string + Examples [][]string + MaxTokens int + Stop []string + File string + MaxRerank int32 + Temperature float32 + Logprobs interface{} + N int + LogitBias map[string]int8 + ReturnPrompt bool + ReturnMetadata bool + Expand []string + } + type args struct { + config RequestConfig + } + tests := []struct { + name string + fields fields + args args + want string + want1 string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &AnswerRequest{ + Documents: tt.fields.Documents, + Question: tt.fields.Question, + SearchModel: tt.fields.SearchModel, + Model: tt.fields.Model, + ExamplesContext: tt.fields.ExamplesContext, + Examples: tt.fields.Examples, + MaxTokens: tt.fields.MaxTokens, + Stop: tt.fields.Stop, + File: tt.fields.File, + MaxRerank: tt.fields.MaxRerank, + Temperature: tt.fields.Temperature, + Logprobs: tt.fields.Logprobs, + N: tt.fields.N, + LogitBias: tt.fields.LogitBias, + ReturnPrompt: tt.fields.ReturnPrompt, + ReturnMetadata: tt.fields.ReturnMetadata, + Expand: tt.fields.Expand, + } + got, got1 := r.getRequestMeta(tt.args.config) + if got != tt.want { + t.Errorf("getRequestMeta() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("getRequestMeta() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestAnswerResponse_GetBody(t *testing.T) { + type fields struct { + Answers []string + Completion CompletionResponse + Model string + Object string + SearchModel string + SelectedDocuments []Document + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &AnswerResponse{ + Answers: tt.fields.Answers, + Completion: tt.fields.Completion, + Model: tt.fields.Model, + Object: tt.fields.Object, + SearchModel: tt.fields.SearchModel, + SelectedDocuments: tt.fields.SelectedDocuments, + } + if got := r.GetBody(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetBody() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClassificationRequest_attachResponse(t *testing.T) { + type fields struct { + Examples [][]string + Labels []string + Query string + File string + SearchModel string + Model string + Temperature float32 + Logprobs interface{} + MaxExamples int32 + LogitBias map[string]int8 + ReturnPrompt bool + ReturnMetadata bool + Expand []string + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ClassificationRequest{ + Examples: tt.fields.Examples, + Labels: tt.fields.Labels, + Query: tt.fields.Query, + File: tt.fields.File, + SearchModel: tt.fields.SearchModel, + Model: tt.fields.Model, + Temperature: tt.fields.Temperature, + Logprobs: tt.fields.Logprobs, + MaxExamples: tt.fields.MaxExamples, + LogitBias: tt.fields.LogitBias, + ReturnPrompt: tt.fields.ReturnPrompt, + ReturnMetadata: tt.fields.ReturnMetadata, + Expand: tt.fields.Expand, + } + if got := r.attachResponse(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("attachResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClassificationRequest_getRequestMeta(t *testing.T) { + type fields struct { + Examples [][]string + Labels []string + Query string + File string + SearchModel string + Model string + Temperature float32 + Logprobs interface{} + MaxExamples int32 + LogitBias map[string]int8 + ReturnPrompt bool + ReturnMetadata bool + Expand []string + } + type args struct { + config RequestConfig + } + tests := []struct { + name string + fields fields + args args + want string + want1 string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ClassificationRequest{ + Examples: tt.fields.Examples, + Labels: tt.fields.Labels, + Query: tt.fields.Query, + File: tt.fields.File, + SearchModel: tt.fields.SearchModel, + Model: tt.fields.Model, + Temperature: tt.fields.Temperature, + Logprobs: tt.fields.Logprobs, + MaxExamples: tt.fields.MaxExamples, + LogitBias: tt.fields.LogitBias, + ReturnPrompt: tt.fields.ReturnPrompt, + ReturnMetadata: tt.fields.ReturnMetadata, + Expand: tt.fields.Expand, + } + got, got1 := r.getRequestMeta(tt.args.config) + if got != tt.want { + t.Errorf("getRequestMeta() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("getRequestMeta() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestClassificationResponse_GetBody(t *testing.T) { + type fields struct { + Completion string + Label string + Model string + Object string + SearchModel string + SelectedExamples []ClassificationExamples + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ClassificationResponse{ + Completion: tt.fields.Completion, + Label: tt.fields.Label, + Model: tt.fields.Model, + Object: tt.fields.Object, + SearchModel: tt.fields.SearchModel, + SelectedExamples: tt.fields.SelectedExamples, + } + if got := r.GetBody(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetBody() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCompletionRequest_attachResponse(t *testing.T) { + type fields struct { + Prompt string + MaxTokens int + Temperature float32 + TopP float32 + N int + Stream bool + Logprobs int + Stop []string + Echo bool + PresencePenalty float32 + FrequencyPenalty float32 + BestOf float32 + LogitBias map[string]int8 + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &CompletionRequest{ + Prompt: tt.fields.Prompt, + MaxTokens: tt.fields.MaxTokens, + Temperature: tt.fields.Temperature, + TopP: tt.fields.TopP, + N: tt.fields.N, + Stream: tt.fields.Stream, + Logprobs: tt.fields.Logprobs, + Stop: tt.fields.Stop, + Echo: tt.fields.Echo, + PresencePenalty: tt.fields.PresencePenalty, + FrequencyPenalty: tt.fields.FrequencyPenalty, + BestOf: tt.fields.BestOf, + LogitBias: tt.fields.LogitBias, + } + if got := r.attachResponse(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("attachResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCompletionRequest_getRequestMeta(t *testing.T) { + type fields struct { + Prompt string + MaxTokens int + Temperature float32 + TopP float32 + N int + Stream bool + Logprobs int + Stop []string + Echo bool + PresencePenalty float32 + FrequencyPenalty float32 + BestOf float32 + LogitBias map[string]int8 + } + type args struct { + config RequestConfig + } + tests := []struct { + name string + fields fields + args args + want string + want1 string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &CompletionRequest{ + Prompt: tt.fields.Prompt, + MaxTokens: tt.fields.MaxTokens, + Temperature: tt.fields.Temperature, + TopP: tt.fields.TopP, + N: tt.fields.N, + Stream: tt.fields.Stream, + Logprobs: tt.fields.Logprobs, + Stop: tt.fields.Stop, + Echo: tt.fields.Echo, + PresencePenalty: tt.fields.PresencePenalty, + FrequencyPenalty: tt.fields.FrequencyPenalty, + BestOf: tt.fields.BestOf, + LogitBias: tt.fields.LogitBias, + } + got, got1 := r.getRequestMeta(tt.args.config) + if got != tt.want { + t.Errorf("getRequestMeta() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("getRequestMeta() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestCompletionResponse_GetBody(t *testing.T) { + type fields struct { + ID string + Object string + Created int + Model string + Choices []Choices + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &CompletionResponse{ + ID: tt.fields.ID, + Object: tt.fields.Object, + Created: tt.fields.Created, + Model: tt.fields.Model, + Choices: tt.fields.Choices, + } + if got := r.GetBody(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetBody() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestContentFilterRequest_attachResponse(t *testing.T) { + type fields struct { + Prompt string + MaxTokens int + Temperature float32 + TopP float32 + N int + Logprobs int + PresencePenalty float32 + FrequencyPenalty float32 + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ContentFilterRequest{ + Prompt: tt.fields.Prompt, + MaxTokens: tt.fields.MaxTokens, + Temperature: tt.fields.Temperature, + TopP: tt.fields.TopP, + N: tt.fields.N, + Logprobs: tt.fields.Logprobs, + PresencePenalty: tt.fields.PresencePenalty, + FrequencyPenalty: tt.fields.FrequencyPenalty, + } + if got := r.attachResponse(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("attachResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestContentFilterRequest_getRequestMeta(t *testing.T) { + type fields struct { + Prompt string + MaxTokens int + Temperature float32 + TopP float32 + N int + Logprobs int + PresencePenalty float32 + FrequencyPenalty float32 + } + type args struct { + config RequestConfig + } + tests := []struct { + name string + fields fields + args args + want string + want1 string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &ContentFilterRequest{ + Prompt: tt.fields.Prompt, + MaxTokens: tt.fields.MaxTokens, + Temperature: tt.fields.Temperature, + TopP: tt.fields.TopP, + N: tt.fields.N, + Logprobs: tt.fields.Logprobs, + PresencePenalty: tt.fields.PresencePenalty, + FrequencyPenalty: tt.fields.FrequencyPenalty, + } + got, got1 := r.getRequestMeta(tt.args.config) + if got != tt.want { + t.Errorf("getRequestMeta() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("getRequestMeta() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestEnginesRequest_attachResponse(t *testing.T) { + tests := []struct { + name string + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &EnginesRequest{} + if got := r.attachResponse(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("attachResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEnginesRequest_getRequestMeta(t *testing.T) { + type args struct { + config RequestConfig + } + tests := []struct { + name string + args args + want string + want1 string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &EnginesRequest{} + got, got1 := r.getRequestMeta(tt.args.config) + if got != tt.want { + t.Errorf("getRequestMeta() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("getRequestMeta() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestEnginesResponse_GetBody(t *testing.T) { + type fields struct { + Data []interface{} + Object string + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := EnginesResponse{ + Data: tt.fields.Data, + Object: tt.fields.Object, + } + if got := e.GetBody(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetBody() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestErrorBag_Error(t *testing.T) { + type fields struct { + Err GptErrorResponse + } + tests := []struct { + name string + fields fields + want string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := ErrorBag{ + Err: tt.fields.Err, + } + if got := e.Error(); got != tt.want { + t.Errorf("Error() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestErrorBag_Temporary(t *testing.T) { + type fields struct { + Err GptErrorResponse + } + tests := []struct { + name string + fields fields + want bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := ErrorBag{ + Err: tt.fields.Err, + } + if got := e.Temporary(); got != tt.want { + t.Errorf("Temporary() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestErrorBag_Timeout(t *testing.T) { + type fields struct { + Err GptErrorResponse + } + tests := []struct { + name string + fields fields + want bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := ErrorBag{ + Err: tt.fields.Err, + } + if got := e.Timeout(); got != tt.want { + t.Errorf("Timeout() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileRequest_attachResponse(t *testing.T) { + type fields struct { + File os.File + Purpose string + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &FileRequest{ + File: tt.fields.File, + Purpose: tt.fields.Purpose, + } + if got := r.attachResponse(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("attachResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileRequest_getRequestMeta(t *testing.T) { + type fields struct { + File os.File + Purpose string + } + type args struct { + config RequestConfig + } + tests := []struct { + name string + fields fields + args args + want string + want1 string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &FileRequest{ + File: tt.fields.File, + Purpose: tt.fields.Purpose, + } + got, got1 := r.getRequestMeta(tt.args.config) + if got != tt.want { + t.Errorf("getRequestMeta() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("getRequestMeta() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestFileResponse_GetBody(t *testing.T) { + type fields struct { + File File + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &FileResponse{ + File: tt.fields.File, + } + if got := r.GetBody(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetBody() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFilesRequest_attachResponse(t *testing.T) { + tests := []struct { + name string + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &FilesRequest{} + if got := r.attachResponse(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("attachResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFilesRequest_getRequestMeta(t *testing.T) { + type args struct { + config RequestConfig + } + tests := []struct { + name string + args args + want string + want1 string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &FilesRequest{} + got, got1 := r.getRequestMeta(tt.args.config) + if got != tt.want { + t.Errorf("getRequestMeta() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("getRequestMeta() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestFilesResponse_GetBody(t *testing.T) { + type fields struct { + Data []File + Object string + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &FilesResponse{ + Data: tt.fields.Data, + Object: tt.fields.Object, + } + if got := r.GetBody(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetBody() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSearchRequest_attachResponse(t *testing.T) { + type fields struct { + target string + Documents []string + Query string + File string + ReturnMetadata bool + MaxRerank int32 + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &SearchRequest{ + target: tt.fields.target, + Documents: tt.fields.Documents, + Query: tt.fields.Query, + File: tt.fields.File, + ReturnMetadata: tt.fields.ReturnMetadata, + MaxRerank: tt.fields.MaxRerank, + } + if got := r.attachResponse(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("attachResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSearchRequest_getRequestMeta(t *testing.T) { + type fields struct { + target string + Documents []string + Query string + File string + ReturnMetadata bool + MaxRerank int32 + } + type args struct { + config RequestConfig + } + tests := []struct { + name string + fields fields + args args + want string + want1 string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &SearchRequest{ + target: tt.fields.target, + Documents: tt.fields.Documents, + Query: tt.fields.Query, + File: tt.fields.File, + ReturnMetadata: tt.fields.ReturnMetadata, + MaxRerank: tt.fields.MaxRerank, + } + got, got1 := r.getRequestMeta(tt.args.config) + if got != tt.want { + t.Errorf("getRequestMeta() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("getRequestMeta() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestSearchResponse_GetBody(t *testing.T) { + type fields struct { + Data []SearchData + Object string + } + tests := []struct { + name string + fields fields + want Response + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &SearchResponse{ + Data: tt.fields.Data, + Object: tt.fields.Object, + } + if got := r.GetBody(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetBody() = %v, want %v", got, tt.want) + } + }) + } +}