diff --git a/keras_cv/models/stable_diffusion/stable_diffusion.py b/keras_cv/models/stable_diffusion/stable_diffusion.py index a68923dc78..bced10ccf1 100644 --- a/keras_cv/models/stable_diffusion/stable_diffusion.py +++ b/keras_cv/models/stable_diffusion/stable_diffusion.py @@ -210,7 +210,11 @@ def generate_image( # Iterative reverse diffusion stage num_timesteps = 1000 - ratio = (num_timesteps - 1) / (num_steps - 1) + ratio = ( + (num_timesteps - 1) / (num_steps - 1) + if num_steps > 1 + else num_timesteps + ) timesteps = (np.arange(0, num_steps) * ratio).round().astype(np.int64) alphas, alphas_prev = self._get_initial_alphas(timesteps) diff --git a/keras_cv/models/stable_diffusion/stable_diffusion_test.py b/keras_cv/models/stable_diffusion/stable_diffusion_test.py index edd8681483..23f57569d3 100644 --- a/keras_cv/models/stable_diffusion/stable_diffusion_test.py +++ b/keras_cv/models/stable_diffusion/stable_diffusion_test.py @@ -72,6 +72,14 @@ def test_text_tokenizer_golden_value(self): [49406, 320, 27111, 9038, 320], ) + @pytest.mark.extra_large + def test_num_steps_equal_to_one_no_error(self): + stablediff = StableDiffusion(128, 128) + _ = stablediff.generate_image( + stablediff.encode_text("thou shall not render"), + num_steps=1, + ) + @pytest.mark.extra_large def test_mixed_precision(self): try: