Removed Op(Complex, Real) lowering to address complex type inference issue
Lowerings that depended on operations between real and complex types may not infer the correct intermediate type. Removing these operations as they are not technically legally generated operations. Updated tests to validate this. PiperOrigin-RevId: 341128903
This commit is contained in:
parent
3dcd8b4ba2
commit
a926e0f040
|
@ -51,40 +51,22 @@ def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
|
||||||
(HLO_MulOp $lhs_real, $rhs_imag),
|
(HLO_MulOp $lhs_real, $rhs_imag),
|
||||||
(HLO_MulOp $lhs_imag, $rhs_real)))>;
|
(HLO_MulOp $lhs_imag, $rhs_real)))>;
|
||||||
|
|
||||||
// Multiplication between a complex and real tensor can be distributed by
|
|
||||||
// applying the real multiplicant to both the real and complex component.
|
|
||||||
//
|
|
||||||
// Note that the sourcep pattern is not legal according to the HLO dialect but
|
|
||||||
// instead handle intermediates generated by other patterns.
|
|
||||||
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
|
||||||
(HLO_ComplexOp
|
|
||||||
(HLO_MulOp (HLO_RealOp $lhs), $rhs),
|
|
||||||
(HLO_MulOp (HLO_ImagOp $lhs), $rhs))>;
|
|
||||||
|
|
||||||
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs),
|
|
||||||
(HLO_ComplexOp
|
|
||||||
(HLO_MulOp $lhs, (HLO_RealOp $rhs)),
|
|
||||||
(HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>;
|
|
||||||
|
|
||||||
|
|
||||||
// Division is performed by normalizing the denominator by multiplying by the
|
// Division is performed by normalizing the denominator by multiplying by the
|
||||||
// conjugate of the rhs.
|
// conjugate of the rhs.
|
||||||
// numerator = lhs * conj(rhs)
|
// numerator = lhs * conj(rhs)
|
||||||
// denominator = rhs * conj(rhs)
|
// denominator = rhs * conj(rhs)
|
||||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
|
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
|
||||||
|
(HLO_ComplexOp
|
||||||
(HLO_DivOp
|
(HLO_DivOp
|
||||||
(HLO_MulOp:$num $lhs,
|
(HLO_RealOp (HLO_MulOp:$num $lhs,
|
||||||
(HLO_ComplexOp:$conj
|
(HLO_ComplexOp:$conj
|
||||||
(HLO_RealOp $rhs),
|
(HLO_RealOp $rhs),
|
||||||
(HLO_NegOp (HLO_ImagOp $rhs)))),
|
(HLO_NegOp (HLO_ImagOp $rhs))))),
|
||||||
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>;
|
(HLO_AddOp:$den
|
||||||
|
(HLO_MulOp (HLO_RealOp $rhs), (HLO_RealOp $rhs)),
|
||||||
|
(HLO_MulOp (HLO_ImagOp $rhs), (HLO_ImagOp $rhs)))),
|
||||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
(HLO_DivOp (HLO_ImagOp $num), $den))>;
|
||||||
(HLO_ComplexOp
|
|
||||||
(HLO_DivOp (HLO_RealOp $lhs), $rhs),
|
|
||||||
(HLO_DivOp (HLO_ImagOp $lhs), $rhs))>;
|
|
||||||
|
|
||||||
|
|
||||||
// Absolute value is evaluated as:
|
// Absolute value is evaluated as:
|
||||||
// result = sqrt(val.real * val.real + val.imag * val.imag)
|
// result = sqrt(val.real * val.real + val.imag * val.imag)
|
||||||
|
@ -98,10 +80,10 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
|
||||||
// sum of sinusoids of the imaginary component, which equates to a normal
|
// sum of sinusoids of the imaginary component, which equates to a normal
|
||||||
// exponential operator multiplied by Euler's formula.
|
// exponential operator multiplied by Euler's formula.
|
||||||
//
|
//
|
||||||
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b))
|
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * Cos(b) + Exp(a) * iSin(b))
|
||||||
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
|
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
|
||||||
(HLO_MulOp
|
|
||||||
(HLO_ExpOp (HLO_RealOp $val)),
|
|
||||||
(HLO_ComplexOp
|
(HLO_ComplexOp
|
||||||
|
(HLO_MulOp
|
||||||
(HLO_CosOp (HLO_ImagOp:$imag $val)),
|
(HLO_CosOp (HLO_ImagOp:$imag $val)),
|
||||||
(HLO_SinOp $imag)))>;
|
(HLO_ExpOp:$exp (HLO_RealOp:$real $val))),
|
||||||
|
(HLO_MulOp (HLO_SinOp $imag), $exp))>;
|
||||||
|
|
|
@ -114,8 +114,8 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
||||||
// Compute the real valued denominator as rhs * con(rhs):
|
// Compute the real valued denominator as rhs * con(rhs):
|
||||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
||||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
|
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, %arg3
|
||||||
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
|
// CHECK-DAG: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]]
|
||||||
|
|
||||||
// Compute the numerator's imaginary component:
|
// Compute the numerator's imaginary component:
|
||||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||||
|
@ -153,8 +153,8 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
||||||
// Compute the real valued denominator as rhs * con(rhs):
|
// Compute the real valued denominator as rhs * con(rhs):
|
||||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
||||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
|
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, %arg3
|
||||||
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
|
// CHECK-DAG: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]]
|
||||||
|
|
||||||
// Compute the numerator's imaginary component:
|
// Compute the numerator's imaginary component:
|
||||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||||
|
@ -165,6 +165,7 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
||||||
// Divide the numerator by the real valued denominator.
|
// Divide the numerator by the real valued denominator.
|
||||||
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
|
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
|
||||||
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
|
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
|
||||||
|
|
||||||
%4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
%4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||||
|
|
||||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||||
|
@ -192,32 +193,48 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
|
||||||
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||||
|
|
||||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
|
// CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"(%arg0)
|
||||||
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
|
// CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"(%arg1)
|
||||||
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
|
// CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"(%arg1)
|
||||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
|
// CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]]
|
||||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
|
// CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]]
|
||||||
%1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
%1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||||
|
|
||||||
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||||
%3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
%3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||||
|
|
||||||
// CHECK: return [[VAL3]], [[VAL4]]
|
// CHECK: [[OUTR]], [[OUTI]]
|
||||||
return %2, %3 : tensor<2xf32>, tensor<2xf32>
|
return %2, %3 : tensor<2xf32>, tensor<2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @exp_unranked
|
// CHECK-LABEL: @exp_complex
|
||||||
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
func @exp_complex(%arg0 : tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>) {
|
||||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
// CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0)
|
||||||
|
// CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0)
|
||||||
|
// CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"([[REAL]])
|
||||||
|
// CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"([[IMAG]])
|
||||||
|
// CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"([[IMAG]])
|
||||||
|
// CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]]
|
||||||
|
// CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]]
|
||||||
|
// CHECK-DAG: [[OUT:%.+]] = "mhlo.complex"([[OUTR]], [[OUTI]])
|
||||||
|
%0 = "mhlo.exponential"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||||
|
|
||||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
|
// CHECK: [[OUT]]
|
||||||
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
|
return %0 : tensor<2xcomplex<f32>>
|
||||||
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
|
}
|
||||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
|
|
||||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
|
// CHECK-LABEL: @exp_unranked
|
||||||
%1 = "mhlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
func @exp_unranked(%arg0 : tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>) {
|
||||||
%2 = "mhlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
// CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0)
|
||||||
%3 = "mhlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
// CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0)
|
||||||
|
// CHECK-DAG: [[EXP:%.+]] = "mhlo.exponential"([[REAL]])
|
||||||
// CHECK: return [[VAL3]], [[VAL4]]
|
// CHECK-DAG: [[COS:%.+]] = "mhlo.cosine"([[IMAG]])
|
||||||
return %2, %3 : tensor<*xf32>, tensor<*xf32>
|
// CHECK-DAG: [[SIN:%.+]] = "mhlo.sine"([[IMAG]])
|
||||||
|
// CHECK-DAG: [[OUTR:%.+]] = mhlo.multiply [[COS]], [[EXP]]
|
||||||
|
// CHECK-DAG: [[OUTI:%.+]] = mhlo.multiply [[SIN]], [[EXP]]
|
||||||
|
// CHECK-DAG: [[OUT:%.+]] = "mhlo.complex"([[OUTR]], [[OUTI]])
|
||||||
|
%0 = "mhlo.exponential"(%arg0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||||
|
|
||||||
|
// CHECK: [[OUT]]
|
||||||
|
return %0 : tensor<*xcomplex<f32>>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue