Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conventions by extrinsic rotations #1364

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions pytorch3d/transforms/rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch

Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3).
convention: Convention string of three uppercase letters from
{"X", "Y", and "Z"}.
convention: Convention string of three letters from
{"X", "Y", "Z"} for intrinsic rotations, or {"x", "y", "z"}
for extrinsic rotations.

Returns:
Rotation matrices as tensor of shape (..., 3, 3).
Expand All @@ -210,12 +211,16 @@ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
if letter.upper() not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = [
_axis_angle_rotation(c, e)
for c, e in zip(convention, torch.unbind(euler_angles, -1))
]
angles = torch.unbind(euler_angles, -1)
if convention == convention.lower():
# Convert extrinsic to intrinsic rotations
convention = convention[::-1].upper()
angles = angles[::-1]
elif convention != convention.upper():
raise ValueError(f"Invalid convention {convention}.")
matrices = [_axis_angle_rotation(c, e) for c, e in zip(convention, angles)]
# return functools.reduce(torch.matmul, matrices)
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])

Expand Down Expand Up @@ -269,7 +274,9 @@ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tenso

Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
convention: Convention string of three uppercase letters.
convention: Convention string of three letters from
{"X", "Y", "Z"} for intrinsic rotations, or {"x", "y", "z"}
for extrinsic rotations.

Returns:
Euler angles in radians as tensor of shape (..., 3).
Expand All @@ -279,8 +286,14 @@ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tenso
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
if letter.upper() not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
extrinsic = False
if convention == convention.lower():
extrinsic = True
convention = convention[::-1].upper()
elif convention != convention.upper():
raise ValueError(f"Invalid convention {convention}.")
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
i0 = _index_from_letter(convention[0])
Expand All @@ -302,6 +315,8 @@ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tenso
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
),
)
if extrinsic:
o = o[::-1]
return torch.stack(o, -1)


Expand Down
40 changes: 28 additions & 12 deletions tests/test_rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,38 @@ def test_quat_grad_exists(self):
[g] = torch.autograd.grad(modified.sum(), rotation)
self.assertTrue(torch.isfinite(g).all())

def _tait_bryan_conventions(self):
def _tait_bryan_intrinsic_conventions(self):
return map("".join, itertools.permutations("XYZ"))

def _proper_euler_conventions(self):
def _proper_euler_intrinsic_conventions(self):
letterpairs = itertools.permutations("XYZ", 2)
return (l0 + l1 + l0 for l0, l1 in letterpairs)

def _all_euler_intrinsic_angle_conventions(self):
return itertools.chain(
self._tait_bryan_intrinsic_conventions(),
self._proper_euler_intrinsic_conventions(),
)

def _all_euler_angle_conventions(self):
return itertools.chain(
self._tait_bryan_conventions(), self._proper_euler_conventions()
self._all_euler_intrinsic_angle_conventions(),
(c.lower() for c in self._all_euler_intrinsic_angle_conventions()),
)

def test_conventions(self):
"""The conventions listings have the right length."""
all = list(self._all_euler_angle_conventions())
all = list(self._all_euler_intrinsic_angle_conventions())
self.assertEqual(len(all), 12)
self.assertEqual(len(set(all)), 12)

data = random_rotations(13, dtype=torch.float64)
for convention_intr in self._all_euler_intrinsic_angle_conventions():
convention_extr = convention_intr[::-1].lower()
euler_angles_intr = matrix_to_euler_angles(data, convention_intr)
euler_angles_extr = matrix_to_euler_angles(data, convention_extr)
self.assertClose(euler_angles_intr, torch.flip(euler_angles_extr, [-1]))

def test_from_euler(self):
"""euler -> mtx -> euler"""
n_repetitions = 10
Expand All @@ -117,16 +131,18 @@ def test_from_euler(self):
data.uniform_(-math.pi, math.pi)

data[:, 1].uniform_(-half_pi + tolerance, half_pi - tolerance)
for convention in self._tait_bryan_conventions():
matrices = euler_angles_to_matrix(data, convention)
mdata = matrix_to_euler_angles(matrices, convention)
self.assertClose(data, mdata)
for convention_intr in self._tait_bryan_intrinsic_conventions():
for convention in [convention_intr, convention_intr.lower()]:
matrices = euler_angles_to_matrix(data, convention)
mdata = matrix_to_euler_angles(matrices, convention)
self.assertClose(data, mdata)

data[:, 1] += half_pi
for convention in self._proper_euler_conventions():
matrices = euler_angles_to_matrix(data, convention)
mdata = matrix_to_euler_angles(matrices, convention)
self.assertClose(data, mdata)
for convention_intr in self._proper_euler_intrinsic_conventions():
for convention in [convention_intr, convention_intr.lower()]:
matrices = euler_angles_to_matrix(data, convention)
mdata = matrix_to_euler_angles(matrices, convention)
self.assertClose(data, mdata)

def test_to_euler(self):
"""mtx -> euler -> mtx"""
Expand Down