From f1718faa7ad14c9da165eb4120b4eed11945aa97 Mon Sep 17 00:00:00 2001 From: Yondon Fu Date: Mon, 11 Mar 2024 21:03:02 +0000 Subject: [PATCH] server: Add AISessionManager For managing the sessions per AI capability + model ID in a way that is compatible with existing broadcast session code --- server/ai_mediaserver.go | 5 +- server/ai_process.go | 42 +++--- server/ai_session.go | 314 +++++++++++++++++++++++++++++++++++++++ server/mediaserver.go | 5 + 4 files changed, 340 insertions(+), 26 deletions(-) create mode 100644 server/ai_session.go diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 74bf2ab89b..ad0e8756a8 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -87,8 +87,9 @@ func (ls *LivepeerServer) TextToImage() http.Handler { clog.V(common.VERBOSE).Infof(r.Context(), "Received TextToImage request prompt=%v model_id=%v", req.Prompt, *req.ModelId) params := aiRequestParams{ - node: ls.LivepeerNode, - os: drivers.NodeStorage.NewSession(requestID), + node: ls.LivepeerNode, + os: drivers.NodeStorage.NewSession(requestID), + sessManager: ls.AISessionManager, } start := time.Now() diff --git a/server/ai_process.go b/server/ai_process.go index 644e715dd2..19a7b7cdba 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -33,8 +33,9 @@ func (e *ServiceUnavailableError) Error() string { } type aiRequestParams struct { - node *core.LivepeerNode - os drivers.OSSession + node *core.LivepeerNode + os drivers.OSSession + sessManager *AISessionManager } func getOrchestratorsForAIRequest(ctx context.Context, params aiRequestParams, cap core.Capability, modelID string) ([]*net.OrchestratorInfo, error) { @@ -85,37 +86,30 @@ func processTextToImage(ctx context.Context, params aiRequestParams, req worker. modelID = *req.ModelId } - orchInfos, err := getOrchestratorsForAIRequest(ctx, params, core.Capability_TextToImage, modelID) - if err != nil { - return nil, err - } - - if len(orchInfos) == 0 { - return nil, &ServiceUnavailableError{err: errors.New("no orchestrators available")} - } - var resp *worker.ImageResponse - // Round robin up to maxProcessingRetries times - orchIdx := 0 tries := 0 for tries < maxProcessingRetries { - orchUrl := orchInfos[orchIdx].Transcoder + sess, err := params.sessManager.Select(ctx, core.Capability_TextToImage, modelID) + if err != nil { + return nil, err + } - var err error - resp, err = submitTextToImage(ctx, orchUrl, req) + if sess == nil { + break + } + + resp, err = submitTextToImage(ctx, params, sess, req) if err == nil { + params.sessManager.Complete(ctx, sess) break } - clog.Infof(ctx, "Error submitting TextToImage request try=%v orch=%v err=%v", tries, orchUrl, err) + clog.Infof(ctx, "Error submitting TextToImage request try=%v orch=%v err=%v", tries, sess.Transcoder(), err) + + params.sessManager.Remove(ctx, sess) tries++ - orchIdx++ - // Wrap back around - if orchIdx >= len(orchInfos) { - orchIdx = 0 - } } if resp == nil { @@ -143,8 +137,8 @@ func processTextToImage(ctx context.Context, params aiRequestParams, req worker. return resp, nil } -func submitTextToImage(ctx context.Context, url string, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { - client, err := worker.NewClientWithResponses(url, worker.WithHTTPClient(httpClient)) +func submitTextToImage(ctx context.Context, params aiRequestParams, sess *AISession, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) { + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) if err != nil { return nil, err } diff --git a/server/ai_session.go b/server/ai_session.go new file mode 100644 index 0000000000..fc4c347fd5 --- /dev/null +++ b/server/ai_session.go @@ -0,0 +1,314 @@ +package server + +import ( + "context" + "strconv" + "sync" + "time" + + "github.com/livepeer/go-livepeer/clog" + "github.com/livepeer/go-livepeer/common" + "github.com/livepeer/go-livepeer/core" + "github.com/livepeer/go-tools/drivers" +) + +type AISession struct { + *BroadcastSession + + // Fields used by AISessionSelector for session lifecycle management + Cap core.Capability + ModelID string + Warm bool +} + +type AISessionPool struct { + selector BroadcastSessionsSelector + sessMap map[string]*BroadcastSession + suspender *suspender +} + +func NewAISessionPool(selector BroadcastSessionsSelector, suspender *suspender) *AISessionPool { + return &AISessionPool{ + selector: selector, + sessMap: make(map[string]*BroadcastSession), + suspender: suspender, + } +} + +func (pool *AISessionPool) Select(ctx context.Context) *BroadcastSession { + for { + sess := pool.selector.Select(ctx) + if sess == nil { + return nil + } + + if _, ok := pool.sessMap[sess.Transcoder()]; !ok { + // If the session is not tracked by sessMap skip it + continue + } + + return sess + } +} + +func (pool *AISessionPool) Complete(sess *BroadcastSession) { + existingSess, ok := pool.sessMap[sess.Transcoder()] + if !ok { + // If the session is not tracked by sessMap, skip returning it to the selector + return + } + + if sess != existingSess { + // If the session is tracked by sessMap AND it is different from what is tracked by sessMap + // skip returning it to the selector + return + } + + pool.selector.Complete(sess) +} + +func (pool *AISessionPool) Add(sessions []*BroadcastSession) { + // If we try to add new sessions to the pool the suspender + // should treat this as a refresh + pool.suspender.signalRefresh() + + var uniqueSessions []*BroadcastSession + for _, sess := range sessions { + if _, ok := pool.sessMap[sess.Transcoder()]; ok { + // Skip the session if it is already tracked by sessMap + continue + } + + pool.sessMap[sess.Transcoder()] = sess + uniqueSessions = append(uniqueSessions, sess) + } + + pool.selector.Add(uniqueSessions) +} + +func (pool *AISessionPool) Remove(sess *BroadcastSession) { + delete(pool.sessMap, sess.Transcoder()) + + // Magic number for now + penalty := 3 + // If this method is called assume that the orch should be suspended + // as well + pool.suspender.suspend(sess.Transcoder(), penalty) +} + +type AISessionSelector struct { + // Pool of sessions with orchs that have the requested model warm + warmPool *AISessionPool + // Pool of sessions with orchs that have the requested model cold + coldPool *AISessionPool + // The time until the pools should be refreshed with orchs from discovery + ttl time.Duration + lastRefreshTime time.Time + + cap core.Capability + modelID string + + node *core.LivepeerNode + suspender *suspender + os drivers.OSSession +} + +func NewAISessionSelector(cap core.Capability, modelID string, node *core.LivepeerNode, ttl time.Duration) (*AISessionSelector, error) { + var stakeRdr stakeReader + if node.Eth != nil { + stakeRdr = &storeStakeReader{store: node.Database} + } + + suspender := newSuspender() + + // The latency score in this context is just the latency of the last completed request for a session + // The "good enough" latency score is set to 0.0 so the selector will always select unknown sessions first + minLS := 0.0 + warmPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore), suspender) + coldPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore), suspender) + sel := &AISessionSelector{ + warmPool: warmPool, + coldPool: coldPool, + ttl: ttl, + cap: cap, + modelID: modelID, + node: node, + suspender: suspender, + os: drivers.NodeStorage.NewSession(strconv.Itoa(int(cap)) + "_" + modelID), + } + + if err := sel.Refresh(context.Background()); err != nil { + return nil, err + } + + return sel, nil +} + +func (sel *AISessionSelector) Select(ctx context.Context) *AISession { + if time.Now().After(sel.lastRefreshTime.Add(sel.ttl)) { + if err := sel.Refresh(ctx); err != nil { + clog.Infof(ctx, "Error refreshing AISessionSelector err=%v", err) + } + } + + sess := sel.warmPool.Select(ctx) + if sess != nil { + return &AISession{BroadcastSession: sess, Cap: sel.cap, ModelID: sel.modelID, Warm: true} + } + + sess = sel.coldPool.Select(ctx) + if sess != nil { + return &AISession{BroadcastSession: sess, Cap: sel.cap, ModelID: sel.modelID, Warm: false} + } + + return nil +} + +func (sel *AISessionSelector) Complete(sess *AISession) { + if sess.Warm { + sel.warmPool.Complete(sess.BroadcastSession) + } else { + sel.coldPool.Complete(sess.BroadcastSession) + } +} + +func (sel *AISessionSelector) Remove(sess *AISession) { + if sess.Warm { + sel.warmPool.Remove(sess.BroadcastSession) + } else { + sel.coldPool.Remove(sess.BroadcastSession) + } +} + +func (sel *AISessionSelector) Refresh(ctx context.Context) error { + sessions, err := sel.getSessions(ctx) + if err != nil { + return err + } + + var warmSessions []*BroadcastSession + var coldSessions []*BroadcastSession + for _, sess := range sessions { + // If the constraints are missing for this capability skip this session + constraints, ok := sess.OrchestratorInfo.Capabilities.Constraints[uint32(sel.cap)] + if !ok { + continue + } + + // If the constraint for the modelID are missing skip this session + modelConstraint, ok := constraints.Models[sel.modelID] + if !ok { + continue + } + + if modelConstraint.Warm { + warmSessions = append(warmSessions, sess) + } else { + coldSessions = append(coldSessions, sess) + } + } + + sel.warmPool.Add(warmSessions) + sel.coldPool.Add(coldSessions) + + sel.lastRefreshTime = time.Now() + + return nil +} + +func (sel *AISessionSelector) getSessions(ctx context.Context) ([]*BroadcastSession, error) { + // No warm constraints applied here because we don't want to filter out orchs based on warm criteria at discovery time + // Instead, we want all orchs that support the model and then will filter for orchs that have a warm model separately + constraints := map[core.Capability]*core.Constraints{ + sel.cap: { + Models: map[string]*core.ModelConstraint{ + sel.modelID: { + Warm: false, + }, + }, + }, + } + caps := core.NewCapabilitiesWithConstraints(append(core.DefaultCapabilities(), sel.cap), nil, constraints) + + // Set numOrchs to the pool size so that discovery tries to find maximum # of compatible orchs within a timeout + numOrchs := sel.node.OrchestratorPool.Size() + + // Use a dummy manifestID specific to the capability + modelID + // Typically, a manifestID would identify a stream + // In the AI context, a manifestID can identify a capability + modelID and each + // request for the capability + modelID can be thought of as a part of the same "stream" + manifestID := strconv.Itoa(int(sel.cap)) + "_" + sel.modelID + streamParams := &core.StreamParameters{ + ManifestID: core.ManifestID(manifestID), + Capabilities: caps, + OS: sel.os, + } + return selectOrchestrator(ctx, sel.node, streamParams, numOrchs, sel.suspender, common.ScoreAtLeast(0)) +} + +type AISessionManager struct { + node *core.LivepeerNode + selectors map[string]*AISessionSelector + mu *sync.Mutex + ttl time.Duration +} + +func NewAISessionManager(node *core.LivepeerNode, ttl time.Duration) *AISessionManager { + return &AISessionManager{ + node: node, + selectors: make(map[string]*AISessionSelector), + mu: &sync.Mutex{}, + ttl: ttl, + } +} + +func (c *AISessionManager) Select(ctx context.Context, cap core.Capability, modelID string) (*AISession, error) { + sel, err := c.getSelector(ctx, cap, modelID) + if err != nil { + return nil, err + } + + return sel.Select(ctx), nil +} + +func (c *AISessionManager) Remove(ctx context.Context, sess *AISession) error { + sel, err := c.getSelector(ctx, sess.Cap, sess.ModelID) + if err != nil { + return err + } + + sel.Remove(sess) + + return nil +} + +func (c *AISessionManager) Complete(ctx context.Context, sess *AISession) error { + sel, err := c.getSelector(ctx, sess.Cap, sess.ModelID) + if err != nil { + return err + } + + sel.Complete(sess) + + return nil +} + +func (c *AISessionManager) getSelector(ctx context.Context, cap core.Capability, modelID string) (*AISessionSelector, error) { + c.mu.Lock() + defer c.mu.Unlock() + + cacheKey := strconv.Itoa(int(cap)) + "_" + modelID + sel, ok := c.selectors[cacheKey] + if !ok { + // Create the selector + var err error + sel, err = NewAISessionSelector(cap, modelID, c.node, c.ttl) + if err != nil { + return nil, err + } + + c.selectors[cacheKey] = sel + } + + return sel, nil +} diff --git a/server/mediaserver.go b/server/mediaserver.go index 522d7fc145..72a445ffb6 100644 --- a/server/mediaserver.go +++ b/server/mediaserver.go @@ -59,6 +59,8 @@ const StreamKeyBytes = 6 const SegLen = 2 * time.Second const BroadcastRetry = 15 * time.Second +const AISessionManagerTTL = 10 * time.Minute + var BroadcastJobVideoProfiles = []ffmpeg.VideoProfile{ffmpeg.P240p30fps4x3, ffmpeg.P360p30fps16x9} var AuthWebhookURL *url.URL @@ -108,6 +110,8 @@ type LivepeerServer struct { ExposeCurrentManifest bool recordingsAuthResponses *cache.Cache + AISessionManager *AISessionManager + // Thread sensitive fields. All accesses to the // following fields should be protected by `connectionLock` rtmpConnections map[core.ManifestID]*rtmpConnection @@ -181,6 +185,7 @@ func NewLivepeerServer(rtmpAddr string, lpNode *core.LivepeerNode, httpIngest bo rtmpConnections: make(map[core.ManifestID]*rtmpConnection), internalManifests: make(map[core.ManifestID]core.ManifestID), recordingsAuthResponses: cache.New(time.Hour, 2*time.Hour), + AISessionManager: NewAISessionManager(lpNode, AISessionManagerTTL), } if lpNode.NodeType == core.BroadcasterNode && httpIngest { opts.HttpMux.HandleFunc("/live/", ls.HandlePush)