Skip to content

Commit

Permalink
misc: update ai-worker dependency and its usage
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Sep 30, 2024
1 parent 4a9a4bc commit eca254f
Show file tree
Hide file tree
Showing 12 changed files with 60 additions and 58 deletions.
12 changes: 6 additions & 6 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1320,19 +1320,19 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
}
n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)

case "llm-generate":
_, ok := capabilityConstraints[core.Capability_LlmGenerate]
case "llm":
_, ok := capabilityConstraints[core.Capability_LLM]
if !ok {
aiCaps = append(aiCaps, core.Capability_LlmGenerate)
capabilityConstraints[core.Capability_LlmGenerate] = &core.PerCapabilityConstraints{
aiCaps = append(aiCaps, core.Capability_LLM)
capabilityConstraints[core.Capability_LLM] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_LlmGenerate].Models[config.ModelID] = modelConstraint
capabilityConstraints[core.Capability_LLM].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_LlmGenerate, config.ModelID, autoPrice)
n.SetBasePriceForCap("default", core.Capability_LLM, config.ModelID, autoPrice)
}
case "segment-anything-2":
_, ok := capabilityConstraints[core.Capability_SegmentAnything2]
Expand Down
2 changes: 1 addition & 1 deletion core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type AI interface {
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LlmGenerate(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
Expand Down
4 changes: 2 additions & 2 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ const (
Capability_ImageToVideo
Capability_Upscale
Capability_AudioToText
Capability_LlmGenerate
Capability_LLM
Capability_SegmentAnything2
)

Expand Down Expand Up @@ -116,7 +116,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_ImageToVideo: "Image to video",
Capability_Upscale: "Upscale",
Capability_AudioToText: "Audio to text",
Capability_LlmGenerate: "LLM Generate",
Capability_LLM: "Large Language Model",
Capability_SegmentAnything2: "Segment anything 2",
}

Expand Down
11 changes: 4 additions & 7 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,10 @@ func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioTo
return orch.node.AudioToText(ctx, req)
}

// Return type is LlmResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return orch.node.llmGenerate(ctx, req)
// Return type is LLMResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return orch.node.AIWorker.LLM(ctx, req)

}

func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
Expand Down Expand Up @@ -1067,10 +1068,6 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi
return &worker.ImageResponse{Images: videos}, nil
}

func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return n.AIWorker.LlmGenerate(ctx, req)
}

func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) {
remoteChan, err := rtm.getTaskChan(tcID)
if err != nil {
Expand Down
4 changes: 1 addition & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/golang/protobuf v1.5.4
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.6.0
github.com/livepeer/ai-worker v0.7.0
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20240909171057-fe5aff1fa6a2
Expand Down Expand Up @@ -238,5 +238,3 @@ require (
lukechampine.com/blake3 v1.2.1 // indirect
rsc.io/tmplfunc v0.0.3 // indirect
)

replace github.com/livepeer/ai-worker => /Users/nico/livepool/ai-worker
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.6.0 h1:sGldUavfbTXPQDKc1a80/zgK8G1VdYRAxiuFTP0YyOU=
github.com/livepeer/ai-worker v0.6.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/ai-worker v0.7.0 h1:9z5Uz9WvKyQTXiurWim1ewDcVPLzz7EYZEfm2qtLAaw=
github.com/livepeer/ai-worker v0.7.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
12 changes: 6 additions & 6 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo()))
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText()))
lp.transRPC.Handle("/llm-generate", oapiReqValidator(lp.LlmGenerate()))
lp.transRPC.Handle("/llm", oapiReqValidator(lp.LLM()))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))

return nil
Expand Down Expand Up @@ -182,7 +182,7 @@ func (h *lphttp) SegmentAnything2() http.Handler {
})
}

func (h *lphttp) LlmGenerate() http.Handler {
func (h *lphttp) LLM() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

Expand Down Expand Up @@ -330,11 +330,11 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
}
outPixels *= 1000 // Convert to milliseconds
case worker.GenLLMFormdataRequestBody:
pipeline = "llm-generate"
cap = core.Capability_LlmGenerate
pipeline = "llm"
cap = core.Capability_LLM
modelID = *v.ModelId
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.LlmGenerate(ctx, v)
return orch.LLM(ctx, v)
}

if v.MaxTokens == nil {
Expand Down Expand Up @@ -447,7 +447,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
}

// Check if the response is a streaming response
if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok {
// Set headers for SSE
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
Expand Down
14 changes: 7 additions & 7 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo()))
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText()))
ls.HTTPMux.Handle("/llm-generate", oapiReqValidator(ls.LlmGenerate()))
ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2()))

return nil
Expand Down Expand Up @@ -396,14 +396,14 @@ func (ls *LivepeerServer) AudioToText() http.Handler {
})
}

func (ls *LivepeerServer) LlmGenerate() http.Handler {
func (ls *LivepeerServer) LLM() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.LlmGenerateFormdataRequestBody
var req worker.GenLLMFormdataRequestBody

multiRdr, err := r.MultipartReader()
if err != nil {
Expand All @@ -416,7 +416,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler {
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v", req.Prompt, *req.ModelId)
clog.V(common.VERBOSE).Infof(ctx, "Received LLM request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, *req.Stream)

params := aiRequestParams{
node: ls.LivepeerNode,
Expand All @@ -425,7 +425,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler {
}

start := time.Now()
resp, err := processLlmGenerate(ctx, params, req)
resp, err := processLLM(ctx, params, req)
if err != nil {
var e *ServiceUnavailableError
if errors.As(err, &e) {
Expand All @@ -437,7 +437,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler {
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed LlmGenerate request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)

if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
// Handle streaming response (SSE)
Expand All @@ -453,7 +453,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler {
break
}
}
} else if llmResp, ok := resp.(*worker.LlmResponse); ok {
} else if llmResp, ok := resp.(*worker.LLMResponse); ok {
// Handle non-streaming response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(llmResp)
Expand Down
37 changes: 19 additions & 18 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const defaultImageToImageModelID = "stabilityai/sdxl-turbo"
const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt"
const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler"
const defaultAudioToTextModelID = "openai/whisper-large-v3"
const defaultLlmGenerateModelID = "meta-llama/llama-3.1-8B-Instruct"
const defaultLLMModelID = "meta-llama/llama-3.1-8B-Instruct"
const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large"

type ServiceUnavailableError struct {
Expand Down Expand Up @@ -813,15 +813,15 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess
return &res, nil
}

func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float64 {
func CalculateLLMLatencyScore(took time.Duration, tokensUsed int) float64 {
if tokensUsed <= 0 {
return 0
}

return took.Seconds() / float64(tokensUsed)
}

func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
Expand All @@ -843,20 +843,20 @@ func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.
return llmResp, nil
}

func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
var buf bytes.Buffer
mw, err := worker.NewLLMMultipartWriter(&buf, req)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, nil)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, nil)
}
return nil, err
}

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
Expand All @@ -869,7 +869,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess
setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, int64(*req.MaxTokens))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
Expand All @@ -879,11 +879,10 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess
resp, err := client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
Expand All @@ -901,6 +900,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r
streamChan := make(chan worker.LlmStreamChunk, 100)
go func() {
defer close(streamChan)
defer body.Close()
scanner := bufio.NewScanner(body)
var totalTokens int
for scanner.Scan() {
Expand All @@ -925,14 +925,14 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r
}

took := time.Since(start)
sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, totalTokens)
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens)

if monitor.Enabled {
var pricePerAIUnit float64
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}
monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}
}()

Expand All @@ -941,30 +941,31 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r

func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMFormdataRequestBody, start time.Time) (*worker.LLMResponse, error) {
data, err := io.ReadAll(body)
defer body.Close()
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

var res worker.LLMResponse
if err := json.Unmarshal(data, &res); err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo)
monitor.AIRequestError(err.Error(), "llm", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

took := time.Since(start)
sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, res.TokensUsed)
sess.LatencyScore = CalculateLLMLatencyScore(took, res.TokensUsed)

if monitor.Enabled {
var pricePerAIUnit float64
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}
monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
monitor.AIRequestFinished(ctx, "llm", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}

return &res, nil
Expand Down Expand Up @@ -1022,13 +1023,13 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
return submitAudioToText(ctx, params, sess, v)
}
case worker.GenLLMFormdataRequestBody:
cap = core.Capability_LlmGenerate
modelID = defaultLlmGenerateModelID
cap = core.Capability_LLM
modelID = defaultLLMModelID
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitLlmGenerate(ctx, params, sess, v)
return submitLLM(ctx, params, sess, v)
}
case worker.GenSegmentAnything2MultipartRequestBody:
cap = core.Capability_SegmentAnything2
Expand Down
10 changes: 5 additions & 5 deletions server/ai_process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (
"github.com/livepeer/ai-worker/worker"
)

func Test_submitLlmGenerate(t *testing.T) {
func Test_submitLLM(t *testing.T) {
type args struct {
ctx context.Context
params aiRequestParams
sess *AISession
req worker.LlmGenerateFormdataRequestBody
req worker.GenLLMFormdataRequestBody
}
tests := []struct {
name string
Expand All @@ -25,13 +25,13 @@ func Test_submitLlmGenerate(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := submitLlmGenerate(tt.args.ctx, tt.args.params, tt.args.sess, tt.args.req)
got, err := submitLLM(tt.args.ctx, tt.args.params, tt.args.sess, tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("submitLlmGenerate() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("submitLLM() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("submitLlmGenerate() = %v, want %v", got, tt.want)
t.Errorf("submitLLM() = %v, want %v", got, tt.want)
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type Orchestrator interface {
ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error)
Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LlmGenerate(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error)
LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
}

Expand Down
Loading

0 comments on commit eca254f

Please sign in to comment.