[MLIR][CHLO] Simplify conversions with upcast
PiperOrigin-RevId: 354975366
This commit is contained in:
		
							parent
							
								
									8e3890e8e8
								
							
						
					
					
						commit
						816d279be3
					
				|  | @ -399,6 +399,22 @@ Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter, | |||
|                                          erfc_approx); | ||||
| } | ||||
| 
 | ||||
| Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc, | ||||
|                             Value arg, FloatType min_precision_ty, | ||||
|                             Value callback(ConversionPatternRewriter &, | ||||
|                                            Location, Value)) { | ||||
|   auto original_ty = getElementTypeOrSelf(arg.getType()).cast<FloatType>(); | ||||
|   bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth(); | ||||
|   if (needs_upcast) { | ||||
|     arg = rewriter.create<mhlo::ConvertOp>(loc, arg, min_precision_ty); | ||||
|   } | ||||
|   Value result = callback(rewriter, loc, arg); | ||||
|   if (needs_upcast) { | ||||
|     result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| struct ConvertErfOp : public OpConversionPattern<ErfOp> { | ||||
|   using OpConversionPattern<ErfOp>::OpConversionPattern; | ||||
|   LogicalResult matchAndRewrite( | ||||
|  | @ -417,20 +433,9 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> { | |||
|       return success(); | ||||
|     } | ||||
| 
 | ||||
|     // Cast argument to f32 tensor if needed.
 | ||||
|     assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point"); | ||||
|     if (ty.isF16()) { | ||||
|       x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type()); | ||||
|     } | ||||
| 
 | ||||
|     Value result = MaterializeErfApproximationF32(rewriter, loc, x); | ||||
| 
 | ||||
|     // Cast back if needed.
 | ||||
|     if (ty.isF16()) { | ||||
|       result = rewriter.create<mhlo::ConvertOp>(loc, result, ty); | ||||
|     } | ||||
| 
 | ||||
|     rewriter.replaceOp(op, result); | ||||
|     rewriter.replaceOp( | ||||
|         op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(), | ||||
|                                   &MaterializeErfApproximationF32)); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  | @ -453,20 +458,9 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> { | |||
|       return success(); | ||||
|     } | ||||
| 
 | ||||
|     // Cast argument to f32 tensor if needed.
 | ||||
|     assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point"); | ||||
|     if (ty.isF16()) { | ||||
|       x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type()); | ||||
|     } | ||||
| 
 | ||||
|     Value result = MaterializeErfcApproximationF32(rewriter, loc, x); | ||||
| 
 | ||||
|     // Cast back if needed.
 | ||||
|     if (ty.isF16()) { | ||||
|       result = rewriter.create<mhlo::ConvertOp>(loc, result, ty); | ||||
|     } | ||||
| 
 | ||||
|     rewriter.replaceOp(op, result); | ||||
|     rewriter.replaceOp( | ||||
|         op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(), | ||||
|                                   &MaterializeErfcApproximationF32)); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  | @ -636,22 +630,11 @@ struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> { | |||
|   LogicalResult matchAndRewrite( | ||||
|       LgammaOp op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const override { | ||||
|     Location loc = op.getLoc(); | ||||
|     LgammaOp::Adaptor transformed(operands); | ||||
|     Value x = transformed.operand(); | ||||
|     Type ty = getElementTypeOrSelf(op.getType()); | ||||
| 
 | ||||
|     if (ty.isF32() || ty.isF64()) { | ||||
|       rewriter.replaceOp(op, MaterializeLgamma(rewriter, loc, x)); | ||||
|       return success(); | ||||
|     } | ||||
| 
 | ||||
|     // Materialize lgamma with upcast to f32.
 | ||||
|     x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type()); | ||||
|     Value result = MaterializeLgamma(rewriter, loc, x); | ||||
|     result = rewriter.create<mhlo::ConvertOp>(loc, result, ty); | ||||
| 
 | ||||
|     rewriter.replaceOp(op, result); | ||||
|     FloatType min_precision_ty = rewriter.getF32Type(); | ||||
|     rewriter.replaceOp( | ||||
|         op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(), | ||||
|                                   min_precision_ty, &MaterializeLgamma)); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue