Add CastOp lowering (#259)
* move scalerop to decompose * change clang format * change clang format * add shape inference for scaler op * fixing generated onnxop * generate onnx.md * add benefit for scaler decompose and simplify scaler shape inference * cast rewrite only for float * add cast op same type rewrite rule * working on cast lowering * cast lowering working * correct onnx version * update onnx md * add test for tensor<10xf64>
This commit is contained in:
parent
e1386b0689
commit
2ee725d939
|
@ -57,6 +57,15 @@ static onnx::TensorProto::DataType llvmTypeToOnnxType(
|
|||
return onnx::TensorProto::UINT32;
|
||||
if (elemType.isUnsignedInteger(64))
|
||||
return onnx::TensorProto::INT64;
|
||||
// LLVM Dialect does not have signed/unsigned int, only signless int
|
||||
if (elemType.isIntegerTy(8))
|
||||
return onnx::TensorProto::INT8;
|
||||
if (elemType.isIntegerTy(16))
|
||||
return onnx::TensorProto::INT16;
|
||||
if (elemType.isIntegerTy(32))
|
||||
return onnx::TensorProto::INT32;
|
||||
if (elemType.isIntegerTy(64))
|
||||
return onnx::TensorProto::INT64;
|
||||
// Complex types don't seem to exist in LLVM Dialect.
|
||||
elemType.dump();
|
||||
llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type.");
|
||||
|
|
|
@ -84,6 +84,50 @@ struct ScalarOp<ONNXSqrtOp> {
|
|||
using IOp = SqrtOp; // not use
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Scalar unary ops for lowering ONNXCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value emitScalarOpFor<ONNXCastOp>(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Operation *op, Type elementType,
|
||||
ArrayRef<Value> scalarOperands) {
|
||||
ONNXCastOp castOp = llvm::dyn_cast<ONNXCastOp>(op);
|
||||
auto mlirtype = convertONNXTypeToMLIRType(rewriter,
|
||||
static_cast<onnx::TensorProto_DataType>(castOp.toAttr().getInt()));
|
||||
Value operand = scalarOperands[0];
|
||||
auto origtype = operand.getType();
|
||||
|
||||
// check output type is the same as expected output type
|
||||
if (elementType != mlirtype)
|
||||
llvm_unreachable("output type different from expected output type");
|
||||
|
||||
// if same input and output type, return input
|
||||
if (origtype == elementType)
|
||||
return operand;
|
||||
|
||||
if (origtype.isa<FloatType>()) {
|
||||
// cast from floating-point type to integer type
|
||||
if (elementType.isa<IntegerType>())
|
||||
return rewriter.create<FPToSIOp>(loc, elementType, operand);
|
||||
// cast from floating-point type to other floating-point type
|
||||
else if (elementType.isa<FloatType>()) {
|
||||
// cast from floating-point to wider floating-point
|
||||
if (origtype.getIntOrFloatBitWidth() <
|
||||
elementType.getIntOrFloatBitWidth())
|
||||
return rewriter.create<FPExtOp>(loc, elementType, operand);
|
||||
// cast from floating-point to narrower floating-point
|
||||
else
|
||||
return rewriter.create<FPTruncOp>(loc, elementType, operand);
|
||||
}
|
||||
}
|
||||
// int to float
|
||||
else if (origtype.isa<IntegerType>()) {
|
||||
if (elementType.isa<FloatType>())
|
||||
return rewriter.create<SIToFPOp>(loc, elementType, operand);
|
||||
}
|
||||
llvm_unreachable("unsupported element type");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Scalar unary ops for lowering ONNXSinhOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -665,5 +709,6 @@ void populateLoweringONNXElementwiseOpPattern(
|
|||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXCastOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx);
|
||||
}
|
||||
|
|
|
@ -2018,3 +2018,88 @@ func @test_split_unknown_dimension(%arg0 : tensor<?x?x64xf32>) -> (tensor<*xf32>
|
|||
// CHECK: }
|
||||
// CHECK: return [[RES_0]], [[RES_1]] : memref<?x2x64xf32>, memref<?x30x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast_lowering_sametype(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f32>) -> tensor<f32>
|
||||
"std.return"(%0) : (tensor<f32>) -> ()
|
||||
|
||||
// CHECK-LABEL: cast_lowering_sametype
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f32>
|
||||
// CHECK: affine.store [[LOAD]], [[RES]][] : memref<f32>
|
||||
// CHECK: return [[RES]] : memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast_lowering_intfloat(%arg0: tensor<i64>) -> tensor<f32> {
|
||||
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<i64>) -> tensor<f32>
|
||||
"std.return"(%0) : (tensor<f32>) -> ()
|
||||
|
||||
// CHECK-LABEL: cast_lowering_intfloat
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<i64>
|
||||
// CHECK: [[VAL:%.+]] = sitofp [[LOAD]] : i64 to f32
|
||||
// CHECK: affine.store [[VAL]], [[RES]][] : memref<f32>
|
||||
// CHECK: return [[RES]] : memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast_lowering_floatint(%arg0: tensor<f32>) -> tensor<i64> {
|
||||
%0 = "onnx.Cast"(%arg0) {to = 7 : i64} : (tensor<f32>) -> tensor<i64>
|
||||
"std.return"(%0) : (tensor<i64>) -> ()
|
||||
|
||||
// CHECK-LABEL: cast_lowering_floatint
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<i64>
|
||||
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f32>
|
||||
// CHECK: [[VAL:%.+]] = fptosi [[LOAD]] : f32 to i64
|
||||
// CHECK: affine.store [[VAL]], [[RES]][] : memref<i64>
|
||||
// CHECK: return [[RES]] : memref<i64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast_lowering_f16f32(%arg0: tensor<f16>) -> tensor<f32> {
|
||||
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f16>) -> tensor<f32>
|
||||
"std.return"(%0) : (tensor<f32>) -> ()
|
||||
|
||||
// CHECK-LABEL: cast_lowering_f16f32
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f16>
|
||||
// CHECK: [[VAL:%.+]] = fpext [[LOAD]] : f16 to f32
|
||||
// CHECK: affine.store [[VAL]], [[RES]][] : memref<f32>
|
||||
// CHECK: return [[RES]] : memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast_lowering_f64f32(%arg0: tensor<f64>) -> tensor<f32> {
|
||||
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f64>) -> tensor<f32>
|
||||
"std.return"(%0) : (tensor<f32>) -> ()
|
||||
|
||||
// CHECK-LABEL: cast_lowering_f64f32
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f64>
|
||||
// CHECK: [[VAL:%.+]] = fptrunc [[LOAD]] : f64 to f32
|
||||
// CHECK: affine.store [[VAL]], [[RES]][] : memref<f32>
|
||||
// CHECK: return [[RES]] : memref<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<10xf64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: cast_lowering_f64f32_10
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1
|
||||
// CHECK: krnl.iterate([[DEF_LOOPS]]) with ([[DEF_LOOPS]] -> %arg1 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1] : memref<10xf64>
|
||||
// CHECK: [[FPTRUNC:%.+]] = fptrunc [[LOAD1]] : f64 to f32
|
||||
// CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32>
|
||||
// CHECK: return [[RES]] : memref<10xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue