[MLIR] Canonicalization pattern for eliminating identity ops (#377)
* Canonicalization pattern for eliminating identity ops * Add a test for the identity elimination rule * Remove frontend from test * Use CHECK-NEXT instead of CHECK
This commit is contained in:
parent
bee32e2041
commit
53ab014a1d
|
@ -264,7 +264,7 @@ def collect_types(schema, input) :
|
|||
|
||||
def gen_schema(schema) :
|
||||
ShapeInferenceList=['Add', 'MatMul', 'Gemm']
|
||||
CanonicalList=['Add']
|
||||
CanonicalList=['Add', 'Identity']
|
||||
line_indent = ' '
|
||||
|
||||
#s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n'
|
||||
|
|
|
@ -1021,6 +1021,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax",
|
|||
|
||||
def ONNXIdentityOp:ONNX_Op<"Identity",
|
||||
[NoSideEffect]> {
|
||||
let hasCanonicalizer = 1;
|
||||
let summary = "ONNX Identity operation";
|
||||
let description = [{
|
||||
"Identity operator"
|
||||
|
@ -2785,11 +2786,13 @@ def ONNXSliceOp:ONNX_Op<"Slice",
|
|||
"Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end"
|
||||
"dimension and step for each axis in the list of axes, it uses this information to"
|
||||
"slice the input `data` tensor. If a negative value is passed for any of the"
|
||||
"start or end indices, it represent number of elements before the end of that"
|
||||
"start or end indices, it represents number of elements before the end of that"
|
||||
"dimension. If the value passed to start or end is larger than the `n` (the"
|
||||
"number of elements in this dimension), it represents `n`. For slicing to the"
|
||||
"end of a dimension with unknown size, it is recommended to pass in `INT_MAX`."
|
||||
"If a negative value is passed for step, it represents slicing backward."
|
||||
"end of a dimension with unknown size, it is recommended to pass in `INT_MAX` "
|
||||
"when sclicing forward and 'INT_MIN' when slicing backward."
|
||||
"If a negative value is passed for step, it represents slicing backward. "
|
||||
"However step value cannot be 0."
|
||||
"If `axes` are omitted, they are set to `[0, ..., ndim-1]`."
|
||||
"If `steps` are omitted, they are set to `[1, ..., 1]` of length `len(starts)`"
|
||||
"Example 1:"
|
||||
|
|
|
@ -28,3 +28,8 @@ void ONNXAddOp::getCanonicalizationPatterns(
|
|||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<MulAddToGemmOptPattern>(context);
|
||||
}
|
||||
/// on the ONNXIdentityOp.
|
||||
void ONNXIdentityOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<IdentityEliminationPattern>(context);
|
||||
}
|
||||
|
|
|
@ -35,4 +35,8 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
|||
(ONNXFullGemmOp $m1, $m2, $m3),
|
||||
[(HasOneUse $res)]>;
|
||||
|
||||
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
||||
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
#endif // ONNX_COMBINE
|
||||
|
|
|
@ -9,3 +9,12 @@ module {
|
|||
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
||||
}
|
||||
}
|
||||
|
||||
func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
||||
// CHECK-LABEL: test_identity_identity
|
||||
// CHECK-NEXT: %{{[0-9]+}} = "onnx.Add"(%{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%0 = "onnx.Identity"(%a0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%1 = "onnx.Identity"(%a1) : (tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
"std.return"(%2) : (tensor<10x10xf32>) -> ()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue