-
Notifications
You must be signed in to change notification settings - Fork 234
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add SAM model * code reformat * update docstring * add image_segmenter file * add init file * move to correct dir * address some review comments * update docstring * code reformat * add layer test for prompt encoder * address review comments * explain reason * address review comments * explain reason * update docstring
- Loading branch information
1 parent
9b1cf95
commit 3fbbeea
Showing
15 changed files
with
2,098 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright 2024 The KerasHub Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.task import Task | ||
|
||
|
||
@keras_hub_export("keras_hub.models.ImageSegmenter") | ||
class ImageSegmenter(Task): | ||
"""Base class for all image segmentation tasks. | ||
`ImageSegmenter` tasks wrap a `keras_hub.models.Task` and | ||
a `keras_hub.models.Preprocessor` to create a model that can be used for | ||
image segmentation. | ||
All `ImageSegmenter` tasks include a `from_preset()` constructor which can | ||
be used to load a pre-trained config and weights. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
# Default compilation. | ||
self.compile() | ||
|
||
def compile( | ||
self, | ||
optimizer="auto", | ||
loss="auto", | ||
*, | ||
metrics="auto", | ||
**kwargs, | ||
): | ||
"""Configures the `ImageSegmenter` task for training. | ||
The `ImageSegmenter` task extends the default compilation signature of | ||
`keras.Model.compile` with defaults for `optimizer`, `loss`, and | ||
`metrics`. To override these defaults, pass any value | ||
to these arguments during compilation. | ||
Args: | ||
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` | ||
instance. Defaults to `"auto"`, which uses the default optimizer | ||
for the given model and task. See `keras.Model.compile` and | ||
`keras.optimizers` for more info on possible `optimizer` values. | ||
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. | ||
Defaults to `"auto"`, where a | ||
`keras.losses.SparseCategoricalCrossentropy` loss will be | ||
applied for the classification task. See | ||
`keras.Model.compile` and `keras.losses` for more info on | ||
possible `loss` values. | ||
metrics: `"auto"`, or a list of metrics to be evaluated by | ||
the model during training and testing. Defaults to `"auto"`, | ||
where a `keras.metrics.SparseCategoricalAccuracy` will be | ||
applied to track the accuracy of the model during training. | ||
See `keras.Model.compile` and `keras.metrics` for | ||
more info on possible `metrics` values. | ||
**kwargs: See `keras.Model.compile` for a full list of arguments | ||
supported by the compile method. | ||
""" | ||
if optimizer == "auto": | ||
optimizer = keras.optimizers.Adam(5e-5) | ||
if loss == "auto": | ||
activation = getattr(self, "activation", None) | ||
activation = keras.activations.get(activation) | ||
from_logits = activation != keras.activations.softmax | ||
loss = keras.losses.CategoricalCrossentropy(from_logits=from_logits) | ||
if metrics == "auto": | ||
metrics = [keras.metrics.CategoricalAccuracy()] | ||
super().compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=metrics, | ||
**kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The KerasHub Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2024 The KerasHub Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
|
||
|
||
@keras_hub_export("keras_hub.models.SAMBackbone") | ||
class SAMBackbone(Backbone): | ||
"""A backbone for the Segment Anything Model (SAM). | ||
Args: | ||
image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor for | ||
the input images. | ||
prompt_encoder: `keras_hub.layers.SAMPromptEncoder`. A Keras layer to | ||
compute embeddings for points, box, and mask prompt. | ||
mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to | ||
generate segmentation masks given the embeddings generated by the | ||
backbone and the prompt encoder. | ||
dtype: The dtype of the layer weights. | ||
Example: | ||
```python | ||
image_size=128 | ||
batch_size=2 | ||
input_data = { | ||
"images": np.ones( | ||
(batch_size, image_size, image_size, 3), | ||
dtype="float32", | ||
), | ||
"points": np.ones((batch_size, 1, 2), dtype="float32"), | ||
"labels": np.ones((batch_size, 1), dtype="float32"), | ||
"boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"), | ||
"masks": np.zeros( | ||
(batch_size, 0, image_size, image_size, 1) | ||
), | ||
} | ||
image_encoder = keras_hub.models.ViTDetBackbone( | ||
hidden_size=16, | ||
num_layers=16, | ||
intermediate_dim=16 * 4, | ||
num_heads=16, | ||
global_attention_layer_indices=[2, 5, 8, 11], | ||
patch_size=16, | ||
num_output_channels=8, | ||
window_size=2, | ||
image_shape=(image_size, image_size, 3), | ||
) | ||
prompt_encoder = keras_hub.layers.SAMPromptEncoder( | ||
hidden_size=8, | ||
image_embedding_size=(8, 8), | ||
input_image_size=( | ||
image_size, | ||
image_size, | ||
), | ||
mask_in_channels=16, | ||
) | ||
mask_decoder = keras_hub.layers.SAMMaskDecoder( | ||
num_layers=2, | ||
hidden_size=8, | ||
intermediate_dim=32, | ||
num_heads=8, | ||
embedding_dim=8, | ||
num_multimask_outputs=3, | ||
iou_head_depth=3, | ||
iou_head_hidden_dim=8, | ||
) | ||
backbone = keras_hub.models.SAMBackbone( | ||
image_encoder=image_encoder, | ||
prompt_encoder=prompt_encoder, | ||
mask_decoder=mask_decoder, | ||
image_shape=(image_size, image_size, 3), | ||
) | ||
backbone(input_data) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_encoder, | ||
prompt_encoder, | ||
mask_decoder, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Layers === | ||
self.image_encoder = image_encoder | ||
self.prompt_encoder = prompt_encoder | ||
self.mask_decoder = mask_decoder | ||
# === Functional model | ||
image_input = self.image_encoder.input | ||
|
||
inputs = { | ||
"images": image_input, | ||
"points": keras.Input(shape=[None, 2], name="points"), | ||
"labels": keras.Input(shape=[None], name="labels"), | ||
"boxes": keras.Input(shape=[None, 2, 2], name="boxes"), | ||
"masks": keras.Input(shape=[None, None, None, 1], name="masks"), | ||
} | ||
image_embeddings = self.image_encoder.output | ||
prompt_embeddings = self.prompt_encoder(**inputs) | ||
outputs = { | ||
"image_embeddings": image_embeddings, | ||
} | ||
outputs.update(prompt_embeddings) | ||
super().__init__( | ||
inputs=inputs, | ||
outputs=outputs, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"image_encoder": keras.layers.serialize(self.image_encoder), | ||
"prompt_encoder": keras.layers.serialize(self.prompt_encoder), | ||
"mask_decoder": keras.layers.serialize(self.mask_decoder), | ||
} | ||
) | ||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
config.update( | ||
{ | ||
"image_encoder": keras.layers.deserialize( | ||
config["image_encoder"] | ||
), | ||
"prompt_encoder": keras.layers.deserialize( | ||
config["prompt_encoder"] | ||
), | ||
"mask_decoder": keras.layers.deserialize( | ||
config["mask_decoder"] | ||
), | ||
} | ||
) | ||
|
||
return super().from_config(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright 2024 The KerasHub Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
|
||
from keras_hub.src.models.sam.sam_backbone import SAMBackbone | ||
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder | ||
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder | ||
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class SAMBackboneTest(TestCase): | ||
def setUp(self): | ||
self.batch_size = 2 | ||
self.image_size = 16 | ||
self.image_encoder = ViTDetBackbone( | ||
hidden_size=16, | ||
num_layers=16, | ||
intermediate_dim=16 * 4, | ||
num_heads=16, | ||
global_attention_layer_indices=[2, 5, 8, 11], | ||
patch_size=16, | ||
num_output_channels=8, | ||
window_size=2, | ||
image_shape=(self.image_size, self.image_size, 3), | ||
) | ||
self.prompt_encoder = SAMPromptEncoder( | ||
hidden_size=8, | ||
image_embedding_size=(8, 8), | ||
input_image_size=( | ||
self.image_size, | ||
self.image_size, | ||
), | ||
mask_in_channels=16, | ||
) | ||
self.mask_decoder = SAMMaskDecoder( | ||
num_layers=2, | ||
hidden_size=8, | ||
intermediate_dim=32, | ||
num_heads=8, | ||
embedding_dim=8, | ||
num_multimask_outputs=3, | ||
iou_head_depth=3, | ||
iou_head_hidden_dim=8, | ||
) | ||
self.init_kwargs = { | ||
"image_encoder": self.image_encoder, | ||
"prompt_encoder": self.prompt_encoder, | ||
"mask_decoder": self.mask_decoder, | ||
} | ||
|
||
self.input_data = { | ||
"images": np.ones( | ||
(self.batch_size, self.image_size, self.image_size, 3), | ||
dtype="float32", | ||
), | ||
"points": np.ones((self.batch_size, 1, 2), dtype="float32"), | ||
"labels": np.ones((self.batch_size, 1), dtype="float32"), | ||
"boxes": np.ones((self.batch_size, 1, 2, 2), dtype="float32"), | ||
"masks": np.zeros( | ||
(self.batch_size, 0, self.image_size, self.image_size, 1) | ||
), | ||
} | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=SAMBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape={ | ||
"image_embeddings": (2, 1, 1, 8), | ||
"prompt_sparse_embeddings": (2, 3, 8), | ||
"prompt_dense_embeddings": (2, 8, 8, 8), | ||
"prompt_dense_positional_embeddings": (1, 8, 8, 8), | ||
}, | ||
run_mixed_precision_check=False, | ||
run_quantization_check=False, | ||
) |
Oops, something went wrong.