Add folder to mhlo::round_nearest_afz
PiperOrigin-RevId: 337823786
This commit is contained in:
parent
92c531cf36
commit
4a18aa41ee
|
@ -241,7 +241,9 @@ def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz",
|
def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp;
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp {
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
|
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
||||||
|
|
|
@ -1933,6 +1933,14 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
|
||||||
return DenseElementsAttr::get(type, values);
|
return DenseElementsAttr::get(type, values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct round {
|
||||||
|
APFloat operator()(const APFloat& f) {
|
||||||
|
APFloat r = f;
|
||||||
|
r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#define UNARY_FOLDER(Op, Func) \
|
#define UNARY_FOLDER(Op, Func) \
|
||||||
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
|
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
|
||||||
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
|
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
|
||||||
|
@ -1942,7 +1950,15 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
|
||||||
return {}; \
|
return {}; \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define UNARY_FOLDER_FLOAT(Op, Func) \
|
||||||
|
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
|
||||||
|
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
|
||||||
|
return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
|
||||||
|
return {}; \
|
||||||
|
}
|
||||||
|
|
||||||
UNARY_FOLDER(NegOp, std::negate);
|
UNARY_FOLDER(NegOp, std::negate);
|
||||||
|
UNARY_FOLDER_FLOAT(RoundOp, round);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BinaryOps
|
// BinaryOps
|
||||||
|
|
|
@ -81,6 +81,14 @@ func @remainder_fold_float() -> tensor<4xf32> {
|
||||||
return %2 : tensor<4xf32>
|
return %2 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: round_fold
|
||||||
|
func @round_fold() -> tensor<4xf32> {
|
||||||
|
%0 = mhlo.constant dense<[-1.5, -0.1, 1.1, 2.5]> : tensor<4xf32>
|
||||||
|
%1 = "mhlo.round_nearest_afz"(%0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
return %1 : tensor<4xf32>
|
||||||
|
// CHECK: mhlo.constant dense<[-2.000000e+00, -0.000000e+00, 1.000000e+00, 3.000000e+00]>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: max_scalar_fold
|
// CHECK-LABEL: max_scalar_fold
|
||||||
func @max_scalar_fold() -> tensor<4xi64> {
|
func @max_scalar_fold() -> tensor<4xi64> {
|
||||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||||
|
|
Loading…
Reference in New Issue