Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semantic Segmentation] - Add SegFormer Architecture, Weight Conversion Script and Presets #1883

Open
wants to merge 68 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
716ae64
initial commit - tf-based, kcv
DavidLandup0 Sep 26, 2024
71bd40b
porting to keras_hub structure - removing aliases, presets, etc.
DavidLandup0 Sep 27, 2024
8894a86
enable instantiation of segformer backbone with custom MiT backbone
DavidLandup0 Sep 27, 2024
b66c659
remove num_classes from backbone
DavidLandup0 Sep 27, 2024
392ec36
fix input
DavidLandup0 Sep 29, 2024
d80d8d0
add imports to __init__
DavidLandup0 Sep 29, 2024
538adf7
Merge branch 'master' into feature/segformer
DavidLandup0 Sep 29, 2024
1571677
update preset
DavidLandup0 Sep 29, 2024
4b82a16
update docstrings
DavidLandup0 Sep 29, 2024
9b260e7
add basic tests
DavidLandup0 Sep 29, 2024
b93954f
remove redundant imports
DavidLandup0 Sep 29, 2024
159dca5
update docstrings
DavidLandup0 Sep 29, 2024
3ec02dd
remove unused import
DavidLandup0 Sep 29, 2024
7b6286e
running api_gen.py
DavidLandup0 Sep 29, 2024
c40fdcd
undo refactor of mit
DavidLandup0 Sep 29, 2024
9a13544
update docstrings
DavidLandup0 Sep 29, 2024
4dc3fff
add presets for mit
DavidLandup0 Sep 29, 2024
191656c
add standin paths
DavidLandup0 Sep 29, 2024
9e47564
add presets for segformer backbone
DavidLandup0 Sep 30, 2024
98bb69d
register presets in __init__.py
DavidLandup0 Sep 30, 2024
21ed167
addressing comments
DavidLandup0 Oct 1, 2024
f6720ac
addressing comments
DavidLandup0 Oct 1, 2024
b0806f2
addressing comments
DavidLandup0 Oct 1, 2024
22df93b
merge master branch into feature branch
DavidLandup0 Oct 1, 2024
0549be7
update most tests
DavidLandup0 Oct 1, 2024
4f66776
add remaining tests
DavidLandup0 Oct 2, 2024
f0b3e56
remove copyright
DavidLandup0 Oct 2, 2024
8c36b6e
fix test
DavidLandup0 Oct 2, 2024
9e1a9d6
override from_config
DavidLandup0 Oct 2, 2024
8d95091
Merge branch 'master' into feature/mit_presets
DavidLandup0 Oct 7, 2024
0c92729
fix op in overlapping patching and embedding, start adding conversion…
DavidLandup0 Oct 7, 2024
6638cb1
style
DavidLandup0 Oct 7, 2024
9a3f82d
add padding to MiT patchingandembedding
DavidLandup0 Oct 7, 2024
76a6dd2
update to support other presets
DavidLandup0 Oct 7, 2024
7b06c79
update conversin script
DavidLandup0 Oct 8, 2024
6e9728f
fix link for b5
DavidLandup0 Oct 8, 2024
7119022
add cityscapes weights
DavidLandup0 Oct 8, 2024
8ea5f63
update presets
DavidLandup0 Oct 8, 2024
0705748
update presets
DavidLandup0 Oct 8, 2024
eb1c236
update conversion script to make directories
DavidLandup0 Oct 8, 2024
fc42598
use save_preset
DavidLandup0 Oct 8, 2024
4274c60
change name of output dir
DavidLandup0 Oct 8, 2024
dc72ea7
add preprocessor flow
DavidLandup0 Oct 8, 2024
65f1822
api gen and add preprocessor to mits
DavidLandup0 Oct 8, 2024
d1bc073
merge master into feature branch
DavidLandup0 Oct 8, 2024
000d7d0
conform to new image classifier style
DavidLandup0 Oct 8, 2024
fa34f9e
format
DavidLandup0 Oct 8, 2024
e3c6dc6
resizing image converter -> ImageConverter
DavidLandup0 Oct 8, 2024
3483c10
Merge branch 'feature/mit_presets' into feature/segformer
DavidLandup0 Oct 8, 2024
38dadef
merge mit branch into segformer branch
DavidLandup0 Oct 8, 2024
983642c
add preprocessor and converter
DavidLandup0 Oct 8, 2024
2a8ffcb
address comments
DavidLandup0 Oct 9, 2024
34e8701
Merge branch 'feature/mit_presets' into feature/segformer
DavidLandup0 Oct 9, 2024
5b5eb93
clarify backbone usage
DavidLandup0 Oct 9, 2024
fcdadb3
add conversion script
DavidLandup0 Oct 9, 2024
68cef65
numerical equivalence changes
DavidLandup0 Oct 14, 2024
205ae4a
fix numerical inaccuracies
DavidLandup0 Oct 15, 2024
8b5fa44
update conversion script
DavidLandup0 Oct 15, 2024
79e05f8
merge master branch into feature branch
DavidLandup0 Oct 15, 2024
de17c03
update conversion script
DavidLandup0 Oct 15, 2024
04ba1eb
remove transpose
DavidLandup0 Oct 16, 2024
9e04b6e
add preprocessor to segformer class
DavidLandup0 Oct 16, 2024
e9e8ed5
fix preset path
DavidLandup0 Oct 16, 2024
a7a21f6
update test shape
DavidLandup0 Oct 16, 2024
28e1297
update presets
DavidLandup0 Oct 16, 2024
fc8fffe
update test shape
DavidLandup0 Oct 16, 2024
fa89a09
expand docstrings
DavidLandup0 Oct 16, 2024
6e3d3d1
add rescaling and normalization to preprocessor
DavidLandup0 Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.segformer.segformer_image_converter import (
SegFormerImageConverter,
)
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageConverter
from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
Expand Down
7 changes: 7 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,13 @@
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
SAMImageSegmenterPreprocessor,
)
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
from keras_hub.src.models.segformer.segformer_image_segmenter import (
SegFormerImageSegmenter,
)
from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import (
SegFormerImageSegmenterPreprocessor,
)
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
Expand Down
13 changes: 12 additions & 1 deletion keras_hub/src/models/mix_transformer/mix_transformer_backbone.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
import numpy as np
from keras import ops
Expand Down Expand Up @@ -104,7 +115,7 @@ def __init__(
]
transformer_blocks.append(transformer_block)
cur += depths[i]
layer_norms.append(keras.layers.LayerNormalization())
layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5))

# === Functional Model ===
image_input = keras.layers.Input(shape=image_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MiTBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"depths": [2, 2],
"image_shape": (16, 16, 3),
"image_shape": (32, 32, 3),
"hidden_dims": [4, 8],
"num_layers": 2,
"blockwise_num_heads": [1, 2],
Expand All @@ -20,7 +20,7 @@ def setUp(self):
"patch_sizes": [7, 3],
"strides": [4, 2],
}
self.input_size = 16
self.input_size = 32
self.input_data = np.ones(
(2, self.input_size, self.input_size, 3), dtype="float32"
)
Expand All @@ -30,9 +30,9 @@ def test_backbone_basics(self):
cls=MiTBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 2, 2, 8),
expected_output_shape=(2, 4, 4, 8),
expected_pyramid_output_keys=["P1", "P2"],
expected_pyramid_image_sizes=[(4, 4), (2, 2)],
expected_pyramid_image_sizes=[(8, 8), (4, 4)],
run_quantization_check=False,
run_mixed_precision_check=False,
run_data_format_check=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
class MiTImageClassifierTest(TestCase):
def setUp(self):
# Setup model.
self.images = np.ones((2, 16, 16, 3), dtype="float32")
self.images = np.ones((2, 32, 32, 3), dtype="float32")
self.labels = [0, 3]
self.backbone = MiTBackbone(
depths=[2, 2, 2, 2],
image_shape=(16, 16, 3),
image_shape=(32, 32, 3),
hidden_dims=[4, 8],
num_layers=2,
blockwise_num_heads=[1, 2],
Expand All @@ -44,7 +44,7 @@ def test_classifier_basics(self):
cls=MiTImageClassifier,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 2),
expected_output_shape=(4, 4),
)

@pytest.mark.large
Expand Down
16 changes: 9 additions & 7 deletions keras_hub/src/models/mix_transformer/mix_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,21 @@ def __init__(self, project_dim, num_heads, sr_ratio):
self.k = keras.layers.Dense(project_dim)
self.v = keras.layers.Dense(project_dim)
self.proj = keras.layers.Dense(project_dim)
self.dropout = keras.layers.Dropout(0.1)
self.proj_drop = keras.layers.Dropout(0.1)

if sr_ratio > 1:
self.sr = keras.layers.Conv2D(
filters=project_dim,
kernel_size=sr_ratio,
strides=sr_ratio,
padding="same",
)
self.norm = keras.layers.LayerNormalization()
self.norm = keras.layers.LayerNormalization(epsilon=1e-5)

def call(self, x):
input_shape = ops.shape(x)
H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1]))
B, C = input_shape[0], input_shape[2]
B, N, C = input_shape[0], input_shape[1], input_shape[2]

q = self.q(x)
q = ops.reshape(
Expand All @@ -212,12 +213,11 @@ def call(self, x):

if self.sr_ratio > 1:
x = ops.reshape(
ops.transpose(x, [0, 2, 1]),
x,
(B, H, W, C),
)
x = self.sr(x)
x = ops.reshape(x, [input_shape[0], input_shape[2], -1])
x = ops.transpose(x, [0, 2, 1])
x = ops.reshape(x, [B, -1, C])
x = self.norm(x)

k = self.k(x)
Expand All @@ -241,14 +241,16 @@ def call(self, x):

attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale
attn = ops.nn.softmax(attn, axis=-1)
attn = self.dropout(attn)

attn = attn @ v
attn = ops.reshape(
ops.transpose(attn, [0, 2, 1, 3]),
[input_shape[0], input_shape[1], input_shape[2]],
[B, N, C],
)

x = self.proj(attn)
x = self.proj_drop(x)
return x


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_512/1",
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_640/1",
},
"mit_b0_cityscapes_1024": {
"metadata": {
Expand Down
34 changes: 34 additions & 0 deletions keras_hub/src/models/segformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed under the Apache License, Version 2.0 (the "License");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove copyright banner

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
from keras_hub.src.models.segformer.segformer_backbone_presets import (
presets as backbone_presets,
)
from keras_hub.src.models.segformer.segformer_image_segmenter import (
SegFormerImageSegmenter,
)
from keras_hub.src.models.segformer.segformer_presets import presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(presets, SegFormerImageSegmenter)
register_presets(backbone_presets, SegFormerBackbone)
191 changes: 191 additions & 0 deletions keras_hub/src/models/segformer/segformer_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove copyright banner

# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
)


@keras_hub_export("keras_hub.models.SegFormerBackbone")
class SegFormerBackbone(Backbone):
"""A Keras model implementing the SegFormer architecture for semantic segmentation.

This class implements the majority of the SegFormer architecture described in
[SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers]
(https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision]
(https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).

SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and
and use a very lightweight all-MLP decoder head.

The MiT encoder uses a hierarchical transformer which outputs features at multiple scales,
similar to that of the hierarchical outputs typically associated with CNNs.

Args:
image_encoder: `keras.Model`. The backbone network for the model that is
used as a feature extractor for the SegFormer encoder.
Should be used with the MiT backbone model
(`keras_hub.models.MiTBackbone`) which was created
specifically for SegFormers.
num_classes: int, the number of classes for the detection model,
including the background class.
projection_filters: int, number of filters in the
convolution layer projecting the concatenated features into
a segmentation map. Defaults to 256`.

Example:

Using the class with a custom `backbone`:

```python
import keras_hub

backbone = keras_hub.models.MiTBackbone(
DavidLandup0 marked this conversation as resolved.
Show resolved Hide resolved
depths=[2, 2, 2, 2],
image_shape=(224, 224, 3),
hidden_dims=[32, 64, 160, 256],
num_layers=4,
blockwise_num_heads=[1, 2, 5, 8],
blockwise_sr_ratios=[8, 4, 2, 1],
max_drop_path_rate=0.1,
patch_sizes=[7, 3, 3, 3],
strides=[4, 2, 2, 2],
)

segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256)
```

Using the class with a preset `backbone`:

```python
import keras_hub

backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256)
```

"""

backbone_cls = MiTBackbone
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed


def __init__(
self,
image_encoder,
projection_filters,
**kwargs,
):
if not isinstance(image_encoder, keras.layers.Layer) or not isinstance(
image_encoder, keras.Model
):
raise ValueError(
"Argument `image_encoder` must be a `keras.layers.Layer` instance "
f" or `keras.Model`. Received instead "
f"image_encoder={image_encoder} (of type {type(image_encoder)})."
)

# === Layers ===
inputs = keras.layers.Input(shape=image_encoder.input.shape[1:])

self.feature_extractor = keras.Model(
image_encoder.inputs, image_encoder.pyramid_outputs
)

features = self.feature_extractor(inputs)
# Get height and width of level one output
_, height, width, _ = features["P1"].shape

self.mlp_blocks = []

for feature_dim, feature in zip(image_encoder.hidden_dims, features):
self.mlp_blocks.append(
keras.layers.Dense(
projection_filters, name=f"linear_{feature_dim}"
)
)

self.resizing = keras.layers.Resizing(
height, width, interpolation="bilinear"
)
self.concat = keras.layers.Concatenate(axis=-1)
self.linear_fuse = keras.Sequential(
[
keras.layers.Conv2D(
filters=projection_filters, kernel_size=1, use_bias=False
),
keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9),
keras.layers.Activation("relu"),
]
)

# === Functional Model ===
# Project all multi-level outputs onto
# the same dimensionality and feature map shape
multi_layer_outs = []
for index, (feature_dim, feature) in enumerate(
zip(image_encoder.hidden_dims, features)
):
out = self.mlp_blocks[index](features[feature])
out = self.resizing(out)
multi_layer_outs.append(out)

# Concat now-equal feature maps
concatenated_outs = self.concat(multi_layer_outs[::-1])

# Fuse concatenated features into a segmentation map
seg = self.linear_fuse(concatenated_outs)

super().__init__(
inputs=inputs,
outputs=seg,
**kwargs,
)

DavidLandup0 marked this conversation as resolved.
Show resolved Hide resolved
# === Config ===
self.projection_filters = projection_filters
self.image_encoder = image_encoder

def get_config(self):
config = super().get_config()
config.update(
{
"projection_filters": self.projection_filters,
"image_encoder": keras.saving.serialize_keras_object(
self.image_encoder
),
}
)
return config

@classmethod
def from_config(cls, config):
if "image_encoder" in config and isinstance(
config["image_encoder"], dict
):
config["image_encoder"] = keras.layers.deserialize(
config["image_encoder"]
)
return super().from_config(config)
Loading
Loading