Skip to content

Commit

Permalink
Fix tests, address review
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Oct 2, 2024
1 parent 80146c7 commit b3d95a3
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 24 deletions.
2 changes: 1 addition & 1 deletion keras_hub/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math

import numpy as np
import keras
import numpy as np
from keras import ops

from keras_hub.src.api_export import keras_hub_export
Expand Down
17 changes: 9 additions & 8 deletions keras_hub/src/models/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class ImageClassifier(Task):
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
All `ImageClassifier` tasks include a `from_preset()` constructor which can
be used to load a pre-trained config and weights.
Args:
backbone: A `keras_hub.models.Backbone` instance or a `keras.Model`.
Expand All @@ -23,18 +25,13 @@ class ImageClassifier(Task):
a `keras.Layer` instance, or a callable. If `None` no preprocessing
will be applied to the inputs.
pooling: `"avg"` or `"max"`. The type of pooling to apply on backbone
output. Default to average pooling.
output. Defaults to average pooling.
activation: `None`, str, or callable. The activation function to use on
the `Dense` layer. Set `activation=None` to return the output
logits. Defaults to `"softmax"`.
head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
dtype to use for the classification head's computations and weights.
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
All `ImageClassifier` tasks include a `from_preset()` constructor which can
be used to load a pre-trained config and weights.
Examples:
Call `predict()` to run inference.
Expand Down Expand Up @@ -109,11 +106,15 @@ def __init__(
self.preprocessor = preprocessor
if pooling == "avg":
self.pooler = keras.layers.GlobalAveragePooling2D(
data_format, dtype=head_dtype
data_format,
dtype=head_dtype,
name="pooler",
)
elif pooling == "max":
self.pooler = keras.layers.GlobalMaxPooling2D(
data_format, dtype=head_dtype
data_format,
dtype=head_dtype,
name="pooler",
)
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
from keras_hub.src.models.causal_lm import CausalLM
from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.models.image_classifier import ImageClassifier
from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.models.task import Task
from keras_hub.src.models.text_classifier import TextClassifier
from keras_hub.src.tests.test_case import TestCase
Expand Down
14 changes: 1 addition & 13 deletions keras_hub/src/models/vgg/vgg_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@ class VGGBackbone(Backbone):
blocks per VGG block. For both VGG16 and VGG19 this is [
64, 128, 256, 512, 512].
image_shape: tuple, optional shape tuple, defaults to (224, 224, 3).
pooling: bool, Optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be
the 4D tensor output of the
last convolutional block.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional block, and thus
the output of the model will be a 2D tensor.
- `max` means that global max pooling will
be applied.
Examples:
```python
Expand All @@ -46,7 +35,6 @@ class VGGBackbone(Backbone):
stackwise_num_repeats = [2, 2, 3, 3, 3],
stackwise_num_filters = [64, 128, 256, 512, 512],
image_shape = (224, 224, 3),
pooling = "avg",
)
model(input_data)
```
Expand All @@ -56,7 +44,7 @@ def __init__(
self,
stackwise_num_repeats,
stackwise_num_filters,
image_shape=(224, 224, 3),
image_shape=(None, None, 3),
**kwargs,
):

Expand Down
172 changes: 172 additions & 0 deletions keras_hub/src/models/vgg/vgg_image_classifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,180 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_classifier import ImageClassifier
from keras_hub.src.models.task import Task
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone


@keras_hub_export("keras_hub.models.VGGImageClassifier")
class VGGImageClassifier(ImageClassifier):
"""VGG image classification task.
`VGGImageClassifier` tasks wrap a `keras_hub.models.VGGBackbone` and
a `keras_hub.models.Preprocessor` to create a model that can be used for
image classification. `VGGImageClassifier` tasks take an additional
`num_classes` argument, controlling the number of predicted output classes.
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
Not that unlike `keras_hub.model.ImageClassifier`, the `VGGImageClassifier`
allows and defaults to `pooling="flatten"`, when inputs are flatten and
passed through two intermediate dense layers before the final output
projection.
Args:
backbone: A `keras_hub.models.VGGBackbone` instance or a `keras.Model`.
num_classes: int. The number of classes to predict.
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
a `keras.Layer` instance, or a callable. If `None` no preprocessing
will be applied to the inputs.
pooling: `"flatten"`, `"avg"`, or `"max"`. The type of pooling to apply
on backbone output. The default is flatten to match the original
VGG implementation, where backbone inputs will be flattened and
passed through two dense layers with a `"relu"` activation.
pooling_hidden_dim: the output feature size of the pooling dense layers.
This only applies when `pooling="flatten"`.
activation: `None`, str, or callable. The activation function to use on
the `Dense` layer. Set `activation=None` to return the output
logits. Defaults to `"softmax"`.
head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
dtype to use for the classification head's computations and weights.
Examples:
Call `predict()` to run inference.
```python
# Load preset and train
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
classifier = keras_hub.models.VGGImageClassifier.from_preset(
"vgg_16_imagenet"
)
classifier.predict(images)
```
Call `fit()` on a single batch.
```python
# Load preset and train
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
labels = [0, 3]
classifier = keras_hub.models.VGGImageClassifier.from_preset(
"vgg_16_imagenet"
)
classifier.fit(x=images, y=labels, batch_size=2)
```
Call `fit()` with custom loss, optimizer and backbone.
```python
classifier = keras_hub.models.VGGImageClassifier.from_preset(
"vgg_16_imagenet"
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
)
classifier.backbone.trainable = False
classifier.fit(x=images, y=labels, batch_size=2)
```
Custom backbone.
```python
images = np.random.randint(0, 256, size=(2, 224, 224, 3))
labels = [0, 3]
model = keras_hub.models.VGGBackbone(
stackwise_num_repeats = [2, 2, 3, 3, 3],
stackwise_num_filters = [64, 128, 256, 512, 512],
image_shape = (224, 224, 3),
)
classifier = keras_hub.models.VGGImageClassifier(
backbone=backbone,
num_classes=4,
)
classifier.fit(x=images, y=labels, batch_size=2)
```
"""

backbone_cls = VGGBackbone

def __init__(
self,
backbone,
num_classes,
preprocessor=None,
pooling="flatten",
pooling_hidden_dim=4096,
activation=None,
head_dtype=None,
**kwargs,
):
head_dtype = head_dtype or backbone.dtype_policy
data_format = getattr(backbone, "data_format", None)

# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor
if pooling == "avg":
self.pooler = keras.layers.GlobalAveragePooling2D(
data_format,
dtype=head_dtype,
name="pooler",
)
elif pooling == "max":
self.pooler = keras.layers.GlobalMaxPooling2D(
data_format,
dtype=head_dtype,
name="pooler",
)
elif pooling == "flatten":
self.pooler = keras.Sequential(
[
keras.layers.Flatten(name="flatten"),
keras.layers.Dense(pooling_hidden_dim, activation="relu"),
keras.layers.Dense(pooling_hidden_dim, activation="relu"),
],
name="pooler",
)
else:
raise ValueError(
"Unknown `pooling` type. Polling should be either `'avg'` or "
f"`'max'`. Received: pooling={pooling}."
)
self.output_dense = keras.layers.Dense(
num_classes,
activation=activation,
dtype=head_dtype,
name="predictions",
)

# === Functional Model ===
inputs = self.backbone.input
x = self.backbone(inputs)
x = self.pooler(x)
outputs = self.output_dense(x)
# Skip the parent class functional model.
Task.__init__(
self,
inputs=inputs,
outputs=outputs,
**kwargs,
)

# === Config ===
self.num_classes = num_classes
self.activation = activation
self.pooling = pooling
self.pooling_hidden_dim = pooling_hidden_dim

def get_config(self):
# Backbone serialized in `super`
config = super().get_config()
config.update(
{
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"pooling_hidden_dim": self.pooling_hidden_dim,
}
)
return config
2 changes: 1 addition & 1 deletion keras_hub/src/models/vgg/vgg_image_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setUp(self):
"backbone": self.backbone,
"num_classes": 2,
"activation": "softmax",
"pooling": "max",
"pooling": "flatten",
}
self.train_data = (
self.images,
Expand Down

0 comments on commit b3d95a3

Please sign in to comment.