Support Optional Inputs (#94)

* 1. Combine variadicIn/Out with expectedNumOperands/Results to simplify import function arguments.
2. Generic improvements to code readability in gen_doc.py.

* Update ONNX Dialect doc.

* Remove redundant code in ImportNode.

* Prettify op_build_table.inc.

* 1. Remove irrelevant code in gen_doc.py

* Refactor code to be more readable.

* Further refactoring for readability improvements.

* Allow gemm to have an optional operand (bias term), and include an example of declarative optimization pattern targeting gemm with bias term ommitted.

* Make shape inference/lowering of gemm op compatible with optional operand declaration.

* Apply canonicalization again after lowering from onnx -> std dialects.

* Make hasBias compatible with the situation of GemmNoBias op.

* Update doc.

* Add a canonicalization test.

* Remove special handler for importing Gemm op, as it's redundant now.
This commit is contained in:
Tian Jin 2020-02-24 23:46:48 +08:00 committed by GitHub
parent 479dd5e35a
commit 9c398c0121
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 3431 additions and 3804 deletions

View File

@ -327,10 +327,10 @@ ONNX BatchNormalization operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values
1. `out_mean`: memref of any type values or tensor of any type values 1. `out_mean`: memref of any type values or tensor of any type values or none type
1. `out_var`: memref of any type values or tensor of any type values 1. `out_var`: memref of any type values or tensor of any type values or none type
1. `saved_mean`: memref of any type values or tensor of any type values 1. `saved_mean`: memref of any type values or tensor of any type values or none type
1. `saved_var`: memref of any type values or tensor of any type values 1. `saved_var`: memref of any type values or tensor of any type values or none type
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp) ### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
ONNX BatchNormalization operation in test mode ONNX BatchNormalization operation in test mode
@ -375,12 +375,12 @@ ONNX BitShift operation
"Bitwise shift operator performs element-wise operation. For each input element, if the" "Bitwise shift operator performs element-wise operation. For each input element, if the"
" attribute "direction" is "RIGHT", this operator moves its binary representation toward" " attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute "direction"" " the right side so that the input value is effectively decreased. If the attribute \"direction\""
" is "LEFT", bits of binary representation moves toward the left side, which results the" " is \"LEFT\", bits of binary representation moves toward the left side, which results the"
" increase of its actual value. The input X is the tensor to be shifted and another input" " increase of its actual value. The input X is the tensor to be shifted and another input"
" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4]," " Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with" " and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]." " X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
" " " "
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are" " Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
@ -413,15 +413,15 @@ ONNX Cast operation
"the converted type. The 'to' argument must be one of the data types specified" "the converted type. The 'to' argument must be one of the data types specified"
"in the 'DataType' enum field in the TensorProto message." "in the 'DataType' enum field in the TensorProto message."
"" ""
"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations" "Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may" "(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
"result 100. There are some string literals reserved for special floating-point values;" "result 100. There are some string literals reserved for special floating-point values;"
""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively." "\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly," "Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors" "this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as "314.15926") would be used. " "to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases " "Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior." "of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
"" ""
"Conversion from a numerical type to any numerical type is always allowed." "Conversion from a numerical type to any numerical type is always allowed."
"User must be aware of precision loss and value change caused by range difference between two types." "User must be aware of precision loss and value change caused by range difference between two types."
@ -476,8 +476,8 @@ ONNX Clip operation
#### Operands: #### Operands:
1. `input`: memref of any type values or tensor of any type values 1. `input`: memref of any type values or tensor of any type values
1. `min`: memref of any type values or tensor of any type values 1. `min`: memref of any type values or tensor of any type values or none type
1. `max`: memref of any type values or tensor of any type values 1. `max`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -618,8 +618,8 @@ ONNX ConvInteger operation
1. `x`: memref of any type values or tensor of any type values 1. `x`: memref of any type values or tensor of any type values
1. `w`: memref of any type values or tensor of any type values 1. `w`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values 1. `x_zero_point`: memref of any type values or tensor of any type values or none type
1. `w_zero_point`: memref of any type values or tensor of any type values 1. `w_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -678,7 +678,7 @@ ONNX Conv operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -720,7 +720,7 @@ ONNX ConvTranspose operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -884,7 +884,7 @@ ONNX DequantizeLinear operation
1. `x`: memref of any type values or tensor of any type values 1. `x`: memref of any type values or tensor of any type values
1. `x_scale`: memref of any type values or tensor of any type values 1. `x_scale`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values 1. `x_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -964,7 +964,7 @@ ONNX Dropout operation
#### Results: #### Results:
1. `output`: memref of any type values or tensor of any type values 1. `output`: memref of any type values or tensor of any type values
1. `mask`: memref of any type values or tensor of any type values 1. `mask`: memref of any type values or tensor of any type values or none type
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp) ### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
ONNX DynamicQuantizeLinear operation ONNX DynamicQuantizeLinear operation
@ -1297,9 +1297,9 @@ ONNX GRU operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values 1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values 1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values 1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -1315,8 +1315,8 @@ ONNX GRU operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values 1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.GatherElements (ONNXGatherElementsOp) ### onnx.GatherElements (ONNXGatherElementsOp)
ONNX GatherElements operation ONNX GatherElements operation
@ -1609,7 +1609,7 @@ ONNX Gemm operation
1. `A`: memref of any type values or tensor of any type values 1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values
1. `C`: memref of any type values or tensor of any type values 1. `C`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -2013,11 +2013,11 @@ ONNX LSTM operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values 1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values 1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values 1. `initial_h`: memref of any type values or tensor of any type values or none type
1. `initial_c`: memref of any type values or tensor of any type values 1. `initial_c`: memref of any type values or tensor of any type values or none type
1. `P`: memref of any type values or tensor of any type values 1. `P`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -2033,9 +2033,9 @@ ONNX LSTM operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values 1. `Y_h`: memref of any type values or tensor of any type values or none type
1. `Y_c`: memref of any type values or tensor of any type values 1. `Y_c`: memref of any type values or tensor of any type values or none type
### onnx.LeakyRelu (ONNXLeakyReluOp) ### onnx.LeakyRelu (ONNXLeakyReluOp)
ONNX LeakyRelu operation ONNX LeakyRelu operation
@ -2160,24 +2160,24 @@ ONNX Loop operation
"" ""
" Operator inputs defined as (max_trip_count, condition_var)." " Operator inputs defined as (max_trip_count, condition_var)."
"" ""
" input ("", ""):" " input (\"\", \"\"):"
" for (int i=0; ; ++i) {" " for (int i=0; ; ++i) {"
" cond = ... // Note this value is ignored, but is required in the body" " cond = ... // Note this value is ignored, but is required in the body"
" }" " }"
"" ""
" input ("", cond) // Note this is analogous to a while loop" " input (\"\", cond) // Note this is analogous to a while loop"
" bool cond = ...;" " bool cond = ...;"
" for (int i=0; cond; ++i) {" " for (int i=0; cond; ++i) {"
" cond = ...;" " cond = ...;"
" }" " }"
"" ""
" input ("", 1) // Note this is analogous to a do-while loop" " input (\"\", 1) // Note this is analogous to a do-while loop"
" bool cond = true" " bool cond = true"
" for (int i=0; cond; ++i) {" " for (int i=0; cond; ++i) {"
" cond = ...;" " cond = ...;"
" }" " }"
"" ""
" input (trip_count, "") // Note this is analogous to a for loop" " input (trip_count, \"\") // Note this is analogous to a for loop"
" int trip_count = ..." " int trip_count = ..."
" for (int i=0; i < trip_count; ++i) {" " for (int i=0; i < trip_count; ++i) {"
" cond = ...; // ignored" " cond = ...; // ignored"
@ -2203,15 +2203,15 @@ ONNX Loop operation
" }" " }"
"" ""
" graph body-net (" " graph body-net ("
" %i[INT32, scalar] // iteration number" " %i[INT32, scalar]"
" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used" " %keepgoing[BOOL, scalar]"
" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b" " %b[INT32, scalar]"
" ) {" " ) {"
" %my_local = Add(%a, %b_in)" " %my_local = Add(%a, %b)"
" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b" " %b_out = Sub(%a, %b)"
" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition" " %keepgoing_out = Greater(%my_local, %b_out)"
" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated" " %user_defined_vals = Add(%b, %b)"
" return %keepgoing_out, %b_out, %user_defined_val" " return %keepgoing_out, %b_out, %user_defined_vals"
" }" " }"
"" ""
"*Sample equivalent C code*" "*Sample equivalent C code*"
@ -2226,51 +2226,31 @@ ONNX Loop operation
" const int max_trip_count = 10; // Analogous to input M" " const int max_trip_count = 10; // Analogous to input M"
" int user_defined_vals[]; // Imagine this is resizable" " int user_defined_vals[]; // Imagine this is resizable"
" /* End implicitly-defined code */" " /* End implicitly-defined code */"
" /* initialize loop-carried variables and scan-output variables */" " for (int i=0; i < max_trip_count && keepgoing; ++i) {"
" bool keepgoing_out = keepgoing"
" int b_out = b"
""
" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
" /* Implicitly-defined code: bind actual parameter values"
" to formal parameter variables of loop-body */"
" bool keepgoing_in = keepgoing_out; "
" bool b_in = b_out;"
""
" /* User-defined code (loop body) */" " /* User-defined code (loop body) */"
" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine" " int my_local = a + b; // Reading values in the enclosing scope is fine"
" b_out = a - b_in;" " b = a - b; // writes fine if we specify b as a loop-carried dependency"
" keepgoing_out = my_local > b_out; " " keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
" user_defined_val = b_in + b_in; // b_in and b_out are different variables" " user_defined_vals[i] = b + b;"
" /* End user-defined code */" " /* End user-defined code */"
""
" /* Implicitly defined-code */"
" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
" }" " }"
" // int t = my_local; // Can't do this. my_local is not accessible here." " // my_local = 123; // Can't do this. my_local was defined in the the body"
"" ""
" // The values below are bound to the output variables of the loop and therefore accessible" " // These below values are live-out from the loop and therefore accessible"
" // b_out; user_defined_vals; keepgoing_out;" " b_out; user_defined_vals; keepgoing_out;"
" }" " }"
"" ""
"There are several things of note in this code snippet:" "There are several things of note in this code snippet:"
"" ""
"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can" "1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
" be referenced in the inputs of the loop." " be referenced in the inputs of the loop."
"2) Any values computed in the loop body that needs to be used in a subsequent" "2) Any variables which you wish to make available in the enclosing scope (i.e."
" iteration or after the loop are modelled using a pair of variables in the loop-body," " the variables b and keepgoing) must be declared as either loop-carried"
" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)." " dependencies (both at the op inputs and output and at the body net input and"
" These are referred to as loop-carried dependences. The loop operation node" " output) or scan_outputs."
" supplies the input value of the input variable for the first iteration, and" "3) Values created in the body cannot be accessed in the enclosing scope."
" returns the output value of the output variable produced by the final"
" iteration."
"3) Scan_output variables are used to implicitly concatenate values computed across"
" all the iterations. In the above example, the value of user_defined_val computed"
" over all iterations are concatenated and returned as the value of user_defined_vals"
" after the loop."
"4) Values created in the body cannot be accessed in the enclosing scope,"
" except using the mechanism described above."
"" ""
"Note that the semantics of this op support "diagonal" or "wavefront" execution." "Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
"(See Step 3 here for an example:" "(See Step 3 here for an example:"
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)." "https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
"Frontends should emit multi-layer RNNs as a series of While operators (with" "Frontends should emit multi-layer RNNs as a series of While operators (with"
@ -2280,8 +2260,8 @@ ONNX Loop operation
#### Operands: #### Operands:
1. `M`: memref of any type values or tensor of any type values 1. `M`: memref of any type values or tensor of any type values or none type
1. `cond`: memref of any type values or tensor of any type values 1. `cond`: memref of any type values or tensor of any type values or none type
1. `v_initial`: memref of any type values or tensor of any type values 1. `v_initial`: memref of any type values or tensor of any type values
#### Attributes: #### Attributes:
@ -2360,8 +2340,8 @@ ONNX MatMulInteger operation
1. `A`: memref of any type values or tensor of any type values 1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values
1. `a_zero_point`: memref of any type values or tensor of any type values 1. `a_zero_point`: memref of any type values or tensor of any type values or none type
1. `b_zero_point`: memref of any type values or tensor of any type values 1. `b_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -2444,7 +2424,7 @@ ONNX MaxPool operation
" ```" " ```"
" pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]" " pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]"
" ```" " ```"
" The output of each pooling window is maximum number of elements exclude pad. " " The output of each pooling window is maximum number of elements exclude pad."
" " " "
#### Operands: #### Operands:
@ -2466,7 +2446,7 @@ ONNX MaxPool operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values
1. `Indices`: memref of any type values or tensor of any type values 1. `Indices`: memref of any type values or tensor of any type values or none type
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp) ### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
ONNX MaxPool operation with a single output. ONNX MaxPool operation with a single output.
@ -2552,7 +2532,7 @@ ONNX MaxUnpool operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `I`: memref of any type values or tensor of any type values 1. `I`: memref of any type values or tensor of any type values
1. `output_shape`: memref of any type values or tensor of any type values 1. `output_shape`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -2752,9 +2732,9 @@ ONNX NonMaxSuppression operation
1. `boxes`: memref of any type values or tensor of any type values 1. `boxes`: memref of any type values or tensor of any type values
1. `scores`: memref of any type values or tensor of any type values 1. `scores`: memref of any type values or tensor of any type values
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values 1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values or none type
1. `iou_threshold`: memref of any type values or tensor of any type values 1. `iou_threshold`: memref of any type values or tensor of any type values or none type
1. `score_threshold`: memref of any type values or tensor of any type values 1. `score_threshold`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3041,7 +3021,7 @@ ONNX Pad operation
1. `data`: memref of any type values or tensor of any type values 1. `data`: memref of any type values or tensor of any type values
1. `pads`: memref of any type values or tensor of any type values 1. `pads`: memref of any type values or tensor of any type values
1. `constant_value`: memref of any type values or tensor of any type values 1. `constant_value`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3098,7 +3078,7 @@ ONNX QLinearConv operation
1. `w_zero_point`: memref of any type values or tensor of any type values 1. `w_zero_point`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values 1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values 1. `y_zero_point`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3162,7 +3142,7 @@ ONNX QuantizeLinear operation
1. `x`: memref of any type values or tensor of any type values 1. `x`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values 1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values 1. `y_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3244,9 +3224,9 @@ ONNX RNN operation
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values 1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values 1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values 1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values 1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values 1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -3261,8 +3241,8 @@ ONNX RNN operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values 1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp) ### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
ONNX RandomNormalLike operation ONNX RandomNormalLike operation
@ -3787,14 +3767,14 @@ ONNX Resize operation
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor." "Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
"Each dimension value of the output tensor is:" "Each dimension value of the output tensor is:"
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified." " output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
#### Operands: #### Operands:
1. `X`: memref of any type values or tensor of any type values 1. `X`: memref of any type values or tensor of any type values
1. `roi`: memref of any type values or tensor of any type values 1. `roi`: memref of any type values or tensor of any type values
1. `scales`: memref of any type values or tensor of any type values 1. `scales`: memref of any type values or tensor of any type values
1. `sizes`: memref of any type values or tensor of any type values 1. `sizes`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4412,7 +4392,7 @@ ONNX SequenceErase operation
#### Operands: #### Operands:
1. `input_sequence`: memref of any type values or tensor of any type values 1. `input_sequence`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values 1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4437,7 +4417,7 @@ ONNX SequenceInsert operation
1. `input_sequence`: memref of any type values or tensor of any type values 1. `input_sequence`: memref of any type values or tensor of any type values
1. `tensor`: memref of any type values or tensor of any type values 1. `tensor`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values 1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4654,8 +4634,8 @@ ONNX Slice operation
1. `data`: memref of any type values or tensor of any type values 1. `data`: memref of any type values or tensor of any type values
1. `starts`: memref of any type values or tensor of any type values 1. `starts`: memref of any type values or tensor of any type values
1. `ends`: memref of any type values or tensor of any type values 1. `ends`: memref of any type values or tensor of any type values
1. `axes`: memref of any type values or tensor of any type values 1. `axes`: memref of any type values or tensor of any type values or none type
1. `steps`: memref of any type values or tensor of any type values 1. `steps`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4808,7 +4788,7 @@ ONNX SplitToSequence operation
#### Operands: #### Operands:
1. `input`: memref of any type values or tensor of any type values 1. `input`: memref of any type values or tensor of any type values
1. `split`: memref of any type values or tensor of any type values 1. `split`: memref of any type values or tensor of any type values or none type
#### Attributes: #### Attributes:
@ -4876,9 +4856,9 @@ ONNX StringNormalizer operation
"StringNormalization performs string operations for basic cleaning." "StringNormalization performs string operations for basic cleaning."
"This operator has only one input (denoted by X) and only one output" "This operator has only one input (denoted by X) and only one output"
"(denoted by Y). This operator first examines the elements in the X," "(denoted by Y). This operator first examines the elements in the X,"
"and removes elements specified in "stopwords" attribute." "and removes elements specified in \"stopwords\" attribute."
"After removing stop words, the intermediate result can be further lowercased," "After removing stop words, the intermediate result can be further lowercased,"
"uppercased, or just returned depending the "case_change_action" attribute." "uppercased, or just returned depending the \"case_change_action\" attribute."
"This operator only accepts [C]- and [1, C]-tensor." "This operator only accepts [C]- and [1, C]-tensor."
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]" "If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
"if input shape is [C] and shape [1, 1] if input shape is [1, C]." "if input shape is [C] and shape [1, 1] if input shape is [1, C]."
@ -5008,8 +4988,8 @@ ONNX TfIdfVectorizer operation
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output." "respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
"Note that we may consider all skips up to S when generating the n-grams." "Note that we may consider all skips up to S when generating the n-grams."
"" ""
"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and" "The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF"," "the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute." "this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
"" ""
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor." "Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
@ -5097,9 +5077,9 @@ ONNX TopK operation
" contains the indices of the top k elements (original indices from the input" " contains the indices of the top k elements (original indices from the input"
" tensor)." " tensor)."
"" ""
"If "largest" is 1 (the default value) then the k largest elements are returned." "If \"largest\" is 1 (the default value) then the k largest elements are returned."
"If "sorted" is 1 (the default value) then the resulting k elements will be sorted." "If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined." "If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
"" ""
"Given two equivalent values, this operator uses the indices along the axis as" "Given two equivalent values, this operator uses the indices along the axis as"
" a tiebreaker. That is, the element with the lower index will appear first." " a tiebreaker. That is, the element with the lower index will appear first."
@ -5158,7 +5138,7 @@ ONNX Unique operation
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. " "This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
"The first output tensor 'Y' contains all unique values or subtensors of the input. " "The first output tensor 'Y' contains all unique values or subtensors of the input. "
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. " "The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". " "The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. " "The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
"" ""
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. " "Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
@ -5242,9 +5222,9 @@ ONNX Unique operation
#### Results: #### Results:
1. `Y`: memref of any type values or tensor of any type values 1. `Y`: memref of any type values or tensor of any type values
1. `indices`: memref of any type values or tensor of any type values 1. `indices`: memref of any type values or tensor of any type values or none type
1. `inverse_indices`: memref of any type values or tensor of any type values 1. `inverse_indices`: memref of any type values or tensor of any type values or none type
1. `counts`: memref of any type values or tensor of any type values 1. `counts`: memref of any type values or tensor of any type values or none type
### onnx.Unsqueeze (ONNXUnsqueezeOp) ### onnx.Unsqueeze (ONNXUnsqueezeOp)
ONNX Unsqueeze operation ONNX Unsqueeze operation

File diff suppressed because it is too large Load Diff

View File

@ -121,6 +121,7 @@ private:
mlir::MLIRContext &context_; mlir::MLIRContext &context_;
mlir::ModuleOp module_; mlir::ModuleOp module_;
mlir::OpBuilder builder_; mlir::OpBuilder builder_;
mlir::Value none_;
// mapping between string name and symbol // mapping between string name and symbol
OnnxOnnfSymbolMapping frontend_symbols_; OnnxOnnfSymbolMapping frontend_symbols_;
@ -287,8 +288,8 @@ private:
} }
} }
std::vector<mlir::NamedAttribute> ImportNodeAttributes( std::vector<mlir::NamedAttribute>
const onnx::NodeProto &node) { ImportNodeAttributes(const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> attributes; std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) { for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i); auto attr = node.attribute(i);
@ -317,21 +318,11 @@ private:
} }
} }
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
// combined with 'if constexpr' the issue is the type of the output is
// different. alternative way to use variadic output for all the op
/*!
* Important onnx node which generates only one output
* @param node onnx node
* @param nIn number of expected inputs
* @param nOut number of expected outputs
* @param attrs list of desription for attributes with format {name, type,
* default}
*/
template <typename T> template <typename T>
void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut, void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
bool variadicIn = false, bool variadicOut = false) { int expectedNumResults = -1) {
bool variadicIn = expectedNumOperands == -1;
bool variadicOut = expectedNumResults == -1;
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
@ -339,6 +330,10 @@ private:
} }
} }
if (!variadicIn)
for (auto i = inputs.size(); i < expectedNumOperands; i++)
inputs.emplace_back(none_);
std::vector<mlir::Type> outputTypes; std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) { for (auto item : node.output()) {
outputTypes.push_back( outputTypes.push_back(
@ -347,49 +342,11 @@ private:
auto attributes = ImportNodeAttributes(node); auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type(); // TODO: Handle optional inputs.
if ((variadicIn || nIn == inputs.size()) && auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
(variadicOut || nOut == outputTypes.size())) { for (int i = 0; i < node.output().size(); i++) {
auto op = frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); *(op.getODSResults(i).begin()));
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
op.getResult());
} else {
ImportNodeGeneric(node);
}
}
template <typename T>
void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
bool variadicIn = false,
bool variadicOut = false) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type();
if ((variadicIn || nIn == inputs.size()) &&
(variadicOut || nOut == outputTypes.size())) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
op.getResult(i));
}
} else {
ImportNodeGeneric(node);
} }
} }
@ -398,8 +355,7 @@ private:
* c++ does not allow template specialization inside a class scope * c++ does not allow template specialization inside a class scope
* a specialized function is used * a specialized function is used
*/ */
void void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
// Conv has attribute dilations, kernel_shape, pads, the default value of // Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the // which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto // shape is unknown now, these attributes can be not generated auto
@ -413,24 +369,20 @@ private:
int nOps = node.input().size(); int nOps = node.input().size();
if (nOps == 2) if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>( buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
node, nOps, nOut);
else else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut); buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
} }
/*! /*!
* Special handle for MaxPool operations. * Special handle for MaxPool operations.
*/ */
void ImportNodeMaxPool( void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>( buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
node, nIn, nOuts);
} else { } else {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>( buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
node, nIn, nOuts);
} }
} }
@ -441,23 +393,10 @@ private:
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
// Test mode with one output. // Test mode with one output.
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
nOuts);
} else { } else {
// Training mode with four trailing optional outputs. Not handled yet. // Training mode with four trailing optional outputs. Not handled yet.
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts); buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
/*!
* Special handle for Gemm operations.
*/
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
} }
} }
@ -467,28 +406,14 @@ private:
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) { void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size(); int nOps = node.input().size();
if (nOps == 2) { if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut); buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
} else { } else {
ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut); buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
} }
} }
void ImportNode(const onnx::NodeProto &node) { void ImportNode(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs; llvm::StringRef opName = node.op_type();
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
llvm::StringRef OpName = node.op_type();
// the following code is generated by gen_doc.py // the following code is generated by gen_doc.py
// refer to dialect/onnx/onnx.td for details // refer to dialect/onnx/onnx.td for details
@ -555,9 +480,11 @@ private:
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it)); ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
} }
// import nodes in the graph // Create a NoneTyped constant.
auto node = graph.node(); none_ =
for (const auto &item : node) { builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph.
for (const auto &item : graph.node()) {
ImportNode(item); ImportNode(item);
} }

View File

@ -1,320 +1,319 @@
//******************************************************** //********************************************************
// Warning: Do not modify this file directly // This file is generated on UTC-02/24/2020, 06:29:01.
// This file is automatically generated via script // Do not modify this file directly.
// Details can be found in doc/readonnxdefs.md // This file is automatically generated via script.
// Details can be found in doc/readonnxdefs.md .
//******************************************************** //********************************************************
if (OpName == "DUMMY") { if (opName == "Abs")
}else if (OpName == "Abs") { return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1); if (opName == "Acos")
}else if (OpName == "Acos") { return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1); if (opName == "Acosh")
}else if (OpName == "Acosh") { return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1); if (opName == "Add")
}else if (OpName == "Add") { return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1); if (opName == "And")
}else if (OpName == "And") { return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1); if (opName == "ArgMax")
}else if (OpName == "ArgMax") { return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1); if (opName == "ArgMin")
}else if (OpName == "ArgMin") { return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1); if (opName == "Asin")
}else if (OpName == "Asin") { return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1); if (opName == "Asinh")
}else if (OpName == "Asinh") { return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1); if (opName == "Atan")
}else if (OpName == "Atan") { return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1); if (opName == "Atanh")
}else if (OpName == "Atanh") { return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1); if (opName == "AveragePool")
}else if (OpName == "AveragePool") { return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1); if (opName == "BatchNormalization")
}else if (OpName == "BatchNormalization") { return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
ImportNodeBatchNormalization(node, 5, 5); if (opName == "BitShift")
}else if (OpName == "BitShift") { return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1); if (opName == "Cast")
}else if (OpName == "Cast") { return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1); if (opName == "Ceil")
}else if (OpName == "Ceil") { return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1); if (opName == "Clip")
}else if (OpName == "Clip") { return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1); if (opName == "Compress")
}else if (OpName == "Compress") { return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1); if (opName == "Concat")
}else if (OpName == "Concat") { return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false); if (opName == "ConcatFromSequence")
}else if (OpName == "ConcatFromSequence") { return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1); if (opName == "Constant")
}else if (OpName == "Constant") { return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1); if (opName == "ConstantOfShape")
}else if (OpName == "ConstantOfShape") { return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1); if (opName == "Conv")
}else if (OpName == "Conv") { return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeConv(node, 3, 1); if (opName == "ConvInteger")
}else if (OpName == "ConvInteger") { return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1); if (opName == "ConvTranspose")
}else if (OpName == "ConvTranspose") { return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1); if (opName == "Cos")
}else if (OpName == "Cos") { return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1); if (opName == "Cosh")
}else if (OpName == "Cosh") { return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1); if (opName == "CumSum")
}else if (OpName == "CumSum") { return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1); if (opName == "DepthToSpace")
}else if (OpName == "DepthToSpace") { return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1); if (opName == "DequantizeLinear")
}else if (OpName == "DequantizeLinear") { return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1); if (opName == "Det")
}else if (OpName == "Det") { return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1); if (opName == "Div")
}else if (OpName == "Div") { return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1); if (opName == "Dropout")
}else if (OpName == "Dropout") { return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2); if (opName == "DynamicQuantizeLinear")
}else if (OpName == "DynamicQuantizeLinear") { return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3); if (opName == "Elu")
}else if (OpName == "Elu") { return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1); if (opName == "Equal")
}else if (OpName == "Equal") { return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1); if (opName == "Erf")
}else if (OpName == "Erf") { return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1); if (opName == "Exp")
}else if (OpName == "Exp") { return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1); if (opName == "Expand")
}else if (OpName == "Expand") { return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1); if (opName == "EyeLike")
}else if (OpName == "EyeLike") { return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1); if (opName == "Flatten")
}else if (OpName == "Flatten") { return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1); if (opName == "Floor")
}else if (OpName == "Floor") { return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1); if (opName == "GRU")
}else if (OpName == "GRU") { return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2); if (opName == "Gather")
}else if (OpName == "Gather") { return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1); if (opName == "GatherElements")
}else if (OpName == "GatherElements") { return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1); if (opName == "GatherND")
}else if (OpName == "GatherND") { return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1); if (opName == "Gemm")
}else if (OpName == "Gemm") { return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeGemm(node, 3, 1); if (opName == "GlobalAveragePool")
}else if (OpName == "GlobalAveragePool") { return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1); if (opName == "GlobalLpPool")
}else if (OpName == "GlobalLpPool") { return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1); if (opName == "GlobalMaxPool")
}else if (OpName == "GlobalMaxPool") { return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1); if (opName == "Greater")
}else if (OpName == "Greater") { return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1); if (opName == "HardSigmoid")
}else if (OpName == "HardSigmoid") { return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1); if (opName == "Hardmax")
}else if (OpName == "Hardmax") { return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1); if (opName == "Identity")
}else if (OpName == "Identity") { return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1); if (opName == "If")
}else if (OpName == "If") { return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1); if (opName == "InstanceNormalization")
}else if (OpName == "InstanceNormalization") { return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1); if (opName == "IsInf")
}else if (OpName == "IsInf") { return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1); if (opName == "IsNaN")
}else if (OpName == "IsNaN") { return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1); if (opName == "LRN")
}else if (OpName == "LRN") { return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1); if (opName == "LSTM")
}else if (OpName == "LSTM") { return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3); if (opName == "LeakyRelu")
}else if (OpName == "LeakyRelu") { return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1); if (opName == "Less")
}else if (OpName == "Less") { return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1); if (opName == "Log")
}else if (OpName == "Log") { return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1); if (opName == "LogSoftmax")
}else if (OpName == "LogSoftmax") { return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1); if (opName == "Loop")
}else if (OpName == "Loop") { return buildOperation<mlir::ONNXLoopOp>(node);
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1); if (opName == "LpNormalization")
}else if (OpName == "LpNormalization") { return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1); if (opName == "LpPool")
}else if (OpName == "LpPool") { return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1); if (opName == "MatMul")
}else if (OpName == "MatMul") { return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1); if (opName == "MatMulInteger")
}else if (OpName == "MatMulInteger") { return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1); if (opName == "Max")
}else if (OpName == "Max") { return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false); if (opName == "MaxPool")
}else if (OpName == "MaxPool") { return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
ImportNodeMaxPool(node, 1, 2); if (opName == "MaxRoiPool")
}else if (OpName == "MaxRoiPool") { return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1); if (opName == "MaxUnpool")
}else if (OpName == "MaxUnpool") { return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1); if (opName == "Mean")
}else if (OpName == "Mean") { return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false); if (opName == "MeanVarianceNormalization")
}else if (OpName == "MeanVarianceNormalization") { return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1); if (opName == "Min")
}else if (OpName == "Min") { return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false); if (opName == "Mod")
}else if (OpName == "Mod") { return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1); if (opName == "Mul")
}else if (OpName == "Mul") { return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1); if (opName == "Multinomial")
}else if (OpName == "Multinomial") { return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1); if (opName == "Neg")
}else if (OpName == "Neg") { return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1); if (opName == "NonMaxSuppression")
}else if (OpName == "NonMaxSuppression") { return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1); if (opName == "NonZero")
}else if (OpName == "NonZero") { return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1); if (opName == "Not")
}else if (OpName == "Not") { return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1); if (opName == "OneHot")
}else if (OpName == "OneHot") { return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1); if (opName == "Or")
}else if (OpName == "Or") { return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1); if (opName == "PRelu")
}else if (OpName == "PRelu") { return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1); if (opName == "Pad")
}else if (OpName == "Pad") { return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodePad(node, 3, 1); if (opName == "Pow")
}else if (OpName == "Pow") { return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1); if (opName == "QLinearConv")
}else if (OpName == "QLinearConv") { return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1); if (opName == "QLinearMatMul")
}else if (OpName == "QLinearMatMul") { return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1); if (opName == "QuantizeLinear")
}else if (OpName == "QuantizeLinear") { return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1); if (opName == "RNN")
}else if (OpName == "RNN") { return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2); if (opName == "RandomNormal")
}else if (OpName == "RandomNormal") { return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1); if (opName == "RandomNormalLike")
}else if (OpName == "RandomNormalLike") { return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1); if (opName == "RandomUniform")
}else if (OpName == "RandomUniform") { return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1); if (opName == "RandomUniformLike")
}else if (OpName == "RandomUniformLike") { return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1); if (opName == "Range")
}else if (OpName == "Range") { return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1); if (opName == "Reciprocal")
}else if (OpName == "Reciprocal") { return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1); if (opName == "ReduceL1")
}else if (OpName == "ReduceL1") { return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1); if (opName == "ReduceL2")
}else if (OpName == "ReduceL2") { return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1); if (opName == "ReduceLogSum")
}else if (OpName == "ReduceLogSum") { return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1); if (opName == "ReduceLogSumExp")
}else if (OpName == "ReduceLogSumExp") { return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1); if (opName == "ReduceMax")
}else if (OpName == "ReduceMax") { return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1); if (opName == "ReduceMean")
}else if (OpName == "ReduceMean") { return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1); if (opName == "ReduceMin")
}else if (OpName == "ReduceMin") { return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1); if (opName == "ReduceProd")
}else if (OpName == "ReduceProd") { return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1); if (opName == "ReduceSum")
}else if (OpName == "ReduceSum") { return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1); if (opName == "ReduceSumSquare")
}else if (OpName == "ReduceSumSquare") { return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1); if (opName == "Relu")
}else if (OpName == "Relu") { return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1); if (opName == "Reshape")
}else if (OpName == "Reshape") { return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1); if (opName == "Resize")
}else if (OpName == "Resize") { return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1); if (opName == "ReverseSequence")
}else if (OpName == "ReverseSequence") { return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1); if (opName == "RoiAlign")
}else if (OpName == "RoiAlign") { return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1); if (opName == "Round")
}else if (OpName == "Round") { return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1); if (opName == "Scan")
}else if (OpName == "Scan") { return buildOperation<mlir::ONNXScanOp>(node);
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1); if (opName == "Scatter")
}else if (OpName == "Scatter") { return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1); if (opName == "ScatterElements")
}else if (OpName == "ScatterElements") { return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1); if (opName == "ScatterND")
}else if (OpName == "ScatterND") { return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1); if (opName == "Selu")
}else if (OpName == "Selu") { return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1); if (opName == "SequenceAt")
}else if (OpName == "SequenceAt") { return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1); if (opName == "SequenceConstruct")
}else if (OpName == "SequenceConstruct") { return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false); if (opName == "SequenceEmpty")
}else if (OpName == "SequenceEmpty") { return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1); if (opName == "SequenceErase")
}else if (OpName == "SequenceErase") { return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1); if (opName == "SequenceInsert")
}else if (OpName == "SequenceInsert") { return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1); if (opName == "SequenceLength")
}else if (OpName == "SequenceLength") { return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1); if (opName == "Shape")
}else if (OpName == "Shape") { return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1); if (opName == "Shrink")
}else if (OpName == "Shrink") { return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1); if (opName == "Sigmoid")
}else if (OpName == "Sigmoid") { return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1); if (opName == "Sign")
}else if (OpName == "Sign") { return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1); if (opName == "Sin")
}else if (OpName == "Sin") { return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1); if (opName == "Sinh")
}else if (OpName == "Sinh") { return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1); if (opName == "Size")
}else if (OpName == "Size") { return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1); if (opName == "Slice")
}else if (OpName == "Slice") { return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1); if (opName == "Softmax")
}else if (OpName == "Softmax") { return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1); if (opName == "Softplus")
}else if (OpName == "Softplus") { return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1); if (opName == "Softsign")
}else if (OpName == "Softsign") { return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1); if (opName == "SpaceToDepth")
}else if (OpName == "SpaceToDepth") { return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1); if (opName == "Split")
}else if (OpName == "Split") { return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1); if (opName == "SplitToSequence")
}else if (OpName == "SplitToSequence") { return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1); if (opName == "Sqrt")
}else if (OpName == "Sqrt") { return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1); if (opName == "Squeeze")
}else if (OpName == "Squeeze") { return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1); if (opName == "StringNormalizer")
}else if (OpName == "StringNormalizer") { return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1); if (opName == "Sub")
}else if (OpName == "Sub") { return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1); if (opName == "Sum")
}else if (OpName == "Sum") { return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false); if (opName == "Tan")
}else if (OpName == "Tan") { return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1); if (opName == "Tanh")
}else if (OpName == "Tanh") { return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1); if (opName == "TfIdfVectorizer")
}else if (OpName == "TfIdfVectorizer") { return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1); if (opName == "ThresholdedRelu")
}else if (OpName == "ThresholdedRelu") { return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1); if (opName == "Tile")
}else if (OpName == "Tile") { return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1); if (opName == "TopK")
}else if (OpName == "TopK") { return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2); if (opName == "Transpose")
}else if (OpName == "Transpose") { return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1); if (opName == "Unique")
}else if (OpName == "Unique") { return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4); if (opName == "Unsqueeze")
}else if (OpName == "Unsqueeze") { return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1); if (opName == "Upsample")
}else if (OpName == "Upsample") { return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1); if (opName == "Where")
}else if (OpName == "Where") { return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1); if (opName == "Xor")
}else if (OpName == "Xor") { return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1);
}

View File

@ -17,20 +17,24 @@ struct ONNXGemmOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto has_bias = (operands.size() == 3); // The first predicate is unnecessary when we remove ONXGemmNoBiasOp.
bool hasBias = (operands.size() == 3) &&
(!op->getOperand(2).getType().isa<NoneType>());
Value A, B, C; Value A, B, C;
A = operands[0]; A = operands[0];
B = operands[1]; B = operands[1];
if (has_bias) if (hasBias)
C = operands[2]; C = operands[2];
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = FloatAttr::get(memRefType.getElementType(), auto alphaAttr =
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat()); FloatAttr::get(memRefType.getElementType(),
auto betaAttr = FloatAttr::get(memRefType.getElementType(), llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat()); auto betaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr); auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -68,8 +72,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Define loops. // Define loops.
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops; std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, Block *optimizationBlock =
optimizedLoops, numLoops); defineLoops(rewriter, loc, originalLoops, optimizedLoops, numLoops);
// We have two Krnl loops: // We have two Krnl loops:
// - Outer loop iterates over the output matrix dimensions, and // - Outer loop iterates over the output matrix dimensions, and
@ -83,8 +87,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
outerLoops.push_back(originalLoops[i]); outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]); optimizedOuterLoops.push_back(optimizedLoops[i]);
} }
KrnlIterateOperandPack outerPack(rewriter, outerLoops, KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
optimizedOuterLoops);
// Induction variables for the outer loops // Induction variables for the outer loops
for (int i = 0; i < 2; ++i) for (int i = 0; i < 2; ++i)
addDimensionToPack(rewriter, loc, outerPack, alloc, i); addDimensionToPack(rewriter, loc, outerPack, alloc, i);
@ -106,20 +109,19 @@ struct ONNXGemmOpLowering : public ConversionPattern {
int64_t K_B_Idx = (isTransB) ? 1 : 0; int64_t K_B_Idx = (isTransB) ? 1 : 0;
reductionPack.pushConstantBound(0); reductionPack.pushConstantBound(0);
if (ATy.getShape()[K_A_Idx] != -1) if (ATy.getShape()[K_A_Idx] != -1)
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
else if (BTy.getShape()[K_B_Idx] != -1)
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
else else
if (BTy.getShape()[K_B_Idx] != -1) reductionPack.pushOperandBound(
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
else
reductionPack.pushOperandBound(
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
// GemmOp supports unidirectional broadcasting from C to A*B. // GemmOp supports unidirectional broadcasting from C to A*B.
// Hence, it must be enough to get broadcasting information for C only. // Hence, it must be enough to get broadcasting information for C only.
std::map<int, Value> broadcastedDimInfo; std::map<int, Value> broadcastedDimInfo;
if (has_bias) { if (hasBias) {
auto shape = C.getType().cast<MemRefType>().getShape(); auto shape = C.getType().cast<MemRefType>().getShape();
for (int i = 0; i < shape.size(); ++i) { for (int i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) { if (shape[i] < 0) {
@ -162,7 +164,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs); auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB); auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
if (has_bias) { if (hasBias) {
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
broadcastedDimInfo); broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs); auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
@ -210,8 +212,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
} }
}; };
void populateLoweringONNXGemmOpPattern( void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
OwningRewritePatternList &patterns, MLIRContext *ctx) { MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx); patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx); patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
} }

View File

@ -120,25 +120,19 @@ void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Tanh // Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the /// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXTanhOp::inferShapes() { void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sinh // Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the /// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXSinhOp::inferShapes() { void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cosh // Cosh
/// Infer the output shape of the ONNXCoshOp. This method is required by the /// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXCoshOp::inferShapes() { void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cos // Cos
@ -178,9 +172,7 @@ void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Relu // Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the /// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXReluOp::inferShapes() { void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LeakyRelu // LeakyRelu
@ -194,9 +186,7 @@ void ONNXLeakyReluOp::inferShapes() {
// Selu // Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by /// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSeluOp::inferShapes() { void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Reciprocal // Reciprocal
@ -234,17 +224,13 @@ void ONNXSoftsignOp::inferShapes() {
// Sqrt // Sqrt
/// Infer the output shape of the ONNXSqrtOp. This method is required by /// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSqrtOp::inferShapes() { void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sign // Sign
/// Infer the output shape of the ONNXSignOp. This method is required by /// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSignOp::inferShapes() { void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
@ -423,8 +409,7 @@ void ONNXMatMulOp::inferShapes() {
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all // need to be removed after the multiplication but cannot be removed if all
// sizes are 1. // sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 && if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices.");
dims.emplace_back(1); dims.emplace_back(1);
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) { } else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
@ -541,14 +526,14 @@ void ONNXMatMulOp::inferShapes() {
// Gemm // Gemm
void ONNXGemmOp::inferShapes() { void ONNXGemmOp::inferShapes() {
bool hasBias = !getOperand(2).getType().isa<NoneType>();
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() || !getOperand(1).getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>()) (hasBias && !getOperand(2).getType().isa<RankedTensorType>()))
return; return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B; int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
@ -560,15 +545,18 @@ void ONNXGemmOp::inferShapes() {
emitError("Tensor shapes mismatched."); emitError("Tensor shapes mismatched.");
} }
// Check whether bias is unidirectional broadcasting or not. if (hasBias) {
auto shape = biasTy.getShape(); // Check whether bias is unidirectional broadcasting or not.
int rank = shape.size(); auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
if ((rank > 2) || auto shape = biasTy.getShape();
(rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] && int rank = shape.size();
shape[rank - 1] != 1) || if ((rank > 2) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] && (rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
shape[rank - 2] != 1)) { N != shape[rank - 1] && shape[rank - 1] != 1) ||
emitError("Bias shape mismatched."); (rank == 2 && shape[rank - 2] != -1 && M != -1 &&
M != shape[rank - 2] && shape[rank - 2] != 1)) {
emitError("Bias shape mismatched.");
}
} }
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
@ -713,7 +701,6 @@ void ONNXTransposeOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceMax // ReduceMax
@ -801,7 +788,8 @@ void ONNXConvNoBiasOp::inferShapes() {
// Required attribute auto_pad defaults to NOTSET. // Required attribute auto_pad defaults to NOTSET.
auto autoPad = auto_pad(); auto autoPad = auto_pad();
// Group is a required attribute and should have default value of 1. // Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue(); int64_t group =
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group)) if (dataShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch."); emitError("Channel dimension mismatch.");
@ -859,8 +847,10 @@ void ONNXConvNoBiasOp::inferShapes() {
if (dilations.getValue().size() != nDims) if (dilations.getValue().size() != nDims)
emitError("dilations length incompatible with spatial dimensions."); emitError("dilations length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
kernelDims[i] = (kernelDims[i] + 1) * kernelDims[i] =
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1; (kernelDims[i] + 1) *
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() -
1;
} }
// Subtract kernel dimensions from input data dimensions. // Subtract kernel dimensions from input data dimensions.
@ -906,8 +896,7 @@ void ONNXConvNoBiasOp::inferShapes() {
if (strides.getValue().size() != nDims) if (strides.getValue().size() != nDims)
emitError("strides length incompatible with spatial dimensions."); emitError("strides length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {
int64_t stride = int64_t stride = strides.getValue()[i].cast<IntegerAttr>().getInt();
strides.getValue()[i].cast<IntegerAttr>().getInt();
outSpatialDims[i] = floor(outSpatialDims[i] / stride); outSpatialDims[i] = floor(outSpatialDims[i] / stride);
} }
} }
@ -934,12 +923,13 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
auto xRank = xShape.size(); auto xRank = xShape.size();
// 2) analyse parameters // 2) analyse parameters
// get kernel sizes from kernel_shape attribute // get kernel sizes from kernel_shape attribute
auto kernelShape = kernel_shape(); auto kernelShape = kernel_shape();
if (!kernelShape) if (!kernelShape)
emitError("kernel_shape is a mandatory attribute for which there is no default."); emitError(
"kernel_shape is a mandatory attribute for which there is no default.");
auto kernelShapeArray = kernelShape.getValue(); auto kernelShapeArray = kernelShape.getValue();
auto kernelRank = kernelShape.size(); auto kernelRank = kernelShape.size();
if (kernelRank > xRank) if (kernelRank > xRank)
emitError("kernel_shape spatial dimension is too large."); emitError("kernel_shape spatial dimension is too large.");
auto kernelOffset = xRank - kernelRank; auto kernelOffset = xRank - kernelRank;
@ -951,41 +941,42 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
SmallVector<int64_t, 4> actualDilations; SmallVector<int64_t, 4> actualDilations;
auto dilationsOpt = dilations(); auto dilationsOpt = dilations();
if (dilationsOpt.hasValue()) { if (dilationsOpt.hasValue()) {
auto dilationsArray = dilationsOpt.getValue().getValue(); // opt -> attr -> array auto dilationsArray =
dilationsOpt.getValue().getValue(); // opt -> attr -> array
if (dilationsArray.size() != kernelRank) if (dilationsArray.size() != kernelRank)
emitError("dialation rank is not the same as the spatial rank."); emitError("dialation rank is not the same as the spatial rank.");
// fill in the actual values // fill in the actual values
for (int i = 0; i < kernelRank; ++i) { for (int i = 0; i < kernelRank; ++i) {
int64_t d = (dilationsArray[i]).cast<IntegerAttr>().getInt(); int64_t d = (dilationsArray[i]).cast<IntegerAttr>().getInt();
if (d < 1) if (d < 1)
emitError("dialation value must be nonzero positive."); emitError("dialation value must be nonzero positive.");
actualDilations.emplace_back(d); actualDilations.emplace_back(d);
} }
} else { } else {
for(int i=0; i < kernelRank; ++i) { for (int i = 0; i < kernelRank; ++i) {
actualDilations.emplace_back(1); actualDilations.emplace_back(1);
} }
} }
// storage order // storage order
// strides // strides
SmallVector<int64_t, 4> actualStrides; SmallVector<int64_t, 4> actualStrides;
auto stridesOpt = strides(); auto stridesOpt = strides();
if (stridesOpt.hasValue()) { if (stridesOpt.hasValue()) {
auto stridesArray = stridesOpt.getValue().getValue(); auto stridesArray = stridesOpt.getValue().getValue();
if (stridesArray.size() != kernelRank) if (stridesArray.size() != kernelRank)
emitError("strides rank is not the same as the spatial rank."); emitError("strides rank is not the same as the spatial rank.");
// fill in the actual values // fill in the actual values
for (int i = 0; i < kernelRank; ++i) { for (int i = 0; i < kernelRank; ++i) {
int64_t s = (stridesArray[i]).cast<IntegerAttr>().getInt(); int64_t s = (stridesArray[i]).cast<IntegerAttr>().getInt();
if (s < 1) if (s < 1)
emitError("strides value must be nonzero positive."); emitError("strides value must be nonzero positive.");
actualStrides.emplace_back(s); actualStrides.emplace_back(s);
} }
} else { } else {
for(int i=0; i < kernelRank; ++i) { for (int i = 0; i < kernelRank; ++i) {
actualStrides.emplace_back(1); actualStrides.emplace_back(1);
} }
} }
@ -1002,9 +993,9 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
if (padsArray.size() != 2 * kernelRank) if (padsArray.size() != 2 * kernelRank)
emitError("pads rank is not twice the spatial rank."); emitError("pads rank is not twice the spatial rank.");
// fill in the actual values // fill in the actual values
for (int i = 0; i < 2*kernelRank; ++i) { for (int i = 0; i < 2 * kernelRank; ++i) {
int64_t p = (padsArray[i]).cast<IntegerAttr>().getInt(); int64_t p = (padsArray[i]).cast<IntegerAttr>().getInt();
if (p < 0) if (p < 0)
emitError("pads value must be nonnegative."); emitError("pads value must be nonnegative.");
actualPads.emplace_back(p); actualPads.emplace_back(p);
} }
@ -1016,24 +1007,26 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
defaultPads = true; defaultPads = true;
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
// init pad with zero // init pad with zero
for(int i=0; i<2*kernelRank; ++i) { for (int i = 0; i < 2 * kernelRank; ++i) {
actualPads.emplace_back(0); actualPads.emplace_back(0);
} }
for(int i=0; i<kernelRank; ++i) { for (int i = 0; i < kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i]; auto inputSpatialShape = xShape[kernelOffset + i];
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt(); auto kernelSpatialShape =
(kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto dilations = actualDilations[i]; auto dilations = actualDilations[i];
auto strideSpatialShape = actualStrides[i]; auto strideSpatialShape = actualStrides[i];
int64_t outputSpatialShape = ceil((1.0 * inputSpatialShape) / int64_t outputSpatialShape =
(1.0 * strideSpatialShape)); ceil((1.0 * inputSpatialShape) / (1.0 * strideSpatialShape));
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape + auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
((kernelSpatialShape - 1) * dilations + 1) - inputSpatialShape; ((kernelSpatialShape - 1) * dilations + 1) -
inputSpatialShape;
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2; actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
if (sumOfPad % 2 != 0) { if (sumOfPad % 2 != 0) {
if (autoPad == "SAME_UPPER") { if (autoPad == "SAME_UPPER") {
actualPads[kernelRank + i] += 1; actualPads[kernelRank + i] += 1;
} else { } else {
actualPads[i] += 1; actualPads[i] += 1;
} }
} }
} }
@ -1042,24 +1035,26 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
} }
// handle case where default pad values must be used // handle case where default pad values must be used
if (defaultPads) { if (defaultPads) {
for(int i=0; i<2*kernelRank; ++i) { for (int i = 0; i < 2 * kernelRank; ++i) {
actualPads.emplace_back(0); actualPads.emplace_back(0);
} }
} }
// initialize output shape // initialize output shape
SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end()); SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end());
// for all kernel dimensions // for all kernel dimensions
for(int i=0; i<kernelRank; ++i) { for (int i = 0; i < kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i]; auto inputSpatialShape = xShape[kernelOffset + i];
auto padShape = actualPads[i] + actualPads[kernelRank+i]; auto padShape = actualPads[i] + actualPads[kernelRank + i];
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt(); auto kernelSpatialShape =
(kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto dilations = actualDilations[i]; auto dilations = actualDilations[i];
auto strideSpatialShape = actualStrides[i]; auto strideSpatialShape = actualStrides[i];
///output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] - /// output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] -
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1) // ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) /
double numerator = inputSpatialShape + padShape - // strides_spatial_shape[i] + 1)
((kernelSpatialShape - 1) * dilations + 1); double numerator = inputSpatialShape + padShape -
((kernelSpatialShape - 1) * dilations + 1);
double denominator = strideSpatialShape; double denominator = strideSpatialShape;
int64_t res; int64_t res;
if (ceilMode) { if (ceilMode) {

File diff suppressed because it is too large Load Diff

View File

@ -127,6 +127,10 @@ int main(int argc, char *argv[]) {
if (emissionTarget >= EmitMLIR) { if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass()); pm.addPass(mlir::createLowerToKrnlPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createLowerKrnlPass()); pm.addPass(mlir::createLowerKrnlPass());
} }

View File

@ -28,6 +28,11 @@ void ONNXAddOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {
results.insert<MulAddToGemmOptPattern>(context); results.insert<MulAddToGemmOptPattern>(context);
} }
void ONNXGemmOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<FuseGemmFollowedByAddition>(context);
}
/// on the ONNXIdentityOp. /// on the ONNXIdentityOp.
void ONNXIdentityOp::getCanonicalizationPatterns( void ONNXIdentityOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {

View File

@ -26,6 +26,7 @@ include "dialect/onnx/onnx.td"
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>; def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>; class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite // Pattern-Match and Rewrite
@ -41,6 +42,11 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)), (ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>; [(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
// onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z)
def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias),
(ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg), def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
(replaceWithValue $arg)>; (replaceWithValue $arg)>;

View File

@ -101,3 +101,14 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> // CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
// CHECK-NEXT: return %1 : tensor<*xf32> // CHECK-NEXT: return %1 : tensor<*xf32>
} }
//CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> {
func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32>
%1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// return [[GEMM]] : tensor<*xf32>
}