Skip to content

Commit

Permalink
remove spatial_shapes
Browse files Browse the repository at this point in the history
Signed-off-by: Phillip Kuznetsov <[email protected]>
  • Loading branch information
philkuz committed Oct 29, 2024
1 parent 99e90b2 commit 104b16d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 12 deletions.
16 changes: 5 additions & 11 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,6 @@ def forward(
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
Expand Down Expand Up @@ -959,7 +958,11 @@ def forward(
)
# batch_size, num_queries, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
offset_normalizer = torch.tensor(
[[shape[1], shape[0]] for shape in spatial_shapes_list],
dtype=torch.long,
device=reference_points.device,
)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
Expand Down Expand Up @@ -1003,7 +1006,6 @@ def forward(
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
Expand All @@ -1018,8 +1020,6 @@ def forward(
Position embeddings, to be added to `hidden_states`.
reference_points (`torch.FloatTensor`, *optional*):
Reference points.
spatial_shapes (`torch.LongTensor`, *optional*):
Spatial shapes of the backbone feature maps.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of the backbone feature maps as a list of tuples.
level_start_index (`torch.LongTensor`, *optional*):
Expand All @@ -1038,7 +1038,6 @@ def forward(
encoder_attention_mask=attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1128,7 +1127,6 @@ def forward(
inputs_embeds=None,
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
Expand All @@ -1147,8 +1145,6 @@ def forward(
[What are attention masks?](../glossary#attention-mask)
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer.
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of each feature map as a list of tuples.
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
Expand Down Expand Up @@ -1185,7 +1181,6 @@ def forward(
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1330,7 +1325,6 @@ def forward(
inputs_embeds=input_embeds_flat,
attention_mask=masks_flat,
position_embeddings=level_pos_embed_flat,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
Expand Down
1 change: 0 additions & 1 deletion tests/models/mask2former/test_modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,6 @@ def test_with_segmentation_maps_and_loss(self):

self.assertTrue(outputs.loss is not None)

@slow
def test_export(self):
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")
Expand Down

0 comments on commit 104b16d

Please sign in to comment.