ScalerOp support non-float input using CastOp (#231)

* move scalerop to decompose

* change clang format

* change clang format

* add shape inference for scaler op

* fixing generated onnxop

* generate onnx.md

* redo get onnx.md and onnxop.td.inc using onnx 1.6

* Add shape inference for scaler op

* add benefit for scaler decompose and simplify scaler shape inference

* add scaler decompose benefit num and simplify shape inference

* add cast builder

Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
Anh Leu 2020-07-24 09:57:52 -05:00 committed by GitHub
parent 2e8f012195
commit e631283c71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 134 additions and 49 deletions

View File

@ -494,6 +494,14 @@ def ONNXCastOp:ONNX_Op<"Cast",
return resultTypes; return resultTypes;
} }
}]; }];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{
auto toAttr = to.getValue().getSExtValue();
auto resultType = mlir::UnrankedTensorType::get(
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr)));
build(builder, state, resultType, input, to);
}] >
];
} }
def ONNXCeilOp:ONNX_Op<"Ceil", def ONNXCeilOp:ONNX_Op<"Ceil",

View File

@ -63,10 +63,6 @@ def AttributeIsNull :
Constraint<CPred<"! ($_self)">, Constraint<CPred<"! ($_self)">,
"Attribute is null">; "Attribute is null">;
def AttributeNotNull :
Constraint<CPred<" ($_self)">,
"Attribute exists">;
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>; def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
def GetNullAttr : def GetNullAttr :
@ -83,20 +79,30 @@ def ScalerNullPattern : Pat<
(ONNXScalerOp $x, $a, $b), (ONNXScalerOp $x, $a, $b),
(replaceWithValue $x), (replaceWithValue $x),
[(HasFloatType:$x), (AttributeIsNull:$a), (AttributeIsNull:$b)], [(HasFloatType:$x), (AttributeIsNull:$a), (AttributeIsNull:$b)],
(addBenefit 4)>; (addBenefit 5)>;
// No attribute, input x not float type // No attribute, input x not float type
def ScalerNullPattern2 : Pat< def ScalerNullPattern2 : Pat<
(ONNXScalerOp $x, $a, $b), (ONNXScalerOp $x, $a, $b),
(ONNXCastOp $x, (ScalerT)), (ONNXCastOp $x, (ScalerT)),
[(AttributeIsNull:$a), (AttributeIsNull:$b)], [(AttributeIsNull:$a), (AttributeIsNull:$b)],
(addBenefit 3)>; (addBenefit 4)>;
// No scale // No scale
def ScalerNoScalePattern : Pat< def ScalerNoScalePattern : Pat<
(ONNXScalerOp $x, $a, $b), (ONNXScalerOp $x, $a, $b),
(ONNXSubOp $x, (ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))), (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
[(HasFloatType:$x), (AttributeIsNull:$b)],
(addBenefit 3)>;
// No scale, input x not float type
def ScalerNoScalePattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXSubOp
(ONNXCastOp $x, (ScalerT)),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))
),
[(AttributeIsNull:$b)], [(AttributeIsNull:$b)],
(addBenefit 2)>; (addBenefit 2)>;
@ -105,6 +111,15 @@ def ScalerNoOffsetPattern : Pat<
(ONNXScalerOp $x, $a, $b), (ONNXScalerOp $x, $a, $b),
(ONNXMulOp $x, (ONNXMulOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))), (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[(HasFloatType:$x), (AttributeIsNull:$a)],
(addBenefit 3)>;
// No offset, input x not float type
def ScalerNoOffsetPattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp
(ONNXCastOp $x, (ScalerT)),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[(AttributeIsNull:$a)], [(AttributeIsNull:$a)],
(addBenefit 2)>; (addBenefit 2)>;
@ -115,7 +130,17 @@ def ScalerPattern : Pat<
(ONNXSubOp $x, (ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))), (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))), (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[], [(HasFloatType:$x)],
(addBenefit 1)>; (addBenefit 1)>;
// Normal ONNXScalerOp, input x not float type
def ScalerPattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp
(ONNXSubOp (ONNXCastOp $x, (ScalerT)),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[],
(addBenefit 0)>;
#endif // ONNX_DECOMPOSE #endif // ONNX_DECOMPOSE

View File

@ -97,6 +97,20 @@ func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> {
// ----- // -----
// scaler no offset, int input
// CHECK-LABEL: func @test_scaler_no_offset2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
func @test_scaler_no_offset2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
// CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Mul"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %2 : tensor<3xf32>
}
// -----
// scaler no scale // scaler no scale
// CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> { // CHECK-LABEL: func @test_scaler_no_scale(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> { func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
@ -110,6 +124,20 @@ func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
// ----- // -----
// scaler no scale, int input
// CHECK-LABEL: func @test_scaler_no_scale2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
func @test_scaler_no_scale2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
//CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
//CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
//CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
//CHECK-NEXT: return %2 : tensor<3xf32>
}
// -----
// normal scaler // normal scaler
// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> { // CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> { func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
@ -125,6 +153,22 @@ func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
// ----- // -----
// normal scaler, int input
// CHECK-LABEL: func @test_scaler_normal2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
func @test_scaler_normal2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
// CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32>
// CHECK-NEXT: %3 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %4 = "onnx.Mul"(%2, %3) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %4 : tensor<3xf32>
}
// -----
// normal scaler with constant offset and scale // normal scaler with constant offset and scale
// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> { // CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> {
func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> { func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {

View File

@ -323,10 +323,18 @@ custom_definition_misc = dict([ ('Constant',
build(builder, state, tensorType, sparse_value, value); build(builder, state, tensorType, sparse_value, value);
} }
}]> }]>
];'''),
('Cast',
''' let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{
auto toAttr = to.getValue().getSExtValue();
auto resultType = mlir::UnrankedTensorType::get(
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr)));
build(builder, state, resultType, input, to);
}] >
];''' ];'''
)]) )])
onnx_types = ( onnx_types = (
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16', 'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
'float', 'double', 'complex64', 'complex128', 'string' 'float', 'double', 'complex64', 'complex128', 'string'