diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index ed62ef8..37f11de 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -353,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp { + let hasFolder = 1; +} def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 6607c9d..0a5bb0e 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2001,6 +2001,23 @@ struct divide { APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } }; +template +struct remainder : std::modulus {}; + +template <> +struct remainder { + APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); } +}; + +template <> +struct remainder { + APFloat operator()(const APFloat& a, const APFloat& b) const { + APFloat result(a); + result.remainder(b); + return result; + } +}; + template struct max { T operator()(const T& a, const T& b) const { return std::max(a, b); } @@ -2042,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus); BINARY_FOLDER(SubOp, std::minus); BINARY_FOLDER(MulOp, std::multiplies); BINARY_FOLDER(DivOp, divide); +BINARY_FOLDER(RemOp, remainder); BINARY_FOLDER(MaxOp, max); BINARY_FOLDER(MinOp, min); diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index b065138..974585b 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> { return %2 : tensor<4xf64> } +// CHECK-LABEL: remainder_fold_int +func @remainder_fold_int() -> tensor<4xi32> { + %0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32> + %1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32> + // CHECK: mhlo.constant dense<[2, 1, 0, 1]> + %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: remainder_fold_float +func @remainder_fold_float() -> tensor<4xf32> { + %0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32> + %1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]> + %2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %2 : tensor<4xf32> +} + // CHECK-LABEL: max_scalar_fold func @max_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<7> : tensor<4xi64>