diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 95d00e47eb..0d6e9920a8 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -1207,6 +1207,19 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { capabilityConstraints[core.Capability_AudioToText].Models[config.ModelID] = modelConstraint n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice) + + case "llm-generate": + _, ok := capabilityConstraints[core.Capability_LlmGenerate] + if !ok { + aiCaps = append(aiCaps, core.Capability_LlmGenerate) + capabilityConstraints[core.Capability_LlmGenerate] = &core.PerCapabilityConstraints{ + Models: make(map[string]*core.ModelConstraint), + } + } + + capabilityConstraints[core.Capability_LlmGenerate].Models[config.ModelID] = modelConstraint + + n.SetBasePriceForCap("default", core.Capability_LlmGenerate, config.ModelID, autoPrice) } if len(aiCaps) > 0 { diff --git a/core/ai.go b/core/ai.go index 93966bd418..45b10d324a 100644 --- a/core/ai.go +++ b/core/ai.go @@ -19,6 +19,7 @@ type AI interface { ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(context.Context, worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) + LlmGenerate(context.Context, worker.LlmGenerateFormdataRequestBody) (interface{}, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/capabilities.go b/core/capabilities.go index 57ec639b2d..ef9f28a401 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -78,6 +78,7 @@ const ( Capability_ImageToVideo Capability_Upscale Capability_AudioToText + Capability_LlmGenerate ) var CapabilityNameLookup = map[Capability]string{ @@ -114,6 +115,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_ImageToVideo: "Image to video", Capability_Upscale: "Upscale", Capability_AudioToText: "Audio to text", + Capability_LlmGenerate: "LLM Generate", } var CapabilityTestLookup = map[Capability]CapabilityTest{ diff --git a/core/orchestrator.go b/core/orchestrator.go index 70975ff3a6..fb8a048ec4 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -130,6 +130,11 @@ func (orch *orchestrator) AudioToText(ctx context.Context, req worker.AudioToTex return orch.node.AudioToText(ctx, req) } +// Return type is LlmResponse, but a stream is available as well as chan(string) +func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { + return orch.node.llmGenerate(ctx, req) +} + func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error { if orch.node == nil || orch.node.Recipient == nil { return nil @@ -1033,6 +1038,10 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideo return &worker.ImageResponse{Images: videos}, nil } +func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { + return n.AIWorker.LlmGenerate(ctx, req) +} + func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) { remoteChan, err := rtm.getTaskChan(tcID) if err != nil { diff --git a/go.mod b/go.mod index 1a59fc36fc..25541b78a4 100644 --- a/go.mod +++ b/go.mod @@ -238,3 +238,5 @@ require ( lukechampine.com/blake3 v1.2.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) + +replace github.com/livepeer/ai-worker => /Users/nico/livepool/ai-worker diff --git a/server/ai_http.go b/server/ai_http.go index eba099df35..34c359bc31 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -44,6 +44,7 @@ func startAIServer(lp lphttp) error { lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo())) lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale())) lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText())) + lp.transRPC.Handle("/llm-generate", oapiReqValidator(lp.LlmGenerate())) return nil } @@ -157,6 +158,29 @@ func (h *lphttp) AudioToText() http.Handler { }) } +func (h *lphttp) LlmGenerate() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orch := h.orchestrator + + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + + multiRdr, err := r.MultipartReader() + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + + var req worker.LlmGenerateFormdataRequestBody + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondWithError(w, err.Error(), http.StatusInternalServerError) + return + } + + handleAIRequest(ctx, w, r, orch, req) + }) +} + func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) { payment, err := getPayment(r.Header.Get(paymentHeader)) if err != nil { @@ -270,6 +294,15 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels *= 1000 // Convert to milliseconds + case worker.LlmGenerateFormdataRequestBody: + pipeline = "llm-generate" + cap = core.Capability_LlmGenerate + modelID = *v.ModelId + submitFn = func(ctx context.Context) (interface{}, error) { + return orch.LlmGenerate(ctx, v) + } + + // TODO: handle tokens for pricing default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return @@ -351,7 +384,37 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request monitor.AIJobProcessed(ctx, pipeline, modelID, monitor.AIJobInfo{LatencyScore: latencyScore, PricePerUnit: pricePerAIUnit}) } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(resp) + // Check if the response is a streaming response + if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + // Set headers for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + for chunk := range streamChan { + data, err := json.Marshal(chunk) + if err != nil { + clog.Errorf(ctx, "Error marshaling stream chunk: %v", err) + continue + } + + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + + if chunk.Done { + break + } + } + } else { + // Non-streaming response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + } } diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 29d33b9fed..b7bb41264c 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "time" @@ -69,7 +70,7 @@ func startAIMediaServer(ls *LivepeerServer) error { ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo())) ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult()) ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText())) - + ls.HTTPMux.Handle("/llm-generate", oapiReqValidator(ls.LlmGenerate())) return nil } @@ -374,6 +375,73 @@ func (ls *LivepeerServer) AudioToText() http.Handler { }) } +func (ls *LivepeerServer) LlmGenerate() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + requestID := string(core.RandomManifestID()) + ctx = clog.AddVal(ctx, "request_id", requestID) + + var req worker.LlmGenerateFormdataRequestBody + + multiRdr, err := r.MultipartReader() + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v", req.Prompt, *req.ModelId) + + params := aiRequestParams{ + node: ls.LivepeerNode, + os: drivers.NodeStorage.NewSession(requestID), + sessManager: ls.AISessionManager, + } + + start := time.Now() + resp, err := processLlmGenerate(ctx, params, req) + if err != nil { + var e *ServiceUnavailableError + if errors.As(err, &e) { + respondJsonError(ctx, w, err, http.StatusServiceUnavailable) + return + } + respondJsonError(ctx, w, err, http.StatusInternalServerError) + return + } + + took := time.Since(start) + clog.V(common.VERBOSE).Infof(ctx, "Processed LlmGenerate request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took) + + if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + // Handle streaming response (SSE) + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + for chunk := range streamChan { + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + w.(http.Flusher).Flush() + if chunk.Done { + break + } + } + } else if llmResp, ok := resp.(*worker.LlmResponse); ok { + // Handle non-streaming response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(llmResp) + } else { + http.Error(w, "Unexpected response type", http.StatusInternalServerError) + } + }) +} + func (ls *LivepeerServer) ImageToVideoResult() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) diff --git a/server/ai_process.go b/server/ai_process.go index d286582215..a036b85eca 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -31,6 +31,7 @@ const defaultImageToImageModelID = "stabilityai/sdxl-turbo" const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-xt" const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler" const defaultAudioToTextModelID = "openai/whisper-large-v3" +const defaultLlmGenerateModelID = "meta-llama/llama-3.1-8B-Instruct" type ServiceUnavailableError struct { err error @@ -679,6 +680,159 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess return &res, nil } +func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float64 { + if tokensUsed <= 0 { + return 0 + } + + return took.Seconds() / float64(tokensUsed) +} + +func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { + resp, err := processAIRequest(ctx, params, req) + if err != nil { + return nil, err + } + + if *req.Stream { + streamChan, ok := resp.(chan worker.LlmStreamChunk) + if !ok { + return nil, errors.New("unexpected response type for streaming request") + } + return streamChan, nil + } + + llmResp, ok := resp.(*worker.LlmResponse) + if !ok { + return nil, errors.New("unexpected response type") + } + + return llmResp, nil +} + +func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { + var buf bytes.Buffer + mw, err := worker.NewLlmGenerateMultipartWriter(&buf, req) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, nil) + } + return nil, err + } + + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + // TODO: calculate payment + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, 0) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) + + start := time.Now() + resp, err := client.LlmGenerateWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) + } + + if *req.Stream { + return handleSSEStream(ctx, resp.Body, sess, req, start) + } + + return handleNonStreamingResponse(ctx, resp.Body, sess, req, start) +} + +func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { + streamChan := make(chan worker.LlmStreamChunk, 100) + go func() { + defer close(streamChan) + scanner := bufio.NewScanner(body) + var totalTokens int + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens} + break + } + var chunk worker.LlmStreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err) + continue + } + totalTokens += chunk.TokensUsed + streamChan <- chunk + } + } + if err := scanner.Err(); err != nil { + clog.Errorf(ctx, "Error reading SSE stream: %v", err) + } + + took := time.Since(start) + sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, totalTokens) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + }() + + return streamChan, nil +} + +func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { + data, err := io.ReadAll(body) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + var res worker.LlmResponse + if err := json.Unmarshal(data, &res); err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + took := time.Since(start) + sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, res.TokensUsed) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + + return &res, nil +} + func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) { var cap core.Capability var modelID string @@ -730,6 +884,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitAudioToText(ctx, params, sess, v) } + case worker.LlmGenerateFormdataRequestBody: + cap = core.Capability_LlmGenerate + modelID = defaultLlmGenerateModelID + if v.ModelId != nil { + modelID = *v.ModelId + } + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { + return submitLlmGenerate(ctx, params, sess, v) + } default: return nil, fmt.Errorf("unsupported request type %T", req) } diff --git a/server/ai_process_test.go b/server/ai_process_test.go new file mode 100644 index 0000000000..e584637ef2 --- /dev/null +++ b/server/ai_process_test.go @@ -0,0 +1,38 @@ +package server + +import ( + "context" + "reflect" + "testing" + + "github.com/livepeer/ai-worker/worker" +) + +func Test_submitLlmGenerate(t *testing.T) { + type args struct { + ctx context.Context + params aiRequestParams + sess *AISession + req worker.LlmGenerateFormdataRequestBody + } + tests := []struct { + name string + args args + want interface{} + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := submitLlmGenerate(tt.args.ctx, tt.args.params, tt.args.sess, tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("submitLlmGenerate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("submitLlmGenerate() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/server/rpc.go b/server/rpc.go index 0c8b9f8066..897bb6d3ba 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -68,6 +68,7 @@ type Orchestrator interface { ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) + LlmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) } // Balance describes methods for a session's balance maintenance