Skip to content

Commit

Permalink
misc: update ai-worker dependency consumption
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Sep 30, 2024
1 parent 7b658d7 commit a17f4ba
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 36 deletions.
4 changes: 2 additions & 2 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1320,11 +1320,11 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
}
n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)

case "llm-generate":
case "llm":
_, ok := capabilityConstraints[core.Capability_LLM]
if !ok {
aiCaps = append(aiCaps, core.Capability_LLM)
capabilityConstraints[core.Capability_LLM] = &core.PerCapabilityConstraints{
capabilityConstraints[core.Capability_LLM] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type AI interface {
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LlmGenerate(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
Expand Down
4 changes: 2 additions & 2 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ const (
Capability_ImageToVideo
Capability_Upscale
Capability_AudioToText
Capability_LlmGenerate
Capability_LLM
Capability_SegmentAnything2
)

Expand Down Expand Up @@ -116,7 +116,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_ImageToVideo: "Image to video",
Capability_Upscale: "Upscale",
Capability_AudioToText: "Audio to text",
Capability_LlmGenerate: "LLM Generate",
Capability_LLM: "Large Language Model",
Capability_SegmentAnything2: "Segment anything 2",
}

Expand Down
13 changes: 2 additions & 11 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,10 @@ func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioTo
return orch.node.AudioToText(ctx, req)
}

<<<<<<< HEAD
// Return type is LlmResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return orch.node.llmGenerate(ctx, req)
=======
// Return type is LLMResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return orch.node.LLM(ctx, req)
>>>>>>> 9b6dc285 (misc: update ai-worker dependency)
return orch.node.AIWorker.LLM(ctx, req)

}

func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
Expand Down Expand Up @@ -1073,10 +1068,6 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi
return &worker.ImageResponse{Images: videos}, nil
}

func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return n.AIWorker.LlmGenerate(ctx, req)
}

func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) {
remoteChan, err := rtm.getTaskChan(tcID)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo()))
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText()))
lp.transRPC.Handle("/llm-generate", oapiReqValidator(lp.LlmGenerate()))
lp.transRPC.Handle("/llm", oapiReqValidator(lp.LLM()))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))

return nil
Expand Down Expand Up @@ -330,7 +330,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
}
outPixels *= 1000 // Convert to milliseconds
case worker.GenLLMFormdataRequestBody:
pipeline = "llm-generate"
pipeline = "llm"
cap = core.Capability_LLM
modelID = *v.ModelId
submitFn = func(ctx context.Context) (interface{}, error) {
Expand Down
4 changes: 2 additions & 2 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo()))
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText()))
ls.HTTPMux.Handle("/llm-generate", oapiReqValidator(ls.LlmGenerate()))
ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2()))

return nil
Expand Down Expand Up @@ -425,7 +425,7 @@ func (ls *LivepeerServer) LLM() http.Handler {
}

start := time.Now()
resp, err := processLlmGenerate(ctx, params, req)
resp, err := processLLM(ctx, params, req)
if err != nil {
var e *ServiceUnavailableError
if errors.As(err, &e) {
Expand Down
28 changes: 14 additions & 14 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const defaultImageToImageModelID = "stabilityai/sdxl-turbo"
const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt"
const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler"
const defaultAudioToTextModelID = "openai/whisper-large-v3"
const defaultLlmGenerateModelID = "meta-llama/llama-3.1-8B-Instruct"
const defaultLLMModelID = "meta-llama/llama-3.1-8B-Instruct"
const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large"

type ServiceUnavailableError struct {
Expand Down Expand Up @@ -821,7 +821,7 @@ func CalculateLLMLatencyScore(took time.Duration, tokensUsed int) float64 {
return took.Seconds() / float64(tokensUsed)
}

func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
Expand All @@ -843,20 +843,20 @@ func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.
return llmResp, nil
}

func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
var buf bytes.Buffer
mw, err := worker.NewLLMMultipartWriter(&buf, req)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, nil)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, nil)
}
return nil, err
}

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
Expand All @@ -869,7 +869,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess
setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, int64(*req.MaxTokens))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
Expand All @@ -879,7 +879,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess
resp, err := client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
Expand Down Expand Up @@ -932,7 +932,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}
monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}
}()

Expand All @@ -944,15 +944,15 @@ func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *A
defer body.Close()
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

var res worker.LLMResponse
if err := json.Unmarshal(data, &res); err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
Expand All @@ -965,7 +965,7 @@ func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *A
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}
monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}

return &res, nil
Expand Down Expand Up @@ -1023,13 +1023,13 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
return submitAudioToText(ctx, params, sess, v)
}
case worker.GenLLMFormdataRequestBody:
cap = core.Capability_LlmGenerate
modelID = defaultLlmGenerateModelID
cap = core.Capability_LLM
modelID = defaultLLMModelID
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitLlmGenerate(ctx, params, sess, v)
return submitLLM(ctx, params, sess, v)
}
case worker.GenSegmentAnything2MultipartRequestBody:
cap = core.Capability_SegmentAnything2
Expand Down
2 changes: 1 addition & 1 deletion server/ai_process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func Test_submitLLM(t *testing.T) {
ctx context.Context
params aiRequestParams
sess *AISession
req worker.LLMFormdataRequestBody
req worker.GenLLMFormdataRequestBody
}
tests := []struct {
name string
Expand Down
2 changes: 1 addition & 1 deletion server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type Orchestrator interface {
ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error)
Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LlmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error)
LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
}

Expand Down
6 changes: 6 additions & 0 deletions server/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ func (r *stubOrchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMul
func (r *stubOrchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return nil, nil
}
func (r *stubOrchestrator) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return nil, nil
}
func (r *stubOrchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return nil, nil
}
Expand Down Expand Up @@ -1388,6 +1391,9 @@ func (r *mockOrchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMul
func (r *mockOrchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return nil, nil
}
func (r *mockOrchestrator) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return nil, nil
}
func (r *mockOrchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return nil, nil
}
Expand Down

0 comments on commit a17f4ba

Please sign in to comment.