Modified HLOAbsOp lowering for differing types.

PiperOrigin-RevId: 323082107
This commit is contained in:
Robert Suderman 2020-07-24 15:17:48 -07:00 committed by TensorFlow MLIR Team
parent 882468da13
commit 8023baa959
2 changed files with 3 additions and 6 deletions

View File

@ -89,12 +89,10 @@ def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
// Absolute value is evaluated as:
// result = sqrt(val.real * val.real + val.imag * val.imag)
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
(HLO_ComplexOp
(HLO_SqrtOp
(HLO_AddOp
(HLO_MulOp (HLO_RealOp:$real $val), $real),
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag))),
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>;
// Exponential can be lowered to an exponential on the real component and a
// sum of sinusoids of the imaginary component, which equates to a normal

View File

@ -182,11 +182,10 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]]
return %2 : tensor<2xf32>
return %1 : tensor<2xf32>
}
// CHECK-LABEL: @exp