Skip to content

Commit

Permalink
Merge pull request #5 from invoke-ai/ryan/lora-injection-sd1
Browse files Browse the repository at this point in the history
Add basic linear LoRA support for Stable Diffusion v1 UNet
  • Loading branch information
RyanJDick authored Aug 3, 2023
2 parents 0528a35 + 5fdc63b commit 19f0766
Show file tree
Hide file tree
Showing 14 changed files with 304 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
ruff --format=github .
- name: Test with pytest
run: |
pytest tests --junitxml=junit/test-results-${{ matrix.python-version }}.xml -m "not cuda"
pytest tests --junitxml=junit/test-results-${{ matrix.python-version }}.xml -m "not cuda and not loads_model"
- name: Upload pytest test results
uses: actions/upload-artifact@v3
with:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"accelerate~=0.21.0",
"diffusers~=0.19.3",
"torch~=2.0.1",
]

Expand Down Expand Up @@ -51,4 +53,5 @@ line-length = 120
addopts = "--strict-markers"
markers = [
"cuda: marks tests that require a CUDA GPU",
"loads_model: marks tests that require a model from the HF hub",
]
Empty file.
43 changes: 43 additions & 0 deletions src/invoke_training/lora/injection/lora_layer_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import typing

import torch

from invoke_training.lora.layers import BaseLoRALayer


class LoRALayerCollection(torch.nn.Module):
"""A collection of LoRA layers (with names). Typically used to perform operations on a group of LoRA layers during
training.
"""

def __init__(self):
super().__init__()

# A torch.nn.ModuleDict may seem like a more natural choice here, but it does not allow keys that contain '.'
# characters. Using a standard python dict is also inconvenient, because it would be ignored by torch.nn.Module
# methods such as `.parameters()` and `.train()`.
self._layers = torch.nn.ModuleList()
self._names = []

def add_layer(self, layer: BaseLoRALayer, name: str):
self._layers.append(layer)
self._names.append(name)

def __len__(self):
return len(self._layers)

def get_lora_state_dict(self) -> typing.Dict[str, torch.Tensor]:
"""A custom alternative to .state_dict() that uses the layer names provided to add_layer(...) as key
prefixes.
"""
state_dict: typing.Dict[str, torch.Tensor] = {}

for name, layer in zip(self._names, self._layers):
layer_state_dict = layer.state_dict()
for key, state in layer_state_dict.items():
full_key = name + "." + key
if full_key in state_dict:
raise RuntimeError(f"Multiple state elements map to the same key: '{full_key}'.")
state_dict[full_key] = state

return state_dict
74 changes: 74 additions & 0 deletions src/invoke_training/lora/injection/stable_diffusion_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import typing

import torch
from diffusers.models import Transformer2DModel, UNet2DConditionModel
from diffusers.models.lora import LoRACompatibleLinear

from invoke_training.lora.injection.lora_layer_collection import LoRALayerCollection
from invoke_training.lora.injection.utils import inject_lora_layers
from invoke_training.lora.layers import LoRALinearLayer


def inject_lora_into_unet_sd1(unet: UNet2DConditionModel) -> LoRALayerCollection:
"""Inject LoRA layers into a Stable Diffusion v1 UNet model.
Args:
unet (UNet2DConditionModel): The UNet model to inject LoRA layers into.
Returns:
LoRALayerCollection: The LoRA layers that were added to the UNet.
"""

lora_layers = inject_lora_layers(
module=unet,
lora_map={torch.nn.Linear: LoRALinearLayer, LoRACompatibleLinear: LoRALinearLayer},
include_descendants_of={Transformer2DModel},
exclude_descendants_of=None,
prefix="lora_unet",
)

return lora_layers


def convert_lora_state_dict_to_kohya_format_sd1(
state_dict: typing.Dict[str, torch.Tensor]
) -> typing.Dict[str, torch.Tensor]:
"""Convert a Stable Diffusion v1 LoRA state_dict from internal invoke-training format to kohya_ss format.
Args:
state_dict (typing.Dict[str, torch.Tensor]): LoRA layer state_dict in invoke-training format.
Raises:
ValueError: If state_dict contains unexpected keys.
RuntimeError: If two input keys map to the same output kohya_ss key.
Returns:
typing.Dict[str, torch.Tensor]: LoRA layer state_dict in kohya_ss format.
"""
new_state_dict = {}

# The following logic converts state_dict keys from the internal invoke-training format to kohya_ss format.
# Example conversion:
# from: 'lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight'
# to: 'lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight'
for key, val in state_dict.items():
if key.endswith("._up.weight"):
key_start = key.removesuffix("._up.weight")
key_end = ".lora_up.weight"
elif key.endswith("._down.weight"):
key_start = key.removesuffix("._down.weight")
key_end = ".lora_down.weight"
elif key.endswith(".alpha"):
key_start = key.removesuffix(".alpha")
key_end = ".alpha"
else:
raise ValueError(f"Unexpected key in state_dict: '{key}'.")

new_key = key_start.replace(".", "_") + key_end

if new_key in new_state_dict:
raise RuntimeError("Multiple input keys map to the same kohya_ss key: '{new_key}'.")

