[MLIR][KernelGen] Add erf kernel and missing lowering for f16 type
PiperOrigin-RevId: 352416184
This commit is contained in:
		
							parent
							
								
									9e07bdf4ea
								
							
						
					
					
						commit
						96fb617413
					
				|  | @ -90,6 +90,8 @@ Value MaterializePolynomialApproximation( | |||
| 
 | ||||
| Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, | ||||
|                                      Location loc, Value operand) { | ||||
|   assert(operand.getType().cast<RankedTensorType>().getElementType().isF32() && | ||||
|          "expect f32 element type"); | ||||
|   const std::vector<float> kAlpha{ | ||||
|       -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f, | ||||
|       -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, | ||||
|  | @ -121,14 +123,28 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> { | |||
|   LogicalResult matchAndRewrite( | ||||
|       ErfOp op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const override { | ||||
|     Type ty = getElementTypeOrSelf(op.getType()); | ||||
| 
 | ||||
|     // For now, we support only f32.
 | ||||
|     if (!ty.isF32()) return failure(); | ||||
| 
 | ||||
|     Location loc = op.getLoc(); | ||||
|     ErfOp::Adaptor transformed(operands); | ||||
|     rewriter.replaceOp(op, MaterializeErfApproximationF32( | ||||
|                                rewriter, op.getLoc(), transformed.operand())); | ||||
|     Value x = transformed.operand(); | ||||
|     Type ty = x.getType().cast<RankedTensorType>().getElementType(); | ||||
| 
 | ||||
|     // For now, we support only f32 and f16.
 | ||||
|     if (!ty.isF32() && !ty.isF16()) return failure(); | ||||
| 
 | ||||
|     // Cast argument to f32 tensor if needed.
 | ||||
|     assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point"); | ||||
|     if (ty.isF16()) { | ||||
|       x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type()); | ||||
|     } | ||||
| 
 | ||||
|     Value result = MaterializeErfApproximationF32(rewriter, loc, x); | ||||
| 
 | ||||
|     // Cast back if needed.
 | ||||
|     if (ty.isF16()) { | ||||
|       result = rewriter.create<mhlo::ConvertOp>(loc, result, ty); | ||||
|     } | ||||
| 
 | ||||
|     rewriter.replaceOp(op, result); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  |  | |||
|  | @ -86,3 +86,13 @@ func @erf_f32(%arg : tensor<f32>) -> tensor<f32> { | |||
|   %1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32> | ||||
|   return %1 : tensor<f32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: @erf_f16 | ||||
| // CHECK-SAME: %[[ARG:.*]]: tensor<f16> | ||||
| func @erf_f16(%arg : tensor<f16>) -> tensor<f16> { | ||||
|   // CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32> | ||||
|   // CHECK: %[[RESULT:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16> | ||||
|   // CHECK: return %[[RESULT]] | ||||
|   %1 = "chlo.erf"(%arg) : (tensor<f16>) -> tensor<f16> | ||||
|   return %1 : tensor<f16> | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue