Modified HLOAbsOp lowering for differing types.
PiperOrigin-RevId: 323082107
This commit is contained in:
		
							parent
							
								
									882468da13
								
							
						
					
					
						commit
						8023baa959
					
				|  | @ -89,12 +89,10 @@ def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), | ||||||
| // Absolute value is evaluated as: | // Absolute value is evaluated as: | ||||||
| //   result = sqrt(val.real * val.real + val.imag * val.imag) | //   result = sqrt(val.real * val.real + val.imag * val.imag) | ||||||
| def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), | def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), | ||||||
|           (HLO_ComplexOp |  | ||||||
|            (HLO_SqrtOp |            (HLO_SqrtOp | ||||||
|              (HLO_AddOp |              (HLO_AddOp | ||||||
|               (HLO_MulOp (HLO_RealOp:$real $val), $real), |               (HLO_MulOp (HLO_RealOp:$real $val), $real), | ||||||
|               (HLO_MulOp (HLO_ImagOp:$imag $val), $imag))), |               (HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>; | ||||||
|            (HLO_ConstOp (ConstantSplat<"0"> $real)))>; |  | ||||||
| 
 | 
 | ||||||
| // Exponential can be lowered to an exponential on the real component and a | // Exponential can be lowered to an exponential on the real component and a | ||||||
| // sum of sinusoids of the imaginary component, which equates to a normal | // sum of sinusoids of the imaginary component, which equates to a normal | ||||||
|  |  | ||||||
|  | @ -182,11 +182,10 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { | ||||||
|   // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1 |   // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1 | ||||||
|   // CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]] |   // CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]] | ||||||
|   // CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]]) |   // CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]]) | ||||||
|   %1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>) |   %1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) | ||||||
|   %2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>) |  | ||||||
| 
 | 
 | ||||||
|   // CHECK: return [[VAL3]] |   // CHECK: return [[VAL3]] | ||||||
|   return %2 : tensor<2xf32> |   return %1 : tensor<2xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // CHECK-LABEL: @exp | // CHECK-LABEL: @exp | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue