diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index fb4bf3a..a48abb6 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -24,16 +24,17 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td" //===----------------------------------------------------------------------===// // Expand acos to MHLO dialect as follows: -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 // = pi if x == -1 def : Pat<(HLOClient_AcosOp $input), (HLO_SelectOp - (HLO_CompareOp $input, - (HLO_ConstantLike<"0"> $input), + (HLO_CompareOp + $input, + (HLO_ConstantLike<"-1"> $input), HLO_COMPARISON_DIRECTION_NE ), (HLO_MulOp - (HLO_ConstantLike<"2.0f"> $input), + (HLO_ConstantLike<"2"> $input), (HLO_Atan2Op (HLO_SqrtOp (HLO_SubOp @@ -47,7 +48,8 @@ def : Pat<(HLOClient_AcosOp $input), ) ) ), - (HLO_ConstantLike<"M_PI"> $input))>; + (HLO_ConstantLike<"M_PI"> $input) + )>; // Express `atan` as // atan(x) = atan2(x, 1)