Add CastOp lowering (#259)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * add benefit for scaler decompose and simplify scaler shape inference * cast rewrite only for float * add cast op same type rewrite rule * working on cast lowering * cast lowering working * correct onnx version * update onnx md * add test for tensor<10xf64>
This commit is contained in:
		
							parent
							
								
									e1386b0689
								
							
						
					
					
						commit
						2ee725d939
					
				|  | @ -57,6 +57,15 @@ static onnx::TensorProto::DataType llvmTypeToOnnxType( | ||||||
|     return onnx::TensorProto::UINT32; |     return onnx::TensorProto::UINT32; | ||||||
|   if (elemType.isUnsignedInteger(64)) |   if (elemType.isUnsignedInteger(64)) | ||||||
|     return onnx::TensorProto::INT64; |     return onnx::TensorProto::INT64; | ||||||
|  |   // LLVM Dialect does not have signed/unsigned int, only signless int
 | ||||||
|  |   if (elemType.isIntegerTy(8)) | ||||||
|  |     return onnx::TensorProto::INT8; | ||||||
|  |   if (elemType.isIntegerTy(16)) | ||||||
|  |     return onnx::TensorProto::INT16; | ||||||
|  |   if (elemType.isIntegerTy(32)) | ||||||
|  |     return onnx::TensorProto::INT32; | ||||||
|  |   if (elemType.isIntegerTy(64)) | ||||||
|  |     return onnx::TensorProto::INT64; | ||||||
|   // Complex types don't seem to exist in LLVM Dialect.
 |   // Complex types don't seem to exist in LLVM Dialect.
 | ||||||
|   elemType.dump(); |   elemType.dump(); | ||||||
|   llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type."); |   llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type."); | ||||||
|  |  | ||||||
|  | @ -84,6 +84,50 @@ struct ScalarOp<ONNXSqrtOp> { | ||||||
|   using IOp = SqrtOp; // not use
 |   using IOp = SqrtOp; // not use
 | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Scalar unary ops for lowering ONNXCastOp
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | template <> | ||||||
|  | Value emitScalarOpFor<ONNXCastOp>(ConversionPatternRewriter &rewriter, | ||||||
|  |     Location loc, Operation *op, Type elementType, | ||||||
|  |     ArrayRef<Value> scalarOperands) { | ||||||
|  |   ONNXCastOp castOp = llvm::dyn_cast<ONNXCastOp>(op); | ||||||
|  |   auto mlirtype = convertONNXTypeToMLIRType(rewriter, | ||||||
|  |       static_cast<onnx::TensorProto_DataType>(castOp.toAttr().getInt())); | ||||||
|  |   Value operand = scalarOperands[0]; | ||||||
|  |   auto origtype = operand.getType(); | ||||||
|  | 
 | ||||||
|  |   // check output type is the same as expected output type
 | ||||||
|  |   if (elementType != mlirtype) | ||||||
|  |     llvm_unreachable("output type different from expected output type"); | ||||||
|  | 
 | ||||||
|  |   // if same input and output type, return input
 | ||||||
|  |   if (origtype == elementType) | ||||||
|  |     return operand; | ||||||
|  | 
 | ||||||
|  |   if (origtype.isa<FloatType>()) { | ||||||
|  |     // cast from floating-point type to integer type
 | ||||||
|  |     if (elementType.isa<IntegerType>()) | ||||||
|  |       return rewriter.create<FPToSIOp>(loc, elementType, operand); | ||||||
|  |     // cast from floating-point type to other floating-point type
 | ||||||
|  |     else if (elementType.isa<FloatType>()) { | ||||||
|  |       // cast from floating-point to wider floating-point
 | ||||||
|  |       if (origtype.getIntOrFloatBitWidth() < | ||||||
|  |           elementType.getIntOrFloatBitWidth()) | ||||||
|  |         return rewriter.create<FPExtOp>(loc, elementType, operand); | ||||||
|  |       // cast from floating-point to narrower floating-point
 | ||||||
|  |       else | ||||||
|  |         return rewriter.create<FPTruncOp>(loc, elementType, operand); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   // int to float
 | ||||||
|  |   else if (origtype.isa<IntegerType>()) { | ||||||
|  |     if (elementType.isa<FloatType>()) | ||||||
|  |       return rewriter.create<SIToFPOp>(loc, elementType, operand); | ||||||
|  |   } | ||||||
|  |   llvm_unreachable("unsupported element type"); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| // Scalar unary ops for lowering ONNXSinhOp
 | // Scalar unary ops for lowering ONNXSinhOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
|  | @ -665,5 +709,6 @@ void populateLoweringONNXElementwiseOpPattern( | ||||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>, |       ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>, | ||||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, |       ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, | ||||||
|       ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>, |       ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>, | ||||||
|  |       ONNXElementwiseUnaryOpLowering<mlir::ONNXCastOp>, | ||||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx); |       ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -2018,3 +2018,88 @@ func @test_split_unknown_dimension(%arg0 : tensor<?x?x64xf32>) -> (tensor<*xf32> | ||||||
|   // CHECK: } |   // CHECK: } | ||||||
|   // CHECK: return [[RES_0]], [[RES_1]] : memref<?x2x64xf32>, memref<?x30x64xf32> |   // CHECK: return [[RES_0]], [[RES_1]] : memref<?x2x64xf32>, memref<?x30x64xf32> | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @cast_lowering_sametype(%arg0: tensor<f32>) -> tensor<f32> { | ||||||
|  |   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f32>) -> tensor<f32> | ||||||
|  |   "std.return"(%0) : (tensor<f32>) -> () | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: cast_lowering_sametype | ||||||
|  |   // CHECK: [[RES:%.+]] = alloc() : memref<f32> | ||||||
|  |   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f32> | ||||||
|  |   // CHECK: affine.store [[LOAD]], [[RES]][] : memref<f32> | ||||||
|  |   // CHECK: return [[RES]] : memref<f32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @cast_lowering_intfloat(%arg0: tensor<i64>) -> tensor<f32> { | ||||||
|  |   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<i64>) -> tensor<f32> | ||||||
|  |   "std.return"(%0) : (tensor<f32>) -> () | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: cast_lowering_intfloat | ||||||
|  |   // CHECK: [[RES:%.+]] = alloc() : memref<f32> | ||||||
|  |   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<i64> | ||||||
|  |   // CHECK: [[VAL:%.+]] = sitofp [[LOAD]] : i64 to f32 | ||||||
|  |   // CHECK: affine.store [[VAL]], [[RES]][] : memref<f32> | ||||||
|  |   // CHECK: return [[RES]] : memref<f32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @cast_lowering_floatint(%arg0: tensor<f32>) -> tensor<i64> { | ||||||
|  |   %0 = "onnx.Cast"(%arg0) {to = 7 : i64} : (tensor<f32>) -> tensor<i64> | ||||||
|  |   "std.return"(%0) : (tensor<i64>) -> () | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: cast_lowering_floatint | ||||||
|  |   // CHECK: [[RES:%.+]] = alloc() : memref<i64> | ||||||
|  |   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f32> | ||||||
|  |   // CHECK: [[VAL:%.+]] = fptosi [[LOAD]] : f32 to i64 | ||||||
|  |   // CHECK: affine.store [[VAL]], [[RES]][] : memref<i64> | ||||||
|  |   // CHECK: return [[RES]] : memref<i64> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @cast_lowering_f16f32(%arg0: tensor<f16>) -> tensor<f32> { | ||||||
|  |   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f16>) -> tensor<f32> | ||||||
|  |   "std.return"(%0) : (tensor<f32>) -> () | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: cast_lowering_f16f32 | ||||||
|  |   // CHECK: [[RES:%.+]] = alloc() : memref<f32> | ||||||
|  |   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f16> | ||||||
|  |   // CHECK: [[VAL:%.+]] = fpext [[LOAD]] : f16 to f32 | ||||||
|  |   // CHECK: affine.store [[VAL]], [[RES]][] : memref<f32> | ||||||
|  |   // CHECK: return [[RES]] : memref<f32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @cast_lowering_f64f32(%arg0: tensor<f64>) -> tensor<f32> { | ||||||
|  |   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f64>) -> tensor<f32> | ||||||
|  |   "std.return"(%0) : (tensor<f32>) -> () | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: cast_lowering_f64f32 | ||||||
|  |   // CHECK: [[RES:%.+]] = alloc() : memref<f32> | ||||||
|  |   // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f64> | ||||||
|  |   // CHECK: [[VAL:%.+]] = fptrunc [[LOAD]] : f64 to f32 | ||||||
|  |   // CHECK: affine.store [[VAL]], [[RES]][] : memref<f32> | ||||||
|  |   // CHECK: return [[RES]] : memref<f32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<10xf64>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: cast_lowering_f64f32_10 | ||||||
|  |   // CHECK: [[RES:%.+]] = alloc() : memref<10xf32> | ||||||
|  |   // CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1 | ||||||
|  |   // CHECK: krnl.iterate([[DEF_LOOPS]]) with ([[DEF_LOOPS]] -> %arg1 = 0 to 10) { | ||||||
|  |   // CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1] : memref<10xf64> | ||||||
|  |   // CHECK: [[FPTRUNC:%.+]] = fptrunc [[LOAD1]] : f64 to f32 | ||||||
|  |   // CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32> | ||||||
|  |   // CHECK: return [[RES]] : memref<10xf32> | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue