Skip to content

Commit

Permalink
bug fix for num_steps=1 (#2373)
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin authored Mar 7, 2024
1 parent 9dd547a commit 5faae37
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
6 changes: 5 additions & 1 deletion keras_cv/models/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@ def generate_image(

# Iterative reverse diffusion stage
num_timesteps = 1000
ratio = (num_timesteps - 1) / (num_steps - 1)
ratio = (
(num_timesteps - 1) / (num_steps - 1)
if num_steps > 1
else num_timesteps
)
timesteps = (np.arange(0, num_steps) * ratio).round().astype(np.int64)

alphas, alphas_prev = self._get_initial_alphas(timesteps)
Expand Down
8 changes: 8 additions & 0 deletions keras_cv/models/stable_diffusion/stable_diffusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def test_text_tokenizer_golden_value(self):
[49406, 320, 27111, 9038, 320],
)

@pytest.mark.extra_large
def test_num_steps_equal_to_one_no_error(self):
stablediff = StableDiffusion(128, 128)
_ = stablediff.generate_image(
stablediff.encode_text("thou shall not render"),
num_steps=1,
)

@pytest.mark.extra_large
def test_mixed_precision(self):
try:
Expand Down

0 comments on commit 5faae37

Please sign in to comment.