Removed Op(Complex, Real) lowering to address complex type inference issue

Lowerings that depended on operations between real and complex types may
not infer the correct intermediate type. Removing these operations as
they are not technically legally generated operations. Updated tests
to validate this.

PiperOrigin-RevId: 341128903
This commit is contained in:
Robert Suderman 2020-11-06 15:23:11 -08:00 committed by TensorFlow MLIR Team
parent 3dcd8b4ba2
commit a926e0f040
2 changed files with 55 additions and 56 deletions

View File

@ -51,40 +51,22 @@ def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
(HLO_MulOp $lhs_real, $rhs_imag),
(HLO_MulOp $lhs_imag, $rhs_real)))>;
// Multiplication between a complex and real tensor can be distributed by
// applying the real multiplicant to both the real and complex component.
//
// Note that the sourcep pattern is not legal according to the HLO dialect but
// instead handle intermediates generated by other patterns.
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
(HLO_ComplexOp
(HLO_MulOp (HLO_RealOp $lhs), $rhs),
(HLO_MulOp (HLO_ImagOp $lhs), $rhs))>;
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(HLO_MulOp $lhs, (HLO_RealOp $rhs)),
(HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>;
// Division is performed by normalizing the denominator by multiplying by the
// conjugate of the rhs.
// numerator = lhs * conj(rhs)
// denominator = rhs * conj(rhs)
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(HLO_DivOp
(HLO_MulOp:$num $lhs,
(HLO_RealOp (HLO_MulOp:$num $lhs,
(HLO_ComplexOp:$conj
(HLO_RealOp $rhs),
(HLO_NegOp (HLO_ImagOp $rhs)))),
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>;
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
(HLO_ComplexOp
(HLO_DivOp (HLO_RealOp $lhs), $rhs),
(HLO_DivOp (HLO_ImagOp $lhs), $rhs))>;
(HLO_NegOp (HLO_ImagOp $rhs))))),
(HLO_AddOp:$den
(HLO_MulOp (HLO_RealOp $rhs), (HLO_RealOp $rhs)),
(HLO_MulOp (HLO_ImagOp $rhs), (HLO_ImagOp $rhs)))),
(HLO_DivOp (HLO_ImagOp $num), $den))>;
// Absolute value is evaluated as:
// result = sqrt(val.real * val.real + val.imag * val.imag)
@ -98,10 +80,10 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
// sum of sinusoids of the imaginary component, which equates to a normal
// exponential operator multiplied by Euler's formula.
//
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b))
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * Cos(b) + Exp(a) * iSin(b))
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
(HLO_MulOp
(HLO_ExpOp (HLO_RealOp $val)),
(HLO_ComplexOp
(HLO_MulOp
(HLO_CosOp (HLO_ImagOp:$imag $val)),
(HLO_SinOp $imag)))>;
(HLO_ExpOp:$exp (HLO_RealOp:$real $val))),
(HLO_MulOp (HLO_SinOp $imag), $exp))>;

View File

@ -114,8 +114,8 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, %arg3
// CHECK-DAG: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
@ -153,8 +153,8 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, %arg3
// CHECK-DAG: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
@ -165,6 +165,7 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
%4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
@ -192,32 +193,48 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
// CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"(%arg0)
// CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"(%arg1)
// CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"(%arg1)
// CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]]
// CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]]
%1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
// CHECK: [[OUTR]], [[OUTI]]
return %2, %3 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @exp_unranked
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-LABEL: @exp_complex
func @exp_complex(%arg0 : tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>) {
// CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0)
// CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0)
// CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"([[REAL]])
// CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"([[IMAG]])
// CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"([[IMAG]])
// CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]]
// CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]]
// CHECK-DAG: [[OUT:%.+]] = "mhlo.complex"([[OUTR]], [[OUTI]])
%0 = "mhlo.exponential"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
%1 = "mhlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%3 = "mhlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<*xf32>, tensor<*xf32>
// CHECK: [[OUT]]
return %0 : tensor<2xcomplex<f32>>
}
// CHECK-LABEL: @exp_unranked
func @exp_unranked(%arg0 : tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>) {
// CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0)
// CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0)
// CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"([[REAL]])
// CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"([[IMAG]])
// CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"([[IMAG]])
// CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]]
// CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]]
// CHECK-DAG: [[OUT:%.+]] = "mhlo.complex"([[OUTR]], [[OUTI]])
%0 = "mhlo.exponential"(%arg0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
// CHECK: [[OUT]]
return %0 : tensor<*xcomplex<f32>>
}