Skip to content

Commit

Permalink
feat: update OpenAI client to support new model behaviors
Browse files Browse the repository at this point in the history
- Import `regexp` and `github.com/sashabaranov/go-openai` packages
- Add `CompletionTokensDetails` field to the `Usage` struct
- Change chat message role from `System` to `Assistant` in OpenAI client functions
- Adjust token settings for models matching `o1-(mini|preview)`
- Modify `GetSummaryPrefix` to handle different model behaviors

Signed-off-by: appleboy <[email protected]>
  • Loading branch information
appleboy committed Oct 2, 2024
1 parent db799fb commit 9a4481b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
9 changes: 6 additions & 3 deletions core/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package core

import (
"context"

"github.com/sashabaranov/go-openai"
)

type Usage struct {
PromptTokens int
CompletionTokens int
TotalTokens int
PromptTokens int
CompletionTokens int
TotalTokens int
CompletionTokensDetails *openai.CompletionTokensDetails
}

type Response struct {
Expand Down
46 changes: 35 additions & 11 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
"regexp"

"github.com/appleboy/CodeGPT/core"

Expand Down Expand Up @@ -54,25 +55,36 @@ func (c *Client) Completion(ctx context.Context, content string) (*core.Response
return &core.Response{
Content: resp.Content,
Usage: core.Usage{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
CompletionTokensDetails: resp.Usage.CompletionTokensDetails,
},
}, nil
}

// GetSummaryPrefix is an API call to get a summary prefix using function call.
func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) {
resp, err := c.CreateFunctionCall(ctx, content, SummaryPrefixFunc)
if err != nil || len(resp.Choices) != 1 {
return nil, err
var resp openai.ChatCompletionResponse
var err error
if checkO1Serial.MatchString(c.model) {
resp, err = c.CreateChatCompletion(ctx, content)
if err != nil || len(resp.Choices) != 1 {
return nil, err
}
} else {
resp, err = c.CreateFunctionCall(ctx, content, SummaryPrefixFunc)
if err != nil || len(resp.Choices) != 1 {
return nil, err
}
}

msg := resp.Choices[0].Message
usage := core.Usage{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
CompletionTokensDetails: resp.Usage.CompletionTokensDetails,
}
if len(msg.ToolCalls) == 0 {
return &core.Response{
Expand All @@ -88,6 +100,8 @@ func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Re
}, nil
}

var checkO1Serial = regexp.MustCompile(`o1-(mini|preview)`)

// CreateChatCompletion is an API call to create a function call for a chat message.
func (c *Client) CreateFunctionCall(
ctx context.Context,
Expand All @@ -108,7 +122,7 @@ func (c *Client) CreateFunctionCall(
PresencePenalty: c.presencePenalty,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Role: openai.ChatMessageRoleAssistant,
Content: "You are a helpful assistant.",
},
{
Expand All @@ -125,6 +139,11 @@ func (c *Client) CreateFunctionCall(
},
}

if checkO1Serial.MatchString(c.model) {
req.MaxTokens = 0
req.MaxCompletionsTokens = c.maxTokens
}

return c.client.CreateChatCompletion(ctx, req)
}

Expand All @@ -142,7 +161,7 @@ func (c *Client) CreateChatCompletion(
PresencePenalty: c.presencePenalty,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Role: openai.ChatMessageRoleAssistant,
Content: "You are a helpful assistant.",
},
{
Expand All @@ -152,6 +171,11 @@ func (c *Client) CreateChatCompletion(
},
}

if checkO1Serial.MatchString(c.model) {
req.MaxTokens = 0
req.MaxCompletionsTokens = c.maxTokens
}

return c.client.CreateChatCompletion(ctx, req)
}

Expand Down

0 comments on commit 9a4481b

Please sign in to comment.