Constprop2 (#167)

* initial const prop attempt

* added support for broadcast ops

* adde all binary broadcast ops into custom builders with precise type

* added test example

* working

* format

* fixed suggestion by Tung, start woring on unary

* added subtraction and neg the right way, and added elementwise mul too

* formatting changes

* format

* format

* added instructions to add new optimizations

* added propagation rules that always migrate constants toward the root of the expression, using assoc and commutativity

* format comment
This commit is contained in:
Alexandre Eichenberger 2020-06-15 14:56:15 -04:00 committed by GitHub
parent a7781791e9
commit 742e817722
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 128 additions and 21 deletions

View File

@ -85,8 +85,42 @@ def AddConstAssociative1 : Pat<
// To add(x, add(c1, c2)). // To add(x, add(c1, c2)).
(ONNXAddOp (ONNXAddOp
$x, $x,
(ONNXAddOp $c1, $c2))>; (ONNXAddOp $c1, $c2)),
[(IsNotAConstant:$x)]>;
def AddConstAssociative2 : Pat<
// From add(add(x, c), y).
(ONNXAddOp
(ONNXAddOp $x,(ONNXConstantOp:$c $_, $_)),
$y),
// To add(add(x, y), c).
(ONNXAddOp
(ONNXAddOp $x, $y),
$c),
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
def AddConstAssociative3 : Pat<
// From add(x, add(y, c)).
(ONNXAddOp
$x,
(ONNXAddOp $y,(ONNXConstantOp:$c $_, $_))),
// To add(add(x, y), c).
(ONNXAddOp
(ONNXAddOp $x, $y),
$c),
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
def AddConstAssociative4 : Pat<
// From add(add(x, c1), add(y, c2)).
(ONNXAddOp
(ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_)),
(ONNXAddOp $y,(ONNXConstantOp:$c2 $_, $_))),
// To add(add(x, y), c1+c2).
(ONNXAddOp
(ONNXAddOp $x, $y),
(ONNXAddOp $c1, $c2)),
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
// Constant Propagation for Add // Constant Propagation for Add
def AddConstProp : Pat< def AddConstProp : Pat<
// From add(c1, c2). // From add(c1, c2).
@ -150,7 +184,41 @@ def MulConstAssociative1 : Pat<
// To mul(x, mul(c1, c2)). // To mul(x, mul(c1, c2)).
(ONNXMulOp (ONNXMulOp
$x, $x,
(ONNXMulOp $c1, $c2))>; (ONNXMulOp $c1, $c2)),
[(IsNotAConstant:$x)]>;
def MulConstAssociative2 : Pat<
// From mul(mul(x, c), y).
(ONNXMulOp
(ONNXMulOp $x,(ONNXConstantOp:$c $_, $_)),
$y),
// To mul(mul(x, y), c).
(ONNXMulOp
(ONNXMulOp $x, $y),
$c),
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
def MulConstAssociative3 : Pat<
// From mul(x, mul(y, c)).
(ONNXMulOp
$x,
(ONNXMulOp $y,(ONNXConstantOp:$c $_, $_))),
// To mul(mul(x, y), c).
(ONNXMulOp
(ONNXMulOp $x, $y),
$c),
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
def MulConstAssociative4 : Pat<
// From mul(mul(x, c1), mul(y, c2)).
(ONNXMulOp
(ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_)),
(ONNXMulOp $y,(ONNXConstantOp:$c2 $_, $_))),
// To mul(mul(x, y), c1+c2).
(ONNXMulOp
(ONNXMulOp $x, $y),
(ONNXMulOp $c1, $c2)),
[(IsNotAConstant:$x), (IsNotAConstant:$y)]>;
// Constant Propagation for Mul // Constant Propagation for Mul
def MulConstProp : Pat< def MulConstProp : Pat<

View File

@ -1,10 +1,11 @@
// RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s // RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s
// =============================================================================
/// MUL tests (same as add, so have only one). //===----------------------------------------------------------------------===//
/// ADD tests
/// Test ConstantOp assoc for add /// Test ConstantOp assoc for add
// -----
// CHECK-LABEL: @test_add_constant_1(%arg0: tensor<3xf32>) -> tensor<3xf32> // CHECK-LABEL: @test_add_constant_1(%arg0: tensor<3xf32>) -> tensor<3xf32>
func @test_add_constant_1(%arg0 : tensor<3xf32>) -> tensor<3xf32> { func @test_add_constant_1(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32> %0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32>
@ -15,6 +16,7 @@ func @test_add_constant_1(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
} }
/// Test ConstantOp assoc for add /// Test ConstantOp assoc for add
// -----
// CHECK-LABEL: @test_add_constant_2(%arg0: tensor<3xf32>) -> tensor<3xf32> // CHECK-LABEL: @test_add_constant_2(%arg0: tensor<3xf32>) -> tensor<3xf32>
func @test_add_constant_2(%arg0 : tensor<3xf32>) -> tensor<3xf32> { func @test_add_constant_2(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32> %0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32>
@ -25,6 +27,7 @@ func @test_add_constant_2(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
} }
/// Change (x+c1)+c2 to x+(c1+c2) /// Change (x+c1)+c2 to x+(c1+c2)
// -----
// CHECK-LABEL: @test_add_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32> // CHECK-LABEL: @test_add_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32>
func @test_add_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> { func @test_add_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
@ -37,7 +40,8 @@ func @test_add_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
} }
/// Same test as above, but with a use of an intermediary result /// Same test as above, but with a use of an intermediary result
/// change (x+c1)+c2 + (x+c1) to x+(c1+c2) + (x+c1) /// change (x+c1)+c2 + (x+c1) to x+x + (c1+c2+c3)
// -----
// CHECK-LABEL: @test_add_constant_4(%arg0: tensor<3xi32>) -> tensor<3xi32> // CHECK-LABEL: @test_add_constant_4(%arg0: tensor<3xi32>) -> tensor<3xi32>
func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> { func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
@ -46,15 +50,30 @@ func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
%3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> %3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%4 = "onnx.Add"(%2, %3) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> %4 = "onnx.Add"(%2, %3) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
"std.return"(%4) : (tensor<3xi32>) -> () "std.return"(%4) : (tensor<3xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, %arg0) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[10, 13, 16]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NEXT: [[CONST2:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32> // CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"([[ADD1]], [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"(%arg0, [[CONST2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> }
// CHECK-NEXT: [[ADD3:%.+]] = "onnx.Add"([[ADD1]], [[ADD2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
/// Change (x+c0)+y + (z+c1) to (x+y)+z + (c1+c2)
// -----
// CHECK-LABEL: @test_add_constant_5(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<3xi32>
func @test_add_constant_5(%arg0 : tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
%2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%3 = "onnx.Add"(%2, %arg1) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%4 = "onnx.Add"(%1, %arg2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%5 = "onnx.Add"(%3, %4) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
"std.return"(%5) : (tensor<3xi32>) -> ()
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"([[ADD1]], %arg2) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NEXT: [[ADD3:%.+]] = "onnx.Add"([[ADD2]], [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
} }
/// Test broadcast 1 -> 2d /// Test broadcast 1 -> 2d
// -----
// CHECK-LABEL: @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-LABEL: @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { func @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32> %0 = "onnx.Constant"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32>
@ -67,7 +86,7 @@ func @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
} }
/// Test broadcast 2d (size one) -> 2d /// Test broadcast 2d (size one) -> 2d
// -----
// CHECK-LABEL: @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-LABEL: @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { func @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[1]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32> %0 = "onnx.Constant"() {value = dense<[[1]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
@ -80,8 +99,8 @@ func @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
} }
/// check 1d -> 2d /// check 1d -> 2d
// -----
// CHECK-LABEL: @test_broadcast_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-LABEL: @test_broadcast_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_broadcast_3(%arg0 : tensor<3x2xi32>) -> tensor<3x2xi32> { func @test_broadcast_3(%arg0 : tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[1], [2], [3]]> : tensor<3x1xi32>} : () -> tensor<3x1xi32> %0 = "onnx.Constant"() {value = dense<[[1], [2], [3]]> : tensor<3x1xi32>} : () -> tensor<3x1xi32>
%1 = "onnx.Constant"() {value = dense<[[10, 11], [21, 22], [31, 32]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> %1 = "onnx.Constant"() {value = dense<[[10, 11], [21, 22], [31, 32]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
@ -92,10 +111,12 @@ func @test_broadcast_3(%arg0 : tensor<3x2xi32>) -> tensor<3x2xi32> {
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
} }
// =============================================================================
/// MUL tests (same as add, so have only one). //===----------------------------------------------------------------------===//
/// MUL tests (same as add, so have only two).
/// Change (x*c1)*c2 to x*(c1*c2) /// Change (x*c1)*c2 to x*(c1*c2)
// -----
// CHECK-LABEL: @test_mul_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32> // CHECK-LABEL: @test_mul_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32>
func @test_mul_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> { func @test_mul_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
@ -107,12 +128,28 @@ func @test_mul_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
// CHECK-NEXT: [[MUL1:%.+]] = "onnx.Mul"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> // CHECK-NEXT: [[MUL1:%.+]] = "onnx.Mul"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
} }
// ============================================================================= /// Change (x*c0)*y * (z*c1) to (x*y)*z * (c1*c2)
// -----
// CHECK-LABEL: @test_mul_constant_5(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<3xi32>
func @test_mul_constant_5(%arg0 : tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
%2 = "onnx.Mul"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%3 = "onnx.Mul"(%2, %arg1) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%4 = "onnx.Mul"(%1, %arg2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%5 = "onnx.Mul"(%3, %4) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
"std.return"(%5) : (tensor<3xi32>) -> ()
// CHECK-NEXT: [[MUL1:%.+]] = "onnx.Mul"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: [[MUL2:%.+]] = "onnx.Mul"([[MUL1]], %arg2) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 11, 24]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NEXT: [[MUL3:%.+]] = "onnx.Mul"([[MUL2]], [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
}
//===----------------------------------------------------------------------===//
/// SUB and NEG tests. /// SUB and NEG tests.
// check of sub two constants // check of sub two constants
// -----
// CHECK-LABEL: @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-LABEL: @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { func @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
@ -123,7 +160,7 @@ func @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
} }
/// check sub to add of negative /// check sub to add of negative
// -----
// CHECK-LABEL: @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-LABEL: @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { func @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
@ -133,6 +170,7 @@ func @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
} }
// -----
// CHECK-LABEL: @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-LABEL: @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { func @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
@ -144,6 +182,7 @@ func @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
} }
// -----
// CHECK-LABEL: @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-LABEL: @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>