Merge branch 'master' into shapeinference-pad
This commit is contained in:
		
						commit
						d4f8fef947
					
				|  | @ -327,10 +327,10 @@ ONNX BatchNormalization operation | |||
| #### Results: | ||||
| 
 | ||||
| 1. `Y`: memref of any type values or tensor of any type values | ||||
| 1. `out_mean`: memref of any type values or tensor of any type values | ||||
| 1. `out_var`: memref of any type values or tensor of any type values | ||||
| 1. `saved_mean`: memref of any type values or tensor of any type values | ||||
| 1. `saved_var`: memref of any type values or tensor of any type values | ||||
| 1. `out_mean`: memref of any type values or tensor of any type values or none type | ||||
| 1. `out_var`: memref of any type values or tensor of any type values or none type | ||||
| 1. `saved_mean`: memref of any type values or tensor of any type values or none type | ||||
| 1. `saved_var`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| ### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp) | ||||
| ONNX BatchNormalization operation in test mode | ||||
|  | @ -375,12 +375,12 @@ ONNX BitShift operation | |||
| 
 | ||||
| 
 | ||||
| "Bitwise shift operator performs element-wise operation. For each input element, if the" | ||||
| " attribute "direction" is "RIGHT", this operator moves its binary representation toward" | ||||
| " the right side so that the input value is effectively decreased. If the attribute "direction"" | ||||
| " is "LEFT", bits of binary representation moves toward the left side, which results the" | ||||
| " attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward" | ||||
| " the right side so that the input value is effectively decreased. If the attribute \"direction\"" | ||||
| " is \"LEFT\", bits of binary representation moves toward the left side, which results the" | ||||
| " increase of its actual value. The input X is the tensor to be shifted and another input" | ||||
| " Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4]," | ||||
| " and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with" | ||||
| " Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4]," | ||||
| " and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with" | ||||
| " X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]." | ||||
| " " | ||||
| " Because this operator supports Numpy-style broadcasting, X's and Y's shapes are" | ||||
|  | @ -413,15 +413,15 @@ ONNX Cast operation | |||
| "the converted type. The 'to' argument must be one of the data types specified" | ||||
| "in the 'DataType' enum field in the TensorProto message." | ||||
| "" | ||||
| "Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations" | ||||
| "(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may" | ||||
| "Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations" | ||||
| "(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may" | ||||
| "result 100. There are some string literals reserved for special floating-point values;" | ||||
| ""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively." | ||||
| "Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly," | ||||
| "this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors" | ||||
| "to string tensors, plain floating-point representation (such as "314.15926") would be used. " | ||||
| "Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases " | ||||
| "of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior." | ||||
| "\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively." | ||||
| "Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly," | ||||
| "this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors" | ||||
| "to string tensors, plain floating-point representation (such as \"314.15926\") would be used. " | ||||
| "Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases " | ||||
| "of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior." | ||||
| "" | ||||
| "Conversion from a numerical type to any numerical type is always allowed." | ||||
| "User must be aware of precision loss and value change caused by range difference between two types." | ||||
|  | @ -476,8 +476,8 @@ ONNX Clip operation | |||
| #### Operands: | ||||
| 
 | ||||
| 1. `input`: memref of any type values or tensor of any type values | ||||
| 1. `min`: memref of any type values or tensor of any type values | ||||
| 1. `max`: memref of any type values or tensor of any type values | ||||
| 1. `min`: memref of any type values or tensor of any type values or none type | ||||
| 1. `max`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -618,8 +618,8 @@ ONNX ConvInteger operation | |||
| 
 | ||||
| 1. `x`: memref of any type values or tensor of any type values | ||||
| 1. `w`: memref of any type values or tensor of any type values | ||||
| 1. `x_zero_point`: memref of any type values or tensor of any type values | ||||
| 1. `w_zero_point`: memref of any type values or tensor of any type values | ||||
| 1. `x_zero_point`: memref of any type values or tensor of any type values or none type | ||||
| 1. `w_zero_point`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -678,7 +678,7 @@ ONNX Conv operation | |||
| 
 | ||||
| 1. `X`: memref of any type values or tensor of any type values | ||||
| 1. `W`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -720,7 +720,7 @@ ONNX ConvTranspose operation | |||
| 
 | ||||
| 1. `X`: memref of any type values or tensor of any type values | ||||
| 1. `W`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -884,7 +884,7 @@ ONNX DequantizeLinear operation | |||
| 
 | ||||
| 1. `x`: memref of any type values or tensor of any type values | ||||
| 1. `x_scale`: memref of any type values or tensor of any type values | ||||
| 1. `x_zero_point`: memref of any type values or tensor of any type values | ||||
| 1. `x_zero_point`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -964,7 +964,7 @@ ONNX Dropout operation | |||
| #### Results: | ||||
| 
 | ||||
| 1. `output`: memref of any type values or tensor of any type values | ||||
| 1. `mask`: memref of any type values or tensor of any type values | ||||
| 1. `mask`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| ### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp) | ||||
| ONNX DynamicQuantizeLinear operation | ||||
|  | @ -1297,9 +1297,9 @@ ONNX GRU operation | |||
| 1. `X`: memref of any type values or tensor of any type values | ||||
| 1. `W`: memref of any type values or tensor of any type values | ||||
| 1. `R`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values | ||||
| 1. `sequence_lens`: memref of any type values or tensor of any type values | ||||
| 1. `initial_h`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values or none type | ||||
| 1. `sequence_lens`: memref of any type values or tensor of any type values or none type | ||||
| 1. `initial_h`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -1315,8 +1315,8 @@ ONNX GRU operation | |||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
| 1. `Y`: memref of any type values or tensor of any type values | ||||
| 1. `Y_h`: memref of any type values or tensor of any type values | ||||
| 1. `Y`: memref of any type values or tensor of any type values or none type | ||||
| 1. `Y_h`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| ### onnx.GatherElements (ONNXGatherElementsOp) | ||||
| ONNX GatherElements operation | ||||
|  | @ -1609,7 +1609,7 @@ ONNX Gemm operation | |||
| 
 | ||||
| 1. `A`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values | ||||
| 1. `C`: memref of any type values or tensor of any type values | ||||
| 1. `C`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -2013,11 +2013,11 @@ ONNX LSTM operation | |||
| 1. `X`: memref of any type values or tensor of any type values | ||||
| 1. `W`: memref of any type values or tensor of any type values | ||||
| 1. `R`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values | ||||
| 1. `sequence_lens`: memref of any type values or tensor of any type values | ||||
| 1. `initial_h`: memref of any type values or tensor of any type values | ||||
| 1. `initial_c`: memref of any type values or tensor of any type values | ||||
| 1. `P`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values or none type | ||||
| 1. `sequence_lens`: memref of any type values or tensor of any type values or none type | ||||
| 1. `initial_h`: memref of any type values or tensor of any type values or none type | ||||
| 1. `initial_c`: memref of any type values or tensor of any type values or none type | ||||
| 1. `P`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -2033,9 +2033,9 @@ ONNX LSTM operation | |||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
| 1. `Y`: memref of any type values or tensor of any type values | ||||
| 1. `Y_h`: memref of any type values or tensor of any type values | ||||
| 1. `Y_c`: memref of any type values or tensor of any type values | ||||
| 1. `Y`: memref of any type values or tensor of any type values or none type | ||||
| 1. `Y_h`: memref of any type values or tensor of any type values or none type | ||||
| 1. `Y_c`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| ### onnx.LeakyRelu (ONNXLeakyReluOp) | ||||
| ONNX LeakyRelu operation | ||||
|  | @ -2160,24 +2160,24 @@ ONNX Loop operation | |||
| "" | ||||
| "    Operator inputs defined as (max_trip_count, condition_var)." | ||||
| "" | ||||
| "    input ("", ""):" | ||||
| "    input (\"\", \"\"):" | ||||
| "        for (int i=0; ; ++i) {" | ||||
| "          cond = ... // Note this value is ignored, but is required in the body" | ||||
| "        }" | ||||
| "" | ||||
| "    input ("", cond) // Note this is analogous to a while loop" | ||||
| "    input (\"\", cond) // Note this is analogous to a while loop" | ||||
| "        bool cond = ...;" | ||||
| "        for (int i=0; cond; ++i) {" | ||||
| "          cond = ...;" | ||||
| "        }" | ||||
| "" | ||||
| "    input ("", 1) // Note this is analogous to a do-while loop" | ||||
| "    input (\"\", 1) // Note this is analogous to a do-while loop" | ||||
| "        bool cond = true" | ||||
| "        for (int i=0; cond; ++i) {" | ||||
| "          cond = ...;" | ||||
| "        }" | ||||
| "" | ||||
| "    input (trip_count, "") // Note this is analogous to a for loop" | ||||
| "    input (trip_count, \"\") // Note this is analogous to a for loop" | ||||
| "        int trip_count = ..." | ||||
| "        for (int i=0; i < trip_count; ++i) {" | ||||
| "          cond = ...; // ignored" | ||||
|  | @ -2203,15 +2203,15 @@ ONNX Loop operation | |||
| "    }" | ||||
| "" | ||||
| "    graph body-net (" | ||||
| "      %i[INT32, scalar]           // iteration number" | ||||
| "      %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used" | ||||
| "      %b_in[INT32, scalar]        // incoming value of loop-carried-dependency b" | ||||
| "      %i[INT32, scalar]" | ||||
| "      %keepgoing[BOOL, scalar]" | ||||
| "      %b[INT32, scalar]" | ||||
| "    ) {" | ||||
| "      %my_local = Add(%a, %b_in)" | ||||
| "      %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b" | ||||
| "      %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition" | ||||
| "      %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated" | ||||
| "      return %keepgoing_out, %b_out, %user_defined_val" | ||||
| "      %my_local = Add(%a, %b)" | ||||
| "      %b_out = Sub(%a, %b)" | ||||
| "      %keepgoing_out = Greater(%my_local, %b_out)" | ||||
| "      %user_defined_vals = Add(%b, %b)" | ||||
| "      return %keepgoing_out, %b_out, %user_defined_vals" | ||||
| "    }" | ||||
| "" | ||||
| "*Sample equivalent C code*" | ||||
|  | @ -2226,51 +2226,31 @@ ONNX Loop operation | |||
| "      const int max_trip_count = 10; // Analogous to input M" | ||||
| "      int user_defined_vals[]; // Imagine this is resizable" | ||||
| "      /* End implicitly-defined code */" | ||||
| "      /* initialize loop-carried variables and scan-output variables */" | ||||
| "      bool keepgoing_out = keepgoing" | ||||
| "      int b_out = b" | ||||
| "" | ||||
| "      for (int i=0; i < max_trip_count && keepgoing_out; ++i) {" | ||||
| "        /* Implicitly-defined code: bind actual parameter values" | ||||
| "           to formal parameter variables of loop-body */" | ||||
| "        bool keepgoing_in = keepgoing_out; " | ||||
| "        bool b_in = b_out;" | ||||
| "" | ||||
| "      for (int i=0; i < max_trip_count && keepgoing; ++i) {" | ||||
| "        /* User-defined code (loop body) */" | ||||
| "        int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine" | ||||
| "        b_out = a - b_in;" | ||||
| "        keepgoing_out = my_local > b_out; " | ||||
| "        user_defined_val = b_in + b_in; // b_in and b_out are different variables" | ||||
| "        int my_local = a + b; // Reading values in the enclosing scope is fine" | ||||
| "        b = a - b; // writes fine if we specify b as a loop-carried dependency" | ||||
| "        keepgoing = my_local > b; // keepgoing is a loop-carried dependency" | ||||
| "        user_defined_vals[i] = b + b;" | ||||
| "        /* End user-defined code */" | ||||
| "" | ||||
| "        /* Implicitly defined-code */" | ||||
| "        user_defined_vals[i] = user_defined_val // accumulate scan-output values" | ||||
| "      }" | ||||
| "      // int t = my_local; // Can't do this. my_local is not accessible here." | ||||
| "      // my_local = 123; // Can't do this. my_local was defined in the the body" | ||||
| "" | ||||
| "      // The values below are bound to the output variables of the loop and therefore accessible" | ||||
| "      // b_out; user_defined_vals; keepgoing_out;" | ||||
| "      // These below values are live-out from the loop and therefore accessible" | ||||
| "      b_out; user_defined_vals; keepgoing_out;" | ||||
| "    }" | ||||
| "" | ||||
| "There are several things of note in this code snippet:" | ||||
| "" | ||||
| "1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can" | ||||
| "1) Values from the enclosing scope (i.e. variable a here) are in scope and can" | ||||
| "   be referenced in the inputs of the loop." | ||||
| "2) Any values computed in the loop body that needs to be used in a subsequent" | ||||
| "   iteration or after the loop are modelled using a pair of variables in the loop-body," | ||||
| "   consisting of an input variable (eg., b_in) and an output variable (eg., b_out)." | ||||
| "   These are referred to as loop-carried dependences. The loop operation node" | ||||
| "   supplies the input value of the input variable for the first iteration, and" | ||||
| "   returns the output value of the output variable produced by the final" | ||||
| "   iteration." | ||||
| "3) Scan_output variables are used to implicitly concatenate values computed across" | ||||
| "   all the iterations. In the above example, the value of user_defined_val computed" | ||||
| "   over all iterations are concatenated and returned as the value of user_defined_vals" | ||||
| "   after the loop." | ||||
| "4) Values created in the body cannot be accessed in the enclosing scope," | ||||
| "   except using the mechanism described above." | ||||
| "2) Any variables which you wish to make available in the enclosing scope (i.e." | ||||
| "   the variables b and keepgoing) must be declared as either loop-carried" | ||||
| "   dependencies (both at the op inputs and output and at the body net input and" | ||||
| "   output) or scan_outputs." | ||||
| "3) Values created in the body cannot be accessed in the enclosing scope." | ||||
| "" | ||||
| "Note that the semantics of this op support "diagonal" or "wavefront" execution." | ||||
| "Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution." | ||||
| "(See Step 3 here for an example:" | ||||
| "https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)." | ||||
| "Frontends should emit multi-layer RNNs as a series of While operators (with" | ||||
|  | @ -2280,8 +2260,8 @@ ONNX Loop operation | |||
| 
 | ||||
| #### Operands: | ||||
| 
 | ||||
| 1. `M`: memref of any type values or tensor of any type values | ||||
| 1. `cond`: memref of any type values or tensor of any type values | ||||
| 1. `M`: memref of any type values or tensor of any type values or none type | ||||
| 1. `cond`: memref of any type values or tensor of any type values or none type | ||||
| 1. `v_initial`: memref of any type values or tensor of any type values | ||||
| 
 | ||||
| #### Attributes: | ||||
|  | @ -2360,8 +2340,8 @@ ONNX MatMulInteger operation | |||
| 
 | ||||
| 1. `A`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values | ||||
| 1. `a_zero_point`: memref of any type values or tensor of any type values | ||||
| 1. `b_zero_point`: memref of any type values or tensor of any type values | ||||
| 1. `a_zero_point`: memref of any type values or tensor of any type values or none type | ||||
| 1. `b_zero_point`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -2444,7 +2424,7 @@ ONNX MaxPool operation | |||
| " ```" | ||||
| " pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]" | ||||
| " ```" | ||||
| " The output of each pooling window is maximum number of elements exclude pad. " | ||||
| " The output of each pooling window is maximum number of elements exclude pad." | ||||
| " " | ||||
| 
 | ||||
| #### Operands: | ||||
|  | @ -2466,7 +2446,7 @@ ONNX MaxPool operation | |||
| #### Results: | ||||
| 
 | ||||
| 1. `Y`: memref of any type values or tensor of any type values | ||||
| 1. `Indices`: memref of any type values or tensor of any type values | ||||
| 1. `Indices`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| ### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp) | ||||
| ONNX MaxPool operation with a single output. | ||||
|  | @ -2552,7 +2532,7 @@ ONNX MaxUnpool operation | |||
| 
 | ||||
| 1. `X`: memref of any type values or tensor of any type values | ||||
| 1. `I`: memref of any type values or tensor of any type values | ||||
| 1. `output_shape`: memref of any type values or tensor of any type values | ||||
| 1. `output_shape`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -2752,9 +2732,9 @@ ONNX NonMaxSuppression operation | |||
| 
 | ||||
| 1. `boxes`: memref of any type values or tensor of any type values | ||||
| 1. `scores`: memref of any type values or tensor of any type values | ||||
| 1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values | ||||
| 1. `iou_threshold`: memref of any type values or tensor of any type values | ||||
| 1. `score_threshold`: memref of any type values or tensor of any type values | ||||
| 1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values or none type | ||||
| 1. `iou_threshold`: memref of any type values or tensor of any type values or none type | ||||
| 1. `score_threshold`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -3067,7 +3047,7 @@ ONNX Pad operation | |||
| 
 | ||||
| 1. `data`: memref of any type values or tensor of any type values | ||||
| 1. `pads`: memref of any type values or tensor of any type values | ||||
| 1. `constant_value`: memref of any type values or tensor of any type values | ||||
| 1. `constant_value`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -3124,7 +3104,7 @@ ONNX QLinearConv operation | |||
| 1. `w_zero_point`: memref of any type values or tensor of any type values | ||||
| 1. `y_scale`: memref of any type values or tensor of any type values | ||||
| 1. `y_zero_point`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -3188,7 +3168,7 @@ ONNX QuantizeLinear operation | |||
| 
 | ||||
| 1. `x`: memref of any type values or tensor of any type values | ||||
| 1. `y_scale`: memref of any type values or tensor of any type values | ||||
| 1. `y_zero_point`: memref of any type values or tensor of any type values | ||||
| 1. `y_zero_point`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -3270,9 +3250,9 @@ ONNX RNN operation | |||
| 1. `X`: memref of any type values or tensor of any type values | ||||
| 1. `W`: memref of any type values or tensor of any type values | ||||
| 1. `R`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values | ||||
| 1. `sequence_lens`: memref of any type values or tensor of any type values | ||||
| 1. `initial_h`: memref of any type values or tensor of any type values | ||||
| 1. `B`: memref of any type values or tensor of any type values or none type | ||||
| 1. `sequence_lens`: memref of any type values or tensor of any type values or none type | ||||
| 1. `initial_h`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -3287,8 +3267,8 @@ ONNX RNN operation | |||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
| 1. `Y`: memref of any type values or tensor of any type values | ||||
| 1. `Y_h`: memref of any type values or tensor of any type values | ||||
| 1. `Y`: memref of any type values or tensor of any type values or none type | ||||
| 1. `Y_h`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| ### onnx.RandomNormalLike (ONNXRandomNormalLikeOp) | ||||
| ONNX RandomNormalLike operation | ||||
|  | @ -3813,14 +3793,14 @@ ONNX Resize operation | |||
| 
 | ||||
| "Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor." | ||||
| "Each dimension value of the output tensor is:" | ||||
| "  output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified." | ||||
| "  output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified." | ||||
| 
 | ||||
| #### Operands: | ||||
| 
 | ||||
| 1. `X`: memref of any type values or tensor of any type values | ||||
| 1. `roi`: memref of any type values or tensor of any type values | ||||
| 1. `scales`: memref of any type values or tensor of any type values | ||||
| 1. `sizes`: memref of any type values or tensor of any type values | ||||
| 1. `sizes`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -4438,7 +4418,7 @@ ONNX SequenceErase operation | |||
| #### Operands: | ||||
| 
 | ||||
| 1. `input_sequence`: memref of any type values or tensor of any type values | ||||
| 1. `position`: memref of any type values or tensor of any type values | ||||
| 1. `position`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -4463,7 +4443,7 @@ ONNX SequenceInsert operation | |||
| 
 | ||||
| 1. `input_sequence`: memref of any type values or tensor of any type values | ||||
| 1. `tensor`: memref of any type values or tensor of any type values | ||||
| 1. `position`: memref of any type values or tensor of any type values | ||||
| 1. `position`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -4680,8 +4660,8 @@ ONNX Slice operation | |||
| 1. `data`: memref of any type values or tensor of any type values | ||||
| 1. `starts`: memref of any type values or tensor of any type values | ||||
| 1. `ends`: memref of any type values or tensor of any type values | ||||
| 1. `axes`: memref of any type values or tensor of any type values | ||||
| 1. `steps`: memref of any type values or tensor of any type values | ||||
| 1. `axes`: memref of any type values or tensor of any type values or none type | ||||
| 1. `steps`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -4834,7 +4814,7 @@ ONNX SplitToSequence operation | |||
| #### Operands: | ||||
| 
 | ||||
| 1. `input`: memref of any type values or tensor of any type values | ||||
| 1. `split`: memref of any type values or tensor of any type values | ||||
| 1. `split`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| #### Attributes: | ||||
| 
 | ||||
|  | @ -4902,9 +4882,9 @@ ONNX StringNormalizer operation | |||
| "StringNormalization performs string operations for basic cleaning." | ||||
| "This operator has only one input (denoted by X) and only one output" | ||||
| "(denoted by Y). This operator first examines the elements in the X," | ||||
| "and removes elements specified in "stopwords" attribute." | ||||
| "and removes elements specified in \"stopwords\" attribute." | ||||
| "After removing stop words, the intermediate result can be further lowercased," | ||||
| "uppercased, or just returned depending the "case_change_action" attribute." | ||||
| "uppercased, or just returned depending the \"case_change_action\" attribute." | ||||
| "This operator only accepts [C]- and [1, C]-tensor." | ||||
| "If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]" | ||||
| "if input shape is [C] and shape [1, 1] if input shape is [1, C]." | ||||
|  | @ -5034,8 +5014,8 @@ ONNX TfIdfVectorizer operation | |||
| "respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output." | ||||
| "Note that we may consider all skips up to S when generating the n-grams." | ||||
| "" | ||||
| "The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and" | ||||
| "the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF"," | ||||
| "The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and" | ||||
| "the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\"," | ||||
| "this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute." | ||||
| "" | ||||
| "Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor." | ||||
|  | @ -5123,9 +5103,9 @@ ONNX TopK operation | |||
| "   contains the indices of the top k elements (original indices from the input" | ||||
| "   tensor)." | ||||
| "" | ||||
| "If "largest" is 1 (the default value) then the k largest elements are returned." | ||||
| "If "sorted" is 1 (the default value) then the resulting k elements will be sorted." | ||||
| "If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined." | ||||
| "If \"largest\" is 1 (the default value) then the k largest elements are returned." | ||||
| "If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted." | ||||
| "If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined." | ||||
| "" | ||||
| "Given two equivalent values, this operator uses the indices along the axis as" | ||||
| " a tiebreaker. That is, the element with the lower index will appear first." | ||||
|  | @ -5184,7 +5164,7 @@ ONNX Unique operation | |||
| "This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. " | ||||
| "The first output tensor 'Y' contains all unique values or subtensors of the input. " | ||||
| "The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. " | ||||
| "The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". " | ||||
| "The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". " | ||||
| "The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. " | ||||
| "" | ||||
| "Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. " | ||||
|  | @ -5268,9 +5248,9 @@ ONNX Unique operation | |||
| #### Results: | ||||
| 
 | ||||
| 1. `Y`: memref of any type values or tensor of any type values | ||||
| 1. `indices`: memref of any type values or tensor of any type values | ||||
| 1. `inverse_indices`: memref of any type values or tensor of any type values | ||||
| 1. `counts`: memref of any type values or tensor of any type values | ||||
| 1. `indices`: memref of any type values or tensor of any type values or none type | ||||
| 1. `inverse_indices`: memref of any type values or tensor of any type values or none type | ||||
| 1. `counts`: memref of any type values or tensor of any type values or none type | ||||
| 
 | ||||
| ### onnx.Unsqueeze (ONNXUnsqueezeOp) | ||||
| ONNX Unsqueeze operation | ||||
|  |  | |||
							
								
								
									
										1041
									
								
								doc/gen_doc.py
								
								
								
								
							
							
						
						
									
										1041
									
								
								doc/gen_doc.py
								
								
								
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -121,6 +121,7 @@ private: | |||
|   mlir::MLIRContext &context_; | ||||
|   mlir::ModuleOp module_; | ||||
|   mlir::OpBuilder builder_; | ||||
|   mlir::Value none_; | ||||
|   // mapping between string name and symbol
 | ||||
|   OnnxOnnfSymbolMapping frontend_symbols_; | ||||
| 
 | ||||
|  | @ -287,8 +288,8 @@ private: | |||
|     } | ||||
|   } | ||||
| 
 | ||||
|   std::vector<mlir::NamedAttribute> ImportNodeAttributes( | ||||
|       const onnx::NodeProto &node) { | ||||
|   std::vector<mlir::NamedAttribute> | ||||
|   ImportNodeAttributes(const onnx::NodeProto &node) { | ||||
|     std::vector<mlir::NamedAttribute> attributes; | ||||
|     for (int i = 0; i < node.attribute_size(); ++i) { | ||||
|       auto attr = node.attribute(i); | ||||
|  | @ -317,21 +318,11 @@ private: | |||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
 | ||||
|   // combined with 'if constexpr' the issue is the type of the output is
 | ||||
|   // different. alternative way to use variadic output for all the op
 | ||||
| 
 | ||||
|   /*!
 | ||||
|    * Important onnx node which generates only one output | ||||
|    * @param node onnx node | ||||
|    * @param nIn number of expected inputs | ||||
|    * @param nOut number of expected outputs | ||||
|    * @param attrs  list of desription for attributes with format {name, type, | ||||
|    * default} | ||||
|    */ | ||||
|   template <typename T> | ||||
|   void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut, | ||||
|                         bool variadicIn = false, bool variadicOut = false) { | ||||
|   void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1, | ||||
|                       int expectedNumResults = -1) { | ||||
|     bool variadicIn = expectedNumOperands == -1; | ||||
|     bool variadicOut = expectedNumResults == -1; | ||||
|     std::vector<mlir::Value> inputs; | ||||
|     for (const auto &item : node.input()) { | ||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||
|  | @ -339,6 +330,10 @@ private: | |||
|       } | ||||
|     } | ||||
| 
 | ||||
|     if (!variadicIn) | ||||
|       for (auto i = inputs.size(); i < expectedNumOperands; i++) | ||||
|         inputs.emplace_back(none_); | ||||
| 
 | ||||
|     std::vector<mlir::Type> outputTypes; | ||||
|     for (auto item : node.output()) { | ||||
|       outputTypes.push_back( | ||||
|  | @ -347,49 +342,11 @@ private: | |||
| 
 | ||||
|     auto attributes = ImportNodeAttributes(node); | ||||
| 
 | ||||
|     llvm::StringRef OpName = node.op_type(); | ||||
|     if ((variadicIn || nIn == inputs.size()) && | ||||
|         (variadicOut || nOut == outputTypes.size())) { | ||||
|       auto op = | ||||
|           builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); | ||||
|       frontend_symbols_.AddMapping(legalize_name(node.output()[0]), | ||||
|                                    op.getResult()); | ||||
|     } else { | ||||
|       ImportNodeGeneric(node); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   template <typename T> | ||||
|   void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut, | ||||
|                               bool variadicIn = false, | ||||
|                               bool variadicOut = false) { | ||||
|     std::vector<mlir::Value> inputs; | ||||
|     for (const auto &item : node.input()) { | ||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     std::vector<mlir::Type> outputTypes; | ||||
|     for (auto item : node.output()) { | ||||
|       outputTypes.push_back( | ||||
|           mlir::UnrankedTensorType::get(builder_.getF32Type())); | ||||
|     } | ||||
| 
 | ||||
|     auto attributes = ImportNodeAttributes(node); | ||||
| 
 | ||||
|     llvm::StringRef OpName = node.op_type(); | ||||
| 
 | ||||
|     if ((variadicIn || nIn == inputs.size()) && | ||||
|         (variadicOut || nOut == outputTypes.size())) { | ||||
|       auto op = | ||||
|           builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); | ||||
|       for (int i = 0; i < node.output().size(); i++) { | ||||
|         frontend_symbols_.AddMapping(legalize_name(node.output()[i]), | ||||
|                                      op.getResult(i)); | ||||
|       } | ||||
|     } else { | ||||
|       ImportNodeGeneric(node); | ||||
|     // TODO: Handle optional inputs.
 | ||||
|     auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes); | ||||
|     for (int i = 0; i < node.output().size(); i++) { | ||||
|       frontend_symbols_.AddMapping(legalize_name(node.output()[i]), | ||||
|                                    *(op.getODSResults(i).begin())); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  | @ -398,8 +355,7 @@ private: | |||
|    * c++ does not allow template specialization inside a class scope | ||||
|    * a specialized function is used | ||||
|    */ | ||||
|   void | ||||
|   ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) { | ||||
|   void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) { | ||||
|     // Conv has attribute dilations, kernel_shape, pads, the default value of
 | ||||
|     // which  is determined by the shape of first argument. However, since the
 | ||||
|     // shape is unknown now, these attributes can be not generated auto
 | ||||
|  | @ -413,24 +369,20 @@ private: | |||
|     int nOps = node.input().size(); | ||||
| 
 | ||||
|     if (nOps == 2) | ||||
|       ImportNodeOneOut<mlir::ONNXConvNoBiasOp>( | ||||
|           node, nOps, nOut); | ||||
|       buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut); | ||||
|     else | ||||
|       ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut); | ||||
|       buildOperation<mlir::ONNXConvOp>(node, nOps, nOut); | ||||
|   } | ||||
| 
 | ||||
|   /*!
 | ||||
|    * Special handle for MaxPool operations. | ||||
|    */ | ||||
|   void ImportNodeMaxPool( | ||||
|       onnx::NodeProto node, int nIn, int nOut) { | ||||
|   void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) { | ||||
|     int nOuts = node.output().size(); | ||||
|     if (nOuts == 1) { | ||||
|       ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>( | ||||
|           node, nIn, nOuts); | ||||
|       buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts); | ||||
|     } else { | ||||
|       ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>( | ||||
|           node, nIn, nOuts); | ||||
|       buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  | @ -441,23 +393,10 @@ private: | |||
|     int nOuts = node.output().size(); | ||||
|     if (nOuts == 1) { | ||||
|       // Test mode with one output.
 | ||||
|       ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, | ||||
|                                                                nOuts); | ||||
|       buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts); | ||||
|     } else { | ||||
|       // Training mode with four trailing optional outputs. Not handled yet.
 | ||||
|       ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /*!
 | ||||
|    * Special handle for Gemm operations. | ||||
|    */ | ||||
|   void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) { | ||||
|     int nOps = node.input().size(); | ||||
|     if (nOps == 2) { | ||||
|       ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut); | ||||
|     } else { | ||||
|       ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut); | ||||
|       buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  | @ -467,28 +406,14 @@ private: | |||
|   void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) { | ||||
|     int nOps = node.input().size(); | ||||
|     if (nOps == 2) { | ||||
|       ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut); | ||||
|       buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut); | ||||
|     } else { | ||||
|       ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut); | ||||
|       buildOperation<mlir::ONNXPadOp>(node, nIn, nOut); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   void ImportNode(const onnx::NodeProto &node) { | ||||
|     std::vector<mlir::Value> inputs; | ||||
|     for (const auto &item : node.input()) { | ||||
|       if (frontend_symbols_.ContainKey(legalize_name(item))) { | ||||
|         inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     std::vector<mlir::Type> outputTypes; | ||||
|     for (auto item : node.output()) { | ||||
|       outputTypes.push_back( | ||||
|           mlir::UnrankedTensorType::get(builder_.getF32Type())); | ||||
|     } | ||||
| 
 | ||||
|     std::vector<mlir::NamedAttribute> attributes; | ||||
|     llvm::StringRef OpName = node.op_type(); | ||||
|     llvm::StringRef opName = node.op_type(); | ||||
| 
 | ||||
|     // the following code is generated by gen_doc.py
 | ||||
|     // refer to dialect/onnx/onnx.td for details
 | ||||
|  | @ -555,9 +480,11 @@ private: | |||
|       ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it)); | ||||
|     } | ||||
| 
 | ||||
|     // import nodes in the graph
 | ||||
|     auto node = graph.node(); | ||||
|     for (const auto &item : node) { | ||||
|     // Create a NoneTyped constant.
 | ||||
|     none_ = | ||||
|         builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr()); | ||||
|     // Import nodes in the graph.
 | ||||
|     for (const auto &item : graph.node()) { | ||||
|       ImportNode(item); | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,320 +1,319 @@ | |||
| //********************************************************
 | ||||
| //   Warning: Do not modify this file directly
 | ||||
| //   This file is automatically generated via script
 | ||||
| //   Details can be found in doc/readonnxdefs.md
 | ||||
| //   This file is generated on UTC-02/24/2020, 06:29:01.
 | ||||
| //   Do not modify this file directly.
 | ||||
| //   This file is automatically generated via script.
 | ||||
| //   Details can be found in doc/readonnxdefs.md .
 | ||||
| //********************************************************
 | ||||
| 
 | ||||
|     if (OpName == "DUMMY") { | ||||
|     }else if (OpName == "Abs") { | ||||
|        ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1); | ||||
|     }else if (OpName == "Acos") { | ||||
|        ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1); | ||||
|     }else if (OpName == "Acosh") { | ||||
|        ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1); | ||||
|     }else if (OpName == "Add") { | ||||
|        ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1); | ||||
|     }else if (OpName == "And") { | ||||
|        ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1); | ||||
|     }else if (OpName == "ArgMax") { | ||||
|        ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1); | ||||
|     }else if (OpName == "ArgMin") { | ||||
|        ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1); | ||||
|     }else if (OpName == "Asin") { | ||||
|        ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1); | ||||
|     }else if (OpName == "Asinh") { | ||||
|        ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1); | ||||
|     }else if (OpName == "Atan") { | ||||
|        ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1); | ||||
|     }else if (OpName == "Atanh") { | ||||
|        ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1); | ||||
|     }else if (OpName == "AveragePool") { | ||||
|        ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1); | ||||
|     }else if (OpName == "BatchNormalization") { | ||||
|        ImportNodeBatchNormalization(node, 5, 5); | ||||
|     }else if (OpName == "BitShift") { | ||||
|        ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1); | ||||
|     }else if (OpName == "Cast") { | ||||
|        ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1); | ||||
|     }else if (OpName == "Ceil") { | ||||
|        ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1); | ||||
|     }else if (OpName == "Clip") { | ||||
|        ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1); | ||||
|     }else if (OpName == "Compress") { | ||||
|        ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1); | ||||
|     }else if (OpName == "Concat") { | ||||
|        ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false); | ||||
|     }else if (OpName == "ConcatFromSequence") { | ||||
|        ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1); | ||||
|     }else if (OpName == "Constant") { | ||||
|        ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1); | ||||
|     }else if (OpName == "ConstantOfShape") { | ||||
|        ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1); | ||||
|     }else if (OpName == "Conv") { | ||||
|        ImportNodeConv(node, 3, 1); | ||||
|     }else if (OpName == "ConvInteger") { | ||||
|        ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1); | ||||
|     }else if (OpName == "ConvTranspose") { | ||||
|        ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1); | ||||
|     }else if (OpName == "Cos") { | ||||
|        ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1); | ||||
|     }else if (OpName == "Cosh") { | ||||
|        ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1); | ||||
|     }else if (OpName == "CumSum") { | ||||
|        ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1); | ||||
|     }else if (OpName == "DepthToSpace") { | ||||
|        ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1); | ||||
|     }else if (OpName == "DequantizeLinear") { | ||||
|        ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1); | ||||
|     }else if (OpName == "Det") { | ||||
|        ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1); | ||||
|     }else if (OpName == "Div") { | ||||
|        ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1); | ||||
|     }else if (OpName == "Dropout") { | ||||
|        ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2); | ||||
|     }else if (OpName == "DynamicQuantizeLinear") { | ||||
|        ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3); | ||||
|     }else if (OpName == "Elu") { | ||||
|        ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1); | ||||
|     }else if (OpName == "Equal") { | ||||
|        ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1); | ||||
|     }else if (OpName == "Erf") { | ||||
|        ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1); | ||||
|     }else if (OpName == "Exp") { | ||||
|        ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1); | ||||
|     }else if (OpName == "Expand") { | ||||
|        ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1); | ||||
|     }else if (OpName == "EyeLike") { | ||||
|        ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1); | ||||
|     }else if (OpName == "Flatten") { | ||||
|        ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1); | ||||
|     }else if (OpName == "Floor") { | ||||
|        ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1); | ||||
|     }else if (OpName == "GRU") { | ||||
|        ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2); | ||||
|     }else if (OpName == "Gather") { | ||||
|        ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1); | ||||
|     }else if (OpName == "GatherElements") { | ||||
|        ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1); | ||||
|     }else if (OpName == "GatherND") { | ||||
|        ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1); | ||||
|     }else if (OpName == "Gemm") { | ||||
|        ImportNodeGemm(node, 3, 1); | ||||
|     }else if (OpName == "GlobalAveragePool") { | ||||
|        ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1); | ||||
|     }else if (OpName == "GlobalLpPool") { | ||||
|        ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1); | ||||
|     }else if (OpName == "GlobalMaxPool") { | ||||
|        ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1); | ||||
|     }else if (OpName == "Greater") { | ||||
|        ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1); | ||||
|     }else if (OpName == "HardSigmoid") { | ||||
|        ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1); | ||||
|     }else if (OpName == "Hardmax") { | ||||
|        ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1); | ||||
|     }else if (OpName == "Identity") { | ||||
|        ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1); | ||||
|     }else if (OpName == "If") { | ||||
|        ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1); | ||||
|     }else if (OpName == "InstanceNormalization") { | ||||
|        ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1); | ||||
|     }else if (OpName == "IsInf") { | ||||
|        ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1); | ||||
|     }else if (OpName == "IsNaN") { | ||||
|        ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1); | ||||
|     }else if (OpName == "LRN") { | ||||
|        ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1); | ||||
|     }else if (OpName == "LSTM") { | ||||
|        ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3); | ||||
|     }else if (OpName == "LeakyRelu") { | ||||
|        ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1); | ||||
|     }else if (OpName == "Less") { | ||||
|        ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1); | ||||
|     }else if (OpName == "Log") { | ||||
|        ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1); | ||||
|     }else if (OpName == "LogSoftmax") { | ||||
|        ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1); | ||||
|     }else if (OpName == "Loop") { | ||||
|        ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1); | ||||
|     }else if (OpName == "LpNormalization") { | ||||
|        ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1); | ||||
|     }else if (OpName == "LpPool") { | ||||
|        ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1); | ||||
|     }else if (OpName == "MatMul") { | ||||
|        ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1); | ||||
|     }else if (OpName == "MatMulInteger") { | ||||
|        ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1); | ||||
|     }else if (OpName == "Max") { | ||||
|        ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false); | ||||
|     }else if (OpName == "MaxPool") { | ||||
|        ImportNodeMaxPool(node, 1, 2); | ||||
|     }else if (OpName == "MaxRoiPool") { | ||||
|        ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1); | ||||
|     }else if (OpName == "MaxUnpool") { | ||||
|        ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1); | ||||
|     }else if (OpName == "Mean") { | ||||
|        ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false); | ||||
|     }else if (OpName == "MeanVarianceNormalization") { | ||||
|        ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1); | ||||
|     }else if (OpName == "Min") { | ||||
|        ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false); | ||||
|     }else if (OpName == "Mod") { | ||||
|        ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1); | ||||
|     }else if (OpName == "Mul") { | ||||
|        ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1); | ||||
|     }else if (OpName == "Multinomial") { | ||||
|        ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1); | ||||
|     }else if (OpName == "Neg") { | ||||
|        ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1); | ||||
|     }else if (OpName == "NonMaxSuppression") { | ||||
|        ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1); | ||||
|     }else if (OpName == "NonZero") { | ||||
|        ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1); | ||||
|     }else if (OpName == "Not") { | ||||
|        ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1); | ||||
|     }else if (OpName == "OneHot") { | ||||
|        ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1); | ||||
|     }else if (OpName == "Or") { | ||||
|        ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1); | ||||
|     }else if (OpName == "PRelu") { | ||||
|        ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1); | ||||
|     }else if (OpName == "Pad") { | ||||
|        ImportNodePad(node, 3, 1); | ||||
|     }else if (OpName == "Pow") { | ||||
|        ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1); | ||||
|     }else if (OpName == "QLinearConv") { | ||||
|        ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1); | ||||
|     }else if (OpName == "QLinearMatMul") { | ||||
|        ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1); | ||||
|     }else if (OpName == "QuantizeLinear") { | ||||
|        ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1); | ||||
|     }else if (OpName == "RNN") { | ||||
|        ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2); | ||||
|     }else if (OpName == "RandomNormal") { | ||||
|        ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1); | ||||
|     }else if (OpName == "RandomNormalLike") { | ||||
|        ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1); | ||||
|     }else if (OpName == "RandomUniform") { | ||||
|        ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1); | ||||
|     }else if (OpName == "RandomUniformLike") { | ||||
|        ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1); | ||||
|     }else if (OpName == "Range") { | ||||
|        ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1); | ||||
|     }else if (OpName == "Reciprocal") { | ||||
|        ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceL1") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceL2") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceLogSum") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceLogSumExp") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceMax") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceMean") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceMin") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceProd") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceSum") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1); | ||||
|     }else if (OpName == "ReduceSumSquare") { | ||||
|        ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1); | ||||
|     }else if (OpName == "Relu") { | ||||
|        ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1); | ||||
|     }else if (OpName == "Reshape") { | ||||
|        ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1); | ||||
|     }else if (OpName == "Resize") { | ||||
|        ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1); | ||||
|     }else if (OpName == "ReverseSequence") { | ||||
|        ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1); | ||||
|     }else if (OpName == "RoiAlign") { | ||||
|        ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1); | ||||
|     }else if (OpName == "Round") { | ||||
|        ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1); | ||||
|     }else if (OpName == "Scan") { | ||||
|        ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1); | ||||
|     }else if (OpName == "Scatter") { | ||||
|        ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1); | ||||
|     }else if (OpName == "ScatterElements") { | ||||
|        ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1); | ||||
|     }else if (OpName == "ScatterND") { | ||||
|        ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1); | ||||
|     }else if (OpName == "Selu") { | ||||
|        ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1); | ||||
|     }else if (OpName == "SequenceAt") { | ||||
|        ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1); | ||||
|     }else if (OpName == "SequenceConstruct") { | ||||
|        ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false); | ||||
|     }else if (OpName == "SequenceEmpty") { | ||||
|        ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1); | ||||
|     }else if (OpName == "SequenceErase") { | ||||
|        ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1); | ||||
|     }else if (OpName == "SequenceInsert") { | ||||
|        ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1); | ||||
|     }else if (OpName == "SequenceLength") { | ||||
|        ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1); | ||||
|     }else if (OpName == "Shape") { | ||||
|        ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1); | ||||
|     }else if (OpName == "Shrink") { | ||||
|        ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1); | ||||
|     }else if (OpName == "Sigmoid") { | ||||
|        ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1); | ||||
|     }else if (OpName == "Sign") { | ||||
|        ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1); | ||||
|     }else if (OpName == "Sin") { | ||||
|        ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1); | ||||
|     }else if (OpName == "Sinh") { | ||||
|        ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1); | ||||
|     }else if (OpName == "Size") { | ||||
|        ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1); | ||||
|     }else if (OpName == "Slice") { | ||||
|        ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1); | ||||
|     }else if (OpName == "Softmax") { | ||||
|        ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1); | ||||
|     }else if (OpName == "Softplus") { | ||||
|        ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1); | ||||
|     }else if (OpName == "Softsign") { | ||||
|        ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1); | ||||
|     }else if (OpName == "SpaceToDepth") { | ||||
|        ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1); | ||||
|     }else if (OpName == "Split") { | ||||
|        ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1); | ||||
|     }else if (OpName == "SplitToSequence") { | ||||
|        ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1); | ||||
|     }else if (OpName == "Sqrt") { | ||||
|        ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1); | ||||
|     }else if (OpName == "Squeeze") { | ||||
|        ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1); | ||||
|     }else if (OpName == "StringNormalizer") { | ||||
|        ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1); | ||||
|     }else if (OpName == "Sub") { | ||||
|        ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1); | ||||
|     }else if (OpName == "Sum") { | ||||
|        ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false); | ||||
|     }else if (OpName == "Tan") { | ||||
|        ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1); | ||||
|     }else if (OpName == "Tanh") { | ||||
|        ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1); | ||||
|     }else if (OpName == "TfIdfVectorizer") { | ||||
|        ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1); | ||||
|     }else if (OpName == "ThresholdedRelu") { | ||||
|        ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1); | ||||
|     }else if (OpName == "Tile") { | ||||
|        ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1); | ||||
|     }else if (OpName == "TopK") { | ||||
|        ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2); | ||||
|     }else if (OpName == "Transpose") { | ||||
|        ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1); | ||||
|     }else if (OpName == "Unique") { | ||||
|        ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4); | ||||
|     }else if (OpName == "Unsqueeze") { | ||||
|        ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1); | ||||
|     }else if (OpName == "Upsample") { | ||||
|        ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1); | ||||
|     }else if (OpName == "Where") { | ||||
|        ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1); | ||||
|     }else if (OpName == "Xor") { | ||||
|        ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1); | ||||
|     } | ||||
| if (opName == "Abs") | ||||
|   return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Acos") | ||||
|   return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Acosh") | ||||
|   return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Add") | ||||
|   return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "And") | ||||
|   return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "ArgMax") | ||||
|   return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ArgMin") | ||||
|   return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Asin") | ||||
|   return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Asinh") | ||||
|   return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Atan") | ||||
|   return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Atanh") | ||||
|   return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "AveragePool") | ||||
|   return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "BatchNormalization") | ||||
|   return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5); | ||||
| if (opName == "BitShift") | ||||
|   return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Cast") | ||||
|   return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Ceil") | ||||
|   return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Clip") | ||||
|   return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Compress") | ||||
|   return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Concat") | ||||
|   return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
| if (opName == "ConcatFromSequence") | ||||
|   return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Constant") | ||||
|   return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); | ||||
| if (opName == "ConstantOfShape") | ||||
|   return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Conv") | ||||
|   return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "ConvInteger") | ||||
|   return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); | ||||
| if (opName == "ConvTranspose") | ||||
|   return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Cos") | ||||
|   return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Cosh") | ||||
|   return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "CumSum") | ||||
|   return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "DepthToSpace") | ||||
|   return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "DequantizeLinear") | ||||
|   return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Det") | ||||
|   return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Div") | ||||
|   return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Dropout") | ||||
|   return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); | ||||
| if (opName == "DynamicQuantizeLinear") | ||||
|   return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3); | ||||
| if (opName == "Elu") | ||||
|   return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Equal") | ||||
|   return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Erf") | ||||
|   return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Exp") | ||||
|   return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Expand") | ||||
|   return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "EyeLike") | ||||
|   return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Flatten") | ||||
|   return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Floor") | ||||
|   return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "GRU") | ||||
|   return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); | ||||
| if (opName == "Gather") | ||||
|   return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "GatherElements") | ||||
|   return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "GatherND") | ||||
|   return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Gemm") | ||||
|   return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "GlobalAveragePool") | ||||
|   return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "GlobalLpPool") | ||||
|   return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "GlobalMaxPool") | ||||
|   return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Greater") | ||||
|   return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "HardSigmoid") | ||||
|   return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Hardmax") | ||||
|   return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Identity") | ||||
|   return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "If") | ||||
|   return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); | ||||
| if (opName == "InstanceNormalization") | ||||
|   return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "IsInf") | ||||
|   return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "IsNaN") | ||||
|   return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "LRN") | ||||
|   return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "LSTM") | ||||
|   return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3); | ||||
| if (opName == "LeakyRelu") | ||||
|   return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Less") | ||||
|   return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Log") | ||||
|   return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "LogSoftmax") | ||||
|   return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Loop") | ||||
|   return buildOperation<mlir::ONNXLoopOp>(node); | ||||
| if (opName == "LpNormalization") | ||||
|   return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "LpPool") | ||||
|   return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "MatMul") | ||||
|   return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "MatMulInteger") | ||||
|   return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); | ||||
| if (opName == "Max") | ||||
|   return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
| if (opName == "MaxPool") | ||||
|   return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); | ||||
| if (opName == "MaxRoiPool") | ||||
|   return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "MaxUnpool") | ||||
|   return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Mean") | ||||
|   return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
| if (opName == "MeanVarianceNormalization") | ||||
|   return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Min") | ||||
|   return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
| if (opName == "Mod") | ||||
|   return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Mul") | ||||
|   return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Multinomial") | ||||
|   return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Neg") | ||||
|   return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "NonMaxSuppression") | ||||
|   return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); | ||||
| if (opName == "NonZero") | ||||
|   return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Not") | ||||
|   return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "OneHot") | ||||
|   return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Or") | ||||
|   return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "PRelu") | ||||
|   return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Pad") | ||||
|   return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Pow") | ||||
|   return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "QLinearConv") | ||||
|   return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1); | ||||
| if (opName == "QLinearMatMul") | ||||
|   return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1); | ||||
| if (opName == "QuantizeLinear") | ||||
|   return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "RNN") | ||||
|   return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); | ||||
| if (opName == "RandomNormal") | ||||
|   return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); | ||||
| if (opName == "RandomNormalLike") | ||||
|   return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "RandomUniform") | ||||
|   return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); | ||||
| if (opName == "RandomUniformLike") | ||||
|   return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Range") | ||||
|   return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Reciprocal") | ||||
|   return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceL1") | ||||
|   return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceL2") | ||||
|   return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceLogSum") | ||||
|   return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceLogSumExp") | ||||
|   return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceMax") | ||||
|   return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceMean") | ||||
|   return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceMin") | ||||
|   return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceProd") | ||||
|   return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceSum") | ||||
|   return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ReduceSumSquare") | ||||
|   return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Relu") | ||||
|   return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Reshape") | ||||
|   return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Resize") | ||||
|   return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); | ||||
| if (opName == "ReverseSequence") | ||||
|   return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "RoiAlign") | ||||
|   return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Round") | ||||
|   return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Scan") | ||||
|   return buildOperation<mlir::ONNXScanOp>(node); | ||||
| if (opName == "Scatter") | ||||
|   return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "ScatterElements") | ||||
|   return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "ScatterND") | ||||
|   return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Selu") | ||||
|   return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "SequenceAt") | ||||
|   return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "SequenceConstruct") | ||||
|   return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
| if (opName == "SequenceEmpty") | ||||
|   return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); | ||||
| if (opName == "SequenceErase") | ||||
|   return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "SequenceInsert") | ||||
|   return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "SequenceLength") | ||||
|   return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Shape") | ||||
|   return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Shrink") | ||||
|   return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Sigmoid") | ||||
|   return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Sign") | ||||
|   return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Sin") | ||||
|   return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Sinh") | ||||
|   return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Size") | ||||
|   return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Slice") | ||||
|   return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); | ||||
| if (opName == "Softmax") | ||||
|   return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Softplus") | ||||
|   return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Softsign") | ||||
|   return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "SpaceToDepth") | ||||
|   return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Split") | ||||
|   return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); | ||||
| if (opName == "SplitToSequence") | ||||
|   return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Sqrt") | ||||
|   return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Squeeze") | ||||
|   return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "StringNormalizer") | ||||
|   return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Sub") | ||||
|   return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Sum") | ||||
|   return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
| if (opName == "Tan") | ||||
|   return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Tanh") | ||||
|   return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "TfIdfVectorizer") | ||||
|   return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "ThresholdedRelu") | ||||
|   return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Tile") | ||||
|   return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "TopK") | ||||
|   return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2); | ||||
| if (opName == "Transpose") | ||||
|   return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Unique") | ||||
|   return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4); | ||||
| if (opName == "Unsqueeze") | ||||
|   return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
| if (opName == "Upsample") | ||||
|   return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
| if (opName == "Where") | ||||
|   return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
| if (opName == "Xor") | ||||
|   return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|  |  | |||
|  | @ -17,20 +17,24 @@ struct ONNXGemmOpLowering : public ConversionPattern { | |||
|   matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||
|                   ConversionPatternRewriter &rewriter) const final { | ||||
|     auto loc = op->getLoc(); | ||||
|     auto has_bias = (operands.size() == 3); | ||||
|     // The first predicate is unnecessary when we remove ONXGemmNoBiasOp.
 | ||||
|     bool hasBias = (operands.size() == 3) && | ||||
|                    (!op->getOperand(2).getType().isa<NoneType>()); | ||||
| 
 | ||||
|     Value A, B, C; | ||||
|     A = operands[0]; | ||||
|     B = operands[1]; | ||||
|     if (has_bias) | ||||
|     if (hasBias) | ||||
|       C = operands[2]; | ||||
| 
 | ||||
|     auto memRefType = convertToMemRefType(*op->result_type_begin()); | ||||
| 
 | ||||
|     auto alphaAttr = FloatAttr::get(memRefType.getElementType(), | ||||
|         llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat()); | ||||
|     auto betaAttr = FloatAttr::get(memRefType.getElementType(), | ||||
|         llvm::dyn_cast<GemmOp>(op).beta().convertToFloat()); | ||||
|     auto alphaAttr = | ||||
|         FloatAttr::get(memRefType.getElementType(), | ||||
|                        llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat()); | ||||
|     auto betaAttr = | ||||
|         FloatAttr::get(memRefType.getElementType(), | ||||
|                        llvm::dyn_cast<GemmOp>(op).beta().convertToFloat()); | ||||
|     auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); | ||||
|     auto beta = rewriter.create<ConstantOp>(loc, betaAttr); | ||||
| 
 | ||||
|  | @ -68,8 +72,8 @@ struct ONNXGemmOpLowering : public ConversionPattern { | |||
|     // Define loops.
 | ||||
|     std::vector<Value> originalLoops; | ||||
|     std::vector<Value> optimizedLoops; | ||||
|     Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, | ||||
|             optimizedLoops, numLoops); | ||||
|     Block *optimizationBlock = | ||||
|         defineLoops(rewriter, loc, originalLoops, optimizedLoops, numLoops); | ||||
| 
 | ||||
|     // We have two Krnl loops:
 | ||||
|     // - Outer loop iterates over the output matrix dimensions, and
 | ||||
|  | @ -83,8 +87,7 @@ struct ONNXGemmOpLowering : public ConversionPattern { | |||
|       outerLoops.push_back(originalLoops[i]); | ||||
|       optimizedOuterLoops.push_back(optimizedLoops[i]); | ||||
|     } | ||||
|     KrnlIterateOperandPack outerPack(rewriter, outerLoops, | ||||
|                                       optimizedOuterLoops); | ||||
|     KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); | ||||
|     // Induction variables for the outer loops
 | ||||
|     for (int i = 0; i < 2; ++i) | ||||
|       addDimensionToPack(rewriter, loc, outerPack, alloc, i); | ||||
|  | @ -106,20 +109,19 @@ struct ONNXGemmOpLowering : public ConversionPattern { | |||
|     int64_t K_B_Idx = (isTransB) ? 1 : 0; | ||||
|     reductionPack.pushConstantBound(0); | ||||
|     if (ATy.getShape()[K_A_Idx] != -1) | ||||
|         reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); | ||||
|       reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); | ||||
|     else if (BTy.getShape()[K_B_Idx] != -1) | ||||
|       reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); | ||||
|     else | ||||
|       if (BTy.getShape()[K_B_Idx] != -1) | ||||
|         reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); | ||||
|       else | ||||
|         reductionPack.pushOperandBound( | ||||
|             rewriter.create<DimOp>(loc, B, K_B_Idx).getResult()); | ||||
|       reductionPack.pushOperandBound( | ||||
|           rewriter.create<DimOp>(loc, B, K_B_Idx).getResult()); | ||||
| 
 | ||||
|     // Get run-time dimension information for unknown dimensions used for
 | ||||
|     // broadcasting.
 | ||||
|     // GemmOp supports unidirectional broadcasting from C to A*B.
 | ||||
|     // Hence, it must be enough to get broadcasting information for C only.
 | ||||
|     std::map<int, Value> broadcastedDimInfo; | ||||
|     if (has_bias) { | ||||
|     if (hasBias) { | ||||
|       auto shape = C.getType().cast<MemRefType>().getShape(); | ||||
|       for (int i = 0; i < shape.size(); ++i) { | ||||
|         if (shape[i] < 0) { | ||||
|  | @ -162,7 +164,7 @@ struct ONNXGemmOpLowering : public ConversionPattern { | |||
|     // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
 | ||||
|     auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs); | ||||
|     auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB); | ||||
|     if (has_bias) { | ||||
|     if (hasBias) { | ||||
|       auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, | ||||
|                                                 broadcastedDimInfo); | ||||
|       auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs); | ||||
|  | @ -210,8 +212,8 @@ struct ONNXGemmOpLowering : public ConversionPattern { | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| void populateLoweringONNXGemmOpPattern( | ||||
|     OwningRewritePatternList &patterns, MLIRContext *ctx) { | ||||
| void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, | ||||
|                                        MLIRContext *ctx) { | ||||
|   patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx); | ||||
|   patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx); | ||||
| } | ||||
|  |  | |||
|  | @ -120,25 +120,19 @@ void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); } | |||
| // Tanh
 | ||||
| /// Infer the output shape of the ONNXTanhOp. This method is required by the
 | ||||
| /// shape inference interface.
 | ||||
| void ONNXTanhOp::inferShapes() { | ||||
|   getResult().setType(getOperand().getType()); | ||||
| } | ||||
| void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Sinh
 | ||||
| /// Infer the output shape of the ONNXSinhOp. This method is required by the
 | ||||
| /// shape inference interface.
 | ||||
| void ONNXSinhOp::inferShapes() { | ||||
|   getResult().setType(getOperand().getType()); | ||||
| } | ||||
| void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Cosh
 | ||||
| /// Infer the output shape of the ONNXCoshOp. This method is required by the
 | ||||
| /// shape inference interface.
 | ||||
| void ONNXCoshOp::inferShapes() { | ||||
|   getResult().setType(getOperand().getType()); | ||||
| } | ||||
| void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Cos
 | ||||
|  | @ -178,9 +172,7 @@ void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); } | |||
| // Relu
 | ||||
| /// Infer the output shape of the ONNXReluOp. This method is required by the
 | ||||
| /// shape inference interface.
 | ||||
| void ONNXReluOp::inferShapes() { | ||||
|   getResult().setType(getOperand().getType()); | ||||
| } | ||||
| void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // LeakyRelu
 | ||||
|  | @ -194,9 +186,7 @@ void ONNXLeakyReluOp::inferShapes() { | |||
| // Selu
 | ||||
| /// Infer the output shape of the ONNXSeluOp. This method is required by
 | ||||
| /// the shape inference interface.
 | ||||
| void ONNXSeluOp::inferShapes() { | ||||
|   getResult().setType(getOperand().getType()); | ||||
| } | ||||
| void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Reciprocal
 | ||||
|  | @ -234,17 +224,13 @@ void ONNXSoftsignOp::inferShapes() { | |||
| // Sqrt
 | ||||
| /// Infer the output shape of the ONNXSqrtOp. This method is required by
 | ||||
| /// the shape inference interface.
 | ||||
| void ONNXSqrtOp::inferShapes() { | ||||
|   getResult().setType(getOperand().getType()); | ||||
| } | ||||
| void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Sign
 | ||||
| /// Infer the output shape of the ONNXSignOp. This method is required by
 | ||||
| /// the shape inference interface.
 | ||||
| void ONNXSignOp::inferShapes() { | ||||
|   getResult().setType(getOperand().getType()); | ||||
| } | ||||
| void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Add
 | ||||
|  | @ -423,8 +409,7 @@ void ONNXMatMulOp::inferShapes() { | |||
|     // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
 | ||||
|     // need to be removed after the multiplication but cannot be removed if all
 | ||||
|     // sizes are 1.
 | ||||
|     if (lhsShape[0] != -1 && rhsShape[0] != -1 && | ||||
|         lhsShape[0] != rhsShape[0]) | ||||
|     if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0]) | ||||
|       emitError("Attempt to multiply incompatible matrices."); | ||||
|     dims.emplace_back(1); | ||||
|   } else if (lhsShape.size() == 1 && rhsShape.size() >= 2) { | ||||
|  | @ -541,14 +526,14 @@ void ONNXMatMulOp::inferShapes() { | |||
| // Gemm
 | ||||
| 
 | ||||
| void ONNXGemmOp::inferShapes() { | ||||
|   bool hasBias = !getOperand(2).getType().isa<NoneType>(); | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(1).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(2).getType().isa<RankedTensorType>()) | ||||
|       (hasBias && !getOperand(2).getType().isa<RankedTensorType>())) | ||||
|     return; | ||||
|   auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); | ||||
|   auto biasTy = getOperand(2).getType().cast<RankedTensorType>(); | ||||
| 
 | ||||
|   int64_t M, N, K_A, K_B; | ||||
|   M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; | ||||
|  | @ -560,15 +545,18 @@ void ONNXGemmOp::inferShapes() { | |||
|     emitError("Tensor shapes mismatched."); | ||||
|   } | ||||
| 
 | ||||
|   // Check whether bias is unidirectional broadcasting or not.
 | ||||
|   auto shape = biasTy.getShape(); | ||||
|   int rank = shape.size(); | ||||
|   if ((rank > 2) || | ||||
|       (rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] && | ||||
|        shape[rank - 1] != 1) || | ||||
|       (rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] && | ||||
|        shape[rank - 2] != 1)) { | ||||
|     emitError("Bias shape mismatched."); | ||||
|   if (hasBias) { | ||||
|     // Check whether bias is unidirectional broadcasting or not.
 | ||||
|     auto biasTy = getOperand(2).getType().cast<RankedTensorType>(); | ||||
|     auto shape = biasTy.getShape(); | ||||
|     int rank = shape.size(); | ||||
|     if ((rank > 2) || | ||||
|         (rank >= 1 && shape[rank - 1] != -1 && N != -1 && | ||||
|          N != shape[rank - 1] && shape[rank - 1] != 1) || | ||||
|         (rank == 2 && shape[rank - 2] != -1 && M != -1 && | ||||
|          M != shape[rank - 2] && shape[rank - 2] != 1)) { | ||||
|       emitError("Bias shape mismatched."); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   SmallVector<int64_t, 2> dims; | ||||
|  | @ -713,7 +701,6 @@ void ONNXTransposeOp::inferShapes() { | |||
|   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| // ReduceMax
 | ||||
|  | @ -801,7 +788,8 @@ void ONNXConvNoBiasOp::inferShapes() { | |||
|   // Required attribute auto_pad defaults to NOTSET.
 | ||||
|   auto autoPad = auto_pad(); | ||||
|   // Group is a required attribute and should have default value of 1.
 | ||||
|   int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
 | ||||
|   int64_t group = | ||||
|       ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
 | ||||
|   // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
 | ||||
|   if (dataShape[1] != (weightShape[1] * group)) | ||||
|     emitError("Channel dimension mismatch."); | ||||
|  | @ -859,8 +847,10 @@ void ONNXConvNoBiasOp::inferShapes() { | |||
|     if (dilations.getValue().size() != nDims) | ||||
|       emitError("dilations length incompatible with spatial dimensions."); | ||||
|     for (int i = 0; i < nDims; ++i) | ||||
|       kernelDims[i] = (kernelDims[i] + 1) * | ||||
|           (dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1; | ||||
|       kernelDims[i] = | ||||
|           (kernelDims[i] + 1) * | ||||
|               (dilations.getValue()[i]).cast<IntegerAttr>().getInt() - | ||||
|           1; | ||||
|   } | ||||
| 
 | ||||
|   // Subtract kernel dimensions from input data dimensions.
 | ||||
|  | @ -906,8 +896,7 @@ void ONNXConvNoBiasOp::inferShapes() { | |||
|     if (strides.getValue().size() != nDims) | ||||
|       emitError("strides length incompatible with spatial dimensions."); | ||||
|     for (int i = 0; i < nDims; ++i) { | ||||
|       int64_t stride = | ||||
|           strides.getValue()[i].cast<IntegerAttr>().getInt(); | ||||
|       int64_t stride = strides.getValue()[i].cast<IntegerAttr>().getInt(); | ||||
|       outSpatialDims[i] = floor(outSpatialDims[i] / stride); | ||||
|     } | ||||
|   } | ||||
|  | @ -934,12 +923,13 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { | |||
|   auto xRank = xShape.size(); | ||||
| 
 | ||||
|   // 2) analyse parameters
 | ||||
|   // get kernel sizes from kernel_shape attribute 
 | ||||
|   // get kernel sizes from kernel_shape attribute
 | ||||
|   auto kernelShape = kernel_shape(); | ||||
|   if (!kernelShape) | ||||
|     emitError("kernel_shape is a mandatory attribute for which there is no default."); | ||||
|     emitError( | ||||
|         "kernel_shape is a mandatory attribute for which there is no default."); | ||||
|   auto kernelShapeArray = kernelShape.getValue(); | ||||
|   auto kernelRank = kernelShape.size();  | ||||
|   auto kernelRank = kernelShape.size(); | ||||
|   if (kernelRank > xRank) | ||||
|     emitError("kernel_shape spatial dimension is too large."); | ||||
|   auto kernelOffset = xRank - kernelRank; | ||||
|  | @ -951,41 +941,42 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { | |||
|   SmallVector<int64_t, 4> actualDilations; | ||||
|   auto dilationsOpt = dilations(); | ||||
|   if (dilationsOpt.hasValue()) { | ||||
|     auto dilationsArray = dilationsOpt.getValue().getValue(); // opt -> attr -> array
 | ||||
|     auto dilationsArray = | ||||
|         dilationsOpt.getValue().getValue(); // opt -> attr -> array
 | ||||
|     if (dilationsArray.size() != kernelRank) | ||||
|         emitError("dialation rank is not the same as the spatial rank."); | ||||
|       emitError("dialation rank is not the same as the spatial rank."); | ||||
|     // fill in the actual values
 | ||||
|     for (int i = 0; i < kernelRank; ++i) { | ||||
|       int64_t d = (dilationsArray[i]).cast<IntegerAttr>().getInt(); | ||||
|       if (d < 1)  | ||||
|       if (d < 1) | ||||
|         emitError("dialation value must be nonzero positive."); | ||||
|       actualDilations.emplace_back(d); | ||||
|     } | ||||
|   } else { | ||||
|     for(int i=0; i < kernelRank; ++i) { | ||||
|       actualDilations.emplace_back(1);       | ||||
|     for (int i = 0; i < kernelRank; ++i) { | ||||
|       actualDilations.emplace_back(1); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // storage order
 | ||||
|    | ||||
| 
 | ||||
|   // strides
 | ||||
|   SmallVector<int64_t, 4> actualStrides; | ||||
|   auto stridesOpt = strides(); | ||||
|   if (stridesOpt.hasValue()) { | ||||
|     auto stridesArray = stridesOpt.getValue().getValue(); | ||||
|     if (stridesArray.size() != kernelRank) | ||||
|         emitError("strides rank is not the same as the spatial rank."); | ||||
|       emitError("strides rank is not the same as the spatial rank."); | ||||
|     // fill in the actual values
 | ||||
|     for (int i = 0; i < kernelRank; ++i) { | ||||
|       int64_t s = (stridesArray[i]).cast<IntegerAttr>().getInt(); | ||||
|       if (s < 1)  | ||||
|       if (s < 1) | ||||
|         emitError("strides value must be nonzero positive."); | ||||
|       actualStrides.emplace_back(s); | ||||
|     } | ||||
|   } else { | ||||
|     for(int i=0; i < kernelRank; ++i) { | ||||
|       actualStrides.emplace_back(1);       | ||||
|     for (int i = 0; i < kernelRank; ++i) { | ||||
|       actualStrides.emplace_back(1); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  | @ -1002,9 +993,9 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { | |||
|       if (padsArray.size() != 2 * kernelRank) | ||||
|         emitError("pads rank is not twice the spatial rank."); | ||||
|       // fill in the actual values
 | ||||
|       for (int i = 0; i < 2*kernelRank; ++i) { | ||||
|       for (int i = 0; i < 2 * kernelRank; ++i) { | ||||
|         int64_t p = (padsArray[i]).cast<IntegerAttr>().getInt(); | ||||
|         if (p < 0)  | ||||
|         if (p < 0) | ||||
|           emitError("pads value must be nonnegative."); | ||||
|         actualPads.emplace_back(p); | ||||
|       } | ||||
|  | @ -1016,24 +1007,26 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { | |||
|     defaultPads = true; | ||||
|   } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { | ||||
|     // init pad with zero
 | ||||
|     for(int i=0; i<2*kernelRank; ++i) { | ||||
|     for (int i = 0; i < 2 * kernelRank; ++i) { | ||||
|       actualPads.emplace_back(0); | ||||
|     } | ||||
|     for(int i=0; i<kernelRank; ++i) { | ||||
|       auto inputSpatialShape = xShape[kernelOffset  + i]; | ||||
|       auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt(); | ||||
|     for (int i = 0; i < kernelRank; ++i) { | ||||
|       auto inputSpatialShape = xShape[kernelOffset + i]; | ||||
|       auto kernelSpatialShape = | ||||
|           (kernelShapeArray[i]).cast<IntegerAttr>().getInt(); | ||||
|       auto dilations = actualDilations[i]; | ||||
|       auto strideSpatialShape = actualStrides[i]; | ||||
|       int64_t outputSpatialShape = ceil((1.0 * inputSpatialShape) / | ||||
|         (1.0 * strideSpatialShape)); | ||||
|       auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +  | ||||
|         ((kernelSpatialShape - 1) * dilations + 1) - inputSpatialShape; | ||||
|       int64_t outputSpatialShape = | ||||
|           ceil((1.0 * inputSpatialShape) / (1.0 * strideSpatialShape)); | ||||
|       auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape + | ||||
|                       ((kernelSpatialShape - 1) * dilations + 1) - | ||||
|                       inputSpatialShape; | ||||
|       actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2; | ||||
|       if (sumOfPad % 2 != 0) { | ||||
|         if (autoPad == "SAME_UPPER") { | ||||
|           actualPads[kernelRank + i] += 1; | ||||
|         } else { | ||||
|           actualPads[i] += 1;           | ||||
|           actualPads[i] += 1; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | @ -1042,24 +1035,26 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { | |||
|   } | ||||
|   // handle case where default pad values must be used
 | ||||
|   if (defaultPads) { | ||||
|     for(int i=0; i<2*kernelRank; ++i) { | ||||
|     for (int i = 0; i < 2 * kernelRank; ++i) { | ||||
|       actualPads.emplace_back(0); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // initialize output shape 
 | ||||
|   // initialize output shape
 | ||||
|   SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end()); | ||||
|   // for all kernel dimensions
 | ||||
|   for(int i=0; i<kernelRank; ++i) { | ||||
|     auto inputSpatialShape = xShape[kernelOffset  + i]; | ||||
|     auto padShape = actualPads[i] + actualPads[kernelRank+i]; | ||||
|     auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt(); | ||||
|   for (int i = 0; i < kernelRank; ++i) { | ||||
|     auto inputSpatialShape = xShape[kernelOffset + i]; | ||||
|     auto padShape = actualPads[i] + actualPads[kernelRank + i]; | ||||
|     auto kernelSpatialShape = | ||||
|         (kernelShapeArray[i]).cast<IntegerAttr>().getInt(); | ||||
|     auto dilations = actualDilations[i]; | ||||
|     auto strideSpatialShape = actualStrides[i]; | ||||
|     ///output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] - 
 | ||||
|     //  ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
 | ||||
|     double numerator = inputSpatialShape + padShape -  | ||||
|       ((kernelSpatialShape - 1) * dilations + 1); | ||||
|     /// output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] -
 | ||||
|     //  ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) /
 | ||||
|     //  strides_spatial_shape[i] + 1)
 | ||||
|     double numerator = inputSpatialShape + padShape - | ||||
|                        ((kernelSpatialShape - 1) * dilations + 1); | ||||
|     double denominator = strideSpatialShape; | ||||
|     int64_t res; | ||||
|     if (ceilMode) { | ||||
|  |  | |||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -127,6 +127,10 @@ int main(int argc, char *argv[]) { | |||
| 
 | ||||
|   if (emissionTarget >= EmitMLIR) { | ||||
|     pm.addPass(mlir::createLowerToKrnlPass()); | ||||
|     // An additional pass of canonicalization is helpful because lowering
 | ||||
|     // from ONNX dialect to Standard dialect exposes additional canonicalization
 | ||||
|     // oppertunities.
 | ||||
|     pm.addPass(mlir::createCanonicalizerPass()); | ||||
|     pm.addPass(mlir::createLowerKrnlPass()); | ||||
|   } | ||||
| 
 | ||||
|  |  | |||
|  | @ -28,6 +28,11 @@ void ONNXAddOp::getCanonicalizationPatterns( | |||
|     OwningRewritePatternList& results, MLIRContext* context) { | ||||
|   results.insert<MulAddToGemmOptPattern>(context); | ||||
| } | ||||
| 
 | ||||
| void ONNXGemmOp::getCanonicalizationPatterns( | ||||
|         OwningRewritePatternList& results, MLIRContext* context) { | ||||
|     results.insert<FuseGemmFollowedByAddition>(context); | ||||
| } | ||||
| /// on the ONNXIdentityOp.
 | ||||
| void ONNXIdentityOp::getCanonicalizationPatterns( | ||||
|     OwningRewritePatternList& results, MLIRContext* context) { | ||||
|  |  | |||
|  | @ -26,6 +26,7 @@ include "dialect/onnx/onnx.td" | |||
| 
 | ||||
| def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>; | ||||
| class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>; | ||||
| def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // Pattern-Match and Rewrite | ||||
|  | @ -41,6 +42,11 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), | |||
|                                  (ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)), | ||||
|                                  [(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>; | ||||
| 
 | ||||
| // onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z) | ||||
| def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias), | ||||
|                                      (ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB), | ||||
|                                      [(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>; | ||||
| 
 | ||||
| // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) | ||||
| def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg), | ||||
|                                      (replaceWithValue $arg)>; | ||||
|  |  | |||
|  | @ -101,3 +101,14 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32> | |||
|   // CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> | ||||
|   // CHECK-NEXT: return %1 : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| //CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> { | ||||
| func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> { | ||||
|   %cst = constant unit | ||||
|   %0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32> | ||||
|   %1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32> | ||||
|   return %1 : tensor<*xf32> | ||||
| 
 | ||||
|   // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32> | ||||
|   // return [[GEMM]] : tensor<*xf32> | ||||
| } | ||||
		Loading…
	
		Reference in New Issue