Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert RandomZoom to backend-agnostic and improve affine_transform #574

Merged
merged 18 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions keras_core/backend/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def resize(
"constant": "zeros",
"nearest": "border",
# "wrap", not supported by torch
# "mirror", not supported by torch
"reflect": "reflection",
"mirror": "reflection", # torch's reflection is mirror in other backends
"reflect": "reflection", # if fill_mode==reflect, redirect to mirror
}


Expand Down Expand Up @@ -122,7 +122,7 @@ def _apply_grid_transform(
grid,
mode=interpolation,
padding_mode=fill_mode,
align_corners=False,
align_corners=True,
)
# Fill with required color
if fill_value is not None:
Expand Down Expand Up @@ -187,9 +187,9 @@ def affine_transform(
f"transform.shape={transform.shape}"
)

if fill_mode != "constant":
# the default fill_value of tnn.grid_sample is "zeros"
if fill_mode != "constant" or (fill_mode == "constant" and fill_value == 0):
fill_value = None
fill_mode = AFFINE_TRANSFORM_FILL_MODES[fill_mode]

# unbatched case
need_squeeze = False
Expand All @@ -202,23 +202,60 @@ def affine_transform(
if data_format == "channels_last":
image = image.permute((0, 3, 1, 2))

batch_size = image.shape[0]
h, w, c = image.shape[-2], image.shape[-1], image.shape[-3]

# get indices
shape = [h, w, c] # (H, W, C)
meshgrid = torch.meshgrid(
*[torch.arange(size) for size in shape], indexing="ij"
)
indices = torch.concatenate(
[torch.unsqueeze(x, dim=-1) for x in meshgrid], dim=-1
)
indices = torch.tile(indices, (batch_size, 1, 1, 1, 1))
indices = indices.to(transform)

# swap the values
a0 = transform[:, 0].clone()
a2 = transform[:, 2].clone()
b1 = transform[:, 4].clone()
b2 = transform[:, 5].clone()
transform[:, 0] = b1
transform[:, 2] = b2
transform[:, 4] = a0
transform[:, 5] = a2

# deal with transform
h, w = image.shape[2], image.shape[3]
theta = torch.zeros((image.shape[0], 2, 3)).to(transform)
theta[:, 0, 0] = transform[:, 0]
theta[:, 0, 1] = transform[:, 1] * h / w
theta[:, 0, 2] = (
transform[:, 2] * 2 / w + theta[:, 0, 0] + theta[:, 0, 1] - 1
transform = torch.nn.functional.pad(
transform, pad=[0, 1, 0, 0], mode="constant", value=1
)
theta[:, 1, 0] = transform[:, 3] * w / h
theta[:, 1, 1] = transform[:, 4]
theta[:, 1, 2] = (
transform[:, 5] * 2 / h + theta[:, 1, 0] + theta[:, 1, 1] - 1
transform = torch.reshape(transform, (batch_size, 3, 3))
offset = transform[:, 0:2, 2].clone()
offset = torch.nn.functional.pad(offset, pad=[0, 1, 0, 0])
transform[:, 0:2, 2] = 0

# transform the indices
coordinates = torch.einsum("Bhwij, Bjk -> Bhwik", indices, transform)
coordinates = torch.moveaxis(coordinates, source=-1, destination=1)
coordinates += torch.reshape(a=offset, shape=(*offset.shape, 1, 1, 1))
coordinates = coordinates[:, 0:2, ..., 0]
coordinates = coordinates.permute((0, 2, 3, 1))

# normalize coordinates
coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / (w - 1) * 2.0 - 1.0
coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / (h - 1) * 2.0 - 1.0
grid = torch.stack(
[coordinates[:, :, :, 1], coordinates[:, :, :, 0]], dim=-1
)

grid = tnn.affine_grid(theta, image.shape)
affined = _apply_grid_transform(
image, grid, interpolation, fill_mode, fill_value
image,
grid,
interpolation=interpolation,
# if fill_mode==reflect, redirect to mirror
fill_mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
fill_value=fill_value,
)

if data_format == "channels_last":
Expand Down
28 changes: 17 additions & 11 deletions keras_core/layers/preprocessing/random_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,23 @@ class RandomTranslation(TFDataLayer):
left by 20%, and shifted right by 30%. `width_factor=0.2` results
in an output height shifted left or right by 20%.
fill_mode: Points outside the boundaries of the input are filled
according to the given mode
(one of `{"constant", "reflect", "wrap", "nearest"}`).
- *reflect*: `(d c b a | a b c d | d c b a)` The input is extended
by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)` The input is extended
by filling all values beyond the edge with the same constant
value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by
wrapping around to the opposite edge.
- *nearest*: `(a a a a | a b c d | d d d d)` The input is extended
by the nearest pixel.
according to the given mode. Available methods are `"constant"`,
`"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
- `"reflect"`: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last
pixel.
- `"constant"`: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond
the edge with the same constant value k specified by
`fill_value`.
- `"wrap"`: `(a b c d | a b c d | a b c d)`
The input is extended by wrapping around to the opposite edge.
- `"nearest"`: `(a a a a | a b c d | d d d d)`
The input is extended by the nearest pixel.
Note that when using torch backend, `"reflect"` is redirected to
`"mirror"` `(c d c b | a b c d | c b a b)` because torch does not
support `"reflect"`.
Note that torch backend does not support `"wrap"`.
interpolation: Interpolation mode. Supported values: `"nearest"`,
`"bilinear"`.
seed: Integer. Used to create a random seed.
Expand Down
129 changes: 90 additions & 39 deletions keras_core/layers/preprocessing/random_translation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,27 @@ def test_random_translation_with_inference_mode(self):
@parameterized.parameters(["channels_first", "channels_last"])
def test_random_translation_up_numeric_reflect(self, data_format):
input_image = np.arange(0, 25)
expected_output = np.asarray(
[
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[20, 21, 22, 23, 24],
]
)
if backend.backend() == "torch":
# redirect fill_mode=reflect to fill_mode=mirror
expected_output = np.asarray(
[
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[15, 16, 17, 18, 19],
]
)
else:
expected_output = np.asarray(
[
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[20, 21, 22, 23, 24],
]
)
if data_format == "channels_last":
input_image = np.reshape(input_image, (1, 5, 5, 1))
expected_output = backend.convert_to_tensor(
Expand Down Expand Up @@ -133,15 +145,27 @@ def test_random_translation_up_numeric_constant(self, data_format):
def test_random_translation_down_numeric_reflect(self, data_format):
input_image = np.arange(0, 25)
# Shifting by .2 * 5 = 1 pixel.
expected_output = np.asarray(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
]
)
if backend.backend() == "torch":
# redirect fill_mode=reflect to fill_mode=mirror
expected_output = np.asarray(
[
[5, 6, 7, 8, 9],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
]
)
else:
expected_output = np.asarray(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
]
)
if data_format == "channels_last":
input_image = np.reshape(input_image, (1, 5, 5, 1))
expected_output = backend.convert_to_tensor(
Expand Down Expand Up @@ -172,18 +196,33 @@ def test_random_translation_asymmetric_size_numeric_reflect(
):
input_image = np.arange(0, 16)
# Shifting by .2 * 5 = 1 pixel.
expected_output = np.asarray(
[
[6, 7],
[4, 5],
[2, 3],
[0, 1],
[0, 1],
[2, 3],
[4, 5],
[6, 7],
]
)
if backend.backend() == "torch":
# redirect fill_mode=reflect to fill_mode=mirror
expected_output = np.asarray(
[
[8, 9],
[6, 7],
[4, 5],
[2, 3],
[0, 1],
[2, 3],
[4, 5],
[6, 7],
]
)
else:
expected_output = np.asarray(
[
[6, 7],
[4, 5],
[2, 3],
[0, 1],
[0, 1],
[2, 3],
[4, 5],
[6, 7],
]
)
if data_format == "channels_last":
input_image = np.reshape(input_image, (1, 8, 2, 1))
expected_output = backend.convert_to_tensor(
Expand Down Expand Up @@ -251,15 +290,27 @@ def test_random_translation_down_numeric_constant(self, data_format):
def test_random_translation_left_numeric_reflect(self, data_format):
input_image = np.arange(0, 25)
# Shifting by .2 * 5 = 1 pixel.
expected_output = np.asarray(
[
[1, 2, 3, 4, 4],
[6, 7, 8, 9, 9],
[11, 12, 13, 14, 14],
[16, 17, 18, 19, 19],
[21, 22, 23, 24, 24],
]
)
if backend.backend() == "torch":
# redirect fill_mode=reflect to fill_mode=mirror
expected_output = np.asarray(
[
[1, 2, 3, 4, 3],
[6, 7, 8, 9, 8],
[11, 12, 13, 14, 13],
[16, 17, 18, 19, 18],
[21, 22, 23, 24, 23],
]
)
else:
expected_output = np.asarray(
[
[1, 2, 3, 4, 4],
[6, 7, 8, 9, 9],
[11, 12, 13, 14, 14],
[16, 17, 18, 19, 19],
[21, 22, 23, 24, 24],
]
)
if data_format == "channels_last":
input_image = np.reshape(input_image, (1, 5, 5, 1))
expected_output = backend.convert_to_tensor(
Expand Down
Loading