Add folder for mhlo::remainder
PiperOrigin-RevId: 335372628
This commit is contained in:
		
							parent
							
								
									c9e249c124
								
							
						
					
					
						commit
						7367eac074
					
				|  | @ -353,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power", | ||||||
|       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; |       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; | ||||||
| 
 | 
 | ||||||
| def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", | def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", | ||||||
|       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; |       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp { | ||||||
|  |   let hasFolder = 1; | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", | def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", | ||||||
|       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; |       [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; | ||||||
|  |  | ||||||
|  | @ -2001,6 +2001,23 @@ struct divide<APInt> { | ||||||
|   APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } |   APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | template <typename T> | ||||||
|  | struct remainder : std::modulus<T> {}; | ||||||
|  | 
 | ||||||
|  | template <> | ||||||
|  | struct remainder<APInt> { | ||||||
|  |   APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | template <> | ||||||
|  | struct remainder<APFloat> { | ||||||
|  |   APFloat operator()(const APFloat& a, const APFloat& b) const { | ||||||
|  |     APFloat result(a); | ||||||
|  |     result.remainder(b); | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| template <typename T> | template <typename T> | ||||||
| struct max { | struct max { | ||||||
|   T operator()(const T& a, const T& b) const { return std::max<T>(a, b); } |   T operator()(const T& a, const T& b) const { return std::max<T>(a, b); } | ||||||
|  | @ -2042,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus); | ||||||
| BINARY_FOLDER(SubOp, std::minus); | BINARY_FOLDER(SubOp, std::minus); | ||||||
| BINARY_FOLDER(MulOp, std::multiplies); | BINARY_FOLDER(MulOp, std::multiplies); | ||||||
| BINARY_FOLDER(DivOp, divide); | BINARY_FOLDER(DivOp, divide); | ||||||
|  | BINARY_FOLDER(RemOp, remainder); | ||||||
| BINARY_FOLDER(MaxOp, max); | BINARY_FOLDER(MaxOp, max); | ||||||
| BINARY_FOLDER(MinOp, min); | BINARY_FOLDER(MinOp, min); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> { | ||||||
|   return %2 : tensor<4xf64> |   return %2 : tensor<4xf64> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CHECK-LABEL: remainder_fold_int | ||||||
|  | func @remainder_fold_int() -> tensor<4xi32> { | ||||||
|  |   %0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32> | ||||||
|  |   %1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32> | ||||||
|  |   // CHECK: mhlo.constant dense<[2, 1, 0, 1]> | ||||||
|  |   %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) | ||||||
|  |   return %2 : tensor<4xi32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: remainder_fold_float | ||||||
|  | func @remainder_fold_float() -> tensor<4xf32> { | ||||||
|  |   %0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32> | ||||||
|  |   %1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32> | ||||||
|  |   // CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]> | ||||||
|  |   %2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) | ||||||
|  |   return %2 : tensor<4xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: max_scalar_fold | // CHECK-LABEL: max_scalar_fold | ||||||
| func @max_scalar_fold() -> tensor<4xi64> { | func @max_scalar_fold() -> tensor<4xi64> { | ||||||
|   %0 = mhlo.constant dense<7> : tensor<4xi64> |   %0 = mhlo.constant dense<7> : tensor<4xi64> | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue