parent
a6fdebdc6c
commit
1800f44a29
|
@ -265,7 +265,9 @@ def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
|
|||
|
||||
def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
||||
BASE_HLO_SqrtOp;
|
||||
BASE_HLO_SqrtOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
|
||||
[NoSideEffect, SameOperandsAndResultType],
|
||||
|
|
|
@ -1821,6 +1821,35 @@ static LogicalResult Verify(CaseOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SqrtOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||
if (!val) return {};
|
||||
|
||||
auto type = getElementTypeOrSelf(getType());
|
||||
if (!type.isF32() && !type.isF64()) return {};
|
||||
|
||||
auto shaped_type = getType().cast<ShapedType>();
|
||||
if (!shaped_type.hasStaticShape()) return {};
|
||||
|
||||
int bit_width = type.getIntOrFloatBitWidth();
|
||||
llvm::SmallVector<APFloat, 4> values;
|
||||
values.reserve(val.getNumElements());
|
||||
for (auto it : val.getFloatValues()) {
|
||||
double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble();
|
||||
if (value < 0) return {};
|
||||
value = std::sqrt(value);
|
||||
if (bit_width == 32)
|
||||
values.emplace_back(static_cast<float>(value));
|
||||
else
|
||||
values.emplace_back(value);
|
||||
}
|
||||
return DenseFPElementsAttr::get(shaped_type, values);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnaryOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -851,3 +851,29 @@ func @fold_negate_float() -> tensor<4xf32> {
|
|||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_sqrt_f32_constants
|
||||
func @fold_sqrt_f32_constants() -> tensor<4xf32> {
|
||||
%0 = mhlo.constant dense<1.0> : tensor<4xf32>
|
||||
%1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32>
|
||||
// CHECK-NOT: mhlo.sqrt
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_sqrt_f64_constants
|
||||
func @fold_sqrt_f64_constants() -> tensor<4xf64> {
|
||||
%0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf64>
|
||||
%1 = "mhlo.sqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64>
|
||||
// CHECK-NOT: mhlo.sqrt
|
||||
return %1 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @not_fold_sqrt_neg_constants
|
||||
func @not_fold_sqrt_neg_constants() -> tensor<4xf32> {
|
||||
%0 = mhlo.constant dense<-1.0> : tensor<4xf32>
|
||||
%1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32>
|
||||
// CHECK: mhlo.sqrt
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue