Support Optional Inputs (#94)

* 1. Combine variadicIn/Out with expectedNumOperands/Results to simplify import function arguments.
2. Generic improvements to code readability in gen_doc.py.

* Update ONNX Dialect doc.

* Remove redundant code in ImportNode.

* Prettify op_build_table.inc.

* 1. Remove irrelevant code in gen_doc.py

* Refactor code to be more readable.

* Further refactoring for readability improvements.

* Allow gemm to have an optional operand (bias term), and include an example of declarative optimization pattern targeting gemm with bias term ommitted.

* Make shape inference/lowering of gemm op compatible with optional operand declaration.

* Apply canonicalization again after lowering from onnx -> std dialects.

* Make hasBias compatible with the situation of GemmNoBias op.

* Update doc.

* Add a canonicalization test.

* Remove special handler for importing Gemm op, as it's redundant now.
This commit is contained in:
Tian Jin 2020-02-24 23:46:48 +08:00 committed by GitHub
parent 479dd5e35a
commit 9c398c0121
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 3431 additions and 3804 deletions

View File

@ -327,10 +327,10 @@ ONNX BatchNormalization operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `out_mean`: memref of any type values or tensor of any type values
1. `out_var`: memref of any type values or tensor of any type values
1. `saved_mean`: memref of any type values or tensor of any type values
1. `saved_var`: memref of any type values or tensor of any type values
1. `out_mean`: memref of any type values or tensor of any type values or none type
1. `out_var`: memref of any type values or tensor of any type values or none type
1. `saved_mean`: memref of any type values or tensor of any type values or none type
1. `saved_var`: memref of any type values or tensor of any type values or none type
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
ONNX BatchNormalization operation in test mode
@ -375,12 +375,12 @@ ONNX BitShift operation
"Bitwise shift operator performs element-wise operation. For each input element, if the"
" attribute "direction" is "RIGHT", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute "direction""
" is "LEFT", bits of binary representation moves toward the left side, which results the"
" attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute \"direction\""
" is \"LEFT\", bits of binary representation moves toward the left side, which results the"
" increase of its actual value. The input X is the tensor to be shifted and another input"
" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with"
" Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
" "
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
@ -413,15 +413,15 @@ ONNX Cast operation
"the converted type. The 'to' argument must be one of the data types specified"
"in the 'DataType' enum field in the TensorProto message."
""
"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations"
"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may"
"Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
"(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
"result 100. There are some string literals reserved for special floating-point values;"
""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as "314.15926") would be used. "
"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior."
"\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
"Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
""
"Conversion from a numerical type to any numerical type is always allowed."
"User must be aware of precision loss and value change caused by range difference between two types."
@ -476,8 +476,8 @@ ONNX Clip operation
#### Operands:
1. `input`: memref of any type values or tensor of any type values
1. `min`: memref of any type values or tensor of any type values
1. `max`: memref of any type values or tensor of any type values
1. `min`: memref of any type values or tensor of any type values or none type
1. `max`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -618,8 +618,8 @@ ONNX ConvInteger operation
1. `x`: memref of any type values or tensor of any type values
1. `w`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values
1. `w_zero_point`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
1. `w_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -678,7 +678,7 @@ ONNX Conv operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -720,7 +720,7 @@ ONNX ConvTranspose operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -884,7 +884,7 @@ ONNX DequantizeLinear operation
1. `x`: memref of any type values or tensor of any type values
1. `x_scale`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -964,7 +964,7 @@ ONNX Dropout operation
#### Results:
1. `output`: memref of any type values or tensor of any type values
1. `mask`: memref of any type values or tensor of any type values
1. `mask`: memref of any type values or tensor of any type values or none type
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
ONNX DynamicQuantizeLinear operation
@ -1297,9 +1297,9 @@ ONNX GRU operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `sequence_lens`: memref of any type values or tensor of any type values
1. `initial_h`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -1315,8 +1315,8 @@ ONNX GRU operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Y_h`: memref of any type values or tensor of any type values
1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.GatherElements (ONNXGatherElementsOp)
ONNX GatherElements operation
@ -1609,7 +1609,7 @@ ONNX Gemm operation
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `C`: memref of any type values or tensor of any type values
1. `C`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2013,11 +2013,11 @@ ONNX LSTM operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `sequence_lens`: memref of any type values or tensor of any type values
1. `initial_h`: memref of any type values or tensor of any type values
1. `initial_c`: memref of any type values or tensor of any type values
1. `P`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values or none type
1. `initial_c`: memref of any type values or tensor of any type values or none type
1. `P`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2033,9 +2033,9 @@ ONNX LSTM operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Y_h`: memref of any type values or tensor of any type values
1. `Y_c`: memref of any type values or tensor of any type values
1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values or none type
1. `Y_c`: memref of any type values or tensor of any type values or none type
### onnx.LeakyRelu (ONNXLeakyReluOp)
ONNX LeakyRelu operation
@ -2160,24 +2160,24 @@ ONNX Loop operation
""
" Operator inputs defined as (max_trip_count, condition_var)."
""
" input ("", ""):"
" input (\"\", \"\"):"
" for (int i=0; ; ++i) {"
" cond = ... // Note this value is ignored, but is required in the body"
" }"
""
" input ("", cond) // Note this is analogous to a while loop"
" input (\"\", cond) // Note this is analogous to a while loop"
" bool cond = ...;"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input ("", 1) // Note this is analogous to a do-while loop"
" input (\"\", 1) // Note this is analogous to a do-while loop"
" bool cond = true"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input (trip_count, "") // Note this is analogous to a for loop"
" input (trip_count, \"\") // Note this is analogous to a for loop"
" int trip_count = ..."
" for (int i=0; i < trip_count; ++i) {"
" cond = ...; // ignored"
@ -2203,15 +2203,15 @@ ONNX Loop operation
" }"
""
" graph body-net ("
" %i[INT32, scalar] // iteration number"
" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used"
" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b"
" %i[INT32, scalar]"
" %keepgoing[BOOL, scalar]"
" %b[INT32, scalar]"
" ) {"
" %my_local = Add(%a, %b_in)"
" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b"
" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition"
" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated"
" return %keepgoing_out, %b_out, %user_defined_val"
" %my_local = Add(%a, %b)"
" %b_out = Sub(%a, %b)"
" %keepgoing_out = Greater(%my_local, %b_out)"
" %user_defined_vals = Add(%b, %b)"
" return %keepgoing_out, %b_out, %user_defined_vals"
" }"
""
"*Sample equivalent C code*"
@ -2226,51 +2226,31 @@ ONNX Loop operation
" const int max_trip_count = 10; // Analogous to input M"
" int user_defined_vals[]; // Imagine this is resizable"
" /* End implicitly-defined code */"
" /* initialize loop-carried variables and scan-output variables */"
" bool keepgoing_out = keepgoing"
" int b_out = b"
""
" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
" /* Implicitly-defined code: bind actual parameter values"
" to formal parameter variables of loop-body */"
" bool keepgoing_in = keepgoing_out; "
" bool b_in = b_out;"
""
" for (int i=0; i < max_trip_count && keepgoing; ++i) {"
" /* User-defined code (loop body) */"
" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine"
" b_out = a - b_in;"
" keepgoing_out = my_local > b_out; "
" user_defined_val = b_in + b_in; // b_in and b_out are different variables"
" int my_local = a + b; // Reading values in the enclosing scope is fine"
" b = a - b; // writes fine if we specify b as a loop-carried dependency"
" keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
" user_defined_vals[i] = b + b;"
" /* End user-defined code */"
""
" /* Implicitly defined-code */"
" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
" }"
" // int t = my_local; // Can't do this. my_local is not accessible here."
" // my_local = 123; // Can't do this. my_local was defined in the the body"
""
" // The values below are bound to the output variables of the loop and therefore accessible"
" // b_out; user_defined_vals; keepgoing_out;"
" // These below values are live-out from the loop and therefore accessible"
" b_out; user_defined_vals; keepgoing_out;"
" }"
""
"There are several things of note in this code snippet:"
""
"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can"
"1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
" be referenced in the inputs of the loop."
"2) Any values computed in the loop body that needs to be used in a subsequent"
" iteration or after the loop are modelled using a pair of variables in the loop-body,"
" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)."
" These are referred to as loop-carried dependences. The loop operation node"
" supplies the input value of the input variable for the first iteration, and"
" returns the output value of the output variable produced by the final"
" iteration."
"3) Scan_output variables are used to implicitly concatenate values computed across"
" all the iterations. In the above example, the value of user_defined_val computed"
" over all iterations are concatenated and returned as the value of user_defined_vals"
" after the loop."
"4) Values created in the body cannot be accessed in the enclosing scope,"
" except using the mechanism described above."
"2) Any variables which you wish to make available in the enclosing scope (i.e."
" the variables b and keepgoing) must be declared as either loop-carried"
" dependencies (both at the op inputs and output and at the body net input and"
" output) or scan_outputs."
"3) Values created in the body cannot be accessed in the enclosing scope."
""
"Note that the semantics of this op support "diagonal" or "wavefront" execution."
"Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
"(See Step 3 here for an example:"
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
"Frontends should emit multi-layer RNNs as a series of While operators (with"
@ -2280,8 +2260,8 @@ ONNX Loop operation
#### Operands:
1. `M`: memref of any type values or tensor of any type values
1. `cond`: memref of any type values or tensor of any type values
1. `M`: memref of any type values or tensor of any type values or none type
1. `cond`: memref of any type values or tensor of any type values or none type
1. `v_initial`: memref of any type values or tensor of any type values
#### Attributes:
@ -2360,8 +2340,8 @@ ONNX MatMulInteger operation
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `a_zero_point`: memref of any type values or tensor of any type values
1. `b_zero_point`: memref of any type values or tensor of any type values
1. `a_zero_point`: memref of any type values or tensor of any type values or none type
1. `b_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2466,7 +2446,7 @@ ONNX MaxPool operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Indices`: memref of any type values or tensor of any type values
1. `Indices`: memref of any type values or tensor of any type values or none type
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
ONNX MaxPool operation with a single output.
@ -2552,7 +2532,7 @@ ONNX MaxUnpool operation
1. `X`: memref of any type values or tensor of any type values
1. `I`: memref of any type values or tensor of any type values
1. `output_shape`: memref of any type values or tensor of any type values
1. `output_shape`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2752,9 +2732,9 @@ ONNX NonMaxSuppression operation
1. `boxes`: memref of any type values or tensor of any type values
1. `scores`: memref of any type values or tensor of any type values
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values
1. `iou_threshold`: memref of any type values or tensor of any type values
1. `score_threshold`: memref of any type values or tensor of any type values
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values or none type
1. `iou_threshold`: memref of any type values or tensor of any type values or none type
1. `score_threshold`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3041,7 +3021,7 @@ ONNX Pad operation
1. `data`: memref of any type values or tensor of any type values
1. `pads`: memref of any type values or tensor of any type values
1. `constant_value`: memref of any type values or tensor of any type values
1. `constant_value`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3098,7 +3078,7 @@ ONNX QLinearConv operation
1. `w_zero_point`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3162,7 +3142,7 @@ ONNX QuantizeLinear operation
1. `x`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3244,9 +3224,9 @@ ONNX RNN operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `sequence_lens`: memref of any type values or tensor of any type values
1. `initial_h`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3261,8 +3241,8 @@ ONNX RNN operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Y_h`: memref of any type values or tensor of any type values
1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
ONNX RandomNormalLike operation
@ -3787,14 +3767,14 @@ ONNX Resize operation
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
"Each dimension value of the output tensor is:"
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified."
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
#### Operands:
1. `X`: memref of any type values or tensor of any type values
1. `roi`: memref of any type values or tensor of any type values
1. `scales`: memref of any type values or tensor of any type values
1. `sizes`: memref of any type values or tensor of any type values
1. `sizes`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4412,7 +4392,7 @@ ONNX SequenceErase operation
#### Operands:
1. `input_sequence`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4437,7 +4417,7 @@ ONNX SequenceInsert operation
1. `input_sequence`: memref of any type values or tensor of any type values
1. `tensor`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4654,8 +4634,8 @@ ONNX Slice operation
1. `data`: memref of any type values or tensor of any type values
1. `starts`: memref of any type values or tensor of any type values
1. `ends`: memref of any type values or tensor of any type values
1. `axes`: memref of any type values or tensor of any type values
1. `steps`: memref of any type values or tensor of any type values
1. `axes`: memref of any type values or tensor of any type values or none type
1. `steps`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4808,7 +4788,7 @@ ONNX SplitToSequence operation
#### Operands:
1. `input`: memref of any type values or tensor of any type values
1. `split`: memref of any type values or tensor of any type values
1. `split`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4876,9 +4856,9 @@ ONNX StringNormalizer operation
"StringNormalization performs string operations for basic cleaning."
"This operator has only one input (denoted by X) and only one output"
"(denoted by Y). This operator first examines the elements in the X,"
"and removes elements specified in "stopwords" attribute."
"and removes elements specified in \"stopwords\" attribute."
"After removing stop words, the intermediate result can be further lowercased,"
"uppercased, or just returned depending the "case_change_action" attribute."
"uppercased, or just returned depending the \"case_change_action\" attribute."
"This operator only accepts [C]- and [1, C]-tensor."
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
@ -5008,8 +4988,8 @@ ONNX TfIdfVectorizer operation
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
"Note that we may consider all skips up to S when generating the n-grams."
""
"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF","
"The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
""
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
@ -5097,9 +5077,9 @@ ONNX TopK operation
" contains the indices of the top k elements (original indices from the input"
" tensor)."
""
"If "largest" is 1 (the default value) then the k largest elements are returned."
"If "sorted" is 1 (the default value) then the resulting k elements will be sorted."
"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined."
"If \"largest\" is 1 (the default value) then the k largest elements are returned."
"If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
"If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
""
"Given two equivalent values, this operator uses the indices along the axis as"
" a tiebreaker. That is, the element with the lower index will appear first."
@ -5158,7 +5138,7 @@ ONNX Unique operation
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
""
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
@ -5242,9 +5222,9 @@ ONNX Unique operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `indices`: memref of any type values or tensor of any type values
1. `inverse_indices`: memref of any type values or tensor of any type values
1. `counts`: memref of any type values or tensor of any type values
1. `indices`: memref of any type values or tensor of any type values or none type
1. `inverse_indices`: memref of any type values or tensor of any type values or none type
1. `counts`: memref of any type values or tensor of any type values or none type
### onnx.Unsqueeze (ONNXUnsqueezeOp)
ONNX Unsqueeze operation

File diff suppressed because it is too large Load Diff

View File

@ -121,6 +121,7 @@ private:
mlir::MLIRContext &context_;
mlir::ModuleOp module_;
mlir::OpBuilder builder_;
mlir::Value none_;
// mapping between string name and symbol
OnnxOnnfSymbolMapping frontend_symbols_;
@ -287,8 +288,8 @@ private:
}
}
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute>
ImportNodeAttributes(const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
@ -317,21 +318,11 @@ private:
}
}
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
// combined with 'if constexpr' the issue is the type of the output is
// different. alternative way to use variadic output for all the op
/*!
* Important onnx node which generates only one output
* @param node onnx node
* @param nIn number of expected inputs
* @param nOut number of expected outputs
* @param attrs list of desription for attributes with format {name, type,
* default}
*/
template <typename T>
void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
bool variadicIn = false, bool variadicOut = false) {
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
int expectedNumResults = -1) {
bool variadicIn = expectedNumOperands == -1;
bool variadicOut = expectedNumResults == -1;
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
@ -339,6 +330,10 @@ private:
}
}
if (!variadicIn)
for (auto i = inputs.size(); i < expectedNumOperands; i++)
inputs.emplace_back(none_);
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
@ -347,49 +342,11 @@ private:
auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type();
if ((variadicIn || nIn == inputs.size()) &&
(variadicOut || nOut == outputTypes.size())) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
op.getResult());
} else {
ImportNodeGeneric(node);
}
}
template <typename T>
void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
bool variadicIn = false,
bool variadicOut = false) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type();
if ((variadicIn || nIn == inputs.size()) &&
(variadicOut || nOut == outputTypes.size())) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
// TODO: Handle optional inputs.
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
op.getResult(i));
}
} else {
ImportNodeGeneric(node);
*(op.getODSResults(i).begin()));
}
}
@ -398,8 +355,7 @@ private:
* c++ does not allow template specialization inside a class scope
* a specialized function is used
*/
void
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto
@ -413,24 +369,20 @@ private:
int nOps = node.input().size();
if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
node, nOps, nOut);
buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut);
buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
}
/*!
* Special handle for MaxPool operations.
*/
void ImportNodeMaxPool(
onnx::NodeProto node, int nIn, int nOut) {
void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size();
if (nOuts == 1) {
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
node, nIn, nOuts);
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
} else {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
node, nIn, nOuts);
buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
}
}
@ -441,23 +393,10 @@ private:
int nOuts = node.output().size();
if (nOuts == 1) {
// Test mode with one output.
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn,
nOuts);
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
} else {
// Training mode with four trailing optional outputs. Not handled yet.
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
/*!
* Special handle for Gemm operations.
*/
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
@ -467,28 +406,14 @@ private:
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut);
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
}
}
void ImportNode(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
llvm::StringRef OpName = node.op_type();
llvm::StringRef opName = node.op_type();
// the following code is generated by gen_doc.py
// refer to dialect/onnx/onnx.td for details
@ -555,9 +480,11 @@ private:
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
}
// import nodes in the graph
auto node = graph.node();
for (const auto &item : node) {
// Create a NoneTyped constant.
none_ =
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph.
for (const auto &item : graph.node()) {
ImportNode(item);
}

View File

@ -1,320 +1,319 @@
//********************************************************
// Warning: Do not modify this file directly
// This file is automatically generated via script
// Details can be found in doc/readonnxdefs.md
// This file is generated on UTC-02/24/2020, 06:29:01.
// Do not modify this file directly.
// This file is automatically generated via script.
// Details can be found in doc/readonnxdefs.md .
//********************************************************
if (OpName == "DUMMY") {
}else if (OpName == "Abs") {
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1);
}else if (OpName == "Acos") {
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1);
}else if (OpName == "Acosh") {
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1);
}else if (OpName == "Add") {
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1);
}else if (OpName == "And") {
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1);
}else if (OpName == "ArgMax") {
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1);
}else if (OpName == "ArgMin") {
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1);
}else if (OpName == "Asin") {
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1);
}else if (OpName == "Asinh") {
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1);
}else if (OpName == "Atan") {
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1);
}else if (OpName == "Atanh") {
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1);
}else if (OpName == "AveragePool") {
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
}else if (OpName == "BatchNormalization") {
ImportNodeBatchNormalization(node, 5, 5);
}else if (OpName == "BitShift") {
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
}else if (OpName == "Cast") {
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1);
}else if (OpName == "Ceil") {
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1);
}else if (OpName == "Clip") {
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1);
}else if (OpName == "Compress") {
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
}else if (OpName == "Concat") {
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false);
}else if (OpName == "ConcatFromSequence") {
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
}else if (OpName == "Constant") {
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1);
}else if (OpName == "ConstantOfShape") {
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1);
}else if (OpName == "Conv") {
ImportNodeConv(node, 3, 1);
}else if (OpName == "ConvInteger") {
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1);
}else if (OpName == "ConvTranspose") {
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1);
}else if (OpName == "Cos") {
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1);
}else if (OpName == "Cosh") {
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1);
}else if (OpName == "CumSum") {
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1);
}else if (OpName == "DepthToSpace") {
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1);
}else if (OpName == "DequantizeLinear") {
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1);
}else if (OpName == "Det") {
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1);
}else if (OpName == "Div") {
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1);
}else if (OpName == "Dropout") {
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2);
}else if (OpName == "DynamicQuantizeLinear") {
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3);
}else if (OpName == "Elu") {
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1);
}else if (OpName == "Equal") {
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1);
}else if (OpName == "Erf") {
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1);
}else if (OpName == "Exp") {
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1);
}else if (OpName == "Expand") {
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1);
}else if (OpName == "EyeLike") {
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1);
}else if (OpName == "Flatten") {
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1);
}else if (OpName == "Floor") {
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1);
}else if (OpName == "GRU") {
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2);
}else if (OpName == "Gather") {
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1);
}else if (OpName == "GatherElements") {
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1);
}else if (OpName == "GatherND") {
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
}else if (OpName == "Gemm") {
ImportNodeGemm(node, 3, 1);
}else if (OpName == "GlobalAveragePool") {
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
}else if (OpName == "GlobalLpPool") {
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1);
}else if (OpName == "GlobalMaxPool") {
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1);
}else if (OpName == "Greater") {
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1);
}else if (OpName == "HardSigmoid") {
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1);
}else if (OpName == "Hardmax") {
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1);
}else if (OpName == "Identity") {
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1);
}else if (OpName == "If") {
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1);
}else if (OpName == "InstanceNormalization") {
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1);
}else if (OpName == "IsInf") {
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1);
}else if (OpName == "IsNaN") {
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1);
}else if (OpName == "LRN") {
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1);
}else if (OpName == "LSTM") {
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3);
}else if (OpName == "LeakyRelu") {
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1);
}else if (OpName == "Less") {
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1);
}else if (OpName == "Log") {
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1);
}else if (OpName == "LogSoftmax") {
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1);
}else if (OpName == "Loop") {
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1);
}else if (OpName == "LpNormalization") {
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1);
}else if (OpName == "LpPool") {
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1);
}else if (OpName == "MatMul") {
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1);
}else if (OpName == "MatMulInteger") {
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
}else if (OpName == "Max") {
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false);
}else if (OpName == "MaxPool") {
ImportNodeMaxPool(node, 1, 2);
}else if (OpName == "MaxRoiPool") {
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1);
}else if (OpName == "MaxUnpool") {
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
}else if (OpName == "Mean") {
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false);
}else if (OpName == "MeanVarianceNormalization") {
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
}else if (OpName == "Min") {
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false);
}else if (OpName == "Mod") {
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
}else if (OpName == "Mul") {
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1);
}else if (OpName == "Multinomial") {
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1);
}else if (OpName == "Neg") {
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1);
}else if (OpName == "NonMaxSuppression") {
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1);
}else if (OpName == "NonZero") {
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1);
}else if (OpName == "Not") {
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1);
}else if (OpName == "OneHot") {
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1);
}else if (OpName == "Or") {
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1);
}else if (OpName == "PRelu") {
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1);
}else if (OpName == "Pad") {
ImportNodePad(node, 3, 1);
}else if (OpName == "Pow") {
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1);
}else if (OpName == "QLinearConv") {
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1);
}else if (OpName == "QLinearMatMul") {
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1);
}else if (OpName == "QuantizeLinear") {
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1);
}else if (OpName == "RNN") {
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2);
}else if (OpName == "RandomNormal") {
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1);
}else if (OpName == "RandomNormalLike") {
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1);
}else if (OpName == "RandomUniform") {
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1);
}else if (OpName == "RandomUniformLike") {
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1);
}else if (OpName == "Range") {
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1);
}else if (OpName == "Reciprocal") {
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1);
}else if (OpName == "ReduceL1") {
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1);
}else if (OpName == "ReduceL2") {
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1);
}else if (OpName == "ReduceLogSum") {
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1);
}else if (OpName == "ReduceLogSumExp") {
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1);
}else if (OpName == "ReduceMax") {
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1);
}else if (OpName == "ReduceMean") {
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1);
}else if (OpName == "ReduceMin") {
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1);
}else if (OpName == "ReduceProd") {
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1);
}else if (OpName == "ReduceSum") {
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1);
}else if (OpName == "ReduceSumSquare") {
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1);
}else if (OpName == "Relu") {
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1);
}else if (OpName == "Reshape") {
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1);
}else if (OpName == "Resize") {
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1);
}else if (OpName == "ReverseSequence") {
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1);
}else if (OpName == "RoiAlign") {
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1);
}else if (OpName == "Round") {
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1);
}else if (OpName == "Scan") {
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1);
}else if (OpName == "Scatter") {
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1);
}else if (OpName == "ScatterElements") {
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1);
}else if (OpName == "ScatterND") {
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1);
}else if (OpName == "Selu") {
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1);
}else if (OpName == "SequenceAt") {
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
}else if (OpName == "SequenceConstruct") {
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false);
}else if (OpName == "SequenceEmpty") {
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
}else if (OpName == "SequenceErase") {
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1);
}else if (OpName == "SequenceInsert") {
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1);
}else if (OpName == "SequenceLength") {
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1);
}else if (OpName == "Shape") {
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1);
}else if (OpName == "Shrink") {
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1);
}else if (OpName == "Sigmoid") {
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1);
}else if (OpName == "Sign") {
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1);
}else if (OpName == "Sin") {
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1);
}else if (OpName == "Sinh") {
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1);
}else if (OpName == "Size") {
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1);
}else if (OpName == "Slice") {
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1);
}else if (OpName == "Softmax") {
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1);
}else if (OpName == "Softplus") {
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1);
}else if (OpName == "Softsign") {
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1);
}else if (OpName == "SpaceToDepth") {
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1);
}else if (OpName == "Split") {
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1);
}else if (OpName == "SplitToSequence") {
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1);
}else if (OpName == "Sqrt") {
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1);
}else if (OpName == "Squeeze") {
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1);
}else if (OpName == "StringNormalizer") {
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1);
}else if (OpName == "Sub") {
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
}else if (OpName == "Sum") {
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false);
}else if (OpName == "Tan") {
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
}else if (OpName == "Tanh") {
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1);
}else if (OpName == "TfIdfVectorizer") {
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1);
}else if (OpName == "ThresholdedRelu") {
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1);
}else if (OpName == "Tile") {
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1);
}else if (OpName == "TopK") {
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2);
}else if (OpName == "Transpose") {
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1);
}else if (OpName == "Unique") {
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4);
}else if (OpName == "Unsqueeze") {
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1);
}else if (OpName == "Upsample") {
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1);
}else if (OpName == "Where") {
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1);
}else if (OpName == "Xor") {
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1);
}
if (opName == "Abs")
return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Acos")
return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Acosh")
return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Add")
return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "And")
return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "ArgMax")
return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ArgMin")
return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Asin")
return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Asinh")
return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Atan")
return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Atanh")
return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "AveragePool")
return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "BatchNormalization")
return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
if (opName == "BitShift")
return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Cast")
return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Ceil")
return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Clip")
return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Compress")
return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Concat")
return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "ConcatFromSequence")
return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Constant")
return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "ConstantOfShape")
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Conv")
return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ConvInteger")
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ConvTranspose")
return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Cos")
return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Cosh")
return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "CumSum")
return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "DepthToSpace")
return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "DequantizeLinear")
return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Det")
return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Div")
return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Dropout")
return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
if (opName == "DynamicQuantizeLinear")
return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
if (opName == "Elu")
return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Equal")
return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Erf")
return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Exp")
return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Expand")
return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "EyeLike")
return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Flatten")
return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Floor")
return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GRU")
return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
if (opName == "Gather")
return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "GatherElements")
return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "GatherND")
return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Gemm")
return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "GlobalAveragePool")
return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GlobalLpPool")
return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GlobalMaxPool")
return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Greater")
return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "HardSigmoid")
return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Hardmax")
return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Identity")
return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "If")
return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
if (opName == "InstanceNormalization")
return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "IsInf")
return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "IsNaN")
return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LRN")
return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LSTM")
return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
if (opName == "LeakyRelu")
return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Less")
return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Log")
return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LogSoftmax")
return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Loop")
return buildOperation<mlir::ONNXLoopOp>(node);
if (opName == "LpNormalization")
return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LpPool")
return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "MatMul")
return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "MatMulInteger")
return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "Max")
return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "MaxPool")
return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
if (opName == "MaxRoiPool")
return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "MaxUnpool")
return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Mean")
return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "MeanVarianceNormalization")
return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Min")
return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "Mod")
return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Mul")
return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Multinomial")
return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Neg")
return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "NonMaxSuppression")
return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
if (opName == "NonZero")
return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Not")
return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "OneHot")
return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Or")
return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "PRelu")
return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Pad")
return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Pow")
return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "QLinearConv")
return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
if (opName == "QLinearMatMul")
return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
if (opName == "QuantizeLinear")
return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "RNN")
return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
if (opName == "RandomNormal")
return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "RandomNormalLike")
return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "RandomUniform")
return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "RandomUniformLike")
return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Range")
return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Reciprocal")
return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceL1")
return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceL2")
return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceLogSum")
return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceLogSumExp")
return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMax")
return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMean")
return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMin")
return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceProd")
return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceSum")
return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceSumSquare")
return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Relu")
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Reshape")
return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Resize")
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ReverseSequence")
return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "RoiAlign")
return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Round")
return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Scan")
return buildOperation<mlir::ONNXScanOp>(node);
if (opName == "Scatter")
return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ScatterElements")
return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ScatterND")
return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Selu")
return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "SequenceAt")
return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "SequenceConstruct")
return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "SequenceEmpty")
return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "SequenceErase")
return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "SequenceInsert")
return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "SequenceLength")
return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Shape")
return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Shrink")
return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sigmoid")
return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sign")
return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sin")
return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sinh")
return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Size")
return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Slice")
return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
if (opName == "Softmax")
return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Softplus")
return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Softsign")
return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "SpaceToDepth")
return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Split")
return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
if (opName == "SplitToSequence")
return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Sqrt")
return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Squeeze")
return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "StringNormalizer")
return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sub")
return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Sum")
return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "Tan")
return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Tanh")
return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "TfIdfVectorizer")
return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ThresholdedRelu")
return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Tile")
return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "TopK")
return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
if (opName == "Transpose")
return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Unique")
return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
if (opName == "Unsqueeze")
return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Upsample")
return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Where")
return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Xor")
return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);

View File

@ -17,19 +17,23 @@ struct ONNXGemmOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
auto has_bias = (operands.size() == 3);
// The first predicate is unnecessary when we remove ONXGemmNoBiasOp.
bool hasBias = (operands.size() == 3) &&
(!op->getOperand(2).getType().isa<NoneType>());
Value A, B, C;
A = operands[0];
B = operands[1];
if (has_bias)
if (hasBias)
C = operands[2];
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
auto alphaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
auto betaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -68,8 +72,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Define loops.
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, numLoops);
Block *optimizationBlock =
defineLoops(rewriter, loc, originalLoops, optimizedLoops, numLoops);
// We have two Krnl loops:
// - Outer loop iterates over the output matrix dimensions, and
@ -83,8 +87,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
optimizedOuterLoops);
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
// Induction variables for the outer loops
for (int i = 0; i < 2; ++i)
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
@ -107,8 +110,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
reductionPack.pushConstantBound(0);
if (ATy.getShape()[K_A_Idx] != -1)
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
else
if (BTy.getShape()[K_B_Idx] != -1)
else if (BTy.getShape()[K_B_Idx] != -1)
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
else
reductionPack.pushOperandBound(
@ -119,7 +121,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// GemmOp supports unidirectional broadcasting from C to A*B.
// Hence, it must be enough to get broadcasting information for C only.
std::map<int, Value> broadcastedDimInfo;
if (has_bias) {
if (hasBias) {
auto shape = C.getType().cast<MemRefType>().getShape();
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
@ -162,7 +164,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
if (has_bias) {
if (hasBias) {
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
@ -210,8 +212,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
}
};
void populateLoweringONNXGemmOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
}

View File

@ -120,25 +120,19 @@ void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface.
void ONNXTanhOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface.
void ONNXSinhOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Cosh
/// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface.
void ONNXCoshOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Cos
@ -178,9 +172,7 @@ void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface.
void ONNXReluOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// LeakyRelu
@ -194,9 +186,7 @@ void ONNXLeakyReluOp::inferShapes() {
// Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface.
void ONNXSeluOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Reciprocal
@ -234,17 +224,13 @@ void ONNXSoftsignOp::inferShapes() {
// Sqrt
/// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface.
void ONNXSqrtOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Sign
/// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface.
void ONNXSignOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Add
@ -423,8 +409,7 @@ void ONNXMatMulOp::inferShapes() {
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all
// sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
lhsShape[0] != rhsShape[0])
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
dims.emplace_back(1);
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
@ -541,14 +526,14 @@ void ONNXMatMulOp::inferShapes() {
// Gemm
void ONNXGemmOp::inferShapes() {
bool hasBias = !getOperand(2).getType().isa<NoneType>();
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>())
(hasBias && !getOperand(2).getType().isa<RankedTensorType>()))
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
@ -560,16 +545,19 @@ void ONNXGemmOp::inferShapes() {
emitError("Tensor shapes mismatched.");
}
if (hasBias) {
// Check whether bias is unidirectional broadcasting or not.
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
auto shape = biasTy.getShape();
int rank = shape.size();
if ((rank > 2) ||
(rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] &&
shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] &&
shape[rank - 2] != 1)) {
(rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
N != shape[rank - 1] && shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
M != shape[rank - 2] && shape[rank - 2] != 1)) {
emitError("Bias shape mismatched.");
}
}
SmallVector<int64_t, 2> dims;
dims.emplace_back(M);
@ -713,7 +701,6 @@ void ONNXTransposeOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// ReduceMax
@ -801,7 +788,8 @@ void ONNXConvNoBiasOp::inferShapes() {
// Required attribute auto_pad defaults to NOTSET.
auto autoPad = auto_pad();
// Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
int64_t group =
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch.");
@ -859,8 +847,10 @@ void ONNXConvNoBiasOp::inferShapes() {
if (dilations.getValue().size() != nDims)
emitError("dilations length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i)
kernelDims[i] = (kernelDims[i] + 1) *
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1;
kernelDims[i] =
(kernelDims[i] + 1) *
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() -
1;
}
// Subtract kernel dimensions from input data dimensions.
@ -906,8 +896,7 @@ void ONNXConvNoBiasOp::inferShapes() {
if (strides.getValue().size() != nDims)
emitError("strides length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) {
int64_t stride =
strides.getValue()[i].cast<IntegerAttr>().getInt();
int64_t stride = strides.getValue()[i].cast<IntegerAttr>().getInt();
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
}
}
@ -937,7 +926,8 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
// get kernel sizes from kernel_shape attribute
auto kernelShape = kernel_shape();
if (!kernelShape)
emitError("kernel_shape is a mandatory attribute for which there is no default.");
emitError(
"kernel_shape is a mandatory attribute for which there is no default.");
auto kernelShapeArray = kernelShape.getValue();
auto kernelRank = kernelShape.size();
if (kernelRank > xRank)
@ -951,7 +941,8 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
SmallVector<int64_t, 4> actualDilations;
auto dilationsOpt = dilations();
if (dilationsOpt.hasValue()) {
auto dilationsArray = dilationsOpt.getValue().getValue(); // opt -> attr -> array
auto dilationsArray =
dilationsOpt.getValue().getValue(); // opt -> attr -> array
if (dilationsArray.size() != kernelRank)
emitError("dialation rank is not the same as the spatial rank.");
// fill in the actual values
@ -1021,13 +1012,15 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
}
for (int i = 0; i < kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i];
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto kernelSpatialShape =
(kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto dilations = actualDilations[i];
auto strideSpatialShape = actualStrides[i];
int64_t outputSpatialShape = ceil((1.0 * inputSpatialShape) /
(1.0 * strideSpatialShape));
int64_t outputSpatialShape =
ceil((1.0 * inputSpatialShape) / (1.0 * strideSpatialShape));
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
((kernelSpatialShape - 1) * dilations + 1) - inputSpatialShape;
((kernelSpatialShape - 1) * dilations + 1) -
inputSpatialShape;
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
if (sumOfPad % 2 != 0) {
if (autoPad == "SAME_UPPER") {
@ -1053,11 +1046,13 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
for (int i = 0; i < kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i];
auto padShape = actualPads[i] + actualPads[kernelRank + i];
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto kernelSpatialShape =
(kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto dilations = actualDilations[i];
auto strideSpatialShape = actualStrides[i];
/// output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] -
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) /
// strides_spatial_shape[i] + 1)
double numerator = inputSpatialShape + padShape -
((kernelSpatialShape - 1) * dilations + 1);
double denominator = strideSpatialShape;

View File

@ -1,7 +1,8 @@
//********************************************************
// Warning: Do not modify this file directly
// This file is automatically generated via script
// Details can be found in doc/readonnxdefs.md
// This file is generated on UTC-02/24/2020, 06:44:13.
// Do not modify this file directly.
// This file is automatically generated via script.
// Details can be found in doc/readonnxdefs.md .
//********************************************************
def ONNXAbsOp:ONNX_Op<"Abs",
@ -213,10 +214,10 @@ def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization",
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
DefaultValuedAttr<F32Attr, "0.9">:$momentum);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$out_mean,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$out_var,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$saved_mean,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$saved_var);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$out_mean,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$out_var,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$saved_mean,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$saved_var);
}
def ONNXBitShiftOp:ONNX_Op<"BitShift",
@ -224,12 +225,12 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
let summary = "ONNX BitShift operation";
let description = [{
"Bitwise shift operator performs element-wise operation. For each input element, if the"
" attribute "direction" is "RIGHT", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute "direction""
" is "LEFT", bits of binary representation moves toward the left side, which results the"
" attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute \"direction\""
" is \"LEFT\", bits of binary representation moves toward the left side, which results the"
" increase of its actual value. The input X is the tensor to be shifted and another input"
" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with"
" Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
" "
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
@ -251,15 +252,15 @@ def ONNXCastOp:ONNX_Op<"Cast",
"the converted type. The 'to' argument must be one of the data types specified"
"in the 'DataType' enum field in the TensorProto message."
""
"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations"
"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may"
"Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
"(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
"result 100. There are some string literals reserved for special floating-point values;"
""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as "314.15926") would be used. "
"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior."
"\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
"Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
""
"Conversion from a numerical type to any numerical type is always allowed."
"User must be aware of precision loss and value change caused by range difference between two types."
@ -292,8 +293,8 @@ def ONNXClipOp:ONNX_Op<"Clip",
"numeric_limits::lowest() and numeric_limits::max(), respectively."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$min,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$max);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$min,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$max);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
}
@ -370,7 +371,7 @@ def ONNXConvOp:ONNX_Op<"Conv",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64Attr, "1">:$group,
@ -389,8 +390,8 @@ def ONNXConvIntegerOp:ONNX_Op<"ConvInteger",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$w,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$x_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$w_zero_point,
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64Attr, "1">:$group,
@ -421,7 +422,7 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64Attr, "1">:$group,
@ -534,7 +535,7 @@ def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_scale,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$x_zero_point);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y);
}
@ -579,7 +580,7 @@ def ONNXDropoutOp:ONNX_Op<"Dropout",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
DefaultValuedAttr<F32Attr, "0.5">:$ratio);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$mask);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$mask);
}
def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear",
@ -817,9 +818,9 @@ def ONNXGRUOp:ONNX_Op<"GRU",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$R,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_h,
OptionalAttr<F32ArrayAttr>:$activation_alpha,
OptionalAttr<F32ArrayAttr>:$activation_beta,
OptionalAttr<StrArrayAttr>:$activations,
@ -827,8 +828,8 @@ def ONNXGRUOp:ONNX_Op<"GRU",
DefaultValuedAttr<StrAttr, "forward">:$direction,
OptionalAttr<I64Attr>:$hidden_size,
DefaultValuedAttr<I64Attr, "0">:$linear_before_reset);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y_h);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_h);
}
def ONNXGatherOp:ONNX_Op<"Gather",
@ -1042,6 +1043,7 @@ def ONNXGatherNDOp:ONNX_Op<"GatherND",
def ONNXGemmOp:ONNX_Op<"Gemm",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Gemm operation";
let description = [{
"General Matrix multiplication:"
@ -1060,7 +1062,7 @@ def ONNXGemmOp:ONNX_Op<"Gemm",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$C,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$C,
DefaultValuedAttr<F32Attr, "1.0">:$alpha,
DefaultValuedAttr<F32Attr, "1.0">:$beta,
DefaultValuedAttr<I64Attr, "0">:$transA,
@ -1332,11 +1334,11 @@ def ONNXLSTMOp:ONNX_Op<"LSTM",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$R,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_c,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$P,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_h,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_c,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$P,
OptionalAttr<F32ArrayAttr>:$activation_alpha,
OptionalAttr<F32ArrayAttr>:$activation_beta,
OptionalAttr<StrArrayAttr>:$activations,
@ -1344,9 +1346,9 @@ def ONNXLSTMOp:ONNX_Op<"LSTM",
DefaultValuedAttr<StrAttr, "forward">:$direction,
OptionalAttr<I64Attr>:$hidden_size,
DefaultValuedAttr<I64Attr, "0">:$input_forget);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y_h,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y_c);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_h,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_c);
}
def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu",
@ -1430,24 +1432,24 @@ def ONNXLoopOp:ONNX_Op<"Loop",
""
" Operator inputs defined as (max_trip_count, condition_var)."
""
" input ("", ""):"
" input (\"\", \"\"):"
" for (int i=0; ; ++i) {"
" cond = ... // Note this value is ignored, but is required in the body"
" }"
""
" input ("", cond) // Note this is analogous to a while loop"
" input (\"\", cond) // Note this is analogous to a while loop"
" bool cond = ...;"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input ("", 1) // Note this is analogous to a do-while loop"
" input (\"\", 1) // Note this is analogous to a do-while loop"
" bool cond = true"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input (trip_count, "") // Note this is analogous to a for loop"
" input (trip_count, \"\") // Note this is analogous to a for loop"
" int trip_count = ..."
" for (int i=0; i < trip_count; ++i) {"
" cond = ...; // ignored"
@ -1473,15 +1475,15 @@ def ONNXLoopOp:ONNX_Op<"Loop",
" }"
""
" graph body-net ("
" %i[INT32, scalar] // iteration number"
" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used"
" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b"
" %i[INT32, scalar]"
" %keepgoing[BOOL, scalar]"
" %b[INT32, scalar]"
" ) {"
" %my_local = Add(%a, %b_in)"
" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b"
" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition"
" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated"
" return %keepgoing_out, %b_out, %user_defined_val"
" %my_local = Add(%a, %b)"
" %b_out = Sub(%a, %b)"
" %keepgoing_out = Greater(%my_local, %b_out)"
" %user_defined_vals = Add(%b, %b)"
" return %keepgoing_out, %b_out, %user_defined_vals"
" }"
""
"*Sample equivalent C code*"
@ -1496,51 +1498,31 @@ def ONNXLoopOp:ONNX_Op<"Loop",
" const int max_trip_count = 10; // Analogous to input M"
" int user_defined_vals[]; // Imagine this is resizable"
" /* End implicitly-defined code */"
" /* initialize loop-carried variables and scan-output variables */"
" bool keepgoing_out = keepgoing"
" int b_out = b"
""
" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
" /* Implicitly-defined code: bind actual parameter values"
" to formal parameter variables of loop-body */"
" bool keepgoing_in = keepgoing_out; "
" bool b_in = b_out;"
""
" for (int i=0; i < max_trip_count && keepgoing; ++i) {"
" /* User-defined code (loop body) */"
" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine"
" b_out = a - b_in;"
" keepgoing_out = my_local > b_out; "
" user_defined_val = b_in + b_in; // b_in and b_out are different variables"
" int my_local = a + b; // Reading values in the enclosing scope is fine"
" b = a - b; // writes fine if we specify b as a loop-carried dependency"
" keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
" user_defined_vals[i] = b + b;"
" /* End user-defined code */"
""
" /* Implicitly defined-code */"
" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
" }"
" // int t = my_local; // Can't do this. my_local is not accessible here."
" // my_local = 123; // Can't do this. my_local was defined in the the body"
""
" // The values below are bound to the output variables of the loop and therefore accessible"
" // b_out; user_defined_vals; keepgoing_out;"
" // These below values are live-out from the loop and therefore accessible"
" b_out; user_defined_vals; keepgoing_out;"
" }"
""
"There are several things of note in this code snippet:"
""
"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can"
"1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
" be referenced in the inputs of the loop."
"2) Any values computed in the loop body that needs to be used in a subsequent"
" iteration or after the loop are modelled using a pair of variables in the loop-body,"
" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)."
" These are referred to as loop-carried dependences. The loop operation node"
" supplies the input value of the input variable for the first iteration, and"
" returns the output value of the output variable produced by the final"
" iteration."
"3) Scan_output variables are used to implicitly concatenate values computed across"
" all the iterations. In the above example, the value of user_defined_val computed"
" over all iterations are concatenated and returned as the value of user_defined_vals"
" after the loop."
"4) Values created in the body cannot be accessed in the enclosing scope,"
" except using the mechanism described above."
"2) Any variables which you wish to make available in the enclosing scope (i.e."
" the variables b and keepgoing) must be declared as either loop-carried"
" dependencies (both at the op inputs and output and at the body net input and"
" output) or scan_outputs."
"3) Values created in the body cannot be accessed in the enclosing scope."
""
"Note that the semantics of this op support "diagonal" or "wavefront" execution."
"Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
"(See Step 3 here for an example:"
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
"Frontends should emit multi-layer RNNs as a series of While operators (with"
@ -1548,8 +1530,8 @@ def ONNXLoopOp:ONNX_Op<"Loop",
"the scan_outputs from the previous layer, possibly going through several"
"point-wise operators (e.g. dropout, residual connections, linear layer)."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$M,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$cond,
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$M,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$cond,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$v_initial,
AnyAttr:$body);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$v_final_and_scan_outputs);
@ -1606,8 +1588,8 @@ def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$a_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$b_zero_point);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$a_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$b_zero_point);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
}
@ -1666,7 +1648,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool",
DefaultValuedAttr<I64Attr, "0">:$storage_order,
OptionalAttr<I64ArrayAttr>:$strides);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Indices);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Indices);
}
def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool",
@ -1709,7 +1691,7 @@ def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$I,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_shape,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$output_shape,
I64ArrayAttr:$kernel_shape,
OptionalAttr<I64ArrayAttr>:$pads,
OptionalAttr<I64ArrayAttr>:$strides);
@ -1841,9 +1823,9 @@ def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$boxes,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$scores,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$max_output_boxes_per_class,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$iou_threshold,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$score_threshold,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$max_output_boxes_per_class,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$iou_threshold,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$score_threshold,
DefaultValuedAttr<I64Attr, "0">:$center_point_box);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$selected_indices);
}
@ -2018,7 +2000,7 @@ def ONNXPadOp:ONNX_Op<"Pad",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$pads,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$constant_value,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$constant_value,
DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
}
@ -2055,7 +2037,7 @@ def ONNXQLinearConvOp:ONNX_Op<"QLinearConv",
AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
DefaultValuedAttr<I64Attr, "1">:$group,
@ -2099,7 +2081,7 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$y_zero_point);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y);
}
@ -2172,17 +2154,17 @@ def ONNXRNNOp:ONNX_Op<"RNN",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$R,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sequence_lens,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_h,
OptionalAttr<F32ArrayAttr>:$activation_alpha,
OptionalAttr<F32ArrayAttr>:$activation_beta,
DefaultValuedAttr<StrArrayAttr, "{\"Tanh\", \"Tanh\"}">:$activations,
OptionalAttr<F32Attr>:$clip,
DefaultValuedAttr<StrAttr, "forward">:$direction,
OptionalAttr<I64Attr>:$hidden_size);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y_h);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_h);
}
def ONNXRandomNormalOp:ONNX_Op<"RandomNormal",
@ -2545,12 +2527,12 @@ def ONNXResizeOp:ONNX_Op<"Resize",
let description = [{
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
"Each dimension value of the output tensor is:"
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified."
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$roi,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$scales,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$sizes,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sizes,
DefaultValuedAttr<StrAttr, "half_pixel">:$coordinate_transformation_mode,
DefaultValuedAttr<F32Attr, "-0.75">:$cubic_coeff_a,
DefaultValuedAttr<I64Attr, "0">:$exclude_outside,
@ -3044,7 +3026,7 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase",
"'position' is optional, by default it erases the last tensor from 'input_sequence'."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$position);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$position);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
}
@ -3060,7 +3042,7 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert",
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$position);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$position);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
}
@ -3194,8 +3176,8 @@ def ONNXSliceOp:ONNX_Op<"Slice",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$starts,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$ends,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$axes,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$steps);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$axes,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$steps);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
}
@ -3269,7 +3251,7 @@ def ONNXSplitOp:ONNX_Op<"Split",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input,
DefaultValuedAttr<I64Attr, "0">:$axis,
OptionalAttr<I64ArrayAttr>:$split);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$outputs);
let results = (outs Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$outputs);
}
def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
@ -3288,7 +3270,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
"dimension size of input tensor on 'axis'."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$split,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$split,
DefaultValuedAttr<I64Attr, "0">:$axis,
DefaultValuedAttr<I64Attr, "1">:$keepdims);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence);
@ -3327,9 +3309,9 @@ def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer",
"StringNormalization performs string operations for basic cleaning."
"This operator has only one input (denoted by X) and only one output"
"(denoted by Y). This operator first examines the elements in the X,"
"and removes elements specified in "stopwords" attribute."
"and removes elements specified in \"stopwords\" attribute."
"After removing stop words, the intermediate result can be further lowercased,"
"uppercased, or just returned depending the "case_change_action" attribute."
"uppercased, or just returned depending the \"case_change_action\" attribute."
"This operator only accepts [C]- and [1, C]-tensor."
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
@ -3412,8 +3394,8 @@ def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer",
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
"Note that we may consider all skips up to S when generating the n-grams."
""
"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF","
"The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
""
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
@ -3470,9 +3452,9 @@ def ONNXTopKOp:ONNX_Op<"TopK",
" contains the indices of the top k elements (original indices from the input"
" tensor)."
""
"If "largest" is 1 (the default value) then the k largest elements are returned."
"If "sorted" is 1 (the default value) then the resulting k elements will be sorted."
"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined."
"If \"largest\" is 1 (the default value) then the k largest elements are returned."
"If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
"If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
""
"Given two equivalent values, this operator uses the indices along the axis as"
" a tiebreaker. That is, the element with the lower index will appear first."
@ -3509,7 +3491,7 @@ def ONNXUniqueOp:ONNX_Op<"Unique",
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
""
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
@ -3583,9 +3565,9 @@ def ONNXUniqueOp:ONNX_Op<"Unique",
OptionalAttr<I64Attr>:$axis,
DefaultValuedAttr<I64Attr, "1">:$sorted);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$inverse_indices,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$counts);
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$indices,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$inverse_indices,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$counts);
}
def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze",
@ -3652,3 +3634,4 @@ def ONNXXorOp:ONNX_Op<"Xor",
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C);
}

View File

@ -127,6 +127,10 @@ int main(int argc, char *argv[]) {
if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createLowerKrnlPass());
}

View File

@ -28,6 +28,11 @@ void ONNXAddOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<MulAddToGemmOptPattern>(context);
}
void ONNXGemmOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<FuseGemmFollowedByAddition>(context);
}
/// on the ONNXIdentityOp.
void ONNXIdentityOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {

View File

@ -26,6 +26,7 @@ include "dialect/onnx/onnx.td"
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite
@ -41,6 +42,11 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
// onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z)
def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias),
(ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
(replaceWithValue $arg)>;

View File

@ -101,3 +101,14 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
// CHECK-NEXT: return %1 : tensor<*xf32>
}
//CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> {
func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32>
%1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// return [[GEMM]] : tensor<*xf32>
}