[MLIR][CHLO] Simplify conversions with upcast
PiperOrigin-RevId: 354975366
This commit is contained in:
parent
8e3890e8e8
commit
816d279be3
|
@ -399,6 +399,22 @@ Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
|
||||||
erfc_approx);
|
erfc_approx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
Value arg, FloatType min_precision_ty,
|
||||||
|
Value callback(ConversionPatternRewriter &,
|
||||||
|
Location, Value)) {
|
||||||
|
auto original_ty = getElementTypeOrSelf(arg.getType()).cast<FloatType>();
|
||||||
|
bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
|
||||||
|
if (needs_upcast) {
|
||||||
|
arg = rewriter.create<mhlo::ConvertOp>(loc, arg, min_precision_ty);
|
||||||
|
}
|
||||||
|
Value result = callback(rewriter, loc, arg);
|
||||||
|
if (needs_upcast) {
|
||||||
|
result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
||||||
using OpConversionPattern<ErfOp>::OpConversionPattern;
|
using OpConversionPattern<ErfOp>::OpConversionPattern;
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
|
@ -417,20 +433,9 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cast argument to f32 tensor if needed.
|
rewriter.replaceOp(
|
||||||
assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point");
|
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
|
||||||
if (ty.isF16()) {
|
&MaterializeErfApproximationF32));
|
||||||
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
|
|
||||||
}
|
|
||||||
|
|
||||||
Value result = MaterializeErfApproximationF32(rewriter, loc, x);
|
|
||||||
|
|
||||||
// Cast back if needed.
|
|
||||||
if (ty.isF16()) {
|
|
||||||
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, result);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -453,20 +458,9 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cast argument to f32 tensor if needed.
|
rewriter.replaceOp(
|
||||||
assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point");
|
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
|
||||||
if (ty.isF16()) {
|
&MaterializeErfcApproximationF32));
|
||||||
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
|
|
||||||
}
|
|
||||||
|
|
||||||
Value result = MaterializeErfcApproximationF32(rewriter, loc, x);
|
|
||||||
|
|
||||||
// Cast back if needed.
|
|
||||||
if (ty.isF16()) {
|
|
||||||
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, result);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -636,22 +630,11 @@ struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
LgammaOp op, ArrayRef<Value> operands,
|
LgammaOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
|
||||||
LgammaOp::Adaptor transformed(operands);
|
LgammaOp::Adaptor transformed(operands);
|
||||||
Value x = transformed.operand();
|
FloatType min_precision_ty = rewriter.getF32Type();
|
||||||
Type ty = getElementTypeOrSelf(op.getType());
|
rewriter.replaceOp(
|
||||||
|
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
|
||||||
if (ty.isF32() || ty.isF64()) {
|
min_precision_ty, &MaterializeLgamma));
|
||||||
rewriter.replaceOp(op, MaterializeLgamma(rewriter, loc, x));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Materialize lgamma with upcast to f32.
|
|
||||||
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
|
|
||||||
Value result = MaterializeLgamma(rewriter, loc, x);
|
|
||||||
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, result);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue