Add CastOp lowering (#259)

* move scalerop to decompose

* change clang format

* change clang format

* add shape inference for scaler op

* fixing generated onnxop

* generate onnx.md

* add benefit for scaler decompose and simplify scaler shape inference

* cast rewrite only for float

* add cast op same type rewrite rule

* working on cast lowering

* cast lowering working

* correct onnx version

* update onnx md

* add test for tensor<10xf64>
This commit is contained in:
Anh Leu 2020-08-11 15:07:13 -05:00 committed by GitHub
parent e1386b0689
commit 2ee725d939
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 139 additions and 0 deletions

View File

@ -57,6 +57,15 @@ static onnx::TensorProto::DataType llvmTypeToOnnxType(
return onnx::TensorProto::UINT32; return onnx::TensorProto::UINT32;
if (elemType.isUnsignedInteger(64)) if (elemType.isUnsignedInteger(64))
return onnx::TensorProto::INT64; return onnx::TensorProto::INT64;
// LLVM Dialect does not have signed/unsigned int, only signless int
if (elemType.isIntegerTy(8))
return onnx::TensorProto::INT8;
if (elemType.isIntegerTy(16))
return onnx::TensorProto::INT16;
if (elemType.isIntegerTy(32))
return onnx::TensorProto::INT32;
if (elemType.isIntegerTy(64))
return onnx::TensorProto::INT64;
// Complex types don't seem to exist in LLVM Dialect. // Complex types don't seem to exist in LLVM Dialect.
elemType.dump(); elemType.dump();
llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type."); llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type.");

View File

@ -84,6 +84,50 @@ struct ScalarOp<ONNXSqrtOp> {
using IOp = SqrtOp; // not use using IOp = SqrtOp; // not use
}; };
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXCastOp
//===----------------------------------------------------------------------===//
template <>
Value emitScalarOpFor<ONNXCastOp>(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, Type elementType,
ArrayRef<Value> scalarOperands) {
ONNXCastOp castOp = llvm::dyn_cast<ONNXCastOp>(op);
auto mlirtype = convertONNXTypeToMLIRType(rewriter,
static_cast<onnx::TensorProto_DataType>(castOp.toAttr().getInt()));
Value operand = scalarOperands[0];
auto origtype = operand.getType();
// check output type is the same as expected output type
if (elementType != mlirtype)
llvm_unreachable("output type different from expected output type");
// if same input and output type, return input
if (origtype == elementType)
return operand;
if (origtype.isa<FloatType>()) {
// cast from floating-point type to integer type
if (elementType.isa<IntegerType>())
return rewriter.create<FPToSIOp>(loc, elementType, operand);
// cast from floating-point type to other floating-point type
else if (elementType.isa<FloatType>()) {
// cast from floating-point to wider floating-point
if (origtype.getIntOrFloatBitWidth() <
elementType.getIntOrFloatBitWidth())
return rewriter.create<FPExtOp>(loc, elementType, operand);
// cast from floating-point to narrower floating-point
else
return rewriter.create<FPTruncOp>(loc, elementType, operand);
}
}
// int to float
else if (origtype.isa<IntegerType>()) {
if (elementType.isa<FloatType>())
return rewriter.create<SIToFPOp>(loc, elementType, operand);
}
llvm_unreachable("unsupported element type");
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXSinhOp // Scalar unary ops for lowering ONNXSinhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -665,5 +709,6 @@ void populateLoweringONNXElementwiseOpPattern(
ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCastOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx); ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx);
} }

View File

@ -2018,3 +2018,88 @@ func @test_split_unknown_dimension(%arg0 : tensor<?x?x64xf32>) -> (tensor<*xf32>
// CHECK: } // CHECK: }
// CHECK: return [[RES_0]], [[RES_1]] : memref<?x2x64xf32>, memref<?x30x64xf32> // CHECK: return [[RES_0]], [[RES_1]] : memref<?x2x64xf32>, memref<?x30x64xf32>
} }
// -----
func @cast_lowering_sametype(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f32>) -> tensor<f32>
"std.return"(%0) : (tensor<f32>) -> ()
// CHECK-LABEL: cast_lowering_sametype
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f32>
// CHECK: affine.store [[LOAD]], [[RES]][] : memref<f32>
// CHECK: return [[RES]] : memref<f32>
}
// -----
func @cast_lowering_intfloat(%arg0: tensor<i64>) -> tensor<f32> {
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<i64>) -> tensor<f32>
"std.return"(%0) : (tensor<f32>) -> ()
// CHECK-LABEL: cast_lowering_intfloat
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<i64>
// CHECK: [[VAL:%.+]] = sitofp [[LOAD]] : i64 to f32
// CHECK: affine.store [[VAL]], [[RES]][] : memref<f32>
// CHECK: return [[RES]] : memref<f32>
}
// -----
func @cast_lowering_floatint(%arg0: tensor<f32>) -> tensor<i64> {
%0 = "onnx.Cast"(%arg0) {to = 7 : i64} : (tensor<f32>) -> tensor<i64>
"std.return"(%0) : (tensor<i64>) -> ()
// CHECK-LABEL: cast_lowering_floatint
// CHECK: [[RES:%.+]] = alloc() : memref<i64>
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f32>
// CHECK: [[VAL:%.+]] = fptosi [[LOAD]] : f32 to i64
// CHECK: affine.store [[VAL]], [[RES]][] : memref<i64>
// CHECK: return [[RES]] : memref<i64>
}
// -----
func @cast_lowering_f16f32(%arg0: tensor<f16>) -> tensor<f32> {
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f16>) -> tensor<f32>
"std.return"(%0) : (tensor<f32>) -> ()
// CHECK-LABEL: cast_lowering_f16f32
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f16>
// CHECK: [[VAL:%.+]] = fpext [[LOAD]] : f16 to f32
// CHECK: affine.store [[VAL]], [[RES]][] : memref<f32>
// CHECK: return [[RES]] : memref<f32>
}
// -----
func @cast_lowering_f64f32(%arg0: tensor<f64>) -> tensor<f32> {
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<f64>) -> tensor<f32>
"std.return"(%0) : (tensor<f32>) -> ()
// CHECK-LABEL: cast_lowering_f64f32
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
// CHECK: [[LOAD:%.+]] = affine.load %arg0[] : memref<f64>
// CHECK: [[VAL:%.+]] = fptrunc [[LOAD]] : f64 to f32
// CHECK: affine.store [[VAL]], [[RES]][] : memref<f32>
// CHECK: return [[RES]] : memref<f32>
}
// -----
func @cast_lowering_f64f32_10(%arg0: tensor<10xf64>) -> tensor<*xf32> {
%0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<10xf64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: cast_lowering_f64f32_10
// CHECK: [[RES:%.+]] = alloc() : memref<10xf32>
// CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[DEF_LOOPS]]) with ([[DEF_LOOPS]] -> %arg1 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = affine.load %arg0[%arg1] : memref<10xf64>
// CHECK: [[FPTRUNC:%.+]] = fptrunc [[LOAD1]] : f64 to f32
// CHECK: affine.store [[FPTRUNC]], [[RES]][%arg1] : memref<10xf32>
// CHECK: return [[RES]] : memref<10xf32>
}