Add sqrt folder.

PiperOrigin-RevId: 331974344
This commit is contained in:
Hanhan Wang 2020-09-16 04:03:40 -07:00 committed by TensorFlow MLIR Team
parent a6fdebdc6c
commit 1800f44a29
3 changed files with 58 additions and 1 deletions

View File

@ -265,7 +265,9 @@ def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
BASE_HLO_SqrtOp; BASE_HLO_SqrtOp {
let hasFolder = 1;
}
def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
[NoSideEffect, SameOperandsAndResultType], [NoSideEffect, SameOperandsAndResultType],

View File

@ -1821,6 +1821,35 @@ static LogicalResult Verify(CaseOp op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// SqrtOp
//===----------------------------------------------------------------------===//
OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
if (!val) return {};
auto type = getElementTypeOrSelf(getType());
if (!type.isF32() && !type.isF64()) return {};
auto shaped_type = getType().cast<ShapedType>();
if (!shaped_type.hasStaticShape()) return {};
int bit_width = type.getIntOrFloatBitWidth();
llvm::SmallVector<APFloat, 4> values;
values.reserve(val.getNumElements());
for (auto it : val.getFloatValues()) {
double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble();
if (value < 0) return {};
value = std::sqrt(value);
if (bit_width == 32)
values.emplace_back(static_cast<float>(value));
else
values.emplace_back(value);
}
return DenseFPElementsAttr::get(shaped_type, values);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// UnaryOps // UnaryOps
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -851,3 +851,29 @@ func @fold_negate_float() -> tensor<4xf32> {
return %1 : tensor<4xf32> return %1 : tensor<4xf32>
} }
// CHECK-LABEL: func @fold_sqrt_f32_constants
func @fold_sqrt_f32_constants() -> tensor<4xf32> {
%0 = mhlo.constant dense<1.0> : tensor<4xf32>
%1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32>
// CHECK-NOT: mhlo.sqrt
return %1 : tensor<4xf32>
}
// CHECK-LABEL: func @fold_sqrt_f64_constants
func @fold_sqrt_f64_constants() -> tensor<4xf64> {
%0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf64>
%1 = "mhlo.sqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64>
// CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64>
// CHECK-NOT: mhlo.sqrt
return %1 : tensor<4xf64>
}
// CHECK-LABEL: func @not_fold_sqrt_neg_constants
func @not_fold_sqrt_neg_constants() -> tensor<4xf32> {
%0 = mhlo.constant dense<-1.0> : tensor<4xf32>
%1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32>
// CHECK: mhlo.sqrt
return %1 : tensor<4xf32>
}