From a926e0f0404fc541723faadf1e65c9877e867a9f Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Fri, 6 Nov 2020 15:23:11 -0800 Subject: [PATCH] Removed Op(Complex, Real) lowering to address complex type inference issue Lowerings that depended on operations between real and complex types may not infer the correct intermediate type. Removing these operations as they are not technically legally generated operations. Updated tests to validate this. PiperOrigin-RevId: 341128903 --- .../mhlo/transforms/lower_complex_patterns.td | 46 ++++--------- tests/lower-complex.mlir | 65 ++++++++++++------- 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/lower_complex_patterns.td b/lib/Dialect/mhlo/transforms/lower_complex_patterns.td index 2cc97c9..d132297 100644 --- a/lib/Dialect/mhlo/transforms/lower_complex_patterns.td +++ b/lib/Dialect/mhlo/transforms/lower_complex_patterns.td @@ -51,40 +51,22 @@ def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, (HLO_MulOp $lhs_real, $rhs_imag), (HLO_MulOp $lhs_imag, $rhs_real)))>; -// Multiplication between a complex and real tensor can be distributed by -// applying the real multiplicant to both the real and complex component. -// -// Note that the sourcep pattern is not legal according to the HLO dialect but -// instead handle intermediates generated by other patterns. -def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), - (HLO_ComplexOp - (HLO_MulOp (HLO_RealOp $lhs), $rhs), - (HLO_MulOp (HLO_ImagOp $lhs), $rhs))>; - -def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs), - (HLO_ComplexOp - (HLO_MulOp $lhs, (HLO_RealOp $rhs)), - (HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>; - // Division is performed by normalizing the denominator by multiplying by the // conjugate of the rhs. // numerator = lhs * conj(rhs) // denominator = rhs * conj(rhs) def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs), - (HLO_DivOp - (HLO_MulOp:$num $lhs, - (HLO_ComplexOp:$conj - (HLO_RealOp $rhs), - (HLO_NegOp (HLO_ImagOp $rhs)))), - (HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>; - - -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_DivOp (HLO_RealOp $lhs), $rhs), - (HLO_DivOp (HLO_ImagOp $lhs), $rhs))>; - + (HLO_DivOp + (HLO_RealOp (HLO_MulOp:$num $lhs, + (HLO_ComplexOp:$conj + (HLO_RealOp $rhs), + (HLO_NegOp (HLO_ImagOp $rhs))))), + (HLO_AddOp:$den + (HLO_MulOp (HLO_RealOp $rhs), (HLO_RealOp $rhs)), + (HLO_MulOp (HLO_ImagOp $rhs), (HLO_ImagOp $rhs)))), + (HLO_DivOp (HLO_ImagOp $num), $den))>; // Absolute value is evaluated as: // result = sqrt(val.real * val.real + val.imag * val.imag) @@ -98,10 +80,10 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), // sum of sinusoids of the imaginary component, which equates to a normal // exponential operator multiplied by Euler's formula. // -// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b)) +// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * Cos(b) + Exp(a) * iSin(b)) def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), - (HLO_MulOp - (HLO_ExpOp (HLO_RealOp $val)), - (HLO_ComplexOp + (HLO_ComplexOp + (HLO_MulOp (HLO_CosOp (HLO_ImagOp:$imag $val)), - (HLO_SinOp $imag)))>; + (HLO_ExpOp:$exp (HLO_RealOp:$real $val))), + (HLO_MulOp (HLO_SinOp $imag), $exp))>; diff --git a/tests/lower-complex.mlir b/tests/lower-complex.mlir index b9c91d6..141c238 100644 --- a/tests/lower-complex.mlir +++ b/tests/lower-complex.mlir @@ -114,8 +114,8 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // Compute the real valued denominator as rhs * con(rhs): // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]] + // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, %arg3 + // CHECK-DAG: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]] // Compute the numerator's imaginary component: // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag @@ -153,8 +153,8 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< // Compute the real valued denominator as rhs * con(rhs): // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]] + // CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, %arg3 + // CHECK-DAG: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]] // Compute the numerator's imaginary component: // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag @@ -165,6 +165,7 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor< // Divide the numerator by the real valued denominator. // CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]] // CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]] + %4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) %5 = "mhlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) @@ -192,32 +193,48 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]] + // CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"(%arg0) + // CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"(%arg1) + // CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"(%arg1) + // CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]] + // CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]] %1 = "mhlo.exponential"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) + %2 = "mhlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) %3 = "mhlo.imag"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) - // CHECK: return [[VAL3]], [[VAL4]] + // CHECK: [[OUTR]], [[OUTI]] return %2, %3 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @exp_unranked -func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { - %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) +// CHECK-LABEL: @exp_complex +func @exp_complex(%arg0 : tensor<2xcomplex>) -> (tensor<2xcomplex>) { + // CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0) + // CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0) + // CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"([[REAL]]) + // CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"([[IMAG]]) + // CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"([[IMAG]]) + // CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]] + // CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]] + // CHECK-DAG: [[OUT:%.+]] = "mhlo.complex"([[OUTR]], [[OUTI]]) + %0 = "mhlo.exponential"(%arg0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) - // CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0) - // CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1) - // CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1) - // CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]] - %1 = "mhlo.exponential"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) - %2 = "mhlo.real"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) - %3 = "mhlo.imag"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) - - // CHECK: return [[VAL3]], [[VAL4]] - return %2, %3 : tensor<*xf32>, tensor<*xf32> + // CHECK: [[OUT]] + return %0 : tensor<2xcomplex> +} + +// CHECK-LABEL: @exp_unranked +func @exp_unranked(%arg0 : tensor<*xcomplex>) -> (tensor<*xcomplex>) { + // CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0) + // CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0) + // CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"([[REAL]]) + // CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"([[IMAG]]) + // CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"([[IMAG]]) + // CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]] + // CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]] + // CHECK-DAG: [[OUT:%.+]] = "mhlo.complex"([[OUTR]], [[OUTI]]) + %0 = "mhlo.exponential"(%arg0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) + + // CHECK: [[OUT]] + return %0 : tensor<*xcomplex> }