[MLIR] Canonicalization pattern for eliminating identity ops (#377)
* Canonicalization pattern for eliminating identity ops * Add a test for the identity elimination rule * Remove frontend from test * Use CHECK-NEXT instead of CHECK
This commit is contained in:
parent
bee32e2041
commit
53ab014a1d
|
@ -264,7 +264,7 @@ def collect_types(schema, input) :
|
||||||
|
|
||||||
def gen_schema(schema) :
|
def gen_schema(schema) :
|
||||||
ShapeInferenceList=['Add', 'MatMul', 'Gemm']
|
ShapeInferenceList=['Add', 'MatMul', 'Gemm']
|
||||||
CanonicalList=['Add']
|
CanonicalList=['Add', 'Identity']
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
#s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n'
|
#s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n'
|
||||||
|
|
|
@ -1021,6 +1021,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax",
|
||||||
|
|
||||||
def ONNXIdentityOp:ONNX_Op<"Identity",
|
def ONNXIdentityOp:ONNX_Op<"Identity",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX Identity operation";
|
let summary = "ONNX Identity operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Identity operator"
|
"Identity operator"
|
||||||
|
@ -2785,11 +2786,13 @@ def ONNXSliceOp:ONNX_Op<"Slice",
|
||||||
"Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end"
|
"Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end"
|
||||||
"dimension and step for each axis in the list of axes, it uses this information to"
|
"dimension and step for each axis in the list of axes, it uses this information to"
|
||||||
"slice the input `data` tensor. If a negative value is passed for any of the"
|
"slice the input `data` tensor. If a negative value is passed for any of the"
|
||||||
"start or end indices, it represent number of elements before the end of that"
|
"start or end indices, it represents number of elements before the end of that"
|
||||||
"dimension. If the value passed to start or end is larger than the `n` (the"
|
"dimension. If the value passed to start or end is larger than the `n` (the"
|
||||||
"number of elements in this dimension), it represents `n`. For slicing to the"
|
"number of elements in this dimension), it represents `n`. For slicing to the"
|
||||||
"end of a dimension with unknown size, it is recommended to pass in `INT_MAX`."
|
"end of a dimension with unknown size, it is recommended to pass in `INT_MAX` "
|
||||||
|
"when sclicing forward and 'INT_MIN' when slicing backward."
|
||||||
"If a negative value is passed for step, it represents slicing backward. "
|
"If a negative value is passed for step, it represents slicing backward. "
|
||||||
|
"However step value cannot be 0."
|
||||||
"If `axes` are omitted, they are set to `[0, ..., ndim-1]`."
|
"If `axes` are omitted, they are set to `[0, ..., ndim-1]`."
|
||||||
"If `steps` are omitted, they are set to `[1, ..., 1]` of length `len(starts)`"
|
"If `steps` are omitted, they are set to `[1, ..., 1]` of length `len(starts)`"
|
||||||
"Example 1:"
|
"Example 1:"
|
||||||
|
|
|
@ -28,3 +28,8 @@ void ONNXAddOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<MulAddToGemmOptPattern>(context);
|
results.insert<MulAddToGemmOptPattern>(context);
|
||||||
}
|
}
|
||||||
|
/// on the ONNXIdentityOp.
|
||||||
|
void ONNXIdentityOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
|
results.insert<IdentityEliminationPattern>(context);
|
||||||
|
}
|
||||||
|
|
|
@ -35,4 +35,8 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
||||||
(ONNXFullGemmOp $m1, $m2, $m3),
|
(ONNXFullGemmOp $m1, $m2, $m3),
|
||||||
[(HasOneUse $res)]>;
|
[(HasOneUse $res)]>;
|
||||||
|
|
||||||
|
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
||||||
|
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
||||||
|
(replaceWithValue $arg)>;
|
||||||
|
|
||||||
#endif // ONNX_COMBINE
|
#endif // ONNX_COMBINE
|
||||||
|
|
|
@ -9,3 +9,12 @@ module {
|
||||||
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
||||||
|
// CHECK-LABEL: test_identity_identity
|
||||||
|
// CHECK-NEXT: %{{[0-9]+}} = "onnx.Add"(%{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
%0 = "onnx.Identity"(%a0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
%1 = "onnx.Identity"(%a1) : (tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
%2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
"std.return"(%2) : (tensor<10x10xf32>) -> ()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue