Add folder for mhlo::remainder
PiperOrigin-RevId: 335372628
This commit is contained in:
parent
c9e249c124
commit
7367eac074
|
@ -353,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
||||||
|
|
||||||
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
||||||
|
|
|
@ -2001,6 +2001,23 @@ struct divide<APInt> {
|
||||||
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct remainder : std::modulus<T> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct remainder<APInt> {
|
||||||
|
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct remainder<APFloat> {
|
||||||
|
APFloat operator()(const APFloat& a, const APFloat& b) const {
|
||||||
|
APFloat result(a);
|
||||||
|
result.remainder(b);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct max {
|
struct max {
|
||||||
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
||||||
|
@ -2042,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
|
||||||
BINARY_FOLDER(SubOp, std::minus);
|
BINARY_FOLDER(SubOp, std::minus);
|
||||||
BINARY_FOLDER(MulOp, std::multiplies);
|
BINARY_FOLDER(MulOp, std::multiplies);
|
||||||
BINARY_FOLDER(DivOp, divide);
|
BINARY_FOLDER(DivOp, divide);
|
||||||
|
BINARY_FOLDER(RemOp, remainder);
|
||||||
BINARY_FOLDER(MaxOp, max);
|
BINARY_FOLDER(MaxOp, max);
|
||||||
BINARY_FOLDER(MinOp, min);
|
BINARY_FOLDER(MinOp, min);
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> {
|
||||||
return %2 : tensor<4xf64>
|
return %2 : tensor<4xf64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: remainder_fold_int
|
||||||
|
func @remainder_fold_int() -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32>
|
||||||
|
%1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[2, 1, 0, 1]>
|
||||||
|
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>)
|
||||||
|
return %2 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: remainder_fold_float
|
||||||
|
func @remainder_fold_float() -> tensor<4xf32> {
|
||||||
|
%0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32>
|
||||||
|
%1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32>
|
||||||
|
// CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]>
|
||||||
|
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
|
||||||
|
return %2 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: max_scalar_fold
|
// CHECK-LABEL: max_scalar_fold
|
||||||
func @max_scalar_fold() -> tensor<4xi64> {
|
func @max_scalar_fold() -> tensor<4xi64> {
|
||||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||||
|
|
Loading…
Reference in New Issue