-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semantic Segmentation] - Add SegFormer Architecture, Weight Conversion Script and Presets #1883
Open
DavidLandup0
wants to merge
68
commits into
keras-team:master
Choose a base branch
from
DavidLandup0:feature/segformer
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,067
−16
Open
Changes from all commits
Commits
Show all changes
68 commits
Select commit
Hold shift + click to select a range
716ae64
initial commit - tf-based, kcv
DavidLandup0 71bd40b
porting to keras_hub structure - removing aliases, presets, etc.
DavidLandup0 8894a86
enable instantiation of segformer backbone with custom MiT backbone
DavidLandup0 b66c659
remove num_classes from backbone
DavidLandup0 392ec36
fix input
DavidLandup0 d80d8d0
add imports to __init__
DavidLandup0 538adf7
Merge branch 'master' into feature/segformer
DavidLandup0 1571677
update preset
DavidLandup0 4b82a16
update docstrings
DavidLandup0 9b260e7
add basic tests
DavidLandup0 b93954f
remove redundant imports
DavidLandup0 159dca5
update docstrings
DavidLandup0 3ec02dd
remove unused import
DavidLandup0 7b6286e
running api_gen.py
DavidLandup0 c40fdcd
undo refactor of mit
DavidLandup0 9a13544
update docstrings
DavidLandup0 4dc3fff
add presets for mit
DavidLandup0 191656c
add standin paths
DavidLandup0 9e47564
add presets for segformer backbone
DavidLandup0 98bb69d
register presets in __init__.py
DavidLandup0 21ed167
addressing comments
DavidLandup0 f6720ac
addressing comments
DavidLandup0 b0806f2
addressing comments
DavidLandup0 22df93b
merge master branch into feature branch
DavidLandup0 0549be7
update most tests
DavidLandup0 4f66776
add remaining tests
DavidLandup0 f0b3e56
remove copyright
DavidLandup0 8c36b6e
fix test
DavidLandup0 9e1a9d6
override from_config
DavidLandup0 8d95091
Merge branch 'master' into feature/mit_presets
DavidLandup0 0c92729
fix op in overlapping patching and embedding, start adding conversion…
DavidLandup0 6638cb1
style
DavidLandup0 9a3f82d
add padding to MiT patchingandembedding
DavidLandup0 76a6dd2
update to support other presets
DavidLandup0 7b06c79
update conversin script
DavidLandup0 6e9728f
fix link for b5
DavidLandup0 7119022
add cityscapes weights
DavidLandup0 8ea5f63
update presets
DavidLandup0 0705748
update presets
DavidLandup0 eb1c236
update conversion script to make directories
DavidLandup0 fc42598
use save_preset
DavidLandup0 4274c60
change name of output dir
DavidLandup0 dc72ea7
add preprocessor flow
DavidLandup0 65f1822
api gen and add preprocessor to mits
DavidLandup0 d1bc073
merge master into feature branch
DavidLandup0 000d7d0
conform to new image classifier style
DavidLandup0 fa34f9e
format
DavidLandup0 e3c6dc6
resizing image converter -> ImageConverter
DavidLandup0 3483c10
Merge branch 'feature/mit_presets' into feature/segformer
DavidLandup0 38dadef
merge mit branch into segformer branch
DavidLandup0 983642c
add preprocessor and converter
DavidLandup0 2a8ffcb
address comments
DavidLandup0 34e8701
Merge branch 'feature/mit_presets' into feature/segformer
DavidLandup0 5b5eb93
clarify backbone usage
DavidLandup0 fcdadb3
add conversion script
DavidLandup0 68cef65
numerical equivalence changes
DavidLandup0 205ae4a
fix numerical inaccuracies
DavidLandup0 8b5fa44
update conversion script
DavidLandup0 79e05f8
merge master branch into feature branch
DavidLandup0 de17c03
update conversion script
DavidLandup0 04ba1eb
remove transpose
DavidLandup0 9e04b6e
add preprocessor to segformer class
DavidLandup0 e9e8ed5
fix preset path
DavidLandup0 a7a21f6
update test shape
DavidLandup0 28e1297
update presets
DavidLandup0 fc8fffe
update test shape
DavidLandup0 fa89a09
expand docstrings
DavidLandup0 6e3d3d1
add rescaling and normalization to preprocessor
DavidLandup0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone | ||
from keras_hub.src.models.segformer.segformer_backbone_presets import ( | ||
presets as backbone_presets, | ||
) | ||
from keras_hub.src.models.segformer.segformer_image_segmenter import ( | ||
SegFormerImageSegmenter, | ||
) | ||
from keras_hub.src.models.segformer.segformer_presets import presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(presets, SegFormerImageSegmenter) | ||
register_presets(backbone_presets, SegFormerBackbone) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove copyright banner |
||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( | ||
MiTBackbone, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.SegFormerBackbone") | ||
class SegFormerBackbone(Backbone): | ||
"""A Keras model implementing the SegFormer architecture for semantic segmentation. | ||
|
||
This class implements the majority of the SegFormer architecture described in | ||
[SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers] | ||
(https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision] | ||
(https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer). | ||
|
||
SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and | ||
and use a very lightweight all-MLP decoder head. | ||
|
||
The MiT encoder uses a hierarchical transformer which outputs features at multiple scales, | ||
similar to that of the hierarchical outputs typically associated with CNNs. | ||
|
||
Args: | ||
image_encoder: `keras.Model`. The backbone network for the model that is | ||
used as a feature extractor for the SegFormer encoder. | ||
Should be used with the MiT backbone model | ||
(`keras_hub.models.MiTBackbone`) which was created | ||
specifically for SegFormers. | ||
num_classes: int, the number of classes for the detection model, | ||
including the background class. | ||
projection_filters: int, number of filters in the | ||
convolution layer projecting the concatenated features into | ||
a segmentation map. Defaults to 256`. | ||
|
||
Example: | ||
|
||
Using the class with a custom `backbone`: | ||
|
||
```python | ||
import keras_hub | ||
|
||
backbone = keras_hub.models.MiTBackbone( | ||
DavidLandup0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
depths=[2, 2, 2, 2], | ||
image_shape=(224, 224, 3), | ||
hidden_dims=[32, 64, 160, 256], | ||
num_layers=4, | ||
blockwise_num_heads=[1, 2, 5, 8], | ||
blockwise_sr_ratios=[8, 4, 2, 1], | ||
max_drop_path_rate=0.1, | ||
patch_sizes=[7, 3, 3, 3], | ||
strides=[4, 2, 2, 2], | ||
) | ||
|
||
segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256) | ||
``` | ||
|
||
Using the class with a preset `backbone`: | ||
|
||
```python | ||
import keras_hub | ||
|
||
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512") | ||
segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256) | ||
``` | ||
|
||
""" | ||
|
||
backbone_cls = MiTBackbone | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needed |
||
|
||
def __init__( | ||
self, | ||
image_encoder, | ||
projection_filters, | ||
**kwargs, | ||
): | ||
if not isinstance(image_encoder, keras.layers.Layer) or not isinstance( | ||
image_encoder, keras.Model | ||
): | ||
raise ValueError( | ||
"Argument `image_encoder` must be a `keras.layers.Layer` instance " | ||
f" or `keras.Model`. Received instead " | ||
f"image_encoder={image_encoder} (of type {type(image_encoder)})." | ||
) | ||
|
||
# === Layers === | ||
inputs = keras.layers.Input(shape=image_encoder.input.shape[1:]) | ||
|
||
self.feature_extractor = keras.Model( | ||
image_encoder.inputs, image_encoder.pyramid_outputs | ||
) | ||
|
||
features = self.feature_extractor(inputs) | ||
# Get height and width of level one output | ||
_, height, width, _ = features["P1"].shape | ||
|
||
self.mlp_blocks = [] | ||
|
||
for feature_dim, feature in zip(image_encoder.hidden_dims, features): | ||
self.mlp_blocks.append( | ||
keras.layers.Dense( | ||
projection_filters, name=f"linear_{feature_dim}" | ||
) | ||
) | ||
|
||
self.resizing = keras.layers.Resizing( | ||
height, width, interpolation="bilinear" | ||
) | ||
self.concat = keras.layers.Concatenate(axis=-1) | ||
self.linear_fuse = keras.Sequential( | ||
[ | ||
keras.layers.Conv2D( | ||
filters=projection_filters, kernel_size=1, use_bias=False | ||
), | ||
keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9), | ||
keras.layers.Activation("relu"), | ||
] | ||
) | ||
|
||
# === Functional Model === | ||
# Project all multi-level outputs onto | ||
# the same dimensionality and feature map shape | ||
multi_layer_outs = [] | ||
for index, (feature_dim, feature) in enumerate( | ||
zip(image_encoder.hidden_dims, features) | ||
): | ||
out = self.mlp_blocks[index](features[feature]) | ||
out = self.resizing(out) | ||
multi_layer_outs.append(out) | ||
|
||
# Concat now-equal feature maps | ||
concatenated_outs = self.concat(multi_layer_outs[::-1]) | ||
|
||
# Fuse concatenated features into a segmentation map | ||
seg = self.linear_fuse(concatenated_outs) | ||
|
||
super().__init__( | ||
inputs=inputs, | ||
outputs=seg, | ||
**kwargs, | ||
) | ||
|
||
DavidLandup0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# === Config === | ||
self.projection_filters = projection_filters | ||
self.image_encoder = image_encoder | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"projection_filters": self.projection_filters, | ||
"image_encoder": keras.saving.serialize_keras_object( | ||
self.image_encoder | ||
), | ||
} | ||
) | ||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
if "image_encoder" in config and isinstance( | ||
config["image_encoder"], dict | ||
): | ||
config["image_encoder"] = keras.layers.deserialize( | ||
config["image_encoder"] | ||
) | ||
return super().from_config(config) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove copyright banner