Add folder for mhlo::remainder
PiperOrigin-RevId: 335372628
This commit is contained in:
		
							parent
							
								
									c9e249c124
								
							
						
					
					
						commit
						7367eac074
					
				|  | @ -353,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power", | |||
|       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; | ||||
| 
 | ||||
| def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", | ||||
|       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; | ||||
|       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp { | ||||
|   let hasFolder = 1; | ||||
| } | ||||
| 
 | ||||
| def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", | ||||
|       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; | ||||
|  |  | |||
|  | @ -2001,6 +2001,23 @@ struct divide<APInt> { | |||
|   APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| struct remainder : std::modulus<T> {}; | ||||
| 
 | ||||
| template <> | ||||
| struct remainder<APInt> { | ||||
|   APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); } | ||||
| }; | ||||
| 
 | ||||
| template <> | ||||
| struct remainder<APFloat> { | ||||
|   APFloat operator()(const APFloat& a, const APFloat& b) const { | ||||
|     APFloat result(a); | ||||
|     result.remainder(b); | ||||
|     return result; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| struct max { | ||||
|   T operator()(const T& a, const T& b) const { return std::max<T>(a, b); } | ||||
|  | @ -2042,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus); | |||
| BINARY_FOLDER(SubOp, std::minus); | ||||
| BINARY_FOLDER(MulOp, std::multiplies); | ||||
| BINARY_FOLDER(DivOp, divide); | ||||
| BINARY_FOLDER(RemOp, remainder); | ||||
| BINARY_FOLDER(MaxOp, max); | ||||
| BINARY_FOLDER(MinOp, min); | ||||
| 
 | ||||
|  |  | |||
|  | @ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> { | |||
|   return %2 : tensor<4xf64> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: remainder_fold_int | ||||
| func @remainder_fold_int() -> tensor<4xi32> { | ||||
|   %0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32> | ||||
|   %1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32> | ||||
|   // CHECK: mhlo.constant dense<[2, 1, 0, 1]> | ||||
|   %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) | ||||
|   return %2 : tensor<4xi32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: remainder_fold_float | ||||
| func @remainder_fold_float() -> tensor<4xf32> { | ||||
|   %0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32> | ||||
|   %1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32> | ||||
|   // CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]> | ||||
|   %2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) | ||||
|   return %2 : tensor<4xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: max_scalar_fold | ||||
| func @max_scalar_fold() -> tensor<4xi64> { | ||||
|   %0 = mhlo.constant dense<7> : tensor<4xi64> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue