Add folder for mhlo::remainder

PiperOrigin-RevId: 335372628
This commit is contained in:
A. Unique TensorFlower 2020-10-05 02:19:10 -07:00 committed by TensorFlow MLIR Team
parent c9e249c124
commit 7367eac074
3 changed files with 39 additions and 1 deletions

View File

@ -353,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
let hasFolder = 1;
}
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;

View File

@ -2001,6 +2001,23 @@ struct divide<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
}; };
template <typename T>
struct remainder : std::modulus<T> {};
template <>
struct remainder<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
};
template <>
struct remainder<APFloat> {
APFloat operator()(const APFloat& a, const APFloat& b) const {
APFloat result(a);
result.remainder(b);
return result;
}
};
template <typename T> template <typename T>
struct max { struct max {
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); } T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
@ -2042,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
BINARY_FOLDER(SubOp, std::minus); BINARY_FOLDER(SubOp, std::minus);
BINARY_FOLDER(MulOp, std::multiplies); BINARY_FOLDER(MulOp, std::multiplies);
BINARY_FOLDER(DivOp, divide); BINARY_FOLDER(DivOp, divide);
BINARY_FOLDER(RemOp, remainder);
BINARY_FOLDER(MaxOp, max); BINARY_FOLDER(MaxOp, max);
BINARY_FOLDER(MinOp, min); BINARY_FOLDER(MinOp, min);

View File

@ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> {
return %2 : tensor<4xf64> return %2 : tensor<4xf64>
} }
// CHECK-LABEL: remainder_fold_int
func @remainder_fold_int() -> tensor<4xi32> {
%0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32>
%1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32>
// CHECK: mhlo.constant dense<[2, 1, 0, 1]>
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>)
return %2 : tensor<4xi32>
}
// CHECK-LABEL: remainder_fold_float
func @remainder_fold_float() -> tensor<4xf32> {
%0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32>
%1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32>
// CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]>
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
return %2 : tensor<4xf32>
}
// CHECK-LABEL: max_scalar_fold // CHECK-LABEL: max_scalar_fold
func @max_scalar_fold() -> tensor<4xi64> { func @max_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<7> : tensor<4xi64> %0 = mhlo.constant dense<7> : tensor<4xi64>