[MLIR] Canonicalization pattern for eliminating identity ops (#377)

* Canonicalization pattern for eliminating identity ops

* Add a test for the identity elimination rule

* Remove frontend from test

* Use CHECK-NEXT instead of CHECK
This commit is contained in:
TUNG LEDUC 2019-11-21 11:57:13 +09:00 committed by Tian Jin
parent bee32e2041
commit 53ab014a1d
5 changed files with 25 additions and 4 deletions

View File

@ -264,7 +264,7 @@ def collect_types(schema, input) :
def gen_schema(schema) : def gen_schema(schema) :
ShapeInferenceList=['Add', 'MatMul', 'Gemm'] ShapeInferenceList=['Add', 'MatMul', 'Gemm']
CanonicalList=['Add'] CanonicalList=['Add', 'Identity']
line_indent = ' ' line_indent = ' '
#s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n' #s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n'

View File

@ -1021,6 +1021,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax",
def ONNXIdentityOp:ONNX_Op<"Identity", def ONNXIdentityOp:ONNX_Op<"Identity",
[NoSideEffect]> { [NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX Identity operation"; let summary = "ONNX Identity operation";
let description = [{ let description = [{
"Identity operator" "Identity operator"
@ -2785,11 +2786,13 @@ def ONNXSliceOp:ONNX_Op<"Slice",
"Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end" "Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end"
"dimension and step for each axis in the list of axes, it uses this information to" "dimension and step for each axis in the list of axes, it uses this information to"
"slice the input `data` tensor. If a negative value is passed for any of the" "slice the input `data` tensor. If a negative value is passed for any of the"
"start or end indices, it represent number of elements before the end of that" "start or end indices, it represents number of elements before the end of that"
"dimension. If the value passed to start or end is larger than the `n` (the" "dimension. If the value passed to start or end is larger than the `n` (the"
"number of elements in this dimension), it represents `n`. For slicing to the" "number of elements in this dimension), it represents `n`. For slicing to the"
"end of a dimension with unknown size, it is recommended to pass in `INT_MAX`." "end of a dimension with unknown size, it is recommended to pass in `INT_MAX` "
"If a negative value is passed for step, it represents slicing backward." "when sclicing forward and 'INT_MIN' when slicing backward."
"If a negative value is passed for step, it represents slicing backward. "
"However step value cannot be 0."
"If `axes` are omitted, they are set to `[0, ..., ndim-1]`." "If `axes` are omitted, they are set to `[0, ..., ndim-1]`."
"If `steps` are omitted, they are set to `[1, ..., 1]` of length `len(starts)`" "If `steps` are omitted, they are set to `[1, ..., 1]` of length `len(starts)`"
"Example 1:" "Example 1:"

View File

@ -28,3 +28,8 @@ void ONNXAddOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {
results.insert<MulAddToGemmOptPattern>(context); results.insert<MulAddToGemmOptPattern>(context);
} }
/// on the ONNXIdentityOp.
void ONNXIdentityOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<IdentityEliminationPattern>(context);
}

View File

@ -35,4 +35,8 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXFullGemmOp $m1, $m2, $m3), (ONNXFullGemmOp $m1, $m2, $m3),
[(HasOneUse $res)]>; [(HasOneUse $res)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
(replaceWithValue $arg)>;
#endif // ONNX_COMBINE #endif // ONNX_COMBINE

View File

@ -9,3 +9,12 @@ module {
"std.return"(%1) : (tensor<10x10xf32>) -> () "std.return"(%1) : (tensor<10x10xf32>) -> ()
} }
} }
func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-LABEL: test_identity_identity
// CHECK-NEXT: %{{[0-9]+}} = "onnx.Add"(%{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = "onnx.Identity"(%a0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Identity"(%a1) : (tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
"std.return"(%2) : (tensor<10x10xf32>) -> ()
}