From 742e817722165dc3eef92bdcd16f61583e752561 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Mon, 15 Jun 2020 14:56:15 -0400 Subject: [PATCH] Constprop2 (#167) * initial const prop attempt * added support for broadcast ops * adde all binary broadcast ops into custom builders with precise type * added test example * working * format * fixed suggestion by Tung, start woring on unary * added subtraction and neg the right way, and added elementwise mul too * formatting changes * format * format * added instructions to add new optimizations * added propagation rules that always migrate constants toward the root of the expression, using assoc and commutativity * format comment --- src/Transform/ONNX/ConstProp.td | 72 +++++++++++++++++++++++++++- test/mlir/onnx/onnx_constprop.mlir | 77 ++++++++++++++++++++++-------- 2 files changed, 128 insertions(+), 21 deletions(-) diff --git a/src/Transform/ONNX/ConstProp.td b/src/Transform/ONNX/ConstProp.td index 649bc57..122b713 100644 --- a/src/Transform/ONNX/ConstProp.td +++ b/src/Transform/ONNX/ConstProp.td @@ -85,8 +85,42 @@ def AddConstAssociative1 : Pat< // To add(x, add(c1, c2)). (ONNXAddOp $x, - (ONNXAddOp $c1, $c2))>; + (ONNXAddOp $c1, $c2)), + [(IsNotAConstant:$x)]>; +def AddConstAssociative2 : Pat< + // From add(add(x, c), y). + (ONNXAddOp + (ONNXAddOp $x,(ONNXConstantOp:$c $_, $_)), + $y), + // To add(add(x, y), c). + (ONNXAddOp + (ONNXAddOp $x, $y), + $c), + [(IsNotAConstant:$x), (IsNotAConstant:$y)]>; + +def AddConstAssociative3 : Pat< + // From add(x, add(y, c)). + (ONNXAddOp + $x, + (ONNXAddOp $y,(ONNXConstantOp:$c $_, $_))), + // To add(add(x, y), c). + (ONNXAddOp + (ONNXAddOp $x, $y), + $c), + [(IsNotAConstant:$x), (IsNotAConstant:$y)]>; + +def AddConstAssociative4 : Pat< + // From add(add(x, c1), add(y, c2)). + (ONNXAddOp + (ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_)), + (ONNXAddOp $y,(ONNXConstantOp:$c2 $_, $_))), + // To add(add(x, y), c1+c2). + (ONNXAddOp + (ONNXAddOp $x, $y), + (ONNXAddOp $c1, $c2)), + [(IsNotAConstant:$x), (IsNotAConstant:$y)]>; + // Constant Propagation for Add def AddConstProp : Pat< // From add(c1, c2). @@ -150,7 +184,41 @@ def MulConstAssociative1 : Pat< // To mul(x, mul(c1, c2)). (ONNXMulOp $x, - (ONNXMulOp $c1, $c2))>; + (ONNXMulOp $c1, $c2)), + [(IsNotAConstant:$x)]>; + +def MulConstAssociative2 : Pat< + // From mul(mul(x, c), y). + (ONNXMulOp + (ONNXMulOp $x,(ONNXConstantOp:$c $_, $_)), + $y), + // To mul(mul(x, y), c). + (ONNXMulOp + (ONNXMulOp $x, $y), + $c), + [(IsNotAConstant:$x), (IsNotAConstant:$y)]>; + +def MulConstAssociative3 : Pat< + // From mul(x, mul(y, c)). + (ONNXMulOp + $x, + (ONNXMulOp $y,(ONNXConstantOp:$c $_, $_))), + // To mul(mul(x, y), c). + (ONNXMulOp + (ONNXMulOp $x, $y), + $c), + [(IsNotAConstant:$x), (IsNotAConstant:$y)]>; + +def MulConstAssociative4 : Pat< + // From mul(mul(x, c1), mul(y, c2)). + (ONNXMulOp + (ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_)), + (ONNXMulOp $y,(ONNXConstantOp:$c2 $_, $_))), + // To mul(mul(x, y), c1+c2). + (ONNXMulOp + (ONNXMulOp $x, $y), + (ONNXMulOp $c1, $c2)), + [(IsNotAConstant:$x), (IsNotAConstant:$y)]>; // Constant Propagation for Mul def MulConstProp : Pat< diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index 6bbfdb8..aba651f 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -1,10 +1,11 @@ // RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s -// ============================================================================= -/// MUL tests (same as add, so have only one). + +//===----------------------------------------------------------------------===// +/// ADD tests /// Test ConstantOp assoc for add - +// ----- // CHECK-LABEL: @test_add_constant_1(%arg0: tensor<3xf32>) -> tensor<3xf32> func @test_add_constant_1(%arg0 : tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32> @@ -15,6 +16,7 @@ func @test_add_constant_1(%arg0 : tensor<3xf32>) -> tensor<3xf32> { } /// Test ConstantOp assoc for add +// ----- // CHECK-LABEL: @test_add_constant_2(%arg0: tensor<3xf32>) -> tensor<3xf32> func @test_add_constant_2(%arg0 : tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32> @@ -25,6 +27,7 @@ func @test_add_constant_2(%arg0 : tensor<3xf32>) -> tensor<3xf32> { } /// Change (x+c1)+c2 to x+(c1+c2) +// ----- // CHECK-LABEL: @test_add_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32> func @test_add_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> { %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> @@ -37,7 +40,8 @@ func @test_add_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> { } /// Same test as above, but with a use of an intermediary result -/// change (x+c1)+c2 + (x+c1) to x+(c1+c2) + (x+c1) +/// change (x+c1)+c2 + (x+c1) to x+x + (c1+c2+c3) +// ----- // CHECK-LABEL: @test_add_constant_4(%arg0: tensor<3xi32>) -> tensor<3xi32> func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> { %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> @@ -46,15 +50,30 @@ func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> { %3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> %4 = "onnx.Add"(%2, %3) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> "std.return"(%4) : (tensor<3xi32>) -> () - // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> - // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[CONST2:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32> - // CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"(%arg0, [[CONST2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> - // CHECK-NEXT: [[ADD3:%.+]] = "onnx.Add"([[ADD1]], [[ADD2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, %arg0) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[10, 13, 16]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"([[ADD1]], [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> +} + +/// Change (x+c0)+y + (z+c1) to (x+y)+z + (c1+c2) +// ----- +// CHECK-LABEL: @test_add_constant_5(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<3xi32> +func @test_add_constant_5(%arg0 : tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<3xi32> { + %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32> + %2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %3 = "onnx.Add"(%2, %arg1) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %4 = "onnx.Add"(%1, %arg2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %5 = "onnx.Add"(%3, %4) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + "std.return"(%5) : (tensor<3xi32>) -> () + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"([[ADD1]], %arg2) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NEXT: [[ADD3:%.+]] = "onnx.Add"([[ADD2]], [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> } /// Test broadcast 1 -> 2d - +// ----- // CHECK-LABEL: @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> func @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "onnx.Constant"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32> @@ -67,7 +86,7 @@ func @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { } /// Test broadcast 2d (size one) -> 2d - +// ----- // CHECK-LABEL: @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> func @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "onnx.Constant"() {value = dense<[[1]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32> @@ -80,8 +99,8 @@ func @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { } /// check 1d -> 2d - - // CHECK-LABEL: @test_broadcast_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> +// ----- +// CHECK-LABEL: @test_broadcast_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> func @test_broadcast_3(%arg0 : tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "onnx.Constant"() {value = dense<[[1], [2], [3]]> : tensor<3x1xi32>} : () -> tensor<3x1xi32> %1 = "onnx.Constant"() {value = dense<[[10, 11], [21, 22], [31, 32]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> @@ -92,10 +111,12 @@ func @test_broadcast_3(%arg0 : tensor<3x2xi32>) -> tensor<3x2xi32> { // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> } -// ============================================================================= -/// MUL tests (same as add, so have only one). + +//===----------------------------------------------------------------------===// +/// MUL tests (same as add, so have only two). /// Change (x*c1)*c2 to x*(c1*c2) +// ----- // CHECK-LABEL: @test_mul_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32> func @test_mul_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> { %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> @@ -107,12 +128,28 @@ func @test_mul_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> { // CHECK-NEXT: [[MUL1:%.+]] = "onnx.Mul"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> } -// ============================================================================= +/// Change (x*c0)*y * (z*c1) to (x*y)*z * (c1*c2) +// ----- +// CHECK-LABEL: @test_mul_constant_5(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<3xi32> +func @test_mul_constant_5(%arg0 : tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<3xi32> { + %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32> + %2 = "onnx.Mul"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %3 = "onnx.Mul"(%2, %arg1) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %4 = "onnx.Mul"(%1, %arg2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %5 = "onnx.Mul"(%3, %4) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + "std.return"(%5) : (tensor<3xi32>) -> () + // CHECK-NEXT: [[MUL1:%.+]] = "onnx.Mul"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[MUL2:%.+]] = "onnx.Mul"([[MUL1]], %arg2) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 11, 24]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NEXT: [[MUL3:%.+]] = "onnx.Mul"([[MUL2]], [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> +} + +//===----------------------------------------------------------------------===// /// SUB and NEG tests. - // check of sub two constants - +// ----- // CHECK-LABEL: @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> func @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> @@ -123,7 +160,7 @@ func @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { } /// check sub to add of negative - +// ----- // CHECK-LABEL: @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> func @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> @@ -133,6 +170,7 @@ func @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> } +// ----- // CHECK-LABEL: @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> func @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> @@ -144,6 +182,7 @@ func @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> } +// ----- // CHECK-LABEL: @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>