Skip to content

Commit

Permalink
Stable diffusion small fixes (#836)
Browse files Browse the repository at this point in the history
* Remove numpy
Add regex dependency in codespace/docker

* Update stable_diffusion.py

* Fix format
  • Loading branch information
bhack authored Sep 25, 2022
1 parent d067696 commit c39df7c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
1 change: 1 addition & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ RUN pip install \
flake8 \
isort \
pytest \
regex \
tensorflow_datasets
32 changes: 16 additions & 16 deletions keras_cv/models/generative/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import math

import numpy as np
import tensorflow as tf
from tensorflow import keras

Expand Down Expand Up @@ -126,24 +125,25 @@ def text_to_image(
f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)"
)
phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
phrase = np.array(phrase)[None].astype("int32")
phrase = np.repeat(phrase, batch_size, axis=0)
phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
phrase = tf.repeat(phrase, batch_size, axis=0)

# Encode prompt tokens + positions into a "context" vector
pos_ids = np.array(list(range(MAX_PROMPT_LENGTH)))[None].astype("int32")
pos_ids = np.repeat(pos_ids, batch_size, axis=0)
pos_ids = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
pos_ids = tf.repeat(pos_ids, batch_size, axis=0)
context = self.text_encoder.predict_on_batch([phrase, pos_ids])

# Encode unconditional tokens + positions as "unconditional context"
unconditional_tokens = np.array(_UNCONDITIONAL_TOKENS)[None].astype("int32")
unconditional_tokens = np.repeat(unconditional_tokens, batch_size, axis=0)
self.unconditional_tokens = tf.convert_to_tensor(unconditional_tokens)
unconditional_tokens = tf.convert_to_tensor(
[_UNCONDITIONAL_TOKENS], dtype=tf.int32
)
self.unconditional_tokens = tf.repeat(unconditional_tokens, batch_size, axis=0)
unconditional_context = self.text_encoder.predict_on_batch(
[self.unconditional_tokens, pos_ids]
)

# Iterative reverse diffusion stage
timesteps = np.arange(1, 1000, 1000 // num_steps)
timesteps = tf.range(1, 1000, 1000 // num_steps)
latent, alphas, alphas_prev = self._get_initial_parameters(
timesteps, batch_size, seed
)
Expand All @@ -168,17 +168,17 @@ def text_to_image(
# Decoding stage
decoded = self.decoder.predict_on_batch(latent)
decoded = ((decoded + 1) / 2) * 255
return np.clip(decoded, 0, 255).astype("uint8")
return tf.cast(tf.clip_by_value(decoded, 0, 255), dtype=tf.uint8)

def _get_timestep_embedding(self, timestep, batch_size, dim=320, max_period=10000):
half = dim // 2
freqs = np.exp(
-math.log(max_period) * np.arange(0, half, dtype="float32") / half
freqs = tf.math.exp(
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
)
args = np.array([timestep]) * freqs
embedding = np.concatenate([np.cos(args), np.sin(args)])
embedding = tf.convert_to_tensor(embedding.reshape(1, -1))
return np.repeat(embedding, batch_size, axis=0)
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
embedding = tf.reshape(embedding, [1, -1])
return tf.repeat(embedding, batch_size, axis=0)

def _get_initial_parameters(self, timesteps, batch_size, seed=None):
alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
Expand Down

0 comments on commit c39df7c

Please sign in to comment.