[MLIR][KernelGen] Add erfc kernel for f16
PiperOrigin-RevId: 353209468
This commit is contained in:
		
							parent
							
								
									ef8ccdaebc
								
							
						
					
					
						commit
						ae2d46414d
					
				|  | @ -445,15 +445,28 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> { | |||
|     Value x = transformed.operand(); | ||||
|     Type ty = x.getType().cast<ShapedType>().getElementType(); | ||||
| 
 | ||||
|     // For now, we support only f64 and f32.
 | ||||
|     if (!ty.isF64() && !ty.isF32()) return failure(); | ||||
|     // For now, we support only f64, f32, and f16.
 | ||||
|     if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure(); | ||||
| 
 | ||||
|     if (ty.isF64()) { | ||||
|       rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x)); | ||||
|       return success(); | ||||
|     } | ||||
| 
 | ||||
|     rewriter.replaceOp(op, MaterializeErfcApproximationF32(rewriter, loc, x)); | ||||
|     // Cast argument to f32 tensor if needed.
 | ||||
|     assert((ty.isF16() || ty.isF32()) && "expect f16 or f32 at this point"); | ||||
|     if (ty.isF16()) { | ||||
|       x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type()); | ||||
|     } | ||||
| 
 | ||||
|     Value result = MaterializeErfcApproximationF32(rewriter, loc, x); | ||||
| 
 | ||||
|     // Cast back if needed.
 | ||||
|     if (ty.isF16()) { | ||||
|       result = rewriter.create<mhlo::ConvertOp>(loc, result, ty); | ||||
|     } | ||||
| 
 | ||||
|     rewriter.replaceOp(op, result); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
|  |  | |||
|  | @ -642,3 +642,13 @@ func @erfc_f32(%arg : tensor<f32>) -> tensor<f32> { | |||
|   %1 = "chlo.erfc"(%arg) : (tensor<f32>) -> tensor<f32> | ||||
|   return %1 : tensor<f32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: @erfc_f16 | ||||
| // CHECK-SAME: %[[ARG:.*]]: tensor<f16> | ||||
| func @erfc_f16(%arg : tensor<f16>) -> tensor<f16> { | ||||
|   // CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32> | ||||
|   // CHECK: %[[RESULT:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16> | ||||
|   // CHECK: return %[[RESULT]] | ||||
|   %1 = "chlo.erfc"(%arg) : (tensor<f16>) -> tensor<f16> | ||||
|   return %1 : tensor<f16> | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue