diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index cd484ec..9f5386f 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -494,6 +494,14 @@ def ONNXCastOp:ONNX_Op<"Cast", return resultTypes; } }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{ + auto toAttr = to.getValue().getSExtValue(); + auto resultType = mlir::UnrankedTensorType::get( + convertONNXTypeToMLIRType(builder, static_cast(toAttr))); + build(builder, state, resultType, input, to); + }] > + ]; } def ONNXCeilOp:ONNX_Op<"Ceil", @@ -648,17 +656,17 @@ def ONNXConstantOp:ONNX_Op<"Constant", return resultTypes; } }]; - let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ - if (value) { - auto tensorType = value.getType(); - build(builder, state, tensorType, sparse_value, value); - } else { - auto tensorType = sparse_value.getType(); - build(builder, state, tensorType, sparse_value, value); - } - }]> - ]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ + if (value) { + auto tensorType = value.getType(); + build(builder, state, tensorType, sparse_value, value); + } else { + auto tensorType = sparse_value.getType(); + build(builder, state, tensorType, sparse_value, value); + } + }]> + ]; } def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", diff --git a/src/Transform/ONNX/Decompose.td b/src/Transform/ONNX/Decompose.td index 730000d..4ca6023 100644 --- a/src/Transform/ONNX/Decompose.td +++ b/src/Transform/ONNX/Decompose.td @@ -63,10 +63,6 @@ def AttributeIsNull : Constraint, "Attribute is null">; -def AttributeNotNull : -Constraint, - "Attribute exists">; - def HasFloatType : Constraint().getElementType().isF32())">>; def GetNullAttr : @@ -83,20 +79,30 @@ def ScalerNullPattern : Pat< (ONNXScalerOp $x, $a, $b), (replaceWithValue $x), [(HasFloatType:$x), (AttributeIsNull:$a), (AttributeIsNull:$b)], - (addBenefit 4)>; + (addBenefit 5)>; // No attribute, input x not float type def ScalerNullPattern2 : Pat< (ONNXScalerOp $x, $a, $b), (ONNXCastOp $x, (ScalerT)), [(AttributeIsNull:$a), (AttributeIsNull:$b)], - (addBenefit 3)>; + (addBenefit 4)>; // No scale def ScalerNoScalePattern : Pat< (ONNXScalerOp $x, $a, $b), (ONNXSubOp $x, (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))), + [(HasFloatType:$x), (AttributeIsNull:$b)], + (addBenefit 3)>; + +// No scale, input x not float type +def ScalerNoScalePattern2 : Pat< + (ONNXScalerOp $x, $a, $b), + (ONNXSubOp + (ONNXCastOp $x, (ScalerT)), + (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a)) + ), [(AttributeIsNull:$b)], (addBenefit 2)>; @@ -105,6 +111,15 @@ def ScalerNoOffsetPattern : Pat< (ONNXScalerOp $x, $a, $b), (ONNXMulOp $x, (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))), + [(HasFloatType:$x), (AttributeIsNull:$a)], + (addBenefit 3)>; + +// No offset, input x not float type +def ScalerNoOffsetPattern2 : Pat< + (ONNXScalerOp $x, $a, $b), + (ONNXMulOp + (ONNXCastOp $x, (ScalerT)), + (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))), [(AttributeIsNull:$a)], (addBenefit 2)>; @@ -115,7 +130,17 @@ def ScalerPattern : Pat< (ONNXSubOp $x, (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))), (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))), - [], + [(HasFloatType:$x)], (addBenefit 1)>; +// Normal ONNXScalerOp, input x not float type +def ScalerPattern2 : Pat< + (ONNXScalerOp $x, $a, $b), + (ONNXMulOp + (ONNXSubOp (ONNXCastOp $x, (ScalerT)), + (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $a))), + (ONNXConstantOp (GetNullAttr), (createDenseArrayAttr $b))), + [], + (addBenefit 0)>; + #endif // ONNX_DECOMPOSE diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index d792af0..a81bf94 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -78,8 +78,8 @@ func @test_scaler_null(%arg0: tensor<3xi32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) : (tensor<3xi32>) -> tensor<3xf32> return %0 : tensor<3xf32> - // CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<3xf32> - // CHECK-NEXT: return %0 : tensor<3xf32> + // CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<3xf32> + // CHECK-NEXT: return %0 : tensor<3xf32> } // ----- @@ -90,9 +90,23 @@ func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> - // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32> - // CHECK-NEXT: %1 = "onnx.Mul"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> - // CHECK-NEXT: return %1 : tensor<3xf32> + // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: %1 = "onnx.Mul"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + // CHECK-NEXT: return %1 : tensor<3xf32> +} + +// ----- + +// scaler no offset, int input +// CHECK-LABEL: func @test_scaler_no_offset2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> { +func @test_scaler_no_offset2(%arg0: tensor<3xi32>) -> tensor<3xf32> { + %0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32> + return %0 : tensor<3xf32> + + // CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32> + // CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: %2 = "onnx.Mul"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> + // CHECK-NEXT: return %2 : tensor<3xf32> } // ----- @@ -103,9 +117,23 @@ func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> - // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32> - // CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> - // CHECK-NEXT: return %1 : tensor<3xf32> + // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + // CHECK-NEXT: return %1 : tensor<3xf32> +} + +// ----- + +// scaler no scale, int input +// CHECK-LABEL: func @test_scaler_no_scale2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> { +func @test_scaler_no_scale2(%arg0: tensor<3xi32>) -> tensor<3xf32> { + %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32> + return %0 : tensor<3xf32> + + //CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32> + //CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32> + //CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> + //CHECK-NEXT: return %2 : tensor<3xf32> } // ----- @@ -116,11 +144,27 @@ func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> - // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32> - // CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> - // CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32> - // CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> - // CHECK-NEXT: return %3 : tensor<3xf32> + // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + // CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + // CHECK-NEXT: return %3 : tensor<3xf32> +} + +// ----- + +// normal scaler, int input +// CHECK-LABEL: func @test_scaler_normal2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> { +func @test_scaler_normal2(%arg0: tensor<3xi32>) -> tensor<3xf32> { + %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32> + return %0 : tensor<3xf32> + + // CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<3xi32>) -> tensor<*xf32> + // CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + // CHECK-NEXT: %3 = "onnx.Constant"() {value = dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: %4 = "onnx.Mul"(%2, %3) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> + // CHECK-NEXT: return %4 : tensor<3xf32> } // ----- @@ -131,9 +175,9 @@ func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32], scale = [3.125000e-02 : f32]} : (tensor<3xf32>) -> tensor<3xf32> return %0 : tensor<3xf32> - // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1986.99939> : tensor<1xf32>} : () -> tensor<1xf32> - // CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> - // CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<3.125000e-02> : tensor<1xf32>} : () -> tensor<1xf32> - // CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> - // CHECK-NEXT: return %3 : tensor<3xf32> + // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1986.99939> : tensor<1xf32>} : () -> tensor<1xf32> + // CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> + // CHECK-NEXT: %2 = "onnx.Constant"() {value = dense<3.125000e-02> : tensor<1xf32>} : () -> tensor<1xf32> + // CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> + // CHECK-NEXT: return %3 : tensor<3xf32> } diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 264a324..83258db 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -313,19 +313,27 @@ custom_builder_ops_list = custom_builder_unranked_ops_list + custom_builder_broa #a dictionary to add any special definition for an operation custom_definition_misc = dict([ ('Constant', - ''' let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ - if (value) { - auto tensorType = value.getType(); - build(builder, state, tensorType, sparse_value, value); - } else { - auto tensorType = sparse_value.getType(); - build(builder, state, tensorType, sparse_value, value); - } - }]> - ];''' - )]) - + ''' let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ + if (value) { + auto tensorType = value.getType(); + build(builder, state, tensorType, sparse_value, value); + } else { + auto tensorType = sparse_value.getType(); + build(builder, state, tensorType, sparse_value, value); + } + }]> + ];'''), + ('Cast', + ''' let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value input, IntegerAttr to", [{ + auto toAttr = to.getValue().getSExtValue(); + auto resultType = mlir::UnrankedTensorType::get( + convertONNXTypeToMLIRType(builder, static_cast(toAttr))); + build(builder, state, resultType, input, to); + }] > + ];''' + )]) onnx_types = ( 'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',