new_state_dict[new_key] = val

return new_state_dict
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

from invoke_training.lora.injection.lora_layer_collection import LoRALayerCollection
from invoke_training.lora.layers import BaseLoRALayer
from invoke_training.lora.lora_block import LoRABlock

Expand All @@ -20,11 +21,11 @@ def find_modules(
module (torch.nn.Module): The base module whose sub-modules will be searched.
targets (typing.Set[typing.Type[torch.nn.Module]]): The set of module types to search for.
include_descendants_of (typing.Set[typing.Type[torch.nn.Module]], optional): If set, then only
descendants of these types will be searched. 'exclude_descendants_of' takes precedence over
'include_descendants_of'.
descendants of these types (and their subclasses) will be searched. 'exclude_descendants_of' takes
precedence over 'include_descendants_of'.
exclude_descendants_of (typing.Set[typing.Type[torch.nn.Module]], optional): If set, then the
descendants of these types will be ignored in the search. 'exclude_descendants_of' takes precedence over
'include_descendants_of'.
descendants of these types (and their subclasses) will be ignored in the search. 'exclude_descendants_of'
takes precedence over 'include_descendants_of'.
memo (typing.Set[torch.nn.Module], optional): A memo to store the set of modules already visited in the search.
memo is typically only set in recursive calls of this function.
prefix (str, optional): A prefix that will be added to the module name.
Expand Down Expand Up @@ -79,7 +80,8 @@ def inject_lora_layers(
lora_map: typing.Dict[type[torch.nn.Module], type[BaseLoRALayer]],
include_descendants_of: typing.Optional[typing.Set[typing.Type[torch.nn.Module]]] = None,
exclude_descendants_of: typing.Optional[typing.Set[typing.Type[torch.nn.Module]]] = None,
) -> torch.nn.ModuleDict:
prefix: str = "",
) -> LoRALayerCollection:
"""Iterates over all of the modules in 'module' and if they are present in 'replace_map' then replaces them with the
mapped LoRA layer type.
Args:
Expand All @@ -92,16 +94,18 @@ def inject_lora_layers(
```
include_descendants_of (typing.Set[typing.Type[torch.nn.Module]], optional): Forwarded to find_modules(...).
exclude_descendants_of (typing.Set[typing.Type[torch.nn.Module]], optional): Forwarded to find_modules(...).
prefix (str, optional): A prefix that will be added to the names of all of the LoRA layers.
Returns:
torch.nn.ModuleDict: A ModuleDict of all of the LoRA layers that were injected into the module.
LoRALayerCollection: A ModuleDict of all of the LoRA layers that were injected into the module.
"""
lora_layers = torch.nn.ModuleDict()
lora_layers = LoRALayerCollection()

for name, parent, module in find_modules(
module=module,
targets=lora_map.keys(),
include_descendants_of=include_descendants_of,
exclude_descendants_of=exclude_descendants_of,
prefix=prefix,
):
# Lookup the LoRA class to use.
lora_layer_cls = lora_map[type(module)]
Expand All @@ -120,6 +124,6 @@ def inject_lora_layers(
lora_block,
)

lora_layers[name] = lora_layer
lora_layers.add_layer(lora_layer, name)

return lora_layers
6 changes: 4 additions & 2 deletions src/invoke_training/lora/layers/lora_linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(
self._down = torch.nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self._up = torch.nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)

self._alpha = alpha
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))

self._rank = rank

self.reset_parameters()
Expand Down Expand Up @@ -86,6 +88,6 @@ def forward(self, input: torch.Tensor):
down_hidden = self._down(input)
up_hidden = self._up(down_hidden)

up_hidden *= self._alpha / self._rank
up_hidden *= self.alpha / self._rank

return up_hidden
2 changes: 1 addition & 1 deletion src/invoke_training/lora/lora_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ def __init__(self, original_module: torch.nn.Module, lora_layer: torch.nn.Module
self.lora_multiplier = lora_multiplier

def forward(self, input):
return self.original_module.forward(input) + self.lora_multiplier * self.lora_layer.forward(input)
return self.original_module(input) + self.lora_multiplier * self.lora_layer(input)
Empty file.
37 changes: 37 additions & 0 deletions tests/invoke_training/lora/injection/test_lora_layer_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from invoke_training.lora.injection.lora_layer_collection import LoRALayerCollection
from invoke_training.lora.layers import LoRALinearLayer


def test_lora_layer_collection_state_dict():
"""Test the behavior of LoRALayerCollection.get_lora_state_dict()."""
lora_layers = LoRALayerCollection()

lora_layers.add_layer(LoRALinearLayer(8, 16), "lora_layer_1")
lora_layers.add_layer(LoRALinearLayer(16, 32), "lora_layer_2")

state_dict = lora_layers.get_lora_state_dict()

expected_state_keys = {
"lora_layer_1._down.weight",
"lora_layer_1._up.weight",
"lora_layer_1.alpha",
"lora_layer_2._down.weight",
"lora_layer_2._up.weight",
"lora_layer_2.alpha",
}
assert set(state_dict.keys()) == expected_state_keys


def test_lora_layer_collection_state_dict_conflicting_keys():
"""Test that LoRALayerCollection.get_lora_state_dict() raises an exception if state Tensors have conflicting
keys.
"""
lora_layers = LoRALayerCollection()

lora_layers.add_layer(LoRALinearLayer(8, 16), "lora_layer_1")
lora_layers.add_layer(LoRALinearLayer(16, 32), "lora_layer_1") # Insert same layer type with same key.

with pytest.raises(RuntimeError):
_ = lora_layers.get_lora_state_dict()
99 changes: 99 additions & 0 deletions tests/invoke_training/lora/injection/test_stable_diffusion_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import torch
from diffusers.models import UNet2DConditionModel

from invoke_training.lora.injection.stable_diffusion_v1 import (
convert_lora_state_dict_to_kohya_format_sd1,
inject_lora_into_unet_sd1,
)


@pytest.mark.loads_model
def test_inject_lora_into_unet_sd1_smoke():
"""Smoke test of inject_lora_into_unet_sd1(...) on full SD 1.5 model."""
unet = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
local_files_only=True,
revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0",
)

lora_layers = inject_lora_into_unet_sd1(unet)

# These assertions are based on a manual check of the injected layers and comparison against the behaviour of
# kohya_ss. They are included here to force another manual review after any future breaking change.
assert len(lora_layers) == 160
# assert len(lora_layers) == 192 # TODO(ryand): Enable this check once conv layers are added.
for layer_name in lora_layers._names:
assert layer_name.endswith(("to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"))


@pytest.mark.loads_model
def test_convert_lora_state_dict_to_kohya_format_sd1_smoke():
"""Smoke test of convert_lora_state_dict_to_kohya_format_sd1(...) with full SD 1.5 model."""
unet = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
local_files_only=True,
revision="c9ab35ff5f2c362e9e22fbafe278077e196057f0",
)

lora_layers = inject_lora_into_unet_sd1(unet)
lora_state_dict = lora_layers.get_lora_state_dict()
kohya_state_dict = convert_lora_state_dict_to_kohya_format_sd1(lora_state_dict)

# These assertions are based on a manual check of the injected layers and comparison against the behaviour of
# kohya_ss. They are included here to force another manual review after any future breaking change.
assert len(kohya_state_dict) == 160 * 3
for key in kohya_state_dict.keys():
assert key.startswith("lora_unet_")
assert key.endswith((".lora_down.weight", ".lora_up.weight", ".alpha"))


def test_convert_lora_state_dict_to_kohya_format_sd1():
"""Basic test of convert_lora_state_dict_to_kohya_format_sd1(...)."""
down_weight = torch.Tensor(4, 2)
up_weight = torch.Tensor(2, 4)
alpha = torch.Tensor([1.0])
in_state_dict = {
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": down_weight,
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._up.weight": up_weight,
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.alpha": alpha,
}

out_state_dict = convert_lora_state_dict_to_kohya_format_sd1(in_state_dict)

expected_out_state_dict = {
"lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight": down_weight,
"lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up.weight": up_weight,
"lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.alpha": alpha,
}

assert out_state_dict == expected_out_state_dict


def test_convert_lora_state_dict_to_kohya_format_sd1_unexpected_key():
"""Test that convert_lora_state_dict_to_kohya_format_sd1(...) raises an exception if it receives an unexpected
key.
"""
in_state_dict = {
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.unexpected": torch.Tensor(4, 2),
}

with pytest.raises(ValueError):
_ = convert_lora_state_dict_to_kohya_format_sd1(in_state_dict)


def test_convert_lora_state_dict_to_kohya_format_sd1_conflicting_keys():
"""Test that convert_lora_state_dict_to_kohya_format_sd1(...) raises an exception if multiple keys map to the same
output key.
"""
# Note: There are differences in the '.' and '_' characters of these keys, but they both map to the same output
# kohya_ss keys.
in_state_dict = {
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q._down.weight": torch.Tensor(4, 2),
"lora_unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1_to_q._down.weight": torch.Tensor(4, 2),
}

with pytest.raises(RuntimeError):
_ = convert_lora_state_dict_to_kohya_format_sd1(in_state_dict)
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch

from invoke_training.lora.injection.utils import find_modules, inject_lora_layers
from invoke_training.lora.layers import LoRALinearLayer
from invoke_training.lora.lora_block import LoRABlock
from invoke_training.lora.utils import find_modules, inject_lora_layers


def test_find_modules_simple():
Expand Down Expand Up @@ -191,9 +191,11 @@ def test_inject_lora_layers():
}
)

lora_layers = inject_lora_layers(module, {torch.nn.Linear: LoRALinearLayer})
lora_layers = inject_lora_layers(module, {torch.nn.Linear: LoRALinearLayer}, prefix="lora_unet")

assert len(lora_layers) == 1
assert all([k.startswith("lora_unet") for k in lora_layers.get_lora_state_dict()])

assert isinstance(module["linear1"], LoRABlock)
assert module["linear1"].original_module == linear1
assert module["linear1"].lora_layer._down.in_features == linear1.in_features
Expand Down
Loading

0 comments on commit 19f0766

Please sign in to comment.