Modified HLOAbsOp lowering for differing types.
PiperOrigin-RevId: 323082107
This commit is contained in:
parent
882468da13
commit
8023baa959
|
@ -89,12 +89,10 @@ def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
|||
// Absolute value is evaluated as:
|
||||
// result = sqrt(val.real * val.real + val.imag * val.imag)
|
||||
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
|
||||
(HLO_ComplexOp
|
||||
(HLO_SqrtOp
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp (HLO_RealOp:$real $val), $real),
|
||||
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag))),
|
||||
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
|
||||
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>;
|
||||
|
||||
// Exponential can be lowered to an exponential on the real component and a
|
||||
// sum of sinusoids of the imaginary component, which equates to a normal
|
||||
|
|
|
@ -182,11 +182,10 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
|
|||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
|
||||
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]]
|
||||
return %2 : tensor<2xf32>
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp
|
||||
|
|
Loading…
Reference in New Issue