[MLIR][KernelGen] Add erf kernel and missing lowering for f16 type

PiperOrigin-RevId: 352416184
This commit is contained in:
A. Unique TensorFlower 2021-01-18 08:19:56 -08:00 committed by TensorFlow MLIR Team
parent 9e07bdf4ea
commit 96fb617413
2 changed files with 33 additions and 7 deletions

View File

@ -90,6 +90,8 @@ Value MaterializePolynomialApproximation(
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value operand) {
assert(operand.getType().cast<RankedTensorType>().getElementType().isF32() &&
"expect f32 element type");
const std::vector<float> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
@ -121,14 +123,28 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> {
LogicalResult matchAndRewrite(
ErfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type ty = getElementTypeOrSelf(op.getType());
// For now, we support only f32.
if (!ty.isF32()) return failure();
Location loc = op.getLoc();
ErfOp::Adaptor transformed(operands);
rewriter.replaceOp(op, MaterializeErfApproximationF32(
rewriter, op.getLoc(), transformed.operand()));
Value x = transformed.operand();
Type ty = x.getType().cast<RankedTensorType>().getElementType();
// For now, we support only f32 and f16.
if (!ty.isF32() && !ty.isF16()) return failure();
// Cast argument to f32 tensor if needed.
assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point");
if (ty.isF16()) {
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
}
Value result = MaterializeErfApproximationF32(rewriter, loc, x);
// Cast back if needed.
if (ty.isF16()) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
}
rewriter.replaceOp(op, result);
return success();
}
};

View File

@ -86,3 +86,13 @@ func @erf_f32(%arg : tensor<f32>) -> tensor<f32> {
%1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}
// CHECK-LABEL: @erf_f16
// CHECK-SAME: %[[ARG:.*]]: tensor<f16>
func @erf_f16(%arg : tensor<f16>) -> tensor<f16> {
// CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32>
// CHECK: %[[RESULT:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
// CHECK: return %[[RESULT]]
%1 = "chlo.erf"(%arg) : (tensor<f16>) -> tensor<f16>
return %1 : tensor<f16>
}