Skip to content

Commit

Permalink
add SAM model (#1847)
Browse files Browse the repository at this point in the history
* add SAM model

* code reformat

* update docstring

* add image_segmenter file

* add init file

* move to correct dir

* address some review comments

* update docstring

* code reformat

* add layer test for prompt encoder

* address review comments

* explain reason

* address review comments

* explain reason

* update docstring
  • Loading branch information
divyashreepathihalli authored Sep 24, 2024
1 parent 9b1cf95 commit 3fbbeea
Show file tree
Hide file tree
Showing 15 changed files with 2,098 additions and 12 deletions.
2 changes: 2 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
from keras_hub.src.models.resnet.resnet_image_converter import (
ResNetImageConverter,
)
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
3 changes: 3 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
from keras_hub.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_hub.src.models.image_segmenter import ImageSegmenter
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM
from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
Expand Down Expand Up @@ -255,6 +256,8 @@
RobertaTextClassifierPreprocessor as RobertaPreprocessor,
)
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
from keras_hub.src.models.t5.t5_backbone import T5Backbone
Expand Down
86 changes: 86 additions & 0 deletions keras_hub/src/models/image_segmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.task import Task


@keras_hub_export("keras_hub.models.ImageSegmenter")
class ImageSegmenter(Task):
"""Base class for all image segmentation tasks.
`ImageSegmenter` tasks wrap a `keras_hub.models.Task` and
a `keras_hub.models.Preprocessor` to create a model that can be used for
image segmentation.
All `ImageSegmenter` tasks include a `from_preset()` constructor which can
be used to load a pre-trained config and weights.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default compilation.
self.compile()

def compile(
self,
optimizer="auto",
loss="auto",
*,
metrics="auto",
**kwargs,
):
"""Configures the `ImageSegmenter` task for training.
The `ImageSegmenter` task extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
`metrics`. To override these defaults, pass any value
to these arguments during compilation.
Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for the given model and task. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, where a
`keras.losses.SparseCategoricalCrossentropy` loss will be
applied for the classification task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
metrics: `"auto"`, or a list of metrics to be evaluated by
the model during training and testing. Defaults to `"auto"`,
where a `keras.metrics.SparseCategoricalAccuracy` will be
applied to track the accuracy of the model during training.
See `keras.Model.compile` and `keras.metrics` for
more info on possible `metrics` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.Adam(5e-5)
if loss == "auto":
activation = getattr(self, "activation", None)
activation = keras.activations.get(activation)
from_logits = activation != keras.activations.softmax
loss = keras.losses.CategoricalCrossentropy(from_logits=from_logits)
if metrics == "auto":
metrics = [keras.metrics.CategoricalAccuracy()]
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
**kwargs,
)
13 changes: 13 additions & 0 deletions keras_hub/src/models/sam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
153 changes: 153 additions & 0 deletions keras_hub/src/models/sam/sam_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.SAMBackbone")
class SAMBackbone(Backbone):
"""A backbone for the Segment Anything Model (SAM).
Args:
image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor for
the input images.
prompt_encoder: `keras_hub.layers.SAMPromptEncoder`. A Keras layer to
compute embeddings for points, box, and mask prompt.
mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to
generate segmentation masks given the embeddings generated by the
backbone and the prompt encoder.
dtype: The dtype of the layer weights.
Example:
```python
image_size=128
batch_size=2
input_data = {
"images": np.ones(
(batch_size, image_size, image_size, 3),
dtype="float32",
),
"points": np.ones((batch_size, 1, 2), dtype="float32"),
"labels": np.ones((batch_size, 1), dtype="float32"),
"boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
"masks": np.zeros(
(batch_size, 0, image_size, image_size, 1)
),
}
image_encoder = keras_hub.models.ViTDetBackbone(
hidden_size=16,
num_layers=16,
intermediate_dim=16 * 4,
num_heads=16,
global_attention_layer_indices=[2, 5, 8, 11],
patch_size=16,
num_output_channels=8,
window_size=2,
image_shape=(image_size, image_size, 3),
)
prompt_encoder = keras_hub.layers.SAMPromptEncoder(
hidden_size=8,
image_embedding_size=(8, 8),
input_image_size=(
image_size,
image_size,
),
mask_in_channels=16,
)
mask_decoder = keras_hub.layers.SAMMaskDecoder(
num_layers=2,
hidden_size=8,
intermediate_dim=32,
num_heads=8,
embedding_dim=8,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=8,
)
backbone = keras_hub.models.SAMBackbone(
image_encoder=image_encoder,
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
image_shape=(image_size, image_size, 3),
)
backbone(input_data)
```
"""

def __init__(
self,
image_encoder,
prompt_encoder,
mask_decoder,
dtype=None,
**kwargs,
):
# === Layers ===
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
# === Functional model
image_input = self.image_encoder.input

inputs = {
"images": image_input,
"points": keras.Input(shape=[None, 2], name="points"),
"labels": keras.Input(shape=[None], name="labels"),
"boxes": keras.Input(shape=[None, 2, 2], name="boxes"),
"masks": keras.Input(shape=[None, None, None, 1], name="masks"),
}
image_embeddings = self.image_encoder.output
prompt_embeddings = self.prompt_encoder(**inputs)
outputs = {
"image_embeddings": image_embeddings,
}
outputs.update(prompt_embeddings)
super().__init__(
inputs=inputs,
outputs=outputs,
dtype=dtype,
**kwargs,
)

def get_config(self):
config = super().get_config()
config.update(
{
"image_encoder": keras.layers.serialize(self.image_encoder),
"prompt_encoder": keras.layers.serialize(self.prompt_encoder),
"mask_decoder": keras.layers.serialize(self.mask_decoder),
}
)
return config

@classmethod
def from_config(cls, config):
config.update(
{
"image_encoder": keras.layers.deserialize(
config["image_encoder"]
),
"prompt_encoder": keras.layers.deserialize(
config["prompt_encoder"]
),
"mask_decoder": keras.layers.deserialize(
config["mask_decoder"]
),
}
)

return super().from_config(config)
90 changes: 90 additions & 0 deletions keras_hub/src/models/sam/sam_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from keras_hub.src.models.sam.sam_backbone import SAMBackbone
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
from keras_hub.src.tests.test_case import TestCase


class SAMBackboneTest(TestCase):
def setUp(self):
self.batch_size = 2
self.image_size = 16
self.image_encoder = ViTDetBackbone(
hidden_size=16,
num_layers=16,
intermediate_dim=16 * 4,
num_heads=16,
global_attention_layer_indices=[2, 5, 8, 11],
patch_size=16,
num_output_channels=8,
window_size=2,
image_shape=(self.image_size, self.image_size, 3),
)
self.prompt_encoder = SAMPromptEncoder(
hidden_size=8,
image_embedding_size=(8, 8),
input_image_size=(
self.image_size,
self.image_size,
),
mask_in_channels=16,
)
self.mask_decoder = SAMMaskDecoder(
num_layers=2,
hidden_size=8,
intermediate_dim=32,
num_heads=8,
embedding_dim=8,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=8,
)
self.init_kwargs = {
"image_encoder": self.image_encoder,
"prompt_encoder": self.prompt_encoder,
"mask_decoder": self.mask_decoder,
}

self.input_data = {
"images": np.ones(
(self.batch_size, self.image_size, self.image_size, 3),
dtype="float32",
),
"points": np.ones((self.batch_size, 1, 2), dtype="float32"),
"labels": np.ones((self.batch_size, 1), dtype="float32"),
"boxes": np.ones((self.batch_size, 1, 2, 2), dtype="float32"),
"masks": np.zeros(
(self.batch_size, 0, self.image_size, self.image_size, 1)
),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=SAMBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape={
"image_embeddings": (2, 1, 1, 8),
"prompt_sparse_embeddings": (2, 3, 8),
"prompt_dense_embeddings": (2, 8, 8, 8),
"prompt_dense_positional_embeddings": (1, 8, 8, 8),
},
run_mixed_precision_check=False,
run_quantization_check=False,
)
Loading

0 comments on commit 3fbbeea

Please sign in to comment.