Skip to content

Commit

Permalink
Merge branch 'ai-video' into rafal/ai-video-fix-unit-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
leszko authored Oct 2, 2024
2 parents 835e681 + 80c0ac9 commit c3e5b39
Show file tree
Hide file tree
Showing 12 changed files with 382 additions and 6 deletions.
16 changes: 16 additions & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,22 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)
}
n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)

case "llm":
_, ok := capabilityConstraints[core.Capability_LLM]
if !ok {
aiCaps = append(aiCaps, core.Capability_LLM)
capabilityConstraints[core.Capability_LLM] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_LLM].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_LLM, config.ModelID, autoPrice)
}
case "segment-anything-2":
_, ok := capabilityConstraints[core.Capability_SegmentAnything2]
if !ok {
Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type AI interface {
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
Expand Down
2 changes: 2 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ const (
Capability_ImageToVideo
Capability_Upscale
Capability_AudioToText
Capability_LLM
Capability_SegmentAnything2
)

Expand Down Expand Up @@ -115,6 +116,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_ImageToVideo: "Image to video",
Capability_Upscale: "Upscale",
Capability_AudioToText: "Audio to text",
Capability_LLM: "Large language model",
Capability_SegmentAnything2: "Segment anything 2",
}

Expand Down
6 changes: 6 additions & 0 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioTo
return orch.node.AudioToText(ctx, req)
}

// Return type is LLMResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error) {
return orch.node.AIWorker.LLM(ctx, req)

}

func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return orch.node.SegmentAnything2(ctx, req)
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/golang/protobuf v1.5.4
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.6.0
github.com/livepeer/ai-worker v0.7.0
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20240909171057-fe5aff1fa6a2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.6.0 h1:sGldUavfbTXPQDKc1a80/zgK8G1VdYRAxiuFTP0YyOU=
github.com/livepeer/ai-worker v0.6.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/ai-worker v0.7.0 h1:9z5Uz9WvKyQTXiurWim1ewDcVPLzz7EYZEfm2qtLAaw=
github.com/livepeer/ai-worker v0.7.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
75 changes: 72 additions & 3 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/image-to-video", oapiReqValidator(lp.ImageToVideo()))
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText()))
lp.transRPC.Handle("/llm", oapiReqValidator(lp.LLM()))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))

return nil
Expand Down Expand Up @@ -181,6 +182,29 @@ func (h *lphttp) SegmentAnything2() http.Handler {
})
}

func (h *lphttp) LLM() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenLLMFormdataRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) {
payment, err := getPayment(r.Header.Get(paymentHeader))
if err != nil {
Expand Down Expand Up @@ -305,6 +329,21 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels *= 1000 // Convert to milliseconds
case worker.GenLLMFormdataRequestBody:
pipeline = "llm"
cap = core.Capability_LLM
modelID = *v.ModelId
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.LLM(ctx, v)
}

if v.MaxTokens == nil {
respondWithError(w, "MaxTokens not specified", http.StatusBadRequest)
return
}

// TODO: Improve pricing
outPixels = int64(*v.MaxTokens)
case worker.GenSegmentAnything2MultipartRequestBody:
pipeline = "segment-anything-2"
cap = core.Capability_SegmentAnything2
Expand Down Expand Up @@ -407,7 +446,37 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
monitor.AIJobProcessed(ctx, pipeline, modelID, monitor.AIJobInfo{LatencyScore: latencyScore, PricePerUnit: pricePerAIUnit})
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
// Check if the response is a streaming response
if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok {
// Set headers for SSE
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

for chunk := range streamChan {
data, err := json.Marshal(chunk)
if err != nil {
clog.Errorf(ctx, "Error marshaling stream chunk: %v", err)
continue
}

fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()

if chunk.Done {
break
}
}
} else {
// Non-streaming response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}
}
69 changes: 69 additions & 0 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"

Expand Down Expand Up @@ -69,6 +70,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo()))
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText()))
ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2()))

return nil
Expand Down Expand Up @@ -394,6 +396,73 @@ func (ls *LivepeerServer) AudioToText() http.Handler {
})
}

func (ls *LivepeerServer) LLM() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.GenLLMFormdataRequestBody

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received LLM request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, *req.Stream)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processLLM(ctx, params, req)
if err != nil {
var e *ServiceUnavailableError
if errors.As(err, &e) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)

if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
// Handle streaming response (SSE)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

for chunk := range streamChan {
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
if chunk.Done {
break
}
}
} else if llmResp, ok := resp.(*worker.LLMResponse); ok {
// Handle non-streaming response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(llmResp)
} else {
http.Error(w, "Unexpected response type", http.StatusInternalServerError)
}
})
}

func (ls *LivepeerServer) SegmentAnything2() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down
Loading

0 comments on commit c3e5b39

Please sign in to comment.