diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index b5f73bf5b..d7fea675a 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -197,8 +197,9 @@ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch Args: euler_angles: Euler angles in radians as tensor of shape (..., 3). - convention: Convention string of three uppercase letters from - {"X", "Y", and "Z"}. + convention: Convention string of three letters from + {"X", "Y", "Z"} for intrinsic rotations, or {"x", "y", "z"} + for extrinsic rotations. Returns: Rotation matrices as tensor of shape (..., 3, 3). @@ -210,12 +211,16 @@ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: - if letter not in ("X", "Y", "Z"): + if letter.upper() not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") - matrices = [ - _axis_angle_rotation(c, e) - for c, e in zip(convention, torch.unbind(euler_angles, -1)) - ] + angles = torch.unbind(euler_angles, -1) + if convention == convention.lower(): + # Convert extrinsic to intrinsic rotations + convention = convention[::-1].upper() + angles = angles[::-1] + elif convention != convention.upper(): + raise ValueError(f"Invalid convention {convention}.") + matrices = [_axis_angle_rotation(c, e) for c, e in zip(convention, angles)] # return functools.reduce(torch.matmul, matrices) return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) @@ -269,7 +274,9 @@ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tenso Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). - convention: Convention string of three uppercase letters. + convention: Convention string of three letters from + {"X", "Y", "Z"} for intrinsic rotations, or {"x", "y", "z"} + for extrinsic rotations. Returns: Euler angles in radians as tensor of shape (..., 3). @@ -279,8 +286,14 @@ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tenso if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: - if letter not in ("X", "Y", "Z"): + if letter.upper() not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") + extrinsic = False + if convention == convention.lower(): + extrinsic = True + convention = convention[::-1].upper() + elif convention != convention.upper(): + raise ValueError(f"Invalid convention {convention}.") if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") i0 = _index_from_letter(convention[0]) @@ -302,6 +315,8 @@ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tenso convention[2], convention[1], matrix[..., i0, :], True, tait_bryan ), ) + if extrinsic: + o = o[::-1] return torch.stack(o, -1) diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index 7090d3ca3..142005c04 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -88,24 +88,38 @@ def test_quat_grad_exists(self): [g] = torch.autograd.grad(modified.sum(), rotation) self.assertTrue(torch.isfinite(g).all()) - def _tait_bryan_conventions(self): + def _tait_bryan_intrinsic_conventions(self): return map("".join, itertools.permutations("XYZ")) - def _proper_euler_conventions(self): + def _proper_euler_intrinsic_conventions(self): letterpairs = itertools.permutations("XYZ", 2) return (l0 + l1 + l0 for l0, l1 in letterpairs) + def _all_euler_intrinsic_angle_conventions(self): + return itertools.chain( + self._tait_bryan_intrinsic_conventions(), + self._proper_euler_intrinsic_conventions(), + ) + def _all_euler_angle_conventions(self): return itertools.chain( - self._tait_bryan_conventions(), self._proper_euler_conventions() + self._all_euler_intrinsic_angle_conventions(), + (c.lower() for c in self._all_euler_intrinsic_angle_conventions()), ) def test_conventions(self): """The conventions listings have the right length.""" - all = list(self._all_euler_angle_conventions()) + all = list(self._all_euler_intrinsic_angle_conventions()) self.assertEqual(len(all), 12) self.assertEqual(len(set(all)), 12) + data = random_rotations(13, dtype=torch.float64) + for convention_intr in self._all_euler_intrinsic_angle_conventions(): + convention_extr = convention_intr[::-1].lower() + euler_angles_intr = matrix_to_euler_angles(data, convention_intr) + euler_angles_extr = matrix_to_euler_angles(data, convention_extr) + self.assertClose(euler_angles_intr, torch.flip(euler_angles_extr, [-1])) + def test_from_euler(self): """euler -> mtx -> euler""" n_repetitions = 10 @@ -117,16 +131,18 @@ def test_from_euler(self): data.uniform_(-math.pi, math.pi) data[:, 1].uniform_(-half_pi + tolerance, half_pi - tolerance) - for convention in self._tait_bryan_conventions(): - matrices = euler_angles_to_matrix(data, convention) - mdata = matrix_to_euler_angles(matrices, convention) - self.assertClose(data, mdata) + for convention_intr in self._tait_bryan_intrinsic_conventions(): + for convention in [convention_intr, convention_intr.lower()]: + matrices = euler_angles_to_matrix(data, convention) + mdata = matrix_to_euler_angles(matrices, convention) + self.assertClose(data, mdata) data[:, 1] += half_pi - for convention in self._proper_euler_conventions(): - matrices = euler_angles_to_matrix(data, convention) - mdata = matrix_to_euler_angles(matrices, convention) - self.assertClose(data, mdata) + for convention_intr in self._proper_euler_intrinsic_conventions(): + for convention in [convention_intr, convention_intr.lower()]: + matrices = euler_angles_to_matrix(data, convention) + mdata = matrix_to_euler_angles(matrices, convention) + self.assertClose(data, mdata) def test_to_euler(self): """mtx -> euler -> mtx"""