[MLIR][CHLO] Simplify conversions with upcast

PiperOrigin-RevId: 354975366
This commit is contained in:
A. Unique TensorFlower 2021-02-01 10:46:41 -08:00 committed by TensorFlow MLIR Team
parent 8e3890e8e8
commit 816d279be3
1 changed files with 26 additions and 43 deletions

View File

@ -399,6 +399,22 @@ Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
erfc_approx);
}
Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
Value arg, FloatType min_precision_ty,
Value callback(ConversionPatternRewriter &,
Location, Value)) {
auto original_ty = getElementTypeOrSelf(arg.getType()).cast<FloatType>();
bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
if (needs_upcast) {
arg = rewriter.create<mhlo::ConvertOp>(loc, arg, min_precision_ty);
}
Value result = callback(rewriter, loc, arg);
if (needs_upcast) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
}
return result;
}
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
using OpConversionPattern<ErfOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
@ -417,20 +433,9 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> {
return success();
}
// Cast argument to f32 tensor if needed.
assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point");
if (ty.isF16()) {
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
}
Value result = MaterializeErfApproximationF32(rewriter, loc, x);
// Cast back if needed.
if (ty.isF16()) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
}
rewriter.replaceOp(op, result);
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
&MaterializeErfApproximationF32));
return success();
}
};
@ -453,20 +458,9 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
return success();
}
// Cast argument to f32 tensor if needed.
assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point");
if (ty.isF16()) {
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
}
Value result = MaterializeErfcApproximationF32(rewriter, loc, x);
// Cast back if needed.
if (ty.isF16()) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
}
rewriter.replaceOp(op, result);
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
&MaterializeErfcApproximationF32));
return success();
}
};
@ -636,22 +630,11 @@ struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
LogicalResult matchAndRewrite(
LgammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
LgammaOp::Adaptor transformed(operands);
Value x = transformed.operand();
Type ty = getElementTypeOrSelf(op.getType());
if (ty.isF32() || ty.isF64()) {
rewriter.replaceOp(op, MaterializeLgamma(rewriter, loc, x));
return success();
}
// Materialize lgamma with upcast to f32.
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
Value result = MaterializeLgamma(rewriter, loc, x);
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
rewriter.replaceOp(op, result);
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
min_precision_ty, &MaterializeLgamma));
return success();
}
};