Skip to content

Commit

Permalink
wip: llm pipeline with stream support
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Jul 31, 2024
1 parent 8c6bd5c commit 936b939
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 4 deletions.
13 changes: 13 additions & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,19 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
capabilityConstraints[core.Capability_AudioToText].Models[config.ModelID] = modelConstraint

n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)

case "llm-generate":
_, ok := capabilityConstraints[core.Capability_LlmGenerate]
if !ok {
aiCaps = append(aiCaps, core.Capability_LlmGenerate)
capabilityConstraints[core.Capability_LlmGenerate] = &core.PerCapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_LlmGenerate].Models[config.ModelID] = modelConstraint

n.SetBasePriceForCap("default", core.Capability_LlmGenerate, config.ModelID, autoPrice)
}

if len(aiCaps) > 0 {
Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type AI interface {
ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LlmGenerate(context.Context, worker.LlmGenerateFormdataRequestBody) (interface{}, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
2 changes: 2 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ const (
Capability_ImageToVideo
Capability_Upscale
Capability_AudioToText
Capability_LlmGenerate
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -114,6 +115,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_ImageToVideo: "Image to video",
Capability_Upscale: "Upscale",
Capability_AudioToText: "Audio to text",
Capability_LlmGenerate: "LLM Generate",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down
9 changes: 9 additions & 0 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func (orch *orchestrator) AudioToText(ctx context.Context, req worker.AudioToTex
return orch.node.AudioToText(ctx, req)
}

// Return type is LlmResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) {
return orch.node.llmGenerate(ctx, req)
}

func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error {
if orch.node == nil || orch.node.Recipient == nil {
return nil
Expand Down Expand Up @@ -1033,6 +1038,10 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideo
return &worker.ImageResponse{Images: videos}, nil
}

func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) {
return n.AIWorker.LlmGenerate(ctx, req)
}

func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) {
remoteChan, err := rtm.getTaskChan(tcID)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,5 @@ require (
lukechampine.com/blake3 v1.2.1 // indirect
rsc.io/tmplfunc v0.0.3 // indirect
)

replace github.com/livepeer/ai-worker => /Users/nico/livepool/ai-worker
69 changes: 66 additions & 3 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo()))
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText()))
lp.transRPC.Handle("/llm-generate", oapiReqValidator(lp.LlmGenerate()))

return nil
}
Expand Down Expand Up @@ -157,6 +158,29 @@ func (h *lphttp) AudioToText() http.Handler {
})
}

func (h *lphttp) LlmGenerate() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.LlmGenerateFormdataRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) {
payment, err := getPayment(r.Header.Get(paymentHeader))
if err != nil {
Expand Down Expand Up @@ -270,6 +294,15 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels *= 1000 // Convert to milliseconds
case worker.LlmGenerateFormdataRequestBody:
pipeline = "llm-generate"
cap = core.Capability_LlmGenerate
modelID = *v.ModelId
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.LlmGenerate(ctx, v)
}

// TODO: handle tokens for pricing
default:
respondWithError(w, "Unknown request type", http.StatusBadRequest)
return
Expand Down Expand Up @@ -351,7 +384,37 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
monitor.AIJobProcessed(ctx, pipeline, modelID, monitor.AIJobInfo{LatencyScore: latencyScore, PricePerUnit: pricePerAIUnit})
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
// Check if the response is a streaming response
if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
// Set headers for SSE
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

for chunk := range streamChan {
data, err := json.Marshal(chunk)
if err != nil {
clog.Errorf(ctx, "Error marshaling stream chunk: %v", err)
continue
}

fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()

if chunk.Done {
break
}
}
} else {
// Non-streaming response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}
}
70 changes: 69 additions & 1 deletion server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"

Expand Down Expand Up @@ -69,7 +70,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo()))
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText()))

ls.HTTPMux.Handle("/llm-generate", oapiReqValidator(ls.LlmGenerate()))
return nil
}

Expand Down Expand Up @@ -374,6 +375,73 @@ func (ls *LivepeerServer) AudioToText() http.Handler {
})
}

func (ls *LivepeerServer) LlmGenerate() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.LlmGenerateFormdataRequestBody

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v", req.Prompt, *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processLlmGenerate(ctx, params, req)
if err != nil {
var e *ServiceUnavailableError
if errors.As(err, &e) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed LlmGenerate request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)

if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
// Handle streaming response (SSE)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

for chunk := range streamChan {
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
if chunk.Done {
break
}
}
} else if llmResp, ok := resp.(*worker.LlmResponse); ok {
// Handle non-streaming response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(llmResp)
} else {
http.Error(w, "Unexpected response type", http.StatusInternalServerError)
}
})
}

func (ls *LivepeerServer) ImageToVideoResult() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down
Loading

0 comments on commit 936b939

Please sign in to comment.