ScalerOp support non-float input using CastOp (#231)

* move scalerop to decompose

* change clang format

* change clang format

* add shape inference for scaler op

* fixing generated onnxop

* generate onnx.md

* redo get onnx.md and onnxop.td.inc using onnx 1.6

* Add shape inference for scaler op

* add benefit for scaler decompose and simplify scaler shape inference

* add scaler decompose benefit num and simplify shape inference

* add cast builder

Co-authored-by: chentong319 <chentong@us.ibm.com>
This commit is contained in:
Anh Leu 2020-07-24 09:57:52 -05:00 committed by GitHub
parent 2e8f012195
commit e631283c71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 134 additions and 49 deletions

View File

@ -494,6 +494,14 @@ def ONNXCastOp:ONNX_Op<"Cast",
return resultTypes;
}
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{
auto toAttr = to.getValue().getSExtValue();
auto resultType = mlir::UnrankedTensorType::get(
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr)));
build(builder, state, resultType, input, to);
}] >
];
}
def ONNXCeilOp:ONNX_Op<"Ceil",
@ -648,17 +656,17 @@ def ONNXConstantOp:ONNX_Op<"Constant",
return resultTypes;
}
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{
if (value) {
auto tensorType = value.getType();
build(builder, state, tensorType, sparse_value, value);
} else {
auto tensorType = sparse_value.getType();
build(builder, state, tensorType, sparse_value, value);
}
}]>
];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{
if (value) {
auto tensorType = value.getType();
build(builder, state, tensorType, sparse_value, value);
} else {
auto tensorType = sparse_value.getType();
build(builder, state, tensorType, sparse_value, value);
}
}]>
];
}
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",

View File

@ -63,10 +63,6 @@ def AttributeIsNull :
Constraint<CPred<"! ($_self)">,
"Attribute is null">;
def AttributeNotNull :
Constraint<CPred<" ($_self)">,
"Attribute exists">;
def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>().getElementType().isF32())">>;
def GetNullAttr :
@ -83,20 +79,30 @@ def ScalerNullPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(replaceWithValue $x),
[(HasFloatType:$x), (AttributeIsNull:$a), (AttributeIsNull:$b)],
(addBenefit 4)>;
(addBenefit 5)>;
// No attribute, input x not float type
def ScalerNullPattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXCastOp $x, (ScalerT)),
[(AttributeIsNull:$a), (AttributeIsNull:$b)],
(addBenefit 3)>;
(addBenefit 4)>;
// No scale
def ScalerNoScalePattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
[(HasFloatType:$x), (AttributeIsNull:$b)],
(addBenefit 3)>;
// No scale, input x not float type
def ScalerNoScalePattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXSubOp
(ONNXCastOp $x, (ScalerT)),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))
),
[(AttributeIsNull:$b)],
(addBenefit 2)>;
@ -105,6 +111,15 @@ def ScalerNoOffsetPattern : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[(HasFloatType:$x), (AttributeIsNull:$a)],
(addBenefit 3)>;
// No offset, input x not float type
def ScalerNoOffsetPattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp
(ONNXCastOp $x, (ScalerT)),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[(AttributeIsNull:$a)],
(addBenefit 2)>;
@ -115,7 +130,17 @@ def ScalerPattern : Pat<
(ONNXSubOp $x,
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[],
[(HasFloatType:$x)],
(addBenefit 1)>;
// Normal ONNXScalerOp, input x not float type
def ScalerPattern2 : Pat<
(ONNXScalerOp $x, $a, $b),
(ONNXMulOp
(ONNXSubOp (ONNXCastOp $x, (ScalerT)),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))),
(ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))),
[],
(addBenefit 0)>;
#endif // ONNX_DECOMPOSE

View File

@ -78,8 +78,8 @@ func @test_scaler_null(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<3xf32>
// CHECK-NEXT: return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<3xf32>
// CHECK-NEXT: return %0 : tensor<3xf32>
}
// -----
@ -90,9 +90,23 @@ func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Mul"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Mul"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
}
// -----
// scaler no offset, int input
// CHECK-LABEL: func @test_scaler_no_offset2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
func @test_scaler_no_offset2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
// CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Mul"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %2 : tensor<3xf32>
}
// -----
@ -103,9 +117,23 @@ func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %1 : tensor<3xf32>
}
// -----
// scaler no scale, int input
// CHECK-LABEL: func @test_scaler_no_scale2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
func @test_scaler_no_scale2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
//CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
//CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
//CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
//CHECK-NEXT: return %2 : tensor<3xf32>
}
// -----
@ -116,11 +144,27 @@ func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
}
// -----
// normal scaler, int input
// CHECK-LABEL: func @test_scaler_normal2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> {
func @test_scaler_normal2(%arg0: tensor<3xi32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32>
// CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32>
// CHECK-NEXT: %3 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: %4 = "onnx.Mul"(%2, %3) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %4 : tensor<3xf32>
}
// -----
@ -131,9 +175,9 @@ func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32], scale = [3.125000e-02 : f32]} : (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1986.99939> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<3.125000e-02> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1986.99939> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<3.125000e-02> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %3 : tensor<3xf32>
}

View File

@ -313,19 +313,27 @@ custom_builder_ops_list = custom_builder_unranked_ops_list + custom_builder_broa
#a dictionary to add any special definition for an operation
custom_definition_misc = dict([ ('Constant',
''' let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{
if (value) {
auto tensorType = value.getType();
build(builder, state, tensorType, sparse_value, value);
} else {
auto tensorType = sparse_value.getType();
build(builder, state, tensorType, sparse_value, value);
}
}]>
];'''
)])
''' let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{
if (value) {
auto tensorType = value.getType();
build(builder, state, tensorType, sparse_value, value);
} else {
auto tensorType = sparse_value.getType();
build(builder, state, tensorType, sparse_value, value);
}
}]>
];'''),
('Cast',
''' let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{
auto toAttr = to.getValue().getSExtValue();
auto resultType = mlir::UnrankedTensorType::get(
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr)));
build(builder, state, resultType, input, to);
}] >
];'''
)])
onnx_types = (
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',