parent
a6fdebdc6c
commit
1800f44a29
|
@ -265,7 +265,9 @@ def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
|
||||||
|
|
||||||
def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt",
|
def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
||||||
BASE_HLO_SqrtOp;
|
BASE_HLO_SqrtOp {
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
|
def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
|
||||||
[NoSideEffect, SameOperandsAndResultType],
|
[NoSideEffect, SameOperandsAndResultType],
|
||||||
|
|
|
@ -1821,6 +1821,35 @@ static LogicalResult Verify(CaseOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SqrtOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult SqrtOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto val = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
if (!val) return {};
|
||||||
|
|
||||||
|
auto type = getElementTypeOrSelf(getType());
|
||||||
|
if (!type.isF32() && !type.isF64()) return {};
|
||||||
|
|
||||||
|
auto shaped_type = getType().cast<ShapedType>();
|
||||||
|
if (!shaped_type.hasStaticShape()) return {};
|
||||||
|
|
||||||
|
int bit_width = type.getIntOrFloatBitWidth();
|
||||||
|
llvm::SmallVector<APFloat, 4> values;
|
||||||
|
values.reserve(val.getNumElements());
|
||||||
|
for (auto it : val.getFloatValues()) {
|
||||||
|
double value = bit_width == 32 ? it.convertToFloat() : it.convertToDouble();
|
||||||
|
if (value < 0) return {};
|
||||||
|
value = std::sqrt(value);
|
||||||
|
if (bit_width == 32)
|
||||||
|
values.emplace_back(static_cast<float>(value));
|
||||||
|
else
|
||||||
|
values.emplace_back(value);
|
||||||
|
}
|
||||||
|
return DenseFPElementsAttr::get(shaped_type, values);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// UnaryOps
|
// UnaryOps
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -851,3 +851,29 @@ func @fold_negate_float() -> tensor<4xf32> {
|
||||||
return %1 : tensor<4xf32>
|
return %1 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_sqrt_f32_constants
|
||||||
|
func @fold_sqrt_f32_constants() -> tensor<4xf32> {
|
||||||
|
%0 = mhlo.constant dense<1.0> : tensor<4xf32>
|
||||||
|
%1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
// CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32>
|
||||||
|
// CHECK-NOT: mhlo.sqrt
|
||||||
|
return %1 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_sqrt_f64_constants
|
||||||
|
func @fold_sqrt_f64_constants() -> tensor<4xf64> {
|
||||||
|
%0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf64>
|
||||||
|
%1 = "mhlo.sqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64>
|
||||||
|
// CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64>
|
||||||
|
// CHECK-NOT: mhlo.sqrt
|
||||||
|
return %1 : tensor<4xf64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @not_fold_sqrt_neg_constants
|
||||||
|
func @not_fold_sqrt_neg_constants() -> tensor<4xf32> {
|
||||||
|
%0 = mhlo.constant dense<-1.0> : tensor<4xf32>
|
||||||
|
%1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
// CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32>
|
||||||
|
// CHECK: mhlo.sqrt
|
||||||
|
return %1 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue