[MLIR][KernelGen] Add erf kernel and missing lowering for f16 type
PiperOrigin-RevId: 352416184
This commit is contained in:
parent
9e07bdf4ea
commit
96fb617413
|
@ -90,6 +90,8 @@ Value MaterializePolynomialApproximation(
|
||||||
|
|
||||||
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
||||||
Location loc, Value operand) {
|
Location loc, Value operand) {
|
||||||
|
assert(operand.getType().cast<RankedTensorType>().getElementType().isF32() &&
|
||||||
|
"expect f32 element type");
|
||||||
const std::vector<float> kAlpha{
|
const std::vector<float> kAlpha{
|
||||||
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
||||||
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
|
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
|
||||||
|
@ -121,14 +123,28 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
ErfOp op, ArrayRef<Value> operands,
|
ErfOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type ty = getElementTypeOrSelf(op.getType());
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
// For now, we support only f32.
|
|
||||||
if (!ty.isF32()) return failure();
|
|
||||||
|
|
||||||
ErfOp::Adaptor transformed(operands);
|
ErfOp::Adaptor transformed(operands);
|
||||||
rewriter.replaceOp(op, MaterializeErfApproximationF32(
|
Value x = transformed.operand();
|
||||||
rewriter, op.getLoc(), transformed.operand()));
|
Type ty = x.getType().cast<RankedTensorType>().getElementType();
|
||||||
|
|
||||||
|
// For now, we support only f32 and f16.
|
||||||
|
if (!ty.isF32() && !ty.isF16()) return failure();
|
||||||
|
|
||||||
|
// Cast argument to f32 tensor if needed.
|
||||||
|
assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point");
|
||||||
|
if (ty.isF16()) {
|
||||||
|
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result = MaterializeErfApproximationF32(rewriter, loc, x);
|
||||||
|
|
||||||
|
// Cast back if needed.
|
||||||
|
if (ty.isF16()) {
|
||||||
|
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -86,3 +86,13 @@ func @erf_f32(%arg : tensor<f32>) -> tensor<f32> {
|
||||||
%1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32>
|
%1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||||
return %1 : tensor<f32>
|
return %1 : tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @erf_f16
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: tensor<f16>
|
||||||
|
func @erf_f16(%arg : tensor<f16>) -> tensor<f16> {
|
||||||
|
// CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
%1 = "chlo.erf"(%arg) : (tensor<f16>) -> tensor<f16>
|
||||||
|
return %1 : tensor<f16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue