Merge branch 'master' into shapeinference-pad

This commit is contained in:
Gheorghe-Teodor Bercea 2020-02-24 16:13:21 -05:00 committed by GitHub
commit d4f8fef947
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 3431 additions and 3804 deletions

View File

@ -327,10 +327,10 @@ ONNX BatchNormalization operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `out_mean`: memref of any type values or tensor of any type values
1. `out_var`: memref of any type values or tensor of any type values
1. `saved_mean`: memref of any type values or tensor of any type values
1. `saved_var`: memref of any type values or tensor of any type values
1. `out_mean`: memref of any type values or tensor of any type values or none type
1. `out_var`: memref of any type values or tensor of any type values or none type
1. `saved_mean`: memref of any type values or tensor of any type values or none type
1. `saved_var`: memref of any type values or tensor of any type values or none type
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
ONNX BatchNormalization operation in test mode
@ -375,12 +375,12 @@ ONNX BitShift operation
"Bitwise shift operator performs element-wise operation. For each input element, if the"
" attribute "direction" is "RIGHT", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute "direction""
" is "LEFT", bits of binary representation moves toward the left side, which results the"
" attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
" the right side so that the input value is effectively decreased. If the attribute \"direction\""
" is \"LEFT\", bits of binary representation moves toward the left side, which results the"
" increase of its actual value. The input X is the tensor to be shifted and another input"
" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with"
" Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
" and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
" "
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
@ -413,15 +413,15 @@ ONNX Cast operation
"the converted type. The 'to' argument must be one of the data types specified"
"in the 'DataType' enum field in the TensorProto message."
""
"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations"
"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may"
"Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
"(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
"result 100. There are some string literals reserved for special floating-point values;"
""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as "314.15926") would be used. "
"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior."
"\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
"Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
"this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
"to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
"Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
"of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
""
"Conversion from a numerical type to any numerical type is always allowed."
"User must be aware of precision loss and value change caused by range difference between two types."
@ -476,8 +476,8 @@ ONNX Clip operation
#### Operands:
1. `input`: memref of any type values or tensor of any type values
1. `min`: memref of any type values or tensor of any type values
1. `max`: memref of any type values or tensor of any type values
1. `min`: memref of any type values or tensor of any type values or none type
1. `max`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -618,8 +618,8 @@ ONNX ConvInteger operation
1. `x`: memref of any type values or tensor of any type values
1. `w`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values
1. `w_zero_point`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
1. `w_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -678,7 +678,7 @@ ONNX Conv operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -720,7 +720,7 @@ ONNX ConvTranspose operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -884,7 +884,7 @@ ONNX DequantizeLinear operation
1. `x`: memref of any type values or tensor of any type values
1. `x_scale`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -964,7 +964,7 @@ ONNX Dropout operation
#### Results:
1. `output`: memref of any type values or tensor of any type values
1. `mask`: memref of any type values or tensor of any type values
1. `mask`: memref of any type values or tensor of any type values or none type
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
ONNX DynamicQuantizeLinear operation
@ -1297,9 +1297,9 @@ ONNX GRU operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `sequence_lens`: memref of any type values or tensor of any type values
1. `initial_h`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -1315,8 +1315,8 @@ ONNX GRU operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Y_h`: memref of any type values or tensor of any type values
1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.GatherElements (ONNXGatherElementsOp)
ONNX GatherElements operation
@ -1609,7 +1609,7 @@ ONNX Gemm operation
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `C`: memref of any type values or tensor of any type values
1. `C`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2013,11 +2013,11 @@ ONNX LSTM operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `sequence_lens`: memref of any type values or tensor of any type values
1. `initial_h`: memref of any type values or tensor of any type values
1. `initial_c`: memref of any type values or tensor of any type values
1. `P`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values or none type
1. `initial_c`: memref of any type values or tensor of any type values or none type
1. `P`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2033,9 +2033,9 @@ ONNX LSTM operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Y_h`: memref of any type values or tensor of any type values
1. `Y_c`: memref of any type values or tensor of any type values
1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values or none type
1. `Y_c`: memref of any type values or tensor of any type values or none type
### onnx.LeakyRelu (ONNXLeakyReluOp)
ONNX LeakyRelu operation
@ -2160,24 +2160,24 @@ ONNX Loop operation
""
" Operator inputs defined as (max_trip_count, condition_var)."
""
" input ("", ""):"
" input (\"\", \"\"):"
" for (int i=0; ; ++i) {"
" cond = ... // Note this value is ignored, but is required in the body"
" }"
""
" input ("", cond) // Note this is analogous to a while loop"
" input (\"\", cond) // Note this is analogous to a while loop"
" bool cond = ...;"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input ("", 1) // Note this is analogous to a do-while loop"
" input (\"\", 1) // Note this is analogous to a do-while loop"
" bool cond = true"
" for (int i=0; cond; ++i) {"
" cond = ...;"
" }"
""
" input (trip_count, "") // Note this is analogous to a for loop"
" input (trip_count, \"\") // Note this is analogous to a for loop"
" int trip_count = ..."
" for (int i=0; i < trip_count; ++i) {"
" cond = ...; // ignored"
@ -2203,15 +2203,15 @@ ONNX Loop operation
" }"
""
" graph body-net ("
" %i[INT32, scalar] // iteration number"
" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used"
" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b"
" %i[INT32, scalar]"
" %keepgoing[BOOL, scalar]"
" %b[INT32, scalar]"
" ) {"
" %my_local = Add(%a, %b_in)"
" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b"
" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition"
" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated"
" return %keepgoing_out, %b_out, %user_defined_val"
" %my_local = Add(%a, %b)"
" %b_out = Sub(%a, %b)"
" %keepgoing_out = Greater(%my_local, %b_out)"
" %user_defined_vals = Add(%b, %b)"
" return %keepgoing_out, %b_out, %user_defined_vals"
" }"
""
"*Sample equivalent C code*"
@ -2226,51 +2226,31 @@ ONNX Loop operation
" const int max_trip_count = 10; // Analogous to input M"
" int user_defined_vals[]; // Imagine this is resizable"
" /* End implicitly-defined code */"
" /* initialize loop-carried variables and scan-output variables */"
" bool keepgoing_out = keepgoing"
" int b_out = b"
""
" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
" /* Implicitly-defined code: bind actual parameter values"
" to formal parameter variables of loop-body */"
" bool keepgoing_in = keepgoing_out; "
" bool b_in = b_out;"
""
" for (int i=0; i < max_trip_count && keepgoing; ++i) {"
" /* User-defined code (loop body) */"
" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine"
" b_out = a - b_in;"
" keepgoing_out = my_local > b_out; "
" user_defined_val = b_in + b_in; // b_in and b_out are different variables"
" int my_local = a + b; // Reading values in the enclosing scope is fine"
" b = a - b; // writes fine if we specify b as a loop-carried dependency"
" keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
" user_defined_vals[i] = b + b;"
" /* End user-defined code */"
""
" /* Implicitly defined-code */"
" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
" }"
" // int t = my_local; // Can't do this. my_local is not accessible here."
" // my_local = 123; // Can't do this. my_local was defined in the the body"
""
" // The values below are bound to the output variables of the loop and therefore accessible"
" // b_out; user_defined_vals; keepgoing_out;"
" // These below values are live-out from the loop and therefore accessible"
" b_out; user_defined_vals; keepgoing_out;"
" }"
""
"There are several things of note in this code snippet:"
""
"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can"
"1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
" be referenced in the inputs of the loop."
"2) Any values computed in the loop body that needs to be used in a subsequent"
" iteration or after the loop are modelled using a pair of variables in the loop-body,"
" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)."
" These are referred to as loop-carried dependences. The loop operation node"
" supplies the input value of the input variable for the first iteration, and"
" returns the output value of the output variable produced by the final"
" iteration."
"3) Scan_output variables are used to implicitly concatenate values computed across"
" all the iterations. In the above example, the value of user_defined_val computed"
" over all iterations are concatenated and returned as the value of user_defined_vals"
" after the loop."
"4) Values created in the body cannot be accessed in the enclosing scope,"
" except using the mechanism described above."
"2) Any variables which you wish to make available in the enclosing scope (i.e."
" the variables b and keepgoing) must be declared as either loop-carried"
" dependencies (both at the op inputs and output and at the body net input and"
" output) or scan_outputs."
"3) Values created in the body cannot be accessed in the enclosing scope."
""
"Note that the semantics of this op support "diagonal" or "wavefront" execution."
"Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
"(See Step 3 here for an example:"
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
"Frontends should emit multi-layer RNNs as a series of While operators (with"
@ -2280,8 +2260,8 @@ ONNX Loop operation
#### Operands:
1. `M`: memref of any type values or tensor of any type values
1. `cond`: memref of any type values or tensor of any type values
1. `M`: memref of any type values or tensor of any type values or none type
1. `cond`: memref of any type values or tensor of any type values or none type
1. `v_initial`: memref of any type values or tensor of any type values
#### Attributes:
@ -2360,8 +2340,8 @@ ONNX MatMulInteger operation
1. `A`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `a_zero_point`: memref of any type values or tensor of any type values
1. `b_zero_point`: memref of any type values or tensor of any type values
1. `a_zero_point`: memref of any type values or tensor of any type values or none type
1. `b_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2444,7 +2424,7 @@ ONNX MaxPool operation
" ```"
" pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]"
" ```"
" The output of each pooling window is maximum number of elements exclude pad. "
" The output of each pooling window is maximum number of elements exclude pad."
" "
#### Operands:
@ -2466,7 +2446,7 @@ ONNX MaxPool operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Indices`: memref of any type values or tensor of any type values
1. `Indices`: memref of any type values or tensor of any type values or none type
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
ONNX MaxPool operation with a single output.
@ -2552,7 +2532,7 @@ ONNX MaxUnpool operation
1. `X`: memref of any type values or tensor of any type values
1. `I`: memref of any type values or tensor of any type values
1. `output_shape`: memref of any type values or tensor of any type values
1. `output_shape`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -2752,9 +2732,9 @@ ONNX NonMaxSuppression operation
1. `boxes`: memref of any type values or tensor of any type values
1. `scores`: memref of any type values or tensor of any type values
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values
1. `iou_threshold`: memref of any type values or tensor of any type values
1. `score_threshold`: memref of any type values or tensor of any type values
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values or none type
1. `iou_threshold`: memref of any type values or tensor of any type values or none type
1. `score_threshold`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3067,7 +3047,7 @@ ONNX Pad operation
1. `data`: memref of any type values or tensor of any type values
1. `pads`: memref of any type values or tensor of any type values
1. `constant_value`: memref of any type values or tensor of any type values
1. `constant_value`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3124,7 +3104,7 @@ ONNX QLinearConv operation
1. `w_zero_point`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3188,7 +3168,7 @@ ONNX QuantizeLinear operation
1. `x`: memref of any type values or tensor of any type values
1. `y_scale`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values
1. `y_zero_point`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3270,9 +3250,9 @@ ONNX RNN operation
1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values
1. `R`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `sequence_lens`: memref of any type values or tensor of any type values
1. `initial_h`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values or none type
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
1. `initial_h`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -3287,8 +3267,8 @@ ONNX RNN operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `Y_h`: memref of any type values or tensor of any type values
1. `Y`: memref of any type values or tensor of any type values or none type
1. `Y_h`: memref of any type values or tensor of any type values or none type
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
ONNX RandomNormalLike operation
@ -3813,14 +3793,14 @@ ONNX Resize operation
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
"Each dimension value of the output tensor is:"
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified."
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
#### Operands:
1. `X`: memref of any type values or tensor of any type values
1. `roi`: memref of any type values or tensor of any type values
1. `scales`: memref of any type values or tensor of any type values
1. `sizes`: memref of any type values or tensor of any type values
1. `sizes`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4438,7 +4418,7 @@ ONNX SequenceErase operation
#### Operands:
1. `input_sequence`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4463,7 +4443,7 @@ ONNX SequenceInsert operation
1. `input_sequence`: memref of any type values or tensor of any type values
1. `tensor`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values
1. `position`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4680,8 +4660,8 @@ ONNX Slice operation
1. `data`: memref of any type values or tensor of any type values
1. `starts`: memref of any type values or tensor of any type values
1. `ends`: memref of any type values or tensor of any type values
1. `axes`: memref of any type values or tensor of any type values
1. `steps`: memref of any type values or tensor of any type values
1. `axes`: memref of any type values or tensor of any type values or none type
1. `steps`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4834,7 +4814,7 @@ ONNX SplitToSequence operation
#### Operands:
1. `input`: memref of any type values or tensor of any type values
1. `split`: memref of any type values or tensor of any type values
1. `split`: memref of any type values or tensor of any type values or none type
#### Attributes:
@ -4902,9 +4882,9 @@ ONNX StringNormalizer operation
"StringNormalization performs string operations for basic cleaning."
"This operator has only one input (denoted by X) and only one output"
"(denoted by Y). This operator first examines the elements in the X,"
"and removes elements specified in "stopwords" attribute."
"and removes elements specified in \"stopwords\" attribute."
"After removing stop words, the intermediate result can be further lowercased,"
"uppercased, or just returned depending the "case_change_action" attribute."
"uppercased, or just returned depending the \"case_change_action\" attribute."
"This operator only accepts [C]- and [1, C]-tensor."
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
@ -5034,8 +5014,8 @@ ONNX TfIdfVectorizer operation
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
"Note that we may consider all skips up to S when generating the n-grams."
""
"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF","
"The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
""
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
@ -5123,9 +5103,9 @@ ONNX TopK operation
" contains the indices of the top k elements (original indices from the input"
" tensor)."
""
"If "largest" is 1 (the default value) then the k largest elements are returned."
"If "sorted" is 1 (the default value) then the resulting k elements will be sorted."
"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined."
"If \"largest\" is 1 (the default value) then the k largest elements are returned."
"If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
"If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
""
"Given two equivalent values, this operator uses the indices along the axis as"
" a tiebreaker. That is, the element with the lower index will appear first."
@ -5184,7 +5164,7 @@ ONNX Unique operation
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". "
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
""
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
@ -5268,9 +5248,9 @@ ONNX Unique operation
#### Results:
1. `Y`: memref of any type values or tensor of any type values
1. `indices`: memref of any type values or tensor of any type values
1. `inverse_indices`: memref of any type values or tensor of any type values
1. `counts`: memref of any type values or tensor of any type values
1. `indices`: memref of any type values or tensor of any type values or none type
1. `inverse_indices`: memref of any type values or tensor of any type values or none type
1. `counts`: memref of any type values or tensor of any type values or none type
### onnx.Unsqueeze (ONNXUnsqueezeOp)
ONNX Unsqueeze operation

File diff suppressed because it is too large Load Diff

View File

@ -121,6 +121,7 @@ private:
mlir::MLIRContext &context_;
mlir::ModuleOp module_;
mlir::OpBuilder builder_;
mlir::Value none_;
// mapping between string name and symbol
OnnxOnnfSymbolMapping frontend_symbols_;
@ -287,8 +288,8 @@ private:
}
}
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute>
ImportNodeAttributes(const onnx::NodeProto &node) {
std::vector<mlir::NamedAttribute> attributes;
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
@ -317,21 +318,11 @@ private:
}
}
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
// combined with 'if constexpr' the issue is the type of the output is
// different. alternative way to use variadic output for all the op
/*!
* Important onnx node which generates only one output
* @param node onnx node
* @param nIn number of expected inputs
* @param nOut number of expected outputs
* @param attrs list of desription for attributes with format {name, type,
* default}
*/
template <typename T>
void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
bool variadicIn = false, bool variadicOut = false) {
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
int expectedNumResults = -1) {
bool variadicIn = expectedNumOperands == -1;
bool variadicOut = expectedNumResults == -1;
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
@ -339,6 +330,10 @@ private:
}
}
if (!variadicIn)
for (auto i = inputs.size(); i < expectedNumOperands; i++)
inputs.emplace_back(none_);
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
@ -347,49 +342,11 @@ private:
auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type();
if ((variadicIn || nIn == inputs.size()) &&
(variadicOut || nOut == outputTypes.size())) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
op.getResult());
} else {
ImportNodeGeneric(node);
}
}
template <typename T>
void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
bool variadicIn = false,
bool variadicOut = false) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
auto attributes = ImportNodeAttributes(node);
llvm::StringRef OpName = node.op_type();
if ((variadicIn || nIn == inputs.size()) &&
(variadicOut || nOut == outputTypes.size())) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
op.getResult(i));
}
} else {
ImportNodeGeneric(node);
// TODO: Handle optional inputs.
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
*(op.getODSResults(i).begin()));
}
}
@ -398,8 +355,7 @@ private:
* c++ does not allow template specialization inside a class scope
* a specialized function is used
*/
void
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto
@ -413,24 +369,20 @@ private:
int nOps = node.input().size();
if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
node, nOps, nOut);
buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut);
buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
}
/*!
* Special handle for MaxPool operations.
*/
void ImportNodeMaxPool(
onnx::NodeProto node, int nIn, int nOut) {
void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size();
if (nOuts == 1) {
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
node, nIn, nOuts);
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
} else {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
node, nIn, nOuts);
buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
}
}
@ -441,23 +393,10 @@ private:
int nOuts = node.output().size();
if (nOuts == 1) {
// Test mode with one output.
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn,
nOuts);
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
} else {
// Training mode with four trailing optional outputs. Not handled yet.
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
/*!
* Special handle for Gemm operations.
*/
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
@ -467,28 +406,14 @@ private:
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut);
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
}
}
void ImportNode(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
llvm::StringRef OpName = node.op_type();
llvm::StringRef opName = node.op_type();
// the following code is generated by gen_doc.py
// refer to dialect/onnx/onnx.td for details
@ -555,9 +480,11 @@ private:
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
}
// import nodes in the graph
auto node = graph.node();
for (const auto &item : node) {
// Create a NoneTyped constant.
none_ =
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
// Import nodes in the graph.
for (const auto &item : graph.node()) {
ImportNode(item);
}

View File

@ -1,320 +1,319 @@
//********************************************************
// Warning: Do not modify this file directly
// This file is automatically generated via script
// Details can be found in doc/readonnxdefs.md
// This file is generated on UTC-02/24/2020, 06:29:01.
// Do not modify this file directly.
// This file is automatically generated via script.
// Details can be found in doc/readonnxdefs.md .
//********************************************************
if (OpName == "DUMMY") {
}else if (OpName == "Abs") {
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1);
}else if (OpName == "Acos") {
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1);
}else if (OpName == "Acosh") {
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1);
}else if (OpName == "Add") {
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1);
}else if (OpName == "And") {
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1);
}else if (OpName == "ArgMax") {
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1);
}else if (OpName == "ArgMin") {
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1);
}else if (OpName == "Asin") {
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1);
}else if (OpName == "Asinh") {
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1);
}else if (OpName == "Atan") {
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1);
}else if (OpName == "Atanh") {
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1);
}else if (OpName == "AveragePool") {
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
}else if (OpName == "BatchNormalization") {
ImportNodeBatchNormalization(node, 5, 5);
}else if (OpName == "BitShift") {
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
}else if (OpName == "Cast") {
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1);
}else if (OpName == "Ceil") {
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1);
}else if (OpName == "Clip") {
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1);
}else if (OpName == "Compress") {
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
}else if (OpName == "Concat") {
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false);
}else if (OpName == "ConcatFromSequence") {
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
}else if (OpName == "Constant") {
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1);
}else if (OpName == "ConstantOfShape") {
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1);
}else if (OpName == "Conv") {
ImportNodeConv(node, 3, 1);
}else if (OpName == "ConvInteger") {
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1);
}else if (OpName == "ConvTranspose") {
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1);
}else if (OpName == "Cos") {
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1);
}else if (OpName == "Cosh") {
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1);
}else if (OpName == "CumSum") {
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1);
}else if (OpName == "DepthToSpace") {
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1);
}else if (OpName == "DequantizeLinear") {
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1);
}else if (OpName == "Det") {
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1);
}else if (OpName == "Div") {
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1);
}else if (OpName == "Dropout") {
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2);
}else if (OpName == "DynamicQuantizeLinear") {
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3);
}else if (OpName == "Elu") {
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1);
}else if (OpName == "Equal") {
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1);
}else if (OpName == "Erf") {
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1);
}else if (OpName == "Exp") {
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1);
}else if (OpName == "Expand") {
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1);
}else if (OpName == "EyeLike") {
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1);
}else if (OpName == "Flatten") {
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1);
}else if (OpName == "Floor") {
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1);
}else if (OpName == "GRU") {
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2);
}else if (OpName == "Gather") {
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1);
}else if (OpName == "GatherElements") {
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1);
}else if (OpName == "GatherND") {
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
}else if (OpName == "Gemm") {
ImportNodeGemm(node, 3, 1);
}else if (OpName == "GlobalAveragePool") {
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
}else if (OpName == "GlobalLpPool") {
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1);
}else if (OpName == "GlobalMaxPool") {
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1);
}else if (OpName == "Greater") {
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1);
}else if (OpName == "HardSigmoid") {
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1);
}else if (OpName == "Hardmax") {
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1);
}else if (OpName == "Identity") {
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1);
}else if (OpName == "If") {
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1);
}else if (OpName == "InstanceNormalization") {
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1);
}else if (OpName == "IsInf") {
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1);
}else if (OpName == "IsNaN") {
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1);
}else if (OpName == "LRN") {
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1);
}else if (OpName == "LSTM") {
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3);
}else if (OpName == "LeakyRelu") {
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1);
}else if (OpName == "Less") {
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1);
}else if (OpName == "Log") {
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1);
}else if (OpName == "LogSoftmax") {
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1);
}else if (OpName == "Loop") {
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1);
}else if (OpName == "LpNormalization") {
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1);
}else if (OpName == "LpPool") {
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1);
}else if (OpName == "MatMul") {
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1);
}else if (OpName == "MatMulInteger") {
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
}else if (OpName == "Max") {
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false);
}else if (OpName == "MaxPool") {
ImportNodeMaxPool(node, 1, 2);
}else if (OpName == "MaxRoiPool") {
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1);
}else if (OpName == "MaxUnpool") {
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
}else if (OpName == "Mean") {
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false);
}else if (OpName == "MeanVarianceNormalization") {
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
}else if (OpName == "Min") {
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false);
}else if (OpName == "Mod") {
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
}else if (OpName == "Mul") {
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1);
}else if (OpName == "Multinomial") {
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1);
}else if (OpName == "Neg") {
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1);
}else if (OpName == "NonMaxSuppression") {
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1);
}else if (OpName == "NonZero") {
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1);
}else if (OpName == "Not") {
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1);
}else if (OpName == "OneHot") {
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1);
}else if (OpName == "Or") {
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1);
}else if (OpName == "PRelu") {
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1);
}else if (OpName == "Pad") {
ImportNodePad(node, 3, 1);
}else if (OpName == "Pow") {
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1);
}else if (OpName == "QLinearConv") {
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1);
}else if (OpName == "QLinearMatMul") {
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1);
}else if (OpName == "QuantizeLinear") {
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1);
}else if (OpName == "RNN") {
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2);
}else if (OpName == "RandomNormal") {
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1);
}else if (OpName == "RandomNormalLike") {
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1);
}else if (OpName == "RandomUniform") {
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1);
}else if (OpName == "RandomUniformLike") {
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1);
}else if (OpName == "Range") {
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1);
}else if (OpName == "Reciprocal") {
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1);
}else if (OpName == "ReduceL1") {
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1);
}else if (OpName == "ReduceL2") {
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1);
}else if (OpName == "ReduceLogSum") {
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1);
}else if (OpName == "ReduceLogSumExp") {
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1);
}else if (OpName == "ReduceMax") {
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1);
}else if (OpName == "ReduceMean") {
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1);
}else if (OpName == "ReduceMin") {
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1);
}else if (OpName == "ReduceProd") {
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1);
}else if (OpName == "ReduceSum") {
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1);
}else if (OpName == "ReduceSumSquare") {
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1);
}else if (OpName == "Relu") {
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1);
}else if (OpName == "Reshape") {
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1);
}else if (OpName == "Resize") {
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1);
}else if (OpName == "ReverseSequence") {
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1);
}else if (OpName == "RoiAlign") {
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1);
}else if (OpName == "Round") {
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1);
}else if (OpName == "Scan") {
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1);
}else if (OpName == "Scatter") {
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1);
}else if (OpName == "ScatterElements") {
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1);
}else if (OpName == "ScatterND") {
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1);
}else if (OpName == "Selu") {
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1);
}else if (OpName == "SequenceAt") {
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
}else if (OpName == "SequenceConstruct") {
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false);
}else if (OpName == "SequenceEmpty") {
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
}else if (OpName == "SequenceErase") {
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1);
}else if (OpName == "SequenceInsert") {
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1);
}else if (OpName == "SequenceLength") {
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1);
}else if (OpName == "Shape") {
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1);
}else if (OpName == "Shrink") {
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1);
}else if (OpName == "Sigmoid") {
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1);
}else if (OpName == "Sign") {
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1);
}else if (OpName == "Sin") {
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1);
}else if (OpName == "Sinh") {
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1);
}else if (OpName == "Size") {
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1);
}else if (OpName == "Slice") {
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1);
}else if (OpName == "Softmax") {
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1);
}else if (OpName == "Softplus") {
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1);
}else if (OpName == "Softsign") {
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1);
}else if (OpName == "SpaceToDepth") {
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1);
}else if (OpName == "Split") {
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1);
}else if (OpName == "SplitToSequence") {
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1);
}else if (OpName == "Sqrt") {
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1);
}else if (OpName == "Squeeze") {
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1);
}else if (OpName == "StringNormalizer") {
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1);
}else if (OpName == "Sub") {
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
}else if (OpName == "Sum") {
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false);
}else if (OpName == "Tan") {
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
}else if (OpName == "Tanh") {
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1);
}else if (OpName == "TfIdfVectorizer") {
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1);
}else if (OpName == "ThresholdedRelu") {
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1);
}else if (OpName == "Tile") {
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1);
}else if (OpName == "TopK") {
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2);
}else if (OpName == "Transpose") {
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1);
}else if (OpName == "Unique") {
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4);
}else if (OpName == "Unsqueeze") {
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1);
}else if (OpName == "Upsample") {
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1);
}else if (OpName == "Where") {
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1);
}else if (OpName == "Xor") {
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1);
}
if (opName == "Abs")
return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Acos")
return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Acosh")
return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Add")
return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "And")
return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "ArgMax")
return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ArgMin")
return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Asin")
return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Asinh")
return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Atan")
return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Atanh")
return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "AveragePool")
return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "BatchNormalization")
return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
if (opName == "BitShift")
return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Cast")
return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Ceil")
return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Clip")
return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Compress")
return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Concat")
return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "ConcatFromSequence")
return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Constant")
return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "ConstantOfShape")
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Conv")
return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ConvInteger")
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ConvTranspose")
return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Cos")
return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Cosh")
return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "CumSum")
return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "DepthToSpace")
return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "DequantizeLinear")
return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Det")
return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Div")
return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Dropout")
return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
if (opName == "DynamicQuantizeLinear")
return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
if (opName == "Elu")
return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Equal")
return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Erf")
return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Exp")
return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Expand")
return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "EyeLike")
return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Flatten")
return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Floor")
return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GRU")
return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
if (opName == "Gather")
return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "GatherElements")
return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "GatherND")
return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Gemm")
return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "GlobalAveragePool")
return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GlobalLpPool")
return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GlobalMaxPool")
return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Greater")
return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "HardSigmoid")
return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Hardmax")
return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Identity")
return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "If")
return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
if (opName == "InstanceNormalization")
return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "IsInf")
return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "IsNaN")
return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LRN")
return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LSTM")
return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
if (opName == "LeakyRelu")
return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Less")
return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Log")
return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LogSoftmax")
return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Loop")
return buildOperation<mlir::ONNXLoopOp>(node);
if (opName == "LpNormalization")
return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LpPool")
return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "MatMul")
return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "MatMulInteger")
return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "Max")
return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "MaxPool")
return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
if (opName == "MaxRoiPool")
return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "MaxUnpool")
return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Mean")
return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "MeanVarianceNormalization")
return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Min")
return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "Mod")
return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Mul")
return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Multinomial")
return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Neg")
return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "NonMaxSuppression")
return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
if (opName == "NonZero")
return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Not")
return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "OneHot")
return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Or")
return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "PRelu")
return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Pad")
return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Pow")
return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "QLinearConv")
return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
if (opName == "QLinearMatMul")
return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
if (opName == "QuantizeLinear")
return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "RNN")
return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
if (opName == "RandomNormal")
return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "RandomNormalLike")
return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "RandomUniform")
return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "RandomUniformLike")
return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Range")
return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Reciprocal")
return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceL1")
return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceL2")
return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceLogSum")
return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceLogSumExp")
return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMax")
return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMean")
return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMin")
return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceProd")
return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceSum")
return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceSumSquare")
return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Relu")
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Reshape")
return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Resize")
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ReverseSequence")
return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "RoiAlign")
return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Round")
return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Scan")
return buildOperation<mlir::ONNXScanOp>(node);
if (opName == "Scatter")
return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ScatterElements")
return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ScatterND")
return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Selu")
return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "SequenceAt")
return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "SequenceConstruct")
return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "SequenceEmpty")
return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "SequenceErase")
return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "SequenceInsert")
return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "SequenceLength")
return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Shape")
return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Shrink")
return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sigmoid")
return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sign")
return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sin")
return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sinh")
return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Size")
return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Slice")
return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
if (opName == "Softmax")
return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Softplus")
return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Softsign")
return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "SpaceToDepth")
return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Split")
return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
if (opName == "SplitToSequence")
return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Sqrt")
return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Squeeze")
return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "StringNormalizer")
return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sub")
return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Sum")
return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "Tan")
return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Tanh")
return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "TfIdfVectorizer")
return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ThresholdedRelu")
return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Tile")
return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "TopK")
return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
if (opName == "Transpose")
return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Unique")
return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
if (opName == "Unsqueeze")
return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Upsample")
return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Where")
return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Xor")
return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);

View File

@ -17,20 +17,24 @@ struct ONNXGemmOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
auto has_bias = (operands.size() == 3);
// The first predicate is unnecessary when we remove ONXGemmNoBiasOp.
bool hasBias = (operands.size() == 3) &&
(!op->getOperand(2).getType().isa<NoneType>());
Value A, B, C;
A = operands[0];
B = operands[1];
if (has_bias)
if (hasBias)
C = operands[2];
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alphaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
auto betaAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -68,8 +72,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Define loops.
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, numLoops);
Block *optimizationBlock =
defineLoops(rewriter, loc, originalLoops, optimizedLoops, numLoops);
// We have two Krnl loops:
// - Outer loop iterates over the output matrix dimensions, and
@ -83,8 +87,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
optimizedOuterLoops);
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
// Induction variables for the outer loops
for (int i = 0; i < 2; ++i)
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
@ -106,20 +109,19 @@ struct ONNXGemmOpLowering : public ConversionPattern {
int64_t K_B_Idx = (isTransB) ? 1 : 0;
reductionPack.pushConstantBound(0);
if (ATy.getShape()[K_A_Idx] != -1)
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
else if (BTy.getShape()[K_B_Idx] != -1)
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
else
if (BTy.getShape()[K_B_Idx] != -1)
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
else
reductionPack.pushOperandBound(
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
reductionPack.pushOperandBound(
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
// Get run-time dimension information for unknown dimensions used for
// broadcasting.
// GemmOp supports unidirectional broadcasting from C to A*B.
// Hence, it must be enough to get broadcasting information for C only.
std::map<int, Value> broadcastedDimInfo;
if (has_bias) {
if (hasBias) {
auto shape = C.getType().cast<MemRefType>().getShape();
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
@ -162,7 +164,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
if (has_bias) {
if (hasBias) {
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
@ -210,8 +212,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
}
};
void populateLoweringONNXGemmOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
}

View File

@ -120,25 +120,19 @@ void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface.
void ONNXTanhOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface.
void ONNXSinhOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Cosh
/// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface.
void ONNXCoshOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Cos
@ -178,9 +172,7 @@ void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
// Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface.
void ONNXReluOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// LeakyRelu
@ -194,9 +186,7 @@ void ONNXLeakyReluOp::inferShapes() {
// Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface.
void ONNXSeluOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Reciprocal
@ -234,17 +224,13 @@ void ONNXSoftsignOp::inferShapes() {
// Sqrt
/// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface.
void ONNXSqrtOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Sign
/// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface.
void ONNXSignOp::inferShapes() {
getResult().setType(getOperand().getType());
}
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Add
@ -423,8 +409,7 @@ void ONNXMatMulOp::inferShapes() {
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all
// sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
lhsShape[0] != rhsShape[0])
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
dims.emplace_back(1);
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
@ -541,14 +526,14 @@ void ONNXMatMulOp::inferShapes() {
// Gemm
void ONNXGemmOp::inferShapes() {
bool hasBias = !getOperand(2).getType().isa<NoneType>();
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>())
(hasBias && !getOperand(2).getType().isa<RankedTensorType>()))
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
@ -560,15 +545,18 @@ void ONNXGemmOp::inferShapes() {
emitError("Tensor shapes mismatched.");
}
// Check whether bias is unidirectional broadcasting or not.
auto shape = biasTy.getShape();
int rank = shape.size();
if ((rank > 2) ||
(rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] &&
shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] &&
shape[rank - 2] != 1)) {
emitError("Bias shape mismatched.");
if (hasBias) {
// Check whether bias is unidirectional broadcasting or not.
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
auto shape = biasTy.getShape();
int rank = shape.size();
if ((rank > 2) ||
(rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
N != shape[rank - 1] && shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
M != shape[rank - 2] && shape[rank - 2] != 1)) {
emitError("Bias shape mismatched.");
}
}
SmallVector<int64_t, 2> dims;
@ -713,7 +701,6 @@ void ONNXTransposeOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// ReduceMax
@ -801,7 +788,8 @@ void ONNXConvNoBiasOp::inferShapes() {
// Required attribute auto_pad defaults to NOTSET.
auto autoPad = auto_pad();
// Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
int64_t group =
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch.");
@ -859,8 +847,10 @@ void ONNXConvNoBiasOp::inferShapes() {
if (dilations.getValue().size() != nDims)
emitError("dilations length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i)
kernelDims[i] = (kernelDims[i] + 1) *
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1;
kernelDims[i] =
(kernelDims[i] + 1) *
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() -
1;
}
// Subtract kernel dimensions from input data dimensions.
@ -906,8 +896,7 @@ void ONNXConvNoBiasOp::inferShapes() {
if (strides.getValue().size() != nDims)
emitError("strides length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) {
int64_t stride =
strides.getValue()[i].cast<IntegerAttr>().getInt();
int64_t stride = strides.getValue()[i].cast<IntegerAttr>().getInt();
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
}
}
@ -934,12 +923,13 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
auto xRank = xShape.size();
// 2) analyse parameters
// get kernel sizes from kernel_shape attribute
// get kernel sizes from kernel_shape attribute
auto kernelShape = kernel_shape();
if (!kernelShape)
emitError("kernel_shape is a mandatory attribute for which there is no default.");
emitError(
"kernel_shape is a mandatory attribute for which there is no default.");
auto kernelShapeArray = kernelShape.getValue();
auto kernelRank = kernelShape.size();
auto kernelRank = kernelShape.size();
if (kernelRank > xRank)
emitError("kernel_shape spatial dimension is too large.");
auto kernelOffset = xRank - kernelRank;
@ -951,41 +941,42 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
SmallVector<int64_t, 4> actualDilations;
auto dilationsOpt = dilations();
if (dilationsOpt.hasValue()) {
auto dilationsArray = dilationsOpt.getValue().getValue(); // opt -> attr -> array
auto dilationsArray =
dilationsOpt.getValue().getValue(); // opt -> attr -> array
if (dilationsArray.size() != kernelRank)
emitError("dialation rank is not the same as the spatial rank.");
emitError("dialation rank is not the same as the spatial rank.");
// fill in the actual values
for (int i = 0; i < kernelRank; ++i) {
int64_t d = (dilationsArray[i]).cast<IntegerAttr>().getInt();
if (d < 1)
if (d < 1)
emitError("dialation value must be nonzero positive.");
actualDilations.emplace_back(d);
}
} else {
for(int i=0; i < kernelRank; ++i) {
actualDilations.emplace_back(1);
for (int i = 0; i < kernelRank; ++i) {
actualDilations.emplace_back(1);
}
}
// storage order
// strides
SmallVector<int64_t, 4> actualStrides;
auto stridesOpt = strides();
if (stridesOpt.hasValue()) {
auto stridesArray = stridesOpt.getValue().getValue();
if (stridesArray.size() != kernelRank)
emitError("strides rank is not the same as the spatial rank.");
emitError("strides rank is not the same as the spatial rank.");
// fill in the actual values
for (int i = 0; i < kernelRank; ++i) {
int64_t s = (stridesArray[i]).cast<IntegerAttr>().getInt();
if (s < 1)
if (s < 1)
emitError("strides value must be nonzero positive.");
actualStrides.emplace_back(s);
}
} else {
for(int i=0; i < kernelRank; ++i) {
actualStrides.emplace_back(1);
for (int i = 0; i < kernelRank; ++i) {
actualStrides.emplace_back(1);
}
}
@ -1002,9 +993,9 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
if (padsArray.size() != 2 * kernelRank)
emitError("pads rank is not twice the spatial rank.");
// fill in the actual values
for (int i = 0; i < 2*kernelRank; ++i) {
for (int i = 0; i < 2 * kernelRank; ++i) {
int64_t p = (padsArray[i]).cast<IntegerAttr>().getInt();
if (p < 0)
if (p < 0)
emitError("pads value must be nonnegative.");
actualPads.emplace_back(p);
}
@ -1016,24 +1007,26 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
defaultPads = true;
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
// init pad with zero
for(int i=0; i<2*kernelRank; ++i) {
for (int i = 0; i < 2 * kernelRank; ++i) {
actualPads.emplace_back(0);
}
for(int i=0; i<kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i];
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
for (int i = 0; i < kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i];
auto kernelSpatialShape =
(kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto dilations = actualDilations[i];
auto strideSpatialShape = actualStrides[i];
int64_t outputSpatialShape = ceil((1.0 * inputSpatialShape) /
(1.0 * strideSpatialShape));
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
((kernelSpatialShape - 1) * dilations + 1) - inputSpatialShape;
int64_t outputSpatialShape =
ceil((1.0 * inputSpatialShape) / (1.0 * strideSpatialShape));
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
((kernelSpatialShape - 1) * dilations + 1) -
inputSpatialShape;
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
if (sumOfPad % 2 != 0) {
if (autoPad == "SAME_UPPER") {
actualPads[kernelRank + i] += 1;
} else {
actualPads[i] += 1;
actualPads[i] += 1;
}
}
}
@ -1042,24 +1035,26 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
}
// handle case where default pad values must be used
if (defaultPads) {
for(int i=0; i<2*kernelRank; ++i) {
for (int i = 0; i < 2 * kernelRank; ++i) {
actualPads.emplace_back(0);
}
}
// initialize output shape
// initialize output shape
SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end());
// for all kernel dimensions
for(int i=0; i<kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i];
auto padShape = actualPads[i] + actualPads[kernelRank+i];
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
for (int i = 0; i < kernelRank; ++i) {
auto inputSpatialShape = xShape[kernelOffset + i];
auto padShape = actualPads[i] + actualPads[kernelRank + i];
auto kernelSpatialShape =
(kernelShapeArray[i]).cast<IntegerAttr>().getInt();
auto dilations = actualDilations[i];
auto strideSpatialShape = actualStrides[i];
///output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] -
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
double numerator = inputSpatialShape + padShape -
((kernelSpatialShape - 1) * dilations + 1);
/// output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] -
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) /
// strides_spatial_shape[i] + 1)
double numerator = inputSpatialShape + padShape -
((kernelSpatialShape - 1) * dilations + 1);
double denominator = strideSpatialShape;
int64_t res;
if (ceilMode) {

File diff suppressed because it is too large Load Diff

View File

@ -127,6 +127,10 @@ int main(int argc, char *argv[]) {
if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass());
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createLowerKrnlPass());
}

View File

@ -28,6 +28,11 @@ void ONNXAddOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<MulAddToGemmOptPattern>(context);
}
void ONNXGemmOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<FuseGemmFollowedByAddition>(context);
}
/// on the ONNXIdentityOp.
void ONNXIdentityOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {

View File

@ -26,6 +26,7 @@ include "dialect/onnx/onnx.td"
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite
@ -41,6 +42,11 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
// onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z)
def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias),
(ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB),
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
(replaceWithValue $arg)>;

View File

@ -101,3 +101,14 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
// CHECK-NEXT: return %1 : tensor<*xf32>
}
//CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> {
func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32>
%1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// return [[GEMM]] : tensor<*xf32>
}