diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 35211ae..a15e240 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -445,15 +445,28 @@ struct ConvertErfcOp : public OpConversionPattern { Value x = transformed.operand(); Type ty = x.getType().cast().getElementType(); - // For now, we support only f64 and f32. - if (!ty.isF64() && !ty.isF32()) return failure(); + // For now, we support only f64, f32, and f16. + if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure(); if (ty.isF64()) { rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x)); return success(); } - rewriter.replaceOp(op, MaterializeErfcApproximationF32(rewriter, loc, x)); + // Cast argument to f32 tensor if needed. + assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point"); + if (ty.isF16()) { + x = rewriter.create(loc, x, rewriter.getF32Type()); + } + + Value result = MaterializeErfcApproximationF32(rewriter, loc, x); + + // Cast back if needed. + if (ty.isF16()) { + result = rewriter.create(loc, result, ty); + } + + rewriter.replaceOp(op, result); return success(); } }; diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index ba91341..0d29f9e 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -642,3 +642,13 @@ func @erfc_f32(%arg : tensor) -> tensor { %1 = "chlo.erfc"(%arg) : (tensor) -> tensor return %1 : tensor } + +// CHECK-LABEL: @erfc_f16 +// CHECK-SAME: %[[ARG:.*]]: tensor +func @erfc_f16(%arg : tensor) -> tensor { + // CHECK: "mhlo.convert"(%[[ARG]]) : (tensor) -> tensor + // CHECK: %[[RESULT:.*]] = "mhlo.convert"(%{{.*}}) : (tensor) -> tensor + // CHECK: return %[[RESULT]] + %1 = "chlo.erfc"(%arg) : (tensor) -> tensor + return %1 : tensor +}