Add CastOp lowering (#259)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * add benefit for scaler decompose and simplify scaler shape inference * cast rewrite only for float * add cast op same type rewrite rule * working on cast lowering * cast lowering working * correct onnx version * update onnx md * add test for tensor<10xf64>
This commit is contained in:
		
							parent
							
								
									e1386b0689
								
							
						
					
					
						commit
						2ee725d939
					
				|  | @ -57,6 +57,15 @@ static onnx::TensorProto::DataType llvmTypeToOnnxType( | |||
|     return onnx::TensorProto::UINT32; | ||||
|   if (elemType.isUnsignedInteger(64)) | ||||
|     return onnx::TensorProto::INT64; | ||||
|   // LLVM Dialect does not have signed/unsigned int, only signless int
 | ||||
|   if (elemType.isIntegerTy(8)) | ||||
|     return onnx::TensorProto::INT8; | ||||
|   if (elemType.isIntegerTy(16)) | ||||
|     return onnx::TensorProto::INT16; | ||||
|   if (elemType.isIntegerTy(32)) | ||||
|     return onnx::TensorProto::INT32; | ||||
|   if (elemType.isIntegerTy(64)) | ||||
|     return onnx::TensorProto::INT64; | ||||
|   // Complex types don't seem to exist in LLVM Dialect.
 | ||||
|   elemType.dump(); | ||||
|   llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type."); | ||||
|  |  | |||
|  | @ -84,6 +84,50 @@ struct ScalarOp<ONNXSqrtOp> { | |||
|   using IOp = SqrtOp; // not use
 | ||||
| }; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Scalar unary ops for lowering ONNXCastOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| template <> | ||||
| Value emitScalarOpFor<ONNXCastOp>(ConversionPatternRewriter &rewriter, | ||||
|     Location loc, Operation *op, Type elementType, | ||||
|     ArrayRef<Value> scalarOperands) { | ||||
|   ONNXCastOp castOp = llvm::dyn_cast<ONNXCastOp>(op); | ||||
|   auto mlirtype = convertONNXTypeToMLIRType(rewriter, | ||||
|       static_cast<onnx::TensorProto_DataType>(castOp.toAttr().getInt())); | ||||
|   Value operand = scalarOperands[0]; | ||||
|   auto origtype = operand.getType(); | ||||
| 
 | ||||
|   // check output type is the same as expected output type
 | ||||
|   if (elementType != mlirtype) | ||||
|     llvm_unreachable("output type different from expected output type"); | ||||
| 
 | ||||
|   // if same input and output type, return input
 | ||||
|   if (origtype == elementType) | ||||
|     return operand; | ||||
| 
 | ||||
|   if (origtype.isa<FloatType>()) { | ||||
|     // cast from floating-point type to integer type
 | ||||
|     if (elementType.isa<IntegerType>()) | ||||
|       return rewriter.create<FPToSIOp>(loc, elementType, operand); | ||||
|     // cast from floating-point type to other floating-point type
 | ||||
|     else if (elementType.isa<FloatType>()) { | ||||
|       // cast from floating-point to wider floating-point
 | ||||
|       if (origtype.getIntOrFloatBitWidth() < | ||||
|           elementType.getIntOrFloatBitWidth()) | ||||
|         return rewriter.create<FPExtOp>(loc, elementType, operand); | ||||
|       // cast from floating-point to narrower floating-point
 | ||||
|       else | ||||
|         return rewriter.create<FPTruncOp>(loc, elementType, operand); | ||||
|     } | ||||
|   } | ||||
|   // int to float
 | ||||
|   else if (origtype.isa<IntegerType>()) { | ||||
|     if (elementType.isa<FloatType>()) | ||||
|       return rewriter.create<SIToFPOp>(loc, elementType, operand); | ||||
|   } | ||||
|   llvm_unreachable("unsupported element type"); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Scalar unary ops for lowering ONNXSinhOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | @ -665,5 +709,6 @@ void populateLoweringONNXElementwiseOpPattern( | |||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>, | ||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXCastOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx); | ||||
| } | ||||
|  |  | |||
|  | @ -2018,3 +2018,88 @@ func @test_split_unknown_dimension(%arg0 : tensor<?x?x64xf32>) -> (tensor<*xf32> | |||
|   // CHECK: } | ||||
|   // CHECK: return [[RES_0]], [[RES_1]] : memref<?x2x64xf32>, memref<?x30x64xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @cast_lowering_sametype(%arg0: tensor<f32>) -> tensor<f32> { | ||||
|   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f32>) -> tensor<f32> | ||||
|   "std.return"(%0) : (tensor<f32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: cast_lowering_sametype | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<f32> | ||||
|   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f32> | ||||
|   // CHECK: affine.store [[LOAD]], [[RES]][] : memref<f32> | ||||
|   // CHECK: return [[RES]] : memref<f32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @cast_lowering_intfloat(%arg0: tensor<i64>) -> tensor<f32> { | ||||
|   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<i64>) -> tensor<f32> | ||||
|   "std.return"(%0) : (tensor<f32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: cast_lowering_intfloat | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<f32> | ||||
|   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<i64> | ||||
|   // CHECK: [[VAL:%.+]] = sitofp [[LOAD]] : i64 to f32 | ||||
|   // CHECK: affine.store [[VAL]], [[RES]][] : memref<f32> | ||||
|   // CHECK: return [[RES]] : memref<f32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @cast_lowering_floatint(%arg0: tensor<f32>) -> tensor<i64> { | ||||
|   %0 = "onnx.Cast"(%arg0) {to = 7 : i64} : (tensor<f32>) -> tensor<i64> | ||||
|   "std.return"(%0) : (tensor<i64>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: cast_lowering_floatint | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<i64> | ||||
|   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f32> | ||||
|   // CHECK: [[VAL:%.+]] = fptosi [[LOAD]] : f32 to i64 | ||||
|   // CHECK: affine.store [[VAL]], [[RES]][] : memref<i64> | ||||
|   // CHECK: return [[RES]] : memref<i64> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @cast_lowering_f16f32(%arg0: tensor<f16>) -> tensor<f32> { | ||||
|   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f16>) -> tensor<f32> | ||||
|   "std.return"(%0) : (tensor<f32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: cast_lowering_f16f32 | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<f32> | ||||
|   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f16> | ||||
|   // CHECK: [[VAL:%.+]] = fpext [[LOAD]] : f16 to f32 | ||||
|   // CHECK: affine.store [[VAL]], [[RES]][] : memref<f32> | ||||
|   // CHECK: return [[RES]] : memref<f32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @cast_lowering_f64f32(%arg0: tensor<f64>) -> tensor<f32> { | ||||
|   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f64>) -> tensor<f32> | ||||
|   "std.return"(%0) : (tensor<f32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: cast_lowering_f64f32 | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<f32> | ||||
|   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f64> | ||||
|   // CHECK: [[VAL:%.+]] = fptrunc [[LOAD]] : f64 to f32 | ||||
|   // CHECK: affine.store [[VAL]], [[RES]][] : memref<f32> | ||||
|   // CHECK: return [[RES]] : memref<f32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<10xf64>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: cast_lowering_f64f32_10 | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10xf32> | ||||
|   // CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1 | ||||
|   // CHECK: krnl.iterate([[DEF_LOOPS]]) with ([[DEF_LOOPS]] -> %arg1 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1] : memref<10xf64> | ||||
|   // CHECK: [[FPTRUNC:%.+]] = fptrunc [[LOAD1]] : f64 to f32 | ||||
|   // CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32> | ||||
|   // CHECK: return [[RES]] : memref<10xf32> | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue