Restrict CHLO Acos and Sinh op lowering to non complex types

These are failing for complex types. Complex types require special handling. We have a fallback lowering for these ops so we can disable complex element types for now.

PiperOrigin-RevId: 348205002
This commit is contained in:
Smit Hinsu 2020-12-18 11:31:03 -08:00 committed by TensorFlow MLIR Team
parent 8d051723c0
commit 9466cffaf3
1 changed files with 12 additions and 2 deletions

View File

@ -23,10 +23,18 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
// Unary op patterns. // Unary op patterns.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def NonComplexElementType : Type<
CPred<"!$_self.cast<ShapedType>().getElementType().isa<ComplexType>()">,
"Non complex element type">;
// Expand acos to MHLO dialect as follows: // Expand acos to MHLO dialect as follows:
// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 // acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
// = pi if x == -1 // = pi if x == -1
def : Pat<(HLOClient_AcosOp $input), //
// TODO(hinsu): Support operands with complex element types separately using
// the following formula.
// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
def : Pat<(HLOClient_AcosOp NonComplexElementType:$input),
(HLO_SelectOp (HLO_SelectOp
(HLO_CompareOp (HLO_CompareOp
$input, $input,
@ -68,7 +76,9 @@ def : Pat<(HLOClient_ConjOp $v),
// Express `sinh` as // Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1 // sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. // = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
def : Pat<(HLOClient_SinhOp $input), // TODO(hinsu): Support operands with complex element types by always using the
// second formula. The compare op below is not legal for complex numbers.
def : Pat<(HLOClient_SinhOp NonComplexElementType:$input),
(HLO_SelectOp (HLO_SelectOp
(HLO_CompareOp (HLO_CompareOp
(HLO_AbsOp $input), (HLO_AbsOp $input),