Add folder to mhlo::round_nearest_afz

PiperOrigin-RevId: 337823786
This commit is contained in:
A. Unique TensorFlower 2020-10-19 03:44:20 -07:00 committed by TensorFlow MLIR Team
parent 92c531cf36
commit 4a18aa41ee
3 changed files with 27 additions and 1 deletions

View File

@ -241,7 +241,9 @@ def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
}
def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp;
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp {
let hasFolder = 1;
}
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,

View File

@ -1933,6 +1933,14 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
return DenseElementsAttr::get(type, values);
}
struct round {
APFloat operator()(const APFloat& f) {
APFloat r = f;
r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
return r;
}
};
#define UNARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
@ -1942,7 +1950,15 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
return {}; \
}
#define UNARY_FOLDER_FLOAT(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
return {}; \
}
UNARY_FOLDER(NegOp, std::negate);
UNARY_FOLDER_FLOAT(RoundOp, round);
//===----------------------------------------------------------------------===//
// BinaryOps

View File

@ -81,6 +81,14 @@ func @remainder_fold_float() -> tensor<4xf32> {
return %2 : tensor<4xf32>
}
// CHECK-LABEL: round_fold
func @round_fold() -> tensor<4xf32> {
%0 = mhlo.constant dense<[-1.5, -0.1, 1.1, 2.5]> : tensor<4xf32>
%1 = "mhlo.round_nearest_afz"(%0) : (tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
// CHECK: mhlo.constant dense<[-2.000000e+00, -0.000000e+00, 1.000000e+00, 3.000000e+00]>
}
// CHECK-LABEL: max_scalar_fold
func @max_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<7> : tensor<4xi64>