diff --git a/src/compiler/dialect/onnx/gen_doc.py b/src/compiler/dialect/onnx/gen_doc.py index ad0b827..1370187 100644 --- a/src/compiler/dialect/onnx/gen_doc.py +++ b/src/compiler/dialect/onnx/gen_doc.py @@ -264,7 +264,7 @@ def collect_types(schema, input) : def gen_schema(schema) : ShapeInferenceList=['Add', 'MatMul', 'Gemm'] - CanonicalList=['Add'] + CanonicalList=['Add', 'Identity'] line_indent = ' ' #s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n' diff --git a/src/compiler/dialect/onnx/onnxop.inc b/src/compiler/dialect/onnx/onnxop.inc index a0b0f20..6cfafdb 100644 --- a/src/compiler/dialect/onnx/onnxop.inc +++ b/src/compiler/dialect/onnx/onnxop.inc @@ -1021,6 +1021,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax", def ONNXIdentityOp:ONNX_Op<"Identity", [NoSideEffect]> { + let hasCanonicalizer = 1; let summary = "ONNX Identity operation"; let description = [{ "Identity operator" @@ -2785,11 +2786,13 @@ def ONNXSliceOp:ONNX_Op<"Slice", "Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end" "dimension and step for each axis in the list of axes, it uses this information to" "slice the input `data` tensor. If a negative value is passed for any of the" - "start or end indices, it represent number of elements before the end of that" + "start or end indices, it represents number of elements before the end of that" "dimension. If the value passed to start or end is larger than the `n` (the" "number of elements in this dimension), it represents `n`. For slicing to the" - "end of a dimension with unknown size, it is recommended to pass in `INT_MAX`." - "If a negative value is passed for step, it represents slicing backward." + "end of a dimension with unknown size, it is recommended to pass in `INT_MAX` " + "when sclicing forward and 'INT_MIN' when slicing backward." + "If a negative value is passed for step, it represents slicing backward. " + "However step value cannot be 0." "If `axes` are omitted, they are set to `[0, ..., ndim-1]`." "If `steps` are omitted, they are set to `[1, ..., 1]` of length `len(starts)`" "Example 1:" diff --git a/src/compiler/pass/onnx_combine.cpp b/src/compiler/pass/onnx_combine.cpp index 4709f8d..ef75579 100644 --- a/src/compiler/pass/onnx_combine.cpp +++ b/src/compiler/pass/onnx_combine.cpp @@ -28,3 +28,8 @@ void ONNXAddOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { results.insert(context); } +/// on the ONNXIdentityOp. +void ONNXIdentityOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} diff --git a/src/compiler/pass/onnx_combine.td b/src/compiler/pass/onnx_combine.td index 946991d..199e27a 100644 --- a/src/compiler/pass/onnx_combine.td +++ b/src/compiler/pass/onnx_combine.td @@ -35,4 +35,8 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), (ONNXFullGemmOp $m1, $m2, $m3), [(HasOneUse $res)]>; +// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) +def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg), + (replaceWithValue $arg)>; + #endif // ONNX_COMBINE diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 1cf1a89..b833728 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -9,3 +9,12 @@ module { "std.return"(%1) : (tensor<10x10xf32>) -> () } } + +func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) -> tensor<10x10xf32> { + // CHECK-LABEL: test_identity_identity + // CHECK-NEXT: %{{[0-9]+}} = "onnx.Add"(%{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %0 = "onnx.Identity"(%a0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %1 = "onnx.Identity"(%a1) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + "std.return"(%2) : (tensor<10x10xf32>) -> () +}