Add folder for mhlo::remainder
PiperOrigin-RevId: 335372628
This commit is contained in:
parent
c9e249c124
commit
7367eac074
|
@ -353,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
|
|||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
||||
|
||||
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
||||
|
|
|
@ -2001,6 +2001,23 @@ struct divide<APInt> {
|
|||
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct remainder : std::modulus<T> {};
|
||||
|
||||
template <>
|
||||
struct remainder<APInt> {
|
||||
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct remainder<APFloat> {
|
||||
APFloat operator()(const APFloat& a, const APFloat& b) const {
|
||||
APFloat result(a);
|
||||
result.remainder(b);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct max {
|
||||
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
||||
|
@ -2042,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
|
|||
BINARY_FOLDER(SubOp, std::minus);
|
||||
BINARY_FOLDER(MulOp, std::multiplies);
|
||||
BINARY_FOLDER(DivOp, divide);
|
||||
BINARY_FOLDER(RemOp, remainder);
|
||||
BINARY_FOLDER(MaxOp, max);
|
||||
BINARY_FOLDER(MinOp, min);
|
||||
|
||||
|
|
|
@ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> {
|
|||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: remainder_fold_int
|
||||
func @remainder_fold_int() -> tensor<4xi32> {
|
||||
%0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32>
|
||||
%1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32>
|
||||
// CHECK: mhlo.constant dense<[2, 1, 0, 1]>
|
||||
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>)
|
||||
return %2 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: remainder_fold_float
|
||||
func @remainder_fold_float() -> tensor<4xf32> {
|
||||
%0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32>
|
||||
%1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32>
|
||||
// CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]>
|
||||
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
|
||||
return %2 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_scalar_fold
|
||||
func @max_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
|
|
Loading…
Reference in New Issue