From 20cdf41a3773955ed7d3584cab4499f433f6130a Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Mon, 4 Mar 2024 12:36:43 -0800 Subject: [PATCH 01/13] Implemented Coca architecture --- keras_cv/layers/attention_pooling.py | 27 ++++ keras_cv/layers/transformer_encoder.py | 5 +- .../models/feature_extractor/CoCa/__init__.py | 0 .../feature_extractor/CoCa/coca_model.py | 133 ++++++++++++++++++ 4 files changed, 162 insertions(+), 3 deletions(-) create mode 100644 keras_cv/layers/attention_pooling.py create mode 100644 keras_cv/models/feature_extractor/CoCa/__init__.py create mode 100644 keras_cv/models/feature_extractor/CoCa/coca_model.py diff --git a/keras_cv/layers/attention_pooling.py b/keras_cv/layers/attention_pooling.py new file mode 100644 index 0000000000..a3bf2a2556 --- /dev/null +++ b/keras_cv/layers/attention_pooling.py @@ -0,0 +1,27 @@ +from keras import layers + + +class AttentionPooling(layers.Layer): + + # TODO: Add args + def __init__(self, + proj_dim, + num_heads, + **kwargs): + super().__init__(self, **kwargs) + + self.proj_dim = proj_dim + self.num_heads = num_heads + + + def build(self, input_shape): + self.multi_head_attn = layers.MultiHeadAttention( + self.num_heads, + self.proj_dim + ) + + self.layer_norm = layers.LayerNormalization() + + def call(self, query, value, *args, **kwargs): + x = self.multi_head_attn(query, value) + return self.layer_norm(x) \ No newline at end of file diff --git a/keras_cv/layers/transformer_encoder.py b/keras_cv/layers/transformer_encoder.py index 152fe354f8..7d6674b9d6 100644 --- a/keras_cv/layers/transformer_encoder.py +++ b/keras_cv/layers/transformer_encoder.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from tensorflow import keras -from tensorflow.keras import layers +import keras +from keras import layers from keras_cv.api_export import keras_cv_export diff --git a/keras_cv/models/feature_extractor/CoCa/__init__.py b/keras_cv/models/feature_extractor/CoCa/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_cv/models/feature_extractor/CoCa/coca_model.py b/keras_cv/models/feature_extractor/CoCa/coca_model.py new file mode 100644 index 0000000000..796b4bd395 --- /dev/null +++ b/keras_cv/models/feature_extractor/CoCa/coca_model.py @@ -0,0 +1,133 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from keras import Sequential +from keras_cv.api_export import keras_cv_export +from keras_nlp.layers import RotaryEmbedding, TransformerDecoder +from keras_cv.layers import TransformerEncoder as CVTransformerEncoder +from keras_cv.models.task import Task +from keras_cv.layers.attention_pooling import AttentionPooling +from keras_cv.layers.vit_layers import PatchingAndEmbedding + + +@keras_cv_export(["keras_cv.models.CoCa"]) +class CoCa(Task): + def __init__(self, + img_query_dim, + text_proj_dim, + img_patch_size=18, + encoder_depth=40, + encoder_heads=16, + encoder_intermediate_dim=6144, + encoder_width=1408, + decoder_intermediate_dim=5632, + unimodal_decoder_heads=18, + multimodal_decoder_heads=18, + con_queries=1, + cap_queries=256, + con_heads=16, + cap_heads=16, + cap_loss_weight=0.5, + con_loss_weight=0.5, + **kwargs): + super().__init__(**kwargs) + + self.img_patch_size = img_patch_size + self.img_query_dim = img_query_dim + + self.encoder_depth = encoder_depth + self.encoder_heads = encoder_heads + self.encoder_width = encoder_width + self.encoder_intermediate_dim = encoder_intermediate_dim + + self.text_proj_dim = text_proj_dim + self.decoder_intermediate_dim = decoder_intermediate_dim + self.unimodal_decoder_heads = unimodal_decoder_heads + self.multimodal_decoder_heads = multimodal_decoder_heads + + self.con_queries = con_queries + self.con_heads = con_heads + self.con_loss_weight = con_loss_weight + + self.cap_queries = cap_queries + self.cap_heads = cap_heads + self.cap_loss_weight = cap_loss_weight + + def build(self, input_shape): + super().build(input_shape) + + self.image_patching = PatchingAndEmbedding(self.encoder_width, self.img_patch_size) + self.image_encoder = Sequential([ + CVTransformerEncoder(self.img_query_dim, self.encoder_heads, self.encoder_intermediate_dim) + for _ in range(self.encoder_depth) + ]) + + self.cls_token = self.add_weight(shape=[1, 1, self.text_proj_dim], name="cls_token", trainable=True) + + self.text_embedding = RotaryEmbedding() + self.unimodal_text_decoder = Sequential([ + TransformerDecoder(self.decoder_intermediate_dim, self.unimodal_decoder_heads) + for _ in range(self.con_queries) + ]) + self.multimodal_text_decoder = TransformerDecoder( + self.decoder_intermediate_dim, + self.multimodal_decoder_heads + ) + + self.con_query = self.add_weight(shape=[1, 1, self.con_queries], trainable=True) + self.cap_query = self.add_weight(shape=[1, 1, self.cap_queries], trainable=True) + + self.con_attn_pooling = AttentionPooling(self.img_query_dim, self.con_heads) + self.cap_attn_pooling = AttentionPooling(self.img_query_dim, self.cap_heads) + + def call(self, images, texts): + """ + Forward pass of the Coca Model + + :param images: [batch_size, height, width, channels] representing images + :param texts: Tensor, typically represented as [batch_size, sequence_length, feature_length] or + [batch_size, sequence_length, num_heads, feature_length]. The sequence_length and/or feature_length + are required. + :return: output of the captioning Transformer Decoder with captioning cross-attention + """ + img_encoding = self.image_patching(images) + img_encoding = self.image_encoder(img_encoding) # [batch, patches_len+1, img_query_dim] + + # This is only needed for loss calculations + # con_feature = self.con_attn_pooling(self.con_query, img_encoding) + cap_feature = self.cap_attn_pooling(self.cap_query, img_encoding) + + text_tokens = np.concatenate(texts, self.cls_token) + mask = np.concatenate((np.ones_like(texts), np.zeros_like(self.cls_token))) + + embed_text = self.text_embedding(text_tokens) + unimodal_out = self.unimodal_text_decoder(embed_text, attention_mask=mask) + multimodal_out = self.multimodal_text_decoder(unimodal_out[:, :-1, :], + encoder_sequence=cap_feature, + decoder_attention_mask=mask) + + return multimodal_out + + def get_config(self): + config = super().get_config() + config.update( + { + "img_query_dim": self.img_query_dim, + "img_patch_size": self.img_patch_size, + "text_proj_dim": self.text_proj_dim, + "cap_loss_weight": self.cap_loss_weight, + "con_loss_weight": self.con_loss_weight + } + ) + return config From b8c0ba45705ac2a20358e4f79c7ded69245a90c5 Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Mon, 4 Mar 2024 12:48:52 -0800 Subject: [PATCH 02/13] Minor clean-up --- keras_cv/layers/attention_pooling.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/keras_cv/layers/attention_pooling.py b/keras_cv/layers/attention_pooling.py index a3bf2a2556..6d162c79fe 100644 --- a/keras_cv/layers/attention_pooling.py +++ b/keras_cv/layers/attention_pooling.py @@ -2,8 +2,12 @@ class AttentionPooling(layers.Layer): + """Implements the Pooled Attention Layer used in "CoCa": Contrastive Captioners are Image-Text Foundation Models" + (https://arxiv.org/pdf/2205.01917.pdf), consisting of a Multiheaded Attention followed by Layer Normalization. - # TODO: Add args + :param proj_dim: The dimensions of the attention heads + :param num_heads: The number of attention heads in the multi-headed attention layer + """ def __init__(self, proj_dim, num_heads, @@ -13,7 +17,6 @@ def __init__(self, self.proj_dim = proj_dim self.num_heads = num_heads - def build(self, input_shape): self.multi_head_attn = layers.MultiHeadAttention( self.num_heads, @@ -22,6 +25,6 @@ def build(self, input_shape): self.layer_norm = layers.LayerNormalization() - def call(self, query, value, *args, **kwargs): + def call(self, query, value): x = self.multi_head_attn(query, value) - return self.layer_norm(x) \ No newline at end of file + return self.layer_norm(x) From bbe17c47caf4b518c543c62a17c4496c086914c7 Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Mon, 4 Mar 2024 12:59:12 -0800 Subject: [PATCH 03/13] Fixed depth of decoders --- .../feature_extractor/CoCa/coca_model.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/keras_cv/models/feature_extractor/CoCa/coca_model.py b/keras_cv/models/feature_extractor/CoCa/coca_model.py index 796b4bd395..fb813e5d38 100644 --- a/keras_cv/models/feature_extractor/CoCa/coca_model.py +++ b/keras_cv/models/feature_extractor/CoCa/coca_model.py @@ -31,9 +31,11 @@ def __init__(self, encoder_heads=16, encoder_intermediate_dim=6144, encoder_width=1408, + unimodal_decoder_depth=18, + multimodal_decoder_depth=18, decoder_intermediate_dim=5632, - unimodal_decoder_heads=18, - multimodal_decoder_heads=18, + unimodal_decoder_heads=16, + multimodal_decoder_heads=16, con_queries=1, cap_queries=256, con_heads=16, @@ -52,6 +54,8 @@ def __init__(self, self.encoder_intermediate_dim = encoder_intermediate_dim self.text_proj_dim = text_proj_dim + self.unimodal_decoder_depth = unimodal_decoder_depth + self.multimodal_decoder_depth = multimodal_decoder_depth self.decoder_intermediate_dim = decoder_intermediate_dim self.unimodal_decoder_heads = unimodal_decoder_heads self.multimodal_decoder_heads = multimodal_decoder_heads @@ -78,12 +82,12 @@ def build(self, input_shape): self.text_embedding = RotaryEmbedding() self.unimodal_text_decoder = Sequential([ TransformerDecoder(self.decoder_intermediate_dim, self.unimodal_decoder_heads) - for _ in range(self.con_queries) + for _ in range(self.unimodal_decoder_depth) + ]) + self.multimodal_text_decoder = Sequential([ + TransformerDecoder(self.decoder_intermediate_dim, self.multimodal_decoder_heads) + for _ in range(self.multimodal_decoder_depth) ]) - self.multimodal_text_decoder = TransformerDecoder( - self.decoder_intermediate_dim, - self.multimodal_decoder_heads - ) self.con_query = self.add_weight(shape=[1, 1, self.con_queries], trainable=True) self.cap_query = self.add_weight(shape=[1, 1, self.cap_queries], trainable=True) From 202526ff3cd0141870b21e0e10a40386399e59ea Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Mon, 4 Mar 2024 15:43:37 -0800 Subject: [PATCH 04/13] Updated config to match args --- .../models/feature_extractor/CoCa/coca_model.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/keras_cv/models/feature_extractor/CoCa/coca_model.py b/keras_cv/models/feature_extractor/CoCa/coca_model.py index fb813e5d38..ffea88fb60 100644 --- a/keras_cv/models/feature_extractor/CoCa/coca_model.py +++ b/keras_cv/models/feature_extractor/CoCa/coca_model.py @@ -127,11 +127,24 @@ def get_config(self): config = super().get_config() config.update( { - "img_query_dim": self.img_query_dim, "img_patch_size": self.img_patch_size, + "img_query_dim": self.img_query_dim, + "encoder_depth": self.encoder_depth, + "encoder_heads": self.encoder_heads, + "encoder_width": self.encoder_width, + "encoder_intermediate_dim": self.encoder_intermediate_dim, "text_proj_dim": self.text_proj_dim, + "unimodal_decoder_depth": self.unimodal_decoder_depth, + "multimodal_decoder_depth": self.multimodal_decoder_depth, + "decoder_intermediate_dim": self.decoder_intermediate_dim, + "unimodal_decoder_heads": self.unimodal_decoder_heads, + "multimodal_decoder_heads": self.multimodal_decoder_heads, + "con_queries": self.con_queries, + "con_heads": self.con_heads, + "con_loss_weight": self.con_loss_weight, + "cap_queries": self.cap_queries, + "cap_heads": self.cap_heads, "cap_loss_weight": self.cap_loss_weight, - "con_loss_weight": self.con_loss_weight } ) return config From 367dd394384533e417b709108298fb3654e4b547 Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Tue, 5 Mar 2024 14:08:18 -0800 Subject: [PATCH 05/13] Moved layer definitions to build and added build calls for each layer --- keras_cv/layers/attention_pooling.py | 12 ++- .../feature_extractor/CoCa/coca_model.py | 77 +++++++++++++++---- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/keras_cv/layers/attention_pooling.py b/keras_cv/layers/attention_pooling.py index 6d162c79fe..d1be136f9e 100644 --- a/keras_cv/layers/attention_pooling.py +++ b/keras_cv/layers/attention_pooling.py @@ -5,8 +5,9 @@ class AttentionPooling(layers.Layer): """Implements the Pooled Attention Layer used in "CoCa": Contrastive Captioners are Image-Text Foundation Models" (https://arxiv.org/pdf/2205.01917.pdf), consisting of a Multiheaded Attention followed by Layer Normalization. - :param proj_dim: The dimensions of the attention heads - :param num_heads: The number of attention heads in the multi-headed attention layer + Args: + proj_dim: The dimensions of the attention heads + num_heads: The number of attention heads in the multi-headed attention layer """ def __init__(self, proj_dim, @@ -17,7 +18,6 @@ def __init__(self, self.proj_dim = proj_dim self.num_heads = num_heads - def build(self, input_shape): self.multi_head_attn = layers.MultiHeadAttention( self.num_heads, self.proj_dim @@ -25,6 +25,12 @@ def build(self, input_shape): self.layer_norm = layers.LayerNormalization() + def build(self, input_shape): + super().build(input_shape) + + self.multi_head_attn.build(input_shape) + self.layer_norm.build(input_shape) + def call(self, query, value): x = self.multi_head_attn(query, value) return self.layer_norm(x) diff --git a/keras_cv/models/feature_extractor/CoCa/coca_model.py b/keras_cv/models/feature_extractor/CoCa/coca_model.py index ffea88fb60..4313277d0e 100644 --- a/keras_cv/models/feature_extractor/CoCa/coca_model.py +++ b/keras_cv/models/feature_extractor/CoCa/coca_model.py @@ -23,6 +23,9 @@ @keras_cv_export(["keras_cv.models.CoCa"]) class CoCa(Task): + """ Contrastive Captioner foundational model implementation. + + CoCa Paper: https://arxiv.org/pdf/2205.01917.pdf""" def __init__(self, img_query_dim, text_proj_dim, @@ -68,17 +71,13 @@ def __init__(self, self.cap_heads = cap_heads self.cap_loss_weight = cap_loss_weight - def build(self, input_shape): - super().build(input_shape) - + # Layer Definitions self.image_patching = PatchingAndEmbedding(self.encoder_width, self.img_patch_size) self.image_encoder = Sequential([ CVTransformerEncoder(self.img_query_dim, self.encoder_heads, self.encoder_intermediate_dim) for _ in range(self.encoder_depth) ]) - self.cls_token = self.add_weight(shape=[1, 1, self.text_proj_dim], name="cls_token", trainable=True) - self.text_embedding = RotaryEmbedding() self.unimodal_text_decoder = Sequential([ TransformerDecoder(self.decoder_intermediate_dim, self.unimodal_decoder_heads) @@ -89,21 +88,71 @@ def build(self, input_shape): for _ in range(self.multimodal_decoder_depth) ]) - self.con_query = self.add_weight(shape=[1, 1, self.con_queries], trainable=True) - self.cap_query = self.add_weight(shape=[1, 1, self.cap_queries], trainable=True) - self.con_attn_pooling = AttentionPooling(self.img_query_dim, self.con_heads) self.cap_attn_pooling = AttentionPooling(self.img_query_dim, self.cap_heads) + # These are learnable weights defined in build as per Keras recommendations + self.cls_token = None + self.con_query = None + self.cap_query = None + + def build(self, input_shape): + super().build(input_shape) + + # Validate Input Shape + if len(input_shape) < 2: + raise ValueError("Build arguments to CoCa expected to contain shapes of both text and image data; " + f"got {len(input_shape)} shapes.") + + images_shape = input_shape[0] + text_shape = input_shape[1] + + if len(images_shape) != 4: + raise ValueError("Image shape expected to be of shape [batch_size, height, width, channels]. Instead got " + f"shape: {images_shape}") + elif len(text_shape) != 2: + raise ValueError("Text shape expected to be of shape [batch_size, context_length]. Instead got shape" + f": {text_shape}") + + text_dim = text_shape[1] + batch_size = images_shape[0] + if batch_size != text_shape[0]: + raise ValueError(f"Differing batch sizes between images and texts input. {batch_size} vs {text_shape[0]}") + + # Build Layers + self.image_patching.build(images_shape) + self.image_encoder.build((batch_size, self.image_patching.num_patches, self.encoder_width)) + + text_shape_with_cls_token = [s for s in text_shape] + text_shape_with_cls_token[-1] += 1 + self.text_embedding.build(text_shape_with_cls_token) + + self.unimodal_text_decoder.build(text_shape_with_cls_token) + + self.con_attn_pooling.build((batch_size, text_dim, self.con_queries)) + self.cap_attn_pooling.build((batch_size, text_dim, self.cap_queries)) + + self.multimodal_text_decoder.build((batch_size, self.image_patching.num_patches, self.encoder_width), + text_shape_with_cls_token) + + # Learnable Weights + self.cls_token = self.add_weight(shape=(batch_size, 1, text_dim), name="cls_token", trainable=True) + + self.con_query = self.add_weight(shape=(batch_size, text_dim, self.con_queries), trainable=True) + self.cap_query = self.add_weight(shape=(batch_size, text_dim, self.cap_queries), trainable=True) + def call(self, images, texts): """ - Forward pass of the Coca Model + Forward pass of the Coca Model from raw image and text data + + Args: + images: [batch_size, height, width, channels] representing images + texts: Tensor, typically represented as [batch_size, sequence_length, feature_length] or + [batch_size, sequence_length, num_heads, feature_length]. The sequence_length and/or feature_length + are required. - :param images: [batch_size, height, width, channels] representing images - :param texts: Tensor, typically represented as [batch_size, sequence_length, feature_length] or - [batch_size, sequence_length, num_heads, feature_length]. The sequence_length and/or feature_length - are required. - :return: output of the captioning Transformer Decoder with captioning cross-attention + Returns: + Output: Output of the captioning Transformer Decoder with captioning cross-attention """ img_encoding = self.image_patching(images) img_encoding = self.image_encoder(img_encoding) # [batch, patches_len+1, img_query_dim] From 80ea7d32f21bf51a9a6a6ee7057f3ef15b3b51aa Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Tue, 5 Mar 2024 14:29:39 -0800 Subject: [PATCH 06/13] Unabbreviated 'contrastive' and 'captioning' --- .../feature_extractor/CoCa/coca_model.py | 60 ++++++++++--------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/keras_cv/models/feature_extractor/CoCa/coca_model.py b/keras_cv/models/feature_extractor/CoCa/coca_model.py index 4313277d0e..3fe5fe071c 100644 --- a/keras_cv/models/feature_extractor/CoCa/coca_model.py +++ b/keras_cv/models/feature_extractor/CoCa/coca_model.py @@ -39,12 +39,12 @@ def __init__(self, decoder_intermediate_dim=5632, unimodal_decoder_heads=16, multimodal_decoder_heads=16, - con_queries=1, - cap_queries=256, - con_heads=16, - cap_heads=16, - cap_loss_weight=0.5, - con_loss_weight=0.5, + contrastive_query_length=1, + captioning_query_length=256, + contrastive_attn_heads=16, + captioning_attn_heads=16, + captioning_loss_weight=0.5, + contrastive_loss_weight=0.5, **kwargs): super().__init__(**kwargs) @@ -63,13 +63,13 @@ def __init__(self, self.unimodal_decoder_heads = unimodal_decoder_heads self.multimodal_decoder_heads = multimodal_decoder_heads - self.con_queries = con_queries - self.con_heads = con_heads - self.con_loss_weight = con_loss_weight + self.contrastive_query_length = contrastive_query_length + self.contrastive_attn_heads = contrastive_attn_heads + self.contrastive_loss_weight = contrastive_loss_weight - self.cap_queries = cap_queries - self.cap_heads = cap_heads - self.cap_loss_weight = cap_loss_weight + self.captioning_query_length = captioning_query_length + self.captioning_attn_heads = captioning_attn_heads + self.captioning_loss_weight = captioning_loss_weight # Layer Definitions self.image_patching = PatchingAndEmbedding(self.encoder_width, self.img_patch_size) @@ -88,13 +88,13 @@ def __init__(self, for _ in range(self.multimodal_decoder_depth) ]) - self.con_attn_pooling = AttentionPooling(self.img_query_dim, self.con_heads) - self.cap_attn_pooling = AttentionPooling(self.img_query_dim, self.cap_heads) + self.contrastive_attn_pooling = AttentionPooling(self.img_query_dim, self.contrastive_attn_heads) + self.captioning_attn_pooling = AttentionPooling(self.img_query_dim, self.captioning_attn_heads) # These are learnable weights defined in build as per Keras recommendations self.cls_token = None - self.con_query = None - self.cap_query = None + self.contrastive_query = None + self.captioning_query = None def build(self, input_shape): super().build(input_shape) @@ -129,8 +129,8 @@ def build(self, input_shape): self.unimodal_text_decoder.build(text_shape_with_cls_token) - self.con_attn_pooling.build((batch_size, text_dim, self.con_queries)) - self.cap_attn_pooling.build((batch_size, text_dim, self.cap_queries)) + self.contrastive_attn_pooling.build((batch_size, text_dim, self.contrastive_query_length)) + self.captioning_attn_pooling.build((batch_size, text_dim, self.captioning_query_length)) self.multimodal_text_decoder.build((batch_size, self.image_patching.num_patches, self.encoder_width), text_shape_with_cls_token) @@ -138,8 +138,10 @@ def build(self, input_shape): # Learnable Weights self.cls_token = self.add_weight(shape=(batch_size, 1, text_dim), name="cls_token", trainable=True) - self.con_query = self.add_weight(shape=(batch_size, text_dim, self.con_queries), trainable=True) - self.cap_query = self.add_weight(shape=(batch_size, text_dim, self.cap_queries), trainable=True) + self.contrastive_query = self.add_weight(shape=(batch_size, text_dim, self.contrastive_query_length), + trainable=True) + self.captioning_query = self.add_weight(shape=(batch_size, text_dim, self.captioning_query_length), + trainable=True) def call(self, images, texts): """ @@ -158,8 +160,8 @@ def call(self, images, texts): img_encoding = self.image_encoder(img_encoding) # [batch, patches_len+1, img_query_dim] # This is only needed for loss calculations - # con_feature = self.con_attn_pooling(self.con_query, img_encoding) - cap_feature = self.cap_attn_pooling(self.cap_query, img_encoding) + # contrastive_feature = self.con_attn_pooling(self.contrastive_query, img_encoding) + captioning_feature = self.captioning_attn_pooling(self.captioning_query, img_encoding) text_tokens = np.concatenate(texts, self.cls_token) mask = np.concatenate((np.ones_like(texts), np.zeros_like(self.cls_token))) @@ -167,7 +169,7 @@ def call(self, images, texts): embed_text = self.text_embedding(text_tokens) unimodal_out = self.unimodal_text_decoder(embed_text, attention_mask=mask) multimodal_out = self.multimodal_text_decoder(unimodal_out[:, :-1, :], - encoder_sequence=cap_feature, + encoder_sequence=captioning_feature, decoder_attention_mask=mask) return multimodal_out @@ -188,12 +190,12 @@ def get_config(self): "decoder_intermediate_dim": self.decoder_intermediate_dim, "unimodal_decoder_heads": self.unimodal_decoder_heads, "multimodal_decoder_heads": self.multimodal_decoder_heads, - "con_queries": self.con_queries, - "con_heads": self.con_heads, - "con_loss_weight": self.con_loss_weight, - "cap_queries": self.cap_queries, - "cap_heads": self.cap_heads, - "cap_loss_weight": self.cap_loss_weight, + "contrastive_query_length": self.contrastive_query_length, + "contrastive_attn_heads": self.contrastive_attn_heads, + "contrastive_loss_weight": self.contrastive_loss_weight, + "captioning_query_length": self.captioning_query_length, + "captioning_attn_heads": self.captioning_attn_heads, + "captioning_loss_weight": self.captioning_loss_weight, } ) return config From 3feacb6c7e0f2f8618ad7f6ed2aa65a840c4afec Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Tue, 5 Mar 2024 15:35:58 -0800 Subject: [PATCH 07/13] Improved documentation and added output sizing to call(), also built each layer with expected sizing --- .../feature_extractor/CoCa/coca_model.py | 87 ++++++++++++++----- 1 file changed, 65 insertions(+), 22 deletions(-) diff --git a/keras_cv/models/feature_extractor/CoCa/coca_model.py b/keras_cv/models/feature_extractor/CoCa/coca_model.py index 3fe5fe071c..964b9bd28a 100644 --- a/keras_cv/models/feature_extractor/CoCa/coca_model.py +++ b/keras_cv/models/feature_extractor/CoCa/coca_model.py @@ -23,12 +23,7 @@ @keras_cv_export(["keras_cv.models.CoCa"]) class CoCa(Task): - """ Contrastive Captioner foundational model implementation. - - CoCa Paper: https://arxiv.org/pdf/2205.01917.pdf""" def __init__(self, - img_query_dim, - text_proj_dim, img_patch_size=18, encoder_depth=40, encoder_heads=16, @@ -43,20 +38,18 @@ def __init__(self, captioning_query_length=256, contrastive_attn_heads=16, captioning_attn_heads=16, - captioning_loss_weight=0.5, contrastive_loss_weight=0.5, + captioning_loss_weight=0.5, **kwargs): super().__init__(**kwargs) self.img_patch_size = img_patch_size - self.img_query_dim = img_query_dim self.encoder_depth = encoder_depth self.encoder_heads = encoder_heads self.encoder_width = encoder_width self.encoder_intermediate_dim = encoder_intermediate_dim - self.text_proj_dim = text_proj_dim self.unimodal_decoder_depth = unimodal_decoder_depth self.multimodal_decoder_depth = multimodal_decoder_depth self.decoder_intermediate_dim = decoder_intermediate_dim @@ -74,7 +67,7 @@ def __init__(self, # Layer Definitions self.image_patching = PatchingAndEmbedding(self.encoder_width, self.img_patch_size) self.image_encoder = Sequential([ - CVTransformerEncoder(self.img_query_dim, self.encoder_heads, self.encoder_intermediate_dim) + CVTransformerEncoder(self.encoder_width, self.encoder_heads, self.encoder_intermediate_dim) for _ in range(self.encoder_depth) ]) @@ -88,13 +81,56 @@ def __init__(self, for _ in range(self.multimodal_decoder_depth) ]) - self.contrastive_attn_pooling = AttentionPooling(self.img_query_dim, self.contrastive_attn_heads) - self.captioning_attn_pooling = AttentionPooling(self.img_query_dim, self.captioning_attn_heads) + self.contrastive_attn_pooling = AttentionPooling(self.encoder_width, self.contrastive_attn_heads) + self.captioning_attn_pooling = AttentionPooling(self.encoder_width, self.captioning_attn_heads) # These are learnable weights defined in build as per Keras recommendations self.cls_token = None self.contrastive_query = None self.captioning_query = None + """ Contrastive Captioner foundational model implementation. + + This model implements the "Contrastive Captioners are image-Text Foundational Models" by Yu, et al. + (https://arxiv.org/pdf/2205.01917.pdf). In short, the CoCa model combines the ideas of Contrastive techniques + such as CLIP, with Generative Captioning approaches such as SimVLM. + + The architecture of clip can be described as an Image Visual Transformer Encoder in parallel to self-attention-only + Text Transformer Decoder, the outputs of both of which are passed into a multimodal Transformer Decoder. The + contrastive loss from the ViT and the uni-modal Text Decoder is combined with a captioning loss from the multi-modal + Decoder in order to produce the combined total loss. + + Basic Usage: + ```python + + images = ... # [batch_size, height, width, channel] + text = ... # [batch_size, text_dim, sequence_length] + + coca = CoCa() + + # [batch_size, sequence_length, captioning_query_length] + output = coca(images, text) + ``` + + All default arguments should be consistent with the original paper's details. + + Args: + img_patch_size: N of each NxN patch generated from linearization of the input images + encoder_depth: number of image encoder blocks + encoder_heads: number of attention heads used in each image encoder block + encoder_intermediate_dim: dimensionality of the encoder blocks' intermediate representation (MLP dimensionality) + encoder_width: dimensionality of the encoder's projection, consistent with wording used in CoCa paper. + unimodal_decoder_depth: number of decoder blocks used for text self-attention/embedding + multimodal_decoder_depth: number of decoder blocks used for image-text cross-attention and captioning + decoder_intermediate_dim: dimensionality of the decoder blocks' MLPs + unimodal_decoder_heads: number of attention heads in the unimodal decoder + multimodal_decoder_heads: number of attention heads in the multimodal decoder + contrastive_query_length: number of tokens to use to represent contrastive query + captioning_query_length: number of tokens to use to represent captioning query + contrastive_attn_heads: number of attention heads used for the contrastive attention pooling + captioning_attn_heads: number of attention heads used for the captioning attention pooling + contrastive_loss_weight: weighting of contrastive loss + captioning_loss_weight: weighting of captioning loss + """ def build(self, input_shape): super().build(input_shape) @@ -121,26 +157,29 @@ def build(self, input_shape): # Build Layers self.image_patching.build(images_shape) - self.image_encoder.build((batch_size, self.image_patching.num_patches, self.encoder_width)) + + # Add 1 for CLs token appended by patching + num_patches = (images_shape[1] // self.img_patch_size) * (images_shape[2] // self.img_patch_size) + 1 + self.image_encoder.build((batch_size, self.encoder_width, num_patches)) text_shape_with_cls_token = [s for s in text_shape] - text_shape_with_cls_token[-1] += 1 + text_shape_with_cls_token[1] += 1 self.text_embedding.build(text_shape_with_cls_token) self.unimodal_text_decoder.build(text_shape_with_cls_token) - self.contrastive_attn_pooling.build((batch_size, text_dim, self.contrastive_query_length)) - self.captioning_attn_pooling.build((batch_size, text_dim, self.captioning_query_length)) + self.contrastive_attn_pooling.build((batch_size, num_patches, self.encoder_width)) + self.captioning_attn_pooling.build((batch_size, num_patches, self.encoder_width)) - self.multimodal_text_decoder.build((batch_size, self.image_patching.num_patches, self.encoder_width), + self.multimodal_text_decoder.build((batch_size, self.encoder_width, self.captioning_query_length), text_shape_with_cls_token) # Learnable Weights self.cls_token = self.add_weight(shape=(batch_size, 1, text_dim), name="cls_token", trainable=True) - self.contrastive_query = self.add_weight(shape=(batch_size, text_dim, self.contrastive_query_length), + self.contrastive_query = self.add_weight(shape=(batch_size, self.encoder_width, self.contrastive_query_length), trainable=True) - self.captioning_query = self.add_weight(shape=(batch_size, text_dim, self.captioning_query_length), + self.captioning_query = self.add_weight(shape=(batch_size, self.encoder_width, self.captioning_query_length), trainable=True) def call(self, images, texts): @@ -156,18 +195,24 @@ def call(self, images, texts): Returns: Output: Output of the captioning Transformer Decoder with captioning cross-attention """ - img_encoding = self.image_patching(images) - img_encoding = self.image_encoder(img_encoding) # [batch, patches_len+1, img_query_dim] + img_encoding = self.image_patching(images) # [batch_size, encoder_width, img_patches_len+1] + img_encoding = self.image_encoder(img_encoding) # [batch_size, img_patches_len+1, encoder_width] # This is only needed for loss calculations # contrastive_feature = self.con_attn_pooling(self.contrastive_query, img_encoding) + + # [batch_size, encoder_width, captioning_query_length] captioning_feature = self.captioning_attn_pooling(self.captioning_query, img_encoding) + # [batch_size, sequence_length+1, text_dim] text_tokens = np.concatenate(texts, self.cls_token) mask = np.concatenate((np.ones_like(texts), np.zeros_like(self.cls_token))) + # [batch_size, sequence_length+1, text_dim] embed_text = self.text_embedding(text_tokens) unimodal_out = self.unimodal_text_decoder(embed_text, attention_mask=mask) + + # [batch_size, sequence_length, captioning_query_length], notice we remove the CLs token multimodal_out = self.multimodal_text_decoder(unimodal_out[:, :-1, :], encoder_sequence=captioning_feature, decoder_attention_mask=mask) @@ -179,12 +224,10 @@ def get_config(self): config.update( { "img_patch_size": self.img_patch_size, - "img_query_dim": self.img_query_dim, "encoder_depth": self.encoder_depth, "encoder_heads": self.encoder_heads, "encoder_width": self.encoder_width, "encoder_intermediate_dim": self.encoder_intermediate_dim, - "text_proj_dim": self.text_proj_dim, "unimodal_decoder_depth": self.unimodal_decoder_depth, "multimodal_decoder_depth": self.multimodal_decoder_depth, "decoder_intermediate_dim": self.decoder_intermediate_dim, From f15408f2e7713dc39a6bc6cb9dcad49ee5869edb Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Tue, 5 Mar 2024 15:42:46 -0800 Subject: [PATCH 08/13] Lowercased coca model directory and added to kokoro build --- .kokoro/github/ubuntu/gpu/build.sh | 2 ++ keras_cv/layers/attention_pooling.py | 2 +- .../feature_extractor/{CoCa => coca}/__init__.py | 0 .../feature_extractor/{CoCa => coca}/coca_model.py | 10 +++++----- 4 files changed, 8 insertions(+), 6 deletions(-) rename keras_cv/models/feature_extractor/{CoCa => coca}/__init__.py (100%) rename keras_cv/models/feature_extractor/{CoCa => coca}/coca_model.py (98%) diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 76ac0631b4..6bd9d341a6 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -69,6 +69,7 @@ then keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/feature_extractor/coca \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion else @@ -83,6 +84,7 @@ else keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/feature_extractor/coca \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion fi \ No newline at end of file diff --git a/keras_cv/layers/attention_pooling.py b/keras_cv/layers/attention_pooling.py index d1be136f9e..41323435cc 100644 --- a/keras_cv/layers/attention_pooling.py +++ b/keras_cv/layers/attention_pooling.py @@ -2,7 +2,7 @@ class AttentionPooling(layers.Layer): - """Implements the Pooled Attention Layer used in "CoCa": Contrastive Captioners are Image-Text Foundation Models" + """Implements the Pooled Attention Layer used in "coca": Contrastive Captioners are Image-Text Foundation Models" (https://arxiv.org/pdf/2205.01917.pdf), consisting of a Multiheaded Attention followed by Layer Normalization. Args: diff --git a/keras_cv/models/feature_extractor/CoCa/__init__.py b/keras_cv/models/feature_extractor/coca/__init__.py similarity index 100% rename from keras_cv/models/feature_extractor/CoCa/__init__.py rename to keras_cv/models/feature_extractor/coca/__init__.py diff --git a/keras_cv/models/feature_extractor/CoCa/coca_model.py b/keras_cv/models/feature_extractor/coca/coca_model.py similarity index 98% rename from keras_cv/models/feature_extractor/CoCa/coca_model.py rename to keras_cv/models/feature_extractor/coca/coca_model.py index 964b9bd28a..f0509f9f7c 100644 --- a/keras_cv/models/feature_extractor/CoCa/coca_model.py +++ b/keras_cv/models/feature_extractor/coca/coca_model.py @@ -21,7 +21,7 @@ from keras_cv.layers.vit_layers import PatchingAndEmbedding -@keras_cv_export(["keras_cv.models.CoCa"]) +@keras_cv_export(["keras_cv.models.coca"]) class CoCa(Task): def __init__(self, img_patch_size=18, @@ -91,7 +91,7 @@ def __init__(self, """ Contrastive Captioner foundational model implementation. This model implements the "Contrastive Captioners are image-Text Foundational Models" by Yu, et al. - (https://arxiv.org/pdf/2205.01917.pdf). In short, the CoCa model combines the ideas of Contrastive techniques + (https://arxiv.org/pdf/2205.01917.pdf). In short, the coca model combines the ideas of Contrastive techniques such as CLIP, with Generative Captioning approaches such as SimVLM. The architecture of clip can be described as an Image Visual Transformer Encoder in parallel to self-attention-only @@ -105,7 +105,7 @@ def __init__(self, images = ... # [batch_size, height, width, channel] text = ... # [batch_size, text_dim, sequence_length] - coca = CoCa() + coca = coca() # [batch_size, sequence_length, captioning_query_length] output = coca(images, text) @@ -118,7 +118,7 @@ def __init__(self, encoder_depth: number of image encoder blocks encoder_heads: number of attention heads used in each image encoder block encoder_intermediate_dim: dimensionality of the encoder blocks' intermediate representation (MLP dimensionality) - encoder_width: dimensionality of the encoder's projection, consistent with wording used in CoCa paper. + encoder_width: dimensionality of the encoder's projection, consistent with wording used in coca paper. unimodal_decoder_depth: number of decoder blocks used for text self-attention/embedding multimodal_decoder_depth: number of decoder blocks used for image-text cross-attention and captioning decoder_intermediate_dim: dimensionality of the decoder blocks' MLPs @@ -137,7 +137,7 @@ def build(self, input_shape): # Validate Input Shape if len(input_shape) < 2: - raise ValueError("Build arguments to CoCa expected to contain shapes of both text and image data; " + raise ValueError("Build arguments to coca expected to contain shapes of both text and image data; " f"got {len(input_shape)} shapes.") images_shape = input_shape[0] From 960873fec97de103a5e4fb9c240067fa427f7494 Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Mon, 11 Mar 2024 09:22:59 -0700 Subject: [PATCH 09/13] Addressed comments by Matt; reformatted as well --- keras_cv/layers/attention_pooling.py | 13 +- .../feature_extractor/coca/coca_model.py | 257 +++++++++++------- 2 files changed, 167 insertions(+), 103 deletions(-) diff --git a/keras_cv/layers/attention_pooling.py b/keras_cv/layers/attention_pooling.py index 41323435cc..2c5a3a03ba 100644 --- a/keras_cv/layers/attention_pooling.py +++ b/keras_cv/layers/attention_pooling.py @@ -6,21 +6,18 @@ class AttentionPooling(layers.Layer): (https://arxiv.org/pdf/2205.01917.pdf), consisting of a Multiheaded Attention followed by Layer Normalization. Args: - proj_dim: The dimensions of the attention heads + head_dim: The dimensions of the attention heads num_heads: The number of attention heads in the multi-headed attention layer """ - def __init__(self, - proj_dim, - num_heads, - **kwargs): + + def __init__(self, head_dim, num_heads, **kwargs): super().__init__(self, **kwargs) - self.proj_dim = proj_dim + self.head_dim = head_dim self.num_heads = num_heads self.multi_head_attn = layers.MultiHeadAttention( - self.num_heads, - self.proj_dim + self.num_heads, self.head_dim ) self.layer_norm = layers.LayerNormalization() diff --git a/keras_cv/models/feature_extractor/coca/coca_model.py b/keras_cv/models/feature_extractor/coca/coca_model.py index f0509f9f7c..b3ae0a950c 100644 --- a/keras_cv/models/feature_extractor/coca/coca_model.py +++ b/keras_cv/models/feature_extractor/coca/coca_model.py @@ -11,84 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np from keras import Sequential +from keras_nlp.layers import RotaryEmbedding +from keras_nlp.layers import TransformerDecoder + from keras_cv.api_export import keras_cv_export -from keras_nlp.layers import RotaryEmbedding, TransformerDecoder +from keras_cv.backend import ops from keras_cv.layers import TransformerEncoder as CVTransformerEncoder -from keras_cv.models.task import Task from keras_cv.layers.attention_pooling import AttentionPooling from keras_cv.layers.vit_layers import PatchingAndEmbedding +from keras_cv.models.task import Task @keras_cv_export(["keras_cv.models.coca"]) class CoCa(Task): - def __init__(self, - img_patch_size=18, - encoder_depth=40, - encoder_heads=16, - encoder_intermediate_dim=6144, - encoder_width=1408, - unimodal_decoder_depth=18, - multimodal_decoder_depth=18, - decoder_intermediate_dim=5632, - unimodal_decoder_heads=16, - multimodal_decoder_heads=16, - contrastive_query_length=1, - captioning_query_length=256, - contrastive_attn_heads=16, - captioning_attn_heads=16, - contrastive_loss_weight=0.5, - captioning_loss_weight=0.5, - **kwargs): - super().__init__(**kwargs) - - self.img_patch_size = img_patch_size - - self.encoder_depth = encoder_depth - self.encoder_heads = encoder_heads - self.encoder_width = encoder_width - self.encoder_intermediate_dim = encoder_intermediate_dim - - self.unimodal_decoder_depth = unimodal_decoder_depth - self.multimodal_decoder_depth = multimodal_decoder_depth - self.decoder_intermediate_dim = decoder_intermediate_dim - self.unimodal_decoder_heads = unimodal_decoder_heads - self.multimodal_decoder_heads = multimodal_decoder_heads - - self.contrastive_query_length = contrastive_query_length - self.contrastive_attn_heads = contrastive_attn_heads - self.contrastive_loss_weight = contrastive_loss_weight - - self.captioning_query_length = captioning_query_length - self.captioning_attn_heads = captioning_attn_heads - self.captioning_loss_weight = captioning_loss_weight - - # Layer Definitions - self.image_patching = PatchingAndEmbedding(self.encoder_width, self.img_patch_size) - self.image_encoder = Sequential([ - CVTransformerEncoder(self.encoder_width, self.encoder_heads, self.encoder_intermediate_dim) - for _ in range(self.encoder_depth) - ]) - - self.text_embedding = RotaryEmbedding() - self.unimodal_text_decoder = Sequential([ - TransformerDecoder(self.decoder_intermediate_dim, self.unimodal_decoder_heads) - for _ in range(self.unimodal_decoder_depth) - ]) - self.multimodal_text_decoder = Sequential([ - TransformerDecoder(self.decoder_intermediate_dim, self.multimodal_decoder_heads) - for _ in range(self.multimodal_decoder_depth) - ]) - - self.contrastive_attn_pooling = AttentionPooling(self.encoder_width, self.contrastive_attn_heads) - self.captioning_attn_pooling = AttentionPooling(self.encoder_width, self.captioning_attn_heads) - - # These are learnable weights defined in build as per Keras recommendations - self.cls_token = None - self.contrastive_query = None - self.captioning_query = None - """ Contrastive Captioner foundational model implementation. + """Contrastive Captioner foundational model implementation. This model implements the "Contrastive Captioners are image-Text Foundational Models" by Yu, et al. (https://arxiv.org/pdf/2205.01917.pdf). In short, the coca model combines the ideas of Contrastive techniques @@ -132,34 +69,132 @@ def __init__(self, captioning_loss_weight: weighting of captioning loss """ + def __init__( + self, + img_patch_size=18, + encoder_depth=40, + encoder_heads=16, + encoder_intermediate_dim=6144, + encoder_width=1408, + unimodal_decoder_depth=18, + multimodal_decoder_depth=18, + decoder_intermediate_dim=5632, + unimodal_decoder_heads=16, + multimodal_decoder_heads=16, + contrastive_query_length=1, + captioning_query_length=256, + contrastive_attn_heads=16, + captioning_attn_heads=16, + contrastive_loss_weight=0.5, + captioning_loss_weight=0.5, + **kwargs, + ): + super().__init__(**kwargs) + + self.img_patch_size = img_patch_size + + self.encoder_depth = encoder_depth + self.encoder_heads = encoder_heads + self.encoder_width = encoder_width + self.encoder_intermediate_dim = encoder_intermediate_dim + + self.unimodal_decoder_depth = unimodal_decoder_depth + self.multimodal_decoder_depth = multimodal_decoder_depth + self.decoder_intermediate_dim = decoder_intermediate_dim + self.unimodal_decoder_heads = unimodal_decoder_heads + self.multimodal_decoder_heads = multimodal_decoder_heads + + self.contrastive_query_length = contrastive_query_length + self.contrastive_attn_heads = contrastive_attn_heads + self.contrastive_loss_weight = contrastive_loss_weight + + self.captioning_query_length = captioning_query_length + self.captioning_attn_heads = captioning_attn_heads + self.captioning_loss_weight = captioning_loss_weight + + # Layer Definitions + self.image_patching = PatchingAndEmbedding( + self.encoder_width, self.img_patch_size + ) + self.image_encoder = Sequential( + [ + CVTransformerEncoder( + self.encoder_width, + self.encoder_heads, + self.encoder_intermediate_dim, + ) + for _ in range(self.encoder_depth) + ] + ) + + self.text_embedding = RotaryEmbedding() + self.unimodal_text_decoder = Sequential( + [ + TransformerDecoder( + self.decoder_intermediate_dim, self.unimodal_decoder_heads + ) + for _ in range(self.unimodal_decoder_depth) + ] + ) + self.multimodal_text_decoder = Sequential( + [ + TransformerDecoder( + self.decoder_intermediate_dim, self.multimodal_decoder_heads + ) + for _ in range(self.multimodal_decoder_depth) + ] + ) + + self.contrastive_attn_pooling = AttentionPooling( + self.encoder_width, self.contrastive_attn_heads + ) + self.captioning_attn_pooling = AttentionPooling( + self.encoder_width, self.captioning_attn_heads + ) + + # These are learnable weights defined in build as per Keras recommendations + self.cls_token = None + self.contrastive_query = None + self.captioning_query = None + def build(self, input_shape): super().build(input_shape) # Validate Input Shape if len(input_shape) < 2: - raise ValueError("Build arguments to coca expected to contain shapes of both text and image data; " - f"got {len(input_shape)} shapes.") + raise ValueError( + "Build arguments to coca expected to contain shapes of both text and image data; " + f"got {len(input_shape)} shapes." + ) images_shape = input_shape[0] text_shape = input_shape[1] if len(images_shape) != 4: - raise ValueError("Image shape expected to be of shape [batch_size, height, width, channels]. Instead got " - f"shape: {images_shape}") + raise ValueError( + "Image shape expected to be of shape [batch_size, height, width, channels]. Instead got " + f"shape: {images_shape}" + ) elif len(text_shape) != 2: - raise ValueError("Text shape expected to be of shape [batch_size, context_length]. Instead got shape" - f": {text_shape}") + raise ValueError( + "Text shape expected to be of shape [batch_size, context_length]. Instead got shape" + f": {text_shape}" + ) text_dim = text_shape[1] batch_size = images_shape[0] if batch_size != text_shape[0]: - raise ValueError(f"Differing batch sizes between images and texts input. {batch_size} vs {text_shape[0]}") + raise ValueError( + f"Differing batch sizes between images and texts input. {batch_size} vs {text_shape[0]}" + ) # Build Layers self.image_patching.build(images_shape) # Add 1 for CLs token appended by patching - num_patches = (images_shape[1] // self.img_patch_size) * (images_shape[2] // self.img_patch_size) + 1 + num_patches = (images_shape[1] // self.img_patch_size) * ( + images_shape[2] // self.img_patch_size + ) + 1 self.image_encoder.build((batch_size, self.encoder_width, num_patches)) text_shape_with_cls_token = [s for s in text_shape] @@ -168,19 +203,39 @@ def build(self, input_shape): self.unimodal_text_decoder.build(text_shape_with_cls_token) - self.contrastive_attn_pooling.build((batch_size, num_patches, self.encoder_width)) - self.captioning_attn_pooling.build((batch_size, num_patches, self.encoder_width)) + self.contrastive_attn_pooling.build( + (batch_size, num_patches, self.encoder_width) + ) + self.captioning_attn_pooling.build( + (batch_size, num_patches, self.encoder_width) + ) - self.multimodal_text_decoder.build((batch_size, self.encoder_width, self.captioning_query_length), - text_shape_with_cls_token) + self.multimodal_text_decoder.build( + (batch_size, self.encoder_width, self.captioning_query_length), + text_shape_with_cls_token, + ) # Learnable Weights - self.cls_token = self.add_weight(shape=(batch_size, 1, text_dim), name="cls_token", trainable=True) + self.cls_token = self.add_weight( + shape=(batch_size, 1, text_dim), name="cls_token", trainable=True + ) - self.contrastive_query = self.add_weight(shape=(batch_size, self.encoder_width, self.contrastive_query_length), - trainable=True) - self.captioning_query = self.add_weight(shape=(batch_size, self.encoder_width, self.captioning_query_length), - trainable=True) + self.contrastive_query = self.add_weight( + shape=( + batch_size, + self.encoder_width, + self.contrastive_query_length, + ), + trainable=True, + ) + self.captioning_query = self.add_weight( + shape=( + batch_size, + self.encoder_width, + self.captioning_query_length, + ), + trainable=True, + ) def call(self, images, texts): """ @@ -195,27 +250,39 @@ def call(self, images, texts): Returns: Output: Output of the captioning Transformer Decoder with captioning cross-attention """ - img_encoding = self.image_patching(images) # [batch_size, encoder_width, img_patches_len+1] - img_encoding = self.image_encoder(img_encoding) # [batch_size, img_patches_len+1, encoder_width] + img_encoding = self.image_patching( + images + ) # [batch_size, encoder_width, img_patches_len+1] + img_encoding = self.image_encoder( + img_encoding + ) # [batch_size, img_patches_len+1, encoder_width] # This is only needed for loss calculations # contrastive_feature = self.con_attn_pooling(self.contrastive_query, img_encoding) # [batch_size, encoder_width, captioning_query_length] - captioning_feature = self.captioning_attn_pooling(self.captioning_query, img_encoding) + captioning_feature = self.captioning_attn_pooling( + self.captioning_query, img_encoding + ) # [batch_size, sequence_length+1, text_dim] - text_tokens = np.concatenate(texts, self.cls_token) - mask = np.concatenate((np.ones_like(texts), np.zeros_like(self.cls_token))) + text_tokens = ops.concatenate(texts, self.cls_token) + mask = ops.concatenate( + (ops.ones_like(texts), ops.zeros_like(self.cls_token)) + ) # [batch_size, sequence_length+1, text_dim] embed_text = self.text_embedding(text_tokens) - unimodal_out = self.unimodal_text_decoder(embed_text, attention_mask=mask) + unimodal_out = self.unimodal_text_decoder( + embed_text, attention_mask=mask + ) # [batch_size, sequence_length, captioning_query_length], notice we remove the CLs token - multimodal_out = self.multimodal_text_decoder(unimodal_out[:, :-1, :], - encoder_sequence=captioning_feature, - decoder_attention_mask=mask) + multimodal_out = self.multimodal_text_decoder( + unimodal_out[:, :-1, :], + encoder_sequence=captioning_feature, + decoder_attention_mask=mask, + ) return multimodal_out From 33cff54d06dd609fe8faddb4a4bef6cee5118310 Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Wed, 13 Mar 2024 12:30:00 -0700 Subject: [PATCH 10/13] Addressed comments related to attn pooling size, attn pooling name --- keras_cv/models/feature_extractor/coca/__init__.py | 13 +++++++++++++ .../feature_extractor/coca/coca_layers.py} | 12 +++++++++--- .../models/feature_extractor/coca/coca_model.py | 12 +++++++----- 3 files changed, 29 insertions(+), 8 deletions(-) rename keras_cv/{layers/attention_pooling.py => models/feature_extractor/coca/coca_layers.py} (70%) diff --git a/keras_cv/models/feature_extractor/coca/__init__.py b/keras_cv/models/feature_extractor/coca/__init__.py index e69de29bb2..3992ffb59a 100644 --- a/keras_cv/models/feature_extractor/coca/__init__.py +++ b/keras_cv/models/feature_extractor/coca/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_cv/layers/attention_pooling.py b/keras_cv/models/feature_extractor/coca/coca_layers.py similarity index 70% rename from keras_cv/layers/attention_pooling.py rename to keras_cv/models/feature_extractor/coca/coca_layers.py index 2c5a3a03ba..829b91e6c5 100644 --- a/keras_cv/layers/attention_pooling.py +++ b/keras_cv/models/feature_extractor/coca/coca_layers.py @@ -1,7 +1,7 @@ from keras import layers -class AttentionPooling(layers.Layer): +class CoCaAttentionPooling(layers.Layer): """Implements the Pooled Attention Layer used in "coca": Contrastive Captioners are Image-Text Foundation Models" (https://arxiv.org/pdf/2205.01917.pdf), consisting of a Multiheaded Attention followed by Layer Normalization. @@ -25,8 +25,14 @@ def __init__(self, head_dim, num_heads, **kwargs): def build(self, input_shape): super().build(input_shape) - self.multi_head_attn.build(input_shape) - self.layer_norm.build(input_shape) + if(len(input_shape) < 2): + raise ValueError("Building CoCa Attention Pooling requires input shape of shape (query_shape, value_shape)") + + query_shape = input_shape[0] + value_shape = input_shape[1] + + self.multi_head_attn._build_from_signature(query_shape, value_shape) + self.layer_norm.build(query_shape) def call(self, query, value): x = self.multi_head_attn(query, value) diff --git a/keras_cv/models/feature_extractor/coca/coca_model.py b/keras_cv/models/feature_extractor/coca/coca_model.py index b3ae0a950c..bd6e765f8c 100644 --- a/keras_cv/models/feature_extractor/coca/coca_model.py +++ b/keras_cv/models/feature_extractor/coca/coca_model.py @@ -18,7 +18,7 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import ops from keras_cv.layers import TransformerEncoder as CVTransformerEncoder -from keras_cv.layers.attention_pooling import AttentionPooling +from keras_cv.models.feature_extractor.coca.coca_layers import CoCaAttentionPooling from keras_cv.layers.vit_layers import PatchingAndEmbedding from keras_cv.models.task import Task @@ -145,10 +145,10 @@ def __init__( ] ) - self.contrastive_attn_pooling = AttentionPooling( + self.contrastive_attn_pooling = CoCaAttentionPooling( self.encoder_width, self.contrastive_attn_heads ) - self.captioning_attn_pooling = AttentionPooling( + self.captioning_attn_pooling = CoCaAttentionPooling( self.encoder_width, self.captioning_attn_heads ) @@ -204,10 +204,12 @@ def build(self, input_shape): self.unimodal_text_decoder.build(text_shape_with_cls_token) self.contrastive_attn_pooling.build( - (batch_size, num_patches, self.encoder_width) + ((batch_size, self.encoder_width, self.contrastive_query_length), + (batch_size, num_patches, self.encoder_width)) ) self.captioning_attn_pooling.build( - (batch_size, num_patches, self.encoder_width) + ((batch_size, self.encoder_width, self.captioning_query_length), + (batch_size, num_patches, self.encoder_width)) ) self.multimodal_text_decoder.build( From 145d7b572d49a0968cf499351f2de0c7fcd5a4cf Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Tue, 19 Mar 2024 13:46:24 -0700 Subject: [PATCH 11/13] Wrote a test for coca saving and loading, which prompted some model changes --- .../models/feature_extractor/coca/__init__.py | 3 + .../feature_extractor/coca/coca_layers.py | 4 +- .../feature_extractor/coca/coca_model.py | 98 ++++++++++--------- .../feature_extractor/coca/coca_model_test.py | 24 +++++ 4 files changed, 81 insertions(+), 48 deletions(-) create mode 100644 keras_cv/models/feature_extractor/coca/coca_model_test.py diff --git a/keras_cv/models/feature_extractor/coca/__init__.py b/keras_cv/models/feature_extractor/coca/__init__.py index 3992ffb59a..5372894aca 100644 --- a/keras_cv/models/feature_extractor/coca/__init__.py +++ b/keras_cv/models/feature_extractor/coca/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_cv.models.feature_extractor.coca.coca_model import CoCa +from keras_cv.models.feature_extractor.coca.coca_layers import CoCaAttentionPooling diff --git a/keras_cv/models/feature_extractor/coca/coca_layers.py b/keras_cv/models/feature_extractor/coca/coca_layers.py index 829b91e6c5..25bbbc1a60 100644 --- a/keras_cv/models/feature_extractor/coca/coca_layers.py +++ b/keras_cv/models/feature_extractor/coca/coca_layers.py @@ -11,7 +11,7 @@ class CoCaAttentionPooling(layers.Layer): """ def __init__(self, head_dim, num_heads, **kwargs): - super().__init__(self, **kwargs) + super().__init__(**kwargs) self.head_dim = head_dim self.num_heads = num_heads @@ -23,7 +23,7 @@ def __init__(self, head_dim, num_heads, **kwargs): self.layer_norm = layers.LayerNormalization() def build(self, input_shape): - super().build(input_shape) + # super().build(input_shape) if(len(input_shape) < 2): raise ValueError("Building CoCa Attention Pooling requires input shape of shape (query_shape, value_shape)") diff --git a/keras_cv/models/feature_extractor/coca/coca_model.py b/keras_cv/models/feature_extractor/coca/coca_model.py index bd6e765f8c..0fd6fcf1bc 100644 --- a/keras_cv/models/feature_extractor/coca/coca_model.py +++ b/keras_cv/models/feature_extractor/coca/coca_model.py @@ -70,24 +70,24 @@ class CoCa(Task): """ def __init__( - self, - img_patch_size=18, - encoder_depth=40, - encoder_heads=16, - encoder_intermediate_dim=6144, - encoder_width=1408, - unimodal_decoder_depth=18, - multimodal_decoder_depth=18, - decoder_intermediate_dim=5632, - unimodal_decoder_heads=16, - multimodal_decoder_heads=16, - contrastive_query_length=1, - captioning_query_length=256, - contrastive_attn_heads=16, - captioning_attn_heads=16, - contrastive_loss_weight=0.5, - captioning_loss_weight=0.5, - **kwargs, + self, + img_patch_size=18, + encoder_depth=40, + encoder_heads=16, + encoder_intermediate_dim=6144, + encoder_width=1408, + unimodal_decoder_depth=18, + multimodal_decoder_depth=18, + decoder_intermediate_dim=5632, + unimodal_decoder_heads=16, + multimodal_decoder_heads=16, + contrastive_query_length=1, + captioning_query_length=256, + contrastive_attn_heads=16, + captioning_attn_heads=16, + contrastive_loss_weight=0.5, + captioning_loss_weight=0.5, + **kwargs, ): super().__init__(**kwargs) @@ -136,14 +136,12 @@ def __init__( for _ in range(self.unimodal_decoder_depth) ] ) - self.multimodal_text_decoder = Sequential( - [ - TransformerDecoder( - self.decoder_intermediate_dim, self.multimodal_decoder_heads - ) - for _ in range(self.multimodal_decoder_depth) - ] - ) + self.multimodal_text_decoders = [ + TransformerDecoder( + self.decoder_intermediate_dim, self.multimodal_decoder_heads + ) + for _ in range(self.multimodal_decoder_depth) + ] self.contrastive_attn_pooling = CoCaAttentionPooling( self.encoder_width, self.contrastive_attn_heads @@ -158,8 +156,6 @@ def __init__( self.captioning_query = None def build(self, input_shape): - super().build(input_shape) - # Validate Input Shape if len(input_shape) < 2: raise ValueError( @@ -175,13 +171,13 @@ def build(self, input_shape): "Image shape expected to be of shape [batch_size, height, width, channels]. Instead got " f"shape: {images_shape}" ) - elif len(text_shape) != 2: + elif len(text_shape) != 3: raise ValueError( - "Text shape expected to be of shape [batch_size, context_length]. Instead got shape" + "Text shape expected to be of shape [batch_size, context_length, text_dim]. Instead got shape" f": {text_shape}" ) - text_dim = text_shape[1] + text_dim = text_shape[-1] batch_size = images_shape[0] if batch_size != text_shape[0]: raise ValueError( @@ -193,9 +189,9 @@ def build(self, input_shape): # Add 1 for CLs token appended by patching num_patches = (images_shape[1] // self.img_patch_size) * ( - images_shape[2] // self.img_patch_size + images_shape[2] // self.img_patch_size ) + 1 - self.image_encoder.build((batch_size, self.encoder_width, num_patches)) + self.image_encoder.build((batch_size, num_patches, self.encoder_width)) text_shape_with_cls_token = [s for s in text_shape] text_shape_with_cls_token[1] += 1 @@ -212,10 +208,9 @@ def build(self, input_shape): (batch_size, num_patches, self.encoder_width)) ) - self.multimodal_text_decoder.build( - (batch_size, self.encoder_width, self.captioning_query_length), - text_shape_with_cls_token, - ) + for text_decoder in self.multimodal_text_decoders: + text_decoder.build((batch_size, self.encoder_width, self.captioning_query_length), + text_shape) # Learnable Weights self.cls_token = self.add_weight( @@ -239,22 +234,23 @@ def build(self, input_shape): trainable=True, ) + self.built = True + def call(self, images, texts): """ Forward pass of the Coca Model from raw image and text data Args: images: [batch_size, height, width, channels] representing images - texts: Tensor, typically represented as [batch_size, sequence_length, feature_length] or - [batch_size, sequence_length, num_heads, feature_length]. The sequence_length and/or feature_length - are required. + texts: Tensor, typically represented as [batch_size, sequence_length, feature_length]. + The sequence_length and/or feature_length are required. Returns: Output: Output of the captioning Transformer Decoder with captioning cross-attention """ img_encoding = self.image_patching( images - ) # [batch_size, encoder_width, img_patches_len+1] + ) # [batch_size, img_patches_len+1, encoder_width] img_encoding = self.image_encoder( img_encoding ) # [batch_size, img_patches_len+1, encoder_width] @@ -280,11 +276,13 @@ def call(self, images, texts): ) # [batch_size, sequence_length, captioning_query_length], notice we remove the CLs token - multimodal_out = self.multimodal_text_decoder( - unimodal_out[:, :-1, :], - encoder_sequence=captioning_feature, - decoder_attention_mask=mask, - ) + multimodal_out = unimodal_out[:, :-1, :] + for decoder in self.multimodal_text_decoders: + multimodal_out = decoder( + multimodal_out, + encoder_sequence=captioning_feature, + decoder_attention_mask=mask + ) return multimodal_out @@ -311,3 +309,11 @@ def get_config(self): } ) return config + + @classmethod + def from_config(cls, config): + return cls(**config) + + def load_own_variables(self, store): + print(store) + super().load_own_variables(store) \ No newline at end of file diff --git a/keras_cv/models/feature_extractor/coca/coca_model_test.py b/keras_cv/models/feature_extractor/coca/coca_model_test.py new file mode 100644 index 0000000000..2cea8c5791 --- /dev/null +++ b/keras_cv/models/feature_extractor/coca/coca_model_test.py @@ -0,0 +1,24 @@ +import keras.saving +import numpy as np +import pytest +import os + +from keras_cv.models.feature_extractor.coca import CoCa +from keras_cv.tests.test_case import TestCase + +class CoCaTest(TestCase): + + @pytest.mark.large + def test_coca_model_save(self): + # TODO: Transformer encoder breaks if you have project dim < num heads + model = CoCa() + model.build(((1, 512, 512, 3), (1, 1, 48))) + + save_path = os.path.join(self.get_temp_dir(), "coca.keras") + model.save(save_path) + + restored_model = keras.models.load_model(save_path, custom_objects={"CoCa": CoCa}) + + self.assertIsInstance(restored_model, CoCa) + + From e8623a9afb980bf8a5b6270f4fbfb039c2825ef9 Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Tue, 26 Mar 2024 11:41:31 -0700 Subject: [PATCH 12/13] Updated to functional model --- .../feature_extractor/coca/coca_model.py | 128 ++++++------------ 1 file changed, 42 insertions(+), 86 deletions(-) diff --git a/keras_cv/models/feature_extractor/coca/coca_model.py b/keras_cv/models/feature_extractor/coca/coca_model.py index 0fd6fcf1bc..5b5aadc003 100644 --- a/keras_cv/models/feature_extractor/coca/coca_model.py +++ b/keras_cv/models/feature_extractor/coca/coca_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import keras from keras import Sequential from keras_nlp.layers import RotaryEmbedding from keras_nlp.layers import TransformerDecoder @@ -91,6 +92,9 @@ def __init__( ): super().__init__(**kwargs) + # + # Save Details + # self.img_patch_size = img_patch_size self.encoder_depth = encoder_depth @@ -112,7 +116,9 @@ def __init__( self.captioning_attn_heads = captioning_attn_heads self.captioning_loss_weight = captioning_loss_weight + # # Layer Definitions + # self.image_patching = PatchingAndEmbedding( self.encoder_width, self.img_patch_size ) @@ -151,75 +157,31 @@ def __init__( ) # These are learnable weights defined in build as per Keras recommendations - self.cls_token = None self.contrastive_query = None self.captioning_query = None - def build(self, input_shape): - # Validate Input Shape - if len(input_shape) < 2: - raise ValueError( - "Build arguments to coca expected to contain shapes of both text and image data; " - f"got {len(input_shape)} shapes." - ) - - images_shape = input_shape[0] - text_shape = input_shape[1] - - if len(images_shape) != 4: - raise ValueError( - "Image shape expected to be of shape [batch_size, height, width, channels]. Instead got " - f"shape: {images_shape}" - ) - elif len(text_shape) != 3: - raise ValueError( - "Text shape expected to be of shape [batch_size, context_length, text_dim]. Instead got shape" - f": {text_shape}" - ) - - text_dim = text_shape[-1] - batch_size = images_shape[0] - if batch_size != text_shape[0]: - raise ValueError( - f"Differing batch sizes between images and texts input. {batch_size} vs {text_shape[0]}" - ) - - # Build Layers - self.image_patching.build(images_shape) - - # Add 1 for CLs token appended by patching - num_patches = (images_shape[1] // self.img_patch_size) * ( - images_shape[2] // self.img_patch_size - ) + 1 - self.image_encoder.build((batch_size, num_patches, self.encoder_width)) - - text_shape_with_cls_token = [s for s in text_shape] - text_shape_with_cls_token[1] += 1 - self.text_embedding.build(text_shape_with_cls_token) - - self.unimodal_text_decoder.build(text_shape_with_cls_token) - - self.contrastive_attn_pooling.build( - ((batch_size, self.encoder_width, self.contrastive_query_length), - (batch_size, num_patches, self.encoder_width)) + # + # Functional Model + # + images = keras.Input( + shape=(None,), dtype="int32", name="images" ) - self.captioning_attn_pooling.build( - ((batch_size, self.encoder_width, self.captioning_query_length), - (batch_size, num_patches, self.encoder_width)) + + captions = keras.Input( + shape=(None,), dtype="int32", name="caption" ) - for text_decoder in self.multimodal_text_decoders: - text_decoder.build((batch_size, self.encoder_width, self.captioning_query_length), - text_shape) + img_encoding = self.image_patching( + images + ) # [batch_size, img_patches_len+1, encoder_width] + img_encoding = self.image_encoder( + img_encoding + ) # [batch_size, img_patches_len+1, encoder_width] # Learnable Weights - self.cls_token = self.add_weight( - shape=(batch_size, 1, text_dim), name="cls_token", trainable=True - ) - self.contrastive_query = self.add_weight( shape=( - batch_size, + None, self.encoder_width, self.contrastive_query_length, ), @@ -227,46 +189,30 @@ def build(self, input_shape): ) self.captioning_query = self.add_weight( shape=( - batch_size, + None, self.encoder_width, self.captioning_query_length, ), trainable=True, ) - self.built = True - - def call(self, images, texts): - """ - Forward pass of the Coca Model from raw image and text data - - Args: - images: [batch_size, height, width, channels] representing images - texts: Tensor, typically represented as [batch_size, sequence_length, feature_length]. - The sequence_length and/or feature_length are required. - - Returns: - Output: Output of the captioning Transformer Decoder with captioning cross-attention - """ - img_encoding = self.image_patching( - images - ) # [batch_size, img_patches_len+1, encoder_width] - img_encoding = self.image_encoder( - img_encoding - ) # [batch_size, img_patches_len+1, encoder_width] - - # This is only needed for loss calculations - # contrastive_feature = self.con_attn_pooling(self.contrastive_query, img_encoding) + # This is for contrastive loss; [batch_size, encoder_width, contrastive_query_length] + contrastive_feature = self.con_attn_pooling(self.contrastive_query, img_encoding) # [batch_size, encoder_width, captioning_query_length] captioning_feature = self.captioning_attn_pooling( self.captioning_query, img_encoding ) + # Learnable CLs Token + self.cls_token = self.add_weight( + shape=(None, 1, ), name="cls_token", trainable=True + ) + # [batch_size, sequence_length+1, text_dim] - text_tokens = ops.concatenate(texts, self.cls_token) + text_tokens = ops.concatenate(captions, self.cls_token) mask = ops.concatenate( - (ops.ones_like(texts), ops.zeros_like(self.cls_token)) + (ops.ones_like(captions), ops.zeros_like(self.cls_token)) ) # [batch_size, sequence_length+1, text_dim] @@ -284,7 +230,17 @@ def call(self, images, texts): decoder_attention_mask=mask ) - return multimodal_out + super().__init__( + inputs={ + "images": images, + "captions": captions, + }, + outputs={ + "multimodal_out": multimodal_out, + "contrastive_feature": contrastive_feature + }, + ) + def get_config(self): config = super().get_config() From c9e1ec10148571db5d780950738e634d8155afbc Mon Sep 17 00:00:00 2001 From: Varun Singh Date: Thu, 28 Mar 2024 11:50:37 -0700 Subject: [PATCH 13/13] added size inputs for functional model --- .../feature_extractor/coca/coca_model.py | 28 +++++++++++-------- .../feature_extractor/coca/coca_model_test.py | 1 - 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/keras_cv/models/feature_extractor/coca/coca_model.py b/keras_cv/models/feature_extractor/coca/coca_model.py index 5b5aadc003..707cf2eb43 100644 --- a/keras_cv/models/feature_extractor/coca/coca_model.py +++ b/keras_cv/models/feature_extractor/coca/coca_model.py @@ -52,6 +52,8 @@ class CoCa(Task): All default arguments should be consistent with the original paper's details. Args: + img_shape: The shape of a single image, typically expressed as [height, weight, channels] + caption_shape: The shape of a single caption, typically expressed as [sequence_length, text_dim] img_patch_size: N of each NxN patch generated from linearization of the input images encoder_depth: number of image encoder blocks encoder_heads: number of attention heads used in each image encoder block @@ -72,6 +74,8 @@ class CoCa(Task): def __init__( self, + img_shape=(512, 512, 3), + caption_shape = (10, 48), img_patch_size=18, encoder_depth=40, encoder_heads=16, @@ -95,6 +99,9 @@ def __init__( # # Save Details # + self.img_shape = img_shape + self.caption_shape = caption_shape + self.img_patch_size = img_patch_size self.encoder_depth = encoder_depth @@ -163,13 +170,8 @@ def __init__( # # Functional Model # - images = keras.Input( - shape=(None,), dtype="int32", name="images" - ) - - captions = keras.Input( - shape=(None,), dtype="int32", name="caption" - ) + images = keras.Input(shape=self.img_shape, name="images") + captions = keras.Input(shape=self.caption_shape, name="caption") img_encoding = self.image_patching( images @@ -182,31 +184,31 @@ def __init__( self.contrastive_query = self.add_weight( shape=( None, - self.encoder_width, self.contrastive_query_length, + self.encoder_width, ), trainable=True, ) self.captioning_query = self.add_weight( shape=( None, - self.encoder_width, self.captioning_query_length, + self.encoder_width, ), trainable=True, ) - # This is for contrastive loss; [batch_size, encoder_width, contrastive_query_length] + # This is for contrastive loss; [batch_size, contrastive_query_length, encoder_width] contrastive_feature = self.con_attn_pooling(self.contrastive_query, img_encoding) - # [batch_size, encoder_width, captioning_query_length] + # [batch_size, captioning_query_length, encoder_width] captioning_feature = self.captioning_attn_pooling( self.captioning_query, img_encoding ) # Learnable CLs Token self.cls_token = self.add_weight( - shape=(None, 1, ), name="cls_token", trainable=True + shape=(None, 1, self.caption_shape[-1]), name="cls_token", trainable=True ) # [batch_size, sequence_length+1, text_dim] @@ -246,6 +248,8 @@ def get_config(self): config = super().get_config() config.update( { + "img_shape": self.img_shape, + "caption_shape": self.caption_shape, "img_patch_size": self.img_patch_size, "encoder_depth": self.encoder_depth, "encoder_heads": self.encoder_heads, diff --git a/keras_cv/models/feature_extractor/coca/coca_model_test.py b/keras_cv/models/feature_extractor/coca/coca_model_test.py index 2cea8c5791..f9c99f903e 100644 --- a/keras_cv/models/feature_extractor/coca/coca_model_test.py +++ b/keras_cv/models/feature_extractor/coca/coca_model_test.py @@ -12,7 +12,6 @@ class CoCaTest(TestCase): def test_coca_model_save(self): # TODO: Transformer encoder breaks if you have project dim < num heads model = CoCa() - model.build(((1, 512, 512, 3), (1, 1, 48))) save_path = os.path.join(self.get_temp_dir(), "coca.keras") model.save(save_path)