diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp index 2763bea..c3f5a60 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp @@ -57,6 +57,15 @@ static onnx::TensorProto::DataType llvmTypeToOnnxType( return onnx::TensorProto::UINT32; if (elemType.isUnsignedInteger(64)) return onnx::TensorProto::INT64; + // LLVM Dialect does not have signed/unsigned int, only signless int + if (elemType.isIntegerTy(8)) + return onnx::TensorProto::INT8; + if (elemType.isIntegerTy(16)) + return onnx::TensorProto::INT16; + if (elemType.isIntegerTy(32)) + return onnx::TensorProto::INT32; + if (elemType.isIntegerTy(64)) + return onnx::TensorProto::INT64; // Complex types don't seem to exist in LLVM Dialect. elemType.dump(); llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type."); diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 3bc2222..5a2d8e4 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -84,6 +84,50 @@ struct ScalarOp { using IOp = SqrtOp; // not use }; +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXCastOp +//===----------------------------------------------------------------------===// +template <> +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { + ONNXCastOp castOp = llvm::dyn_cast(op); + auto mlirtype = convertONNXTypeToMLIRType(rewriter, + static_cast(castOp.toAttr().getInt())); + Value operand = scalarOperands[0]; + auto origtype = operand.getType(); + + // check output type is the same as expected output type + if (elementType != mlirtype) + llvm_unreachable("output type different from expected output type"); + + // if same input and output type, return input + if (origtype == elementType) + return operand; + + if (origtype.isa()) { + // cast from floating-point type to integer type + if (elementType.isa()) + return rewriter.create(loc, elementType, operand); + // cast from floating-point type to other floating-point type + else if (elementType.isa()) { + // cast from floating-point to wider floating-point + if (origtype.getIntOrFloatBitWidth() < + elementType.getIntOrFloatBitWidth()) + return rewriter.create(loc, elementType, operand); + // cast from floating-point to narrower floating-point + else + return rewriter.create(loc, elementType, operand); + } + } + // int to float + else if (origtype.isa()) { + if (elementType.isa()) + return rewriter.create(loc, elementType, operand); + } + llvm_unreachable("unsupported element type"); +} + //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// @@ -665,5 +709,6 @@ void populateLoweringONNXElementwiseOpPattern( ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering>(ctx); } diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index ab303d0..690a991 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -2018,3 +2018,88 @@ func @test_split_unknown_dimension(%arg0 : tensor) -> (tensor<*xf32> // CHECK: } // CHECK: return [[RES_0]], [[RES_1]] : memref, memref } + +// ----- + +func @cast_lowering_sametype(%arg0: tensor) -> tensor { + %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor) -> tensor + "std.return"(%0) : (tensor) -> () + + // CHECK-LABEL: cast_lowering_sametype + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref + // CHECK: affine.store [[LOAD]], [[RES]][] : memref + // CHECK: return [[RES]] : memref +} + +// ----- + +func @cast_lowering_intfloat(%arg0: tensor) -> tensor { + %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor) -> tensor + "std.return"(%0) : (tensor) -> () + + // CHECK-LABEL: cast_lowering_intfloat + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref + // CHECK: [[VAL:%.+]] = sitofp [[LOAD]] : i64 to f32 + // CHECK: affine.store [[VAL]], [[RES]][] : memref + // CHECK: return [[RES]] : memref +} + +// ----- + +func @cast_lowering_floatint(%arg0: tensor) -> tensor { + %0 = "onnx.Cast"(%arg0) {to = 7 : i64} : (tensor) -> tensor + "std.return"(%0) : (tensor) -> () + + // CHECK-LABEL: cast_lowering_floatint + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref + // CHECK: [[VAL:%.+]] = fptosi [[LOAD]] : f32 to i64 + // CHECK: affine.store [[VAL]], [[RES]][] : memref + // CHECK: return [[RES]] : memref +} + +// ----- + +func @cast_lowering_f16f32(%arg0: tensor) -> tensor { + %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor) -> tensor + "std.return"(%0) : (tensor) -> () + + // CHECK-LABEL: cast_lowering_f16f32 + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref + // CHECK: [[VAL:%.+]] = fpext [[LOAD]] : f16 to f32 + // CHECK: affine.store [[VAL]], [[RES]][] : memref + // CHECK: return [[RES]] : memref +} + +// ----- + +func @cast_lowering_f64f32(%arg0: tensor) -> tensor { + %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor) -> tensor + "std.return"(%0) : (tensor) -> () + + // CHECK-LABEL: cast_lowering_f64f32 + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref + // CHECK: [[VAL:%.+]] = fptrunc [[LOAD]] : f64 to f32 + // CHECK: affine.store [[VAL]], [[RES]][] : memref + // CHECK: return [[RES]] : memref +} + +// ----- + +func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> { + %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<10xf64>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: cast_lowering_f64f32_10 + // CHECK: [[RES:%.+]] = alloc() : memref<10xf32> + // CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1 + // CHECK: krnl.iterate([[DEF_LOOPS]]) with ([[DEF_LOOPS]] -> %arg1 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1] : memref<10xf64> + // CHECK: [[FPTRUNC:%.+]] = fptrunc [[LOAD1]] : f64 to f32 + // CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32> + // CHECK: return [[RES]] : memref<10xf32> +}