Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLIPSeg] Make interpolate_pos_encoding default to True #34419

Merged
merged 6 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 31 additions & 58 deletions src/transformers/models/clipseg/modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig

Expand Down Expand Up @@ -164,62 +163,40 @@ def __init__(self, config: CLIPSegVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.

Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
position_embedding = self.position_embedding.weight.unsqueeze(0)
num_positions = position_embedding.shape[1] - 1

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embedding(self.position_ids)

class_pos_embed = position_embedding[:, :1]
patch_pos_embed = position_embedding[:, 1:]

dim = embeddings.shape[-1]

new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
def interpolate_pos_encoding(self, new_size):
if len(new_size) != 2:
raise ValueError("new_size should consist of 2 values")

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
num_patches_one_direction = int(self.num_patches**0.5)
# we interpolate the position embeddings in 2D
a = self.position_embedding.weight[1:].T.view(
1, self.config.hidden_size, num_patches_one_direction, num_patches_one_direction
)
b = (
nn.functional.interpolate(a, new_size, mode="bicubic", align_corners=False)
.squeeze(0)
.view(self.config.hidden_size, new_size[0] * new_size[1])
.T
)
result = torch.cat([self.position_embedding.weight[:1], b])

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
return result

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
)
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = True) -> torch.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)

if embeddings.shape[1] != self.num_positions and interpolate_pos_encoding:
new_shape = int(math.sqrt(embeddings.shape[1] - 1))
embeddings = embeddings + self.interpolate_pos_encoding((new_shape, new_shape))
embeddings = embeddings.to(embeddings.dtype)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)

return embeddings


Expand Down Expand Up @@ -535,7 +512,7 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
interpolate_pos_encoding (`bool`, *optional*, defaults to `True`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Expand Down Expand Up @@ -574,7 +551,7 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
interpolate_pos_encoding (`bool`, *optional*, defaults to `True`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Expand Down Expand Up @@ -845,14 +822,13 @@ def __init__(self, config: CLIPSegVisionConfig):

@add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor],
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
interpolate_pos_encoding: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -864,9 +840,6 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if pixel_values is None:
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)

Expand Down Expand Up @@ -912,7 +885,7 @@ def forward(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
interpolate_pos_encoding: Optional[bool] = True,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Expand Down Expand Up @@ -1035,7 +1008,7 @@ def get_image_features(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
interpolate_pos_encoding: bool = True,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Expand Down Expand Up @@ -1091,7 +1064,7 @@ def forward(
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
interpolate_pos_encoding: bool = True,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPSegOutput]:
r"""
Expand Down Expand Up @@ -1397,7 +1370,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
interpolate_pos_encoding: bool = True,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPSegOutput]:
r"""
Expand Down
41 changes: 2 additions & 39 deletions tests/models/clipseg/test_modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,15 +796,15 @@ def test_inference_image_segmentation(self):

# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
outputs = model(**inputs)

# verify the predicted masks
self.assertEqual(
outputs.logits.shape,
torch.Size((3, 352, 352)),
)
expected_masks_slice = torch.tensor(
[[-7.4613, -7.4785, -7.3627], [-7.3268, -7.0898, -7.1333], [-6.9838, -6.7900, -6.8913]]
[[-7.4613, -7.4785, -7.3628], [-7.3268, -7.0899, -7.1333], [-6.9838, -6.7900, -6.8913]]
).to(torch_device)

self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3))
Expand All @@ -814,40 +814,3 @@ def test_inference_image_segmentation(self):
expected_pooled_output = torch.tensor([0.5036, -0.2681, -0.2644]).to(torch_device)
self.assertTrue(torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3))
self.assertTrue(torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3))

@slow
def test_inference_interpolate_pos_encoding(self):
# ViT models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
model = CLIPSegModel.from_pretrained("openai/clip-vit-base-patch32").to(torch_device)

processor = CLIPSegProcessor.from_pretrained(
"openai/clip-vit-base-patch32", size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180}
)

image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)

# interpolate_pos_encodiung false should return value error
with self.assertRaises(ValueError, msg="doesn't match model"):
with torch.no_grad():
model(**inputs, interpolate_pos_encoding=False)

# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)

# verify the logits
expected_shape = torch.Size((1, 26, 768))

self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)

expected_slice = torch.tensor(
[[-0.1538, 0.0322, -0.3235], [0.2893, 0.1135, -0.5708], [0.0461, 0.1540, -0.6018]]
).to(torch_device)

self.assertTrue(
torch.allclose(outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)
)
Loading