[MLIR][KernelGen] Add erf kernel and missing lowering for f16 type
PiperOrigin-RevId: 352416184
This commit is contained in:
		
							parent
							
								
									9e07bdf4ea
								
							
						
					
					
						commit
						96fb617413
					
				|  | @ -90,6 +90,8 @@ Value MaterializePolynomialApproximation( | ||||||
| 
 | 
 | ||||||
| Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, | Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, | ||||||
|                                      Location loc, Value operand) { |                                      Location loc, Value operand) { | ||||||
|  |   assert(operand.getType().cast<RankedTensorType>().getElementType().isF32() && | ||||||
|  |          "expect f32 element type"); | ||||||
|   const std::vector<float> kAlpha{ |   const std::vector<float> kAlpha{ | ||||||
|       -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f, |       -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f, | ||||||
|       -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, |       -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, | ||||||
|  | @ -121,14 +123,28 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> { | ||||||
|   LogicalResult matchAndRewrite( |   LogicalResult matchAndRewrite( | ||||||
|       ErfOp op, ArrayRef<Value> operands, |       ErfOp op, ArrayRef<Value> operands, | ||||||
|       ConversionPatternRewriter &rewriter) const override { |       ConversionPatternRewriter &rewriter) const override { | ||||||
|     Type ty = getElementTypeOrSelf(op.getType()); |     Location loc = op.getLoc(); | ||||||
| 
 |  | ||||||
|     // For now, we support only f32.
 |  | ||||||
|     if (!ty.isF32()) return failure(); |  | ||||||
| 
 |  | ||||||
|     ErfOp::Adaptor transformed(operands); |     ErfOp::Adaptor transformed(operands); | ||||||
|     rewriter.replaceOp(op, MaterializeErfApproximationF32( |     Value x = transformed.operand(); | ||||||
|                                rewriter, op.getLoc(), transformed.operand())); |     Type ty = x.getType().cast<RankedTensorType>().getElementType(); | ||||||
|  | 
 | ||||||
|  |     // For now, we support only f32 and f16.
 | ||||||
|  |     if (!ty.isF32() && !ty.isF16()) return failure(); | ||||||
|  | 
 | ||||||
|  |     // Cast argument to f32 tensor if needed.
 | ||||||
|  |     assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point"); | ||||||
|  |     if (ty.isF16()) { | ||||||
|  |       x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Value result = MaterializeErfApproximationF32(rewriter, loc, x); | ||||||
|  | 
 | ||||||
|  |     // Cast back if needed.
 | ||||||
|  |     if (ty.isF16()) { | ||||||
|  |       result = rewriter.create<mhlo::ConvertOp>(loc, result, ty); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     rewriter.replaceOp(op, result); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | @ -86,3 +86,13 @@ func @erf_f32(%arg : tensor<f32>) -> tensor<f32> { | ||||||
|   %1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32> |   %1 = "chlo.erf"(%arg) : (tensor<f32>) -> tensor<f32> | ||||||
|   return %1 : tensor<f32> |   return %1 : tensor<f32> | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // CHECK-LABEL: @erf_f16 | ||||||
|  | // CHECK-SAME: %[[ARG:.*]]: tensor<f16> | ||||||
|  | func @erf_f16(%arg : tensor<f16>) -> tensor<f16> { | ||||||
|  |   // CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32> | ||||||
|  |   // CHECK: %[[RESULT:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16> | ||||||
|  |   // CHECK: return %[[RESULT]] | ||||||
|  |   %1 = "chlo.erf"(%arg) : (tensor<f16>) -> tensor<f16> | ||||||
|  |   return %1 : tensor<f16> | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue