[MLIR][KernelGen] Add erfc kernel for f16

PiperOrigin-RevId: 353209468
This commit is contained in:
A. Unique TensorFlower 2021-01-22 03:37:28 -08:00 committed by TensorFlow MLIR Team
parent ef8ccdaebc
commit ae2d46414d
2 changed files with 26 additions and 3 deletions

View File

@ -445,15 +445,28 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
Value x = transformed.operand(); Value x = transformed.operand();
Type ty = x.getType().cast<ShapedType>().getElementType(); Type ty = x.getType().cast<ShapedType>().getElementType();
// For now, we support only f64 and f32. // For now, we support only f64, f32, and f16.
if (!ty.isF64() && !ty.isF32()) return failure(); if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
if (ty.isF64()) { if (ty.isF64()) {
rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x)); rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x));
return success(); return success();
} }
rewriter.replaceOp(op, MaterializeErfcApproximationF32(rewriter, loc, x)); // Cast argument to f32 tensor if needed.
assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point");
if (ty.isF16()) {
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
}
Value result = MaterializeErfcApproximationF32(rewriter, loc, x);
// Cast back if needed.
if (ty.isF16()) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
}
rewriter.replaceOp(op, result);
return success(); return success();
} }
}; };

View File

@ -642,3 +642,13 @@ func @erfc_f32(%arg : tensor<f32>) -> tensor<f32> {
%1 = "chlo.erfc"(%arg) : (tensor<f32>) -> tensor<f32> %1 = "chlo.erfc"(%arg) : (tensor<f32>) -> tensor<f32>
return %1 : tensor<f32> return %1 : tensor<f32>
} }
// CHECK-LABEL: @erfc_f16
// CHECK-SAME: %[[ARG:.*]]: tensor<f16>
func @erfc_f16(%arg : tensor<f16>) -> tensor<f16> {
// CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32>
// CHECK: %[[RESULT:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
// CHECK: return %[[RESULT]]
%1 = "chlo.erfc"(%arg) : (tensor<f16>) -> tensor<f16>
return %1 : tensor<f16>
}