Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flux] Port Flux Core Model #1864

Open
wants to merge 69 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
286f4b2
starter commit - ported time embeddings to keras ops
DavidLandup0 Sep 23, 2024
244f013
add mlpembedder
DavidLandup0 Sep 23, 2024
480ad24
add RMS Norm re-implementation
DavidLandup0 Sep 23, 2024
2782242
add qknorm reimplementation
DavidLandup0 Sep 23, 2024
48c82e6
add rope, scaled dot product attention and self attention
DavidLandup0 Sep 26, 2024
513e370
modulation layer
DavidLandup0 Sep 29, 2024
8ccbb26
fix typing
DavidLandup0 Sep 29, 2024
c88c949
add double stream block
DavidLandup0 Sep 29, 2024
2bc150e
adjustments to doublestreamblock
DavidLandup0 Sep 30, 2024
969d508
add signle stream layer@
DavidLandup0 Oct 2, 2024
77c9297
update layers and add flux core model
DavidLandup0 Oct 4, 2024
35769ab
functions to layers
DavidLandup0 Oct 5, 2024
13d46c4
refactor layer usage
DavidLandup0 Oct 5, 2024
c00c6a5
refactor layer usage
DavidLandup0 Oct 5, 2024
05a1e3f
position math args in call()
DavidLandup0 Oct 5, 2024
f076006
name arguments
DavidLandup0 Oct 5, 2024
f9fc4a4
fix arg name
DavidLandup0 Oct 5, 2024
f2f2c96
start adding conversion script utils
DavidLandup0 Oct 5, 2024
311d342
change reshape into rearrange
DavidLandup0 Oct 6, 2024
db14c01
add rest of weight conversion and remove redundant shape extraction
DavidLandup0 Oct 6, 2024
c5b37c6
fix mlpembedder arg
DavidLandup0 Oct 6, 2024
8d3a385
remove redundant args
DavidLandup0 Oct 6, 2024
fa5379e
fix params. to self.
DavidLandup0 Oct 6, 2024
34e2477
add license
DavidLandup0 Oct 6, 2024
cdd397a
add einops
DavidLandup0 Oct 6, 2024
8169aa4
fix default arg
DavidLandup0 Oct 6, 2024
b1caa7f
expand docstrings
DavidLandup0 Oct 6, 2024
76eae83
tanh to gelu
DavidLandup0 Oct 6, 2024
c0236ac
refactor weight conversion into tools
DavidLandup0 Oct 6, 2024
b418659
update weight conversion
DavidLandup0 Oct 7, 2024
99839af
add stand-in presets until weights are uploaded
DavidLandup0 Oct 7, 2024
ac5c4b1
set float32 to t.dtype in timestep embedding
DavidLandup0 Oct 7, 2024
89dc08c
update more float32s into dynamic types
DavidLandup0 Oct 7, 2024
d3de26b
dtype
DavidLandup0 Oct 7, 2024
9d4aa22
dtype
DavidLandup0 Oct 7, 2024
dbddde7
enable float16 mode
DavidLandup0 Oct 7, 2024
b3c75a9
update conversion script to not require flux repo
DavidLandup0 Oct 7, 2024
4333bab
add build() methods to avoid running dummy input through model
DavidLandup0 Oct 7, 2024
199ba1c
update build call
DavidLandup0 Oct 7, 2024
a8de665
fix build calls
DavidLandup0 Oct 7, 2024
efe993a
style
DavidLandup0 Oct 7, 2024
ff118bb
change dummy call into build() call
DavidLandup0 Oct 7, 2024
da78707
Merge branch 'master' into feature/flux
DavidLandup0 Oct 8, 2024
a3ccf6d
reference einops issue
DavidLandup0 Oct 8, 2024
f88e1e9
address docstring comments in flux layers
DavidLandup0 Oct 8, 2024
6e2c320
address docstring comments in flux maths
DavidLandup0 Oct 8, 2024
b407ffc
remove numpy
DavidLandup0 Oct 8, 2024
ac43081
add docstrings for flux model
DavidLandup0 Oct 8, 2024
4b585a0
qkv bias -> use_bias
DavidLandup0 Oct 8, 2024
a2facb2
docstring updates
DavidLandup0 Oct 8, 2024
bd2ebe2
remove type hints
DavidLandup0 Oct 8, 2024
f48bbd2
all img->image, txt->text
DavidLandup0 Oct 8, 2024
cbad326
functional subclassing model
DavidLandup0 Oct 14, 2024
eeb8e0d
shape fixes
DavidLandup0 Oct 15, 2024
330ed70
format
DavidLandup0 Oct 15, 2024
9233411
self.hidden_size -> self.dim
DavidLandup0 Oct 15, 2024
ed2badc
einops rearrange
DavidLandup0 Oct 15, 2024
a65424b
remove build method
DavidLandup0 Oct 15, 2024
cb11e28
ops to rearrange
DavidLandup0 Oct 15, 2024
f478f39
remove build
DavidLandup0 Oct 15, 2024
3b5cb4d
rearrange -> symbolic_rearrange
DavidLandup0 Oct 15, 2024
40178e1
turn timesteps and guidance into inputs
DavidLandup0 Oct 15, 2024
078459d
basic preprocessor flow
DavidLandup0 Oct 15, 2024
0003b08
refactor layer names in conversion script
DavidLandup0 Oct 15, 2024
71b564f
add backbone tests
DavidLandup0 Oct 15, 2024
7aa93a2
raise not implemented on encode, encode_text, etc. methods
DavidLandup0 Oct 15, 2024
b05c94b
styling
DavidLandup0 Oct 15, 2024
94f9ffb
fix shape hack with a cleaner alternative
DavidLandup0 Oct 16, 2024
adeb842
remove unused attributes, fix tests
DavidLandup0 Oct 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@
)
from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
from keras_hub.src.models.flux.flux_model import FluxBackbone
from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage
from keras_hub.src.models.flux.flux_text_to_image_preprocessor import (
FluxTextToImagePreprocessor,
)
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
Expand Down
5 changes: 5 additions & 0 deletions keras_hub/src/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras_hub.src.models.flux.flux_model import FluxBackbone
from keras_hub.src.models.flux.flux_presets import presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(presets, FluxBackbone)
88 changes: 88 additions & 0 deletions keras_hub/src/models/flux/flux_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
from keras import ops

from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder
from keras_hub.src.models.flux.flux_model import FluxBackbone
from keras_hub.src.models.vae.vae_backbone import VAEBackbone
from keras_hub.src.tests.test_case import TestCase


class FluxBackboneTest(TestCase):
def setUp(self):
vae = VAEBackbone(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be part of the generation pipeline so these are added preemptively and unused for now

[32, 32, 32, 32],
[1, 1, 1, 1],
[32, 32, 32, 32],
[1, 1, 1, 1],
# Use `mode` generate a deterministic output.
sampler_method="mode",
name="vae",
)
clip_l = CLIPTextEncoder(
20, 32, 32, 2, 2, 64, "quick_gelu", -2, name="clip_l"
)
self.init_kwargs = {
"input_channels": 256,
"hidden_size": 1024,
"mlp_ratio": 2.0,
"num_heads": 8,
"depth": 4,
"depth_single_blocks": 8,
"axes_dim": [16, 56, 56],
"theta": 10_000,
"use_bias": True,
"guidance_embed": True,
"image_shape": (32, 256),
"text_shape": (32, 256),
"image_ids_shape": (32, 3),
"text_ids_shape": (32, 3),
"timestep_shape": (128,),
"y_shape": (256,),
"guidance_shape": (128,),
}

self.pipeline_models = {
"vae": vae,
"clip_l": clip_l,
}

input_data = {
"image": ops.ones((1, 32, 256)),
"image_ids": ops.ones((1, 32, 3)),
"text": ops.ones((1, 32, 256)),
"text_ids": ops.ones((1, 32, 3)),
"y": ops.ones((1, 256)),
# Name is set but for some reason, it's overriden
"keras_tensor_8CLONE": ops.ones((32,)),
"keras_tensor_9CLONE": ops.ones((32,)),
}

self.input_data = [
input_data["image"],
input_data["image_ids"],
input_data["text"],
input_data["text_ids"],
input_data["y"],
input_data["keras_tensor_8CLONE"],
input_data["keras_tensor_9CLONE"],
]

# backbone.predict() will complain about data cardinality.
# i.e. all data has a batch size of 1, but the
# timesteps and guidance are unbatched and the cardinality
# thus doesn't match.
def test_backbone_basics(self):
self.run_backbone_test(
cls=FluxBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=[32, 32, 256],
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=FluxBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading
Loading