Skip to content

Commit

Permalink
feat(ai): add 'num_inference_steps' to I2I,I2V and upscale pipeliens
Browse files Browse the repository at this point in the history
This commit adds support for the `num_inference_steps` parameter to the
I2I, I2V and upscale pipelines. It also fixes a incorrect latencyScore
calculation for the bytedance model.
  • Loading branch information
rickstaa committed Jul 17, 2024
1 parent 29d4603 commit 29669c0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
17 changes: 17 additions & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"os"
"regexp"
"strconv"
"strings"

Expand Down Expand Up @@ -85,3 +86,19 @@ func ParseAIModelConfigs(config string) ([]AIModelConfig, error) {

return configs, nil
}

// parseStepsFromModelID parses the number of inference steps from the model ID suffix.
func ParseStepsFromModelID(modelID *string, defaultSteps float64) float64 {
numInferenceSteps := defaultSteps

// Regular expression to find "_<number>step" pattern anywhere in the model ID.
stepPattern := regexp.MustCompile(`_(\d+)step`)
matches := stepPattern.FindStringSubmatch(*modelID)
if len(matches) == 2 {
if parsedSteps, err := strconv.Atoi(matches[1]); err == nil {
numInferenceSteps = float64(parsedSteps)
}
}

return numInferenceSteps
}
34 changes: 30 additions & 4 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ func submitTextToImage(ctx context.Context, params aiRequestParams, sess *AISess
if req.NumInferenceSteps != nil {
numInferenceSteps = float64(*req.NumInferenceSteps)
}
// Handle special case for SDXL-Lightning model.
if strings.HasPrefix(*req.ModelId, "ByteDance/SDXL-Lightning") {
numInferenceSteps = core.ParseStepsFromModelID(req.ModelId, 8)
}

sess.LatencyScore = took.Seconds() / float64(outPixels) / (numImages * numInferenceSteps)

return resp.JSON200, nil
Expand Down Expand Up @@ -202,13 +207,22 @@ func submitImageToImage(ctx context.Context, params aiRequestParams, sess *AISes
}

// TODO: Refine this rough estimate in future iterations.
// TODO: Default values for the number of images is currently hardcoded.
// TODO: Default values for the number of images and inference steps are currently hardcoded.
// These should be managed by the nethttpmiddleware. Refer to issue LIV-412 for more details.
numImages := float64(1)
if req.NumImagesPerPrompt != nil {
numImages = float64(*req.NumImagesPerPrompt)
}
sess.LatencyScore = took.Seconds() / float64(outPixels) / numImages
numInferenceSteps := float64(100)
if req.NumInferenceSteps != nil {
numInferenceSteps = float64(*req.NumInferenceSteps)
}
// Handle special case for SDXL-Lightning model.
if strings.HasPrefix(*req.ModelId, "ByteDance/SDXL-Lightning") {
numInferenceSteps = core.ParseStepsFromModelID(req.ModelId, 8)
}

sess.LatencyScore = took.Seconds() / float64(outPixels) / (numImages * numInferenceSteps)

return resp.JSON200, nil
}
Expand Down Expand Up @@ -303,7 +317,13 @@ func submitImageToVideo(ctx context.Context, params aiRequestParams, sess *AISes
}

// TODO: Refine this rough estimate in future iterations
sess.LatencyScore = took.Seconds() / float64(outPixels)
// TODO: Default values for the number of inference steps is currently hardcoded.
// These should be managed by the nethttpmiddleware. Refer to issue LIV-412 for more details.
numInferenceSteps := float64(25)
if req.NumInferenceSteps != nil {
numInferenceSteps = float64(*req.NumInferenceSteps)
}
sess.LatencyScore = took.Seconds() / float64(outPixels) / numInferenceSteps

return &res, nil
}
Expand Down Expand Up @@ -383,7 +403,13 @@ func submitUpscale(ctx context.Context, params aiRequestParams, sess *AISession,
}

// TODO: Refine this rough estimate in future iterations
sess.LatencyScore = took.Seconds() / float64(outPixels)
// TODO: Default values for the number of inference steps is currently hardcoded.
// These should be managed by the nethttpmiddleware. Refer to issue LIV-412 for more details.
numInferenceSteps := float64(75)
if req.NumInferenceSteps != nil {
numInferenceSteps = float64(*req.NumInferenceSteps)
}
sess.LatencyScore = took.Seconds() / float64(outPixels) / numInferenceSteps

return resp.JSON200, nil
}
Expand Down

0 comments on commit 29669c0

Please sign in to comment.