Merge branch 'master' into shapeinference-pad
This commit is contained in:
commit
d4f8fef947
|
@ -327,10 +327,10 @@ ONNX BatchNormalization operation
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values
|
||||||
1. `out_mean`: memref of any type values or tensor of any type values
|
1. `out_mean`: memref of any type values or tensor of any type values or none type
|
||||||
1. `out_var`: memref of any type values or tensor of any type values
|
1. `out_var`: memref of any type values or tensor of any type values or none type
|
||||||
1. `saved_mean`: memref of any type values or tensor of any type values
|
1. `saved_mean`: memref of any type values or tensor of any type values or none type
|
||||||
1. `saved_var`: memref of any type values or tensor of any type values
|
1. `saved_var`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
|
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
|
||||||
ONNX BatchNormalization operation in test mode
|
ONNX BatchNormalization operation in test mode
|
||||||
|
@ -375,12 +375,12 @@ ONNX BitShift operation
|
||||||
|
|
||||||
|
|
||||||
"Bitwise shift operator performs element-wise operation. For each input element, if the"
|
"Bitwise shift operator performs element-wise operation. For each input element, if the"
|
||||||
" attribute "direction" is "RIGHT", this operator moves its binary representation toward"
|
" attribute \"direction\" is \"RIGHT\", this operator moves its binary representation toward"
|
||||||
" the right side so that the input value is effectively decreased. If the attribute "direction""
|
" the right side so that the input value is effectively decreased. If the attribute \"direction\""
|
||||||
" is "LEFT", bits of binary representation moves toward the left side, which results the"
|
" is \"LEFT\", bits of binary representation moves toward the left side, which results the"
|
||||||
" increase of its actual value. The input X is the tensor to be shifted and another input"
|
" increase of its actual value. The input X is the tensor to be shifted and another input"
|
||||||
" Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4],"
|
" Y specifies the amounts of shifting. For example, if \"direction\" is \"Right\", X is [1, 4],"
|
||||||
" and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with"
|
" and S is [1, 1], the corresponding output Z would be [0, 2]. If \"direction\" is \"LEFT\" with"
|
||||||
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
|
" X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]."
|
||||||
" "
|
" "
|
||||||
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
|
" Because this operator supports Numpy-style broadcasting, X's and Y's shapes are"
|
||||||
|
@ -413,15 +413,15 @@ ONNX Cast operation
|
||||||
"the converted type. The 'to' argument must be one of the data types specified"
|
"the converted type. The 'to' argument must be one of the data types specified"
|
||||||
"in the 'DataType' enum field in the TensorProto message."
|
"in the 'DataType' enum field in the TensorProto message."
|
||||||
""
|
""
|
||||||
"Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations"
|
"Casting from string tensor in plain (e.g., \"3.14\" and \"1000\") and scientific numeric representations"
|
||||||
"(e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may"
|
"(e.g., \"1e-5\" and \"1E8\") to float types is supported. For example, converting string \"100.5\" to an integer may"
|
||||||
"result 100. There are some string literals reserved for special floating-point values;"
|
"result 100. There are some string literals reserved for special floating-point values;"
|
||||||
""+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively."
|
"\"+INF\" (and \"INF\"), \"-INF\", and \"NaN\" are positive infinity, negative infinity, and not-a-number, respectively."
|
||||||
"Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly,"
|
"Any string which can exactly match \"+INF\" in a case-insensitive way would be mapped to positive infinite. Similarly,"
|
||||||
"this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors"
|
"this case-insensitive rule is applied to \"INF\" and \"NaN\". When casting from numeric tensors"
|
||||||
"to string tensors, plain floating-point representation (such as "314.15926") would be used. "
|
"to string tensors, plain floating-point representation (such as \"314.15926\") would be used. "
|
||||||
"Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases "
|
"Converting non-numerical-literal string such as \"Hello World!\" is an undefined behavior. Cases "
|
||||||
"of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior."
|
"of converting string representing floating-point arithmetic value, such as \"2.718\", to INT is an undefined behavior."
|
||||||
""
|
""
|
||||||
"Conversion from a numerical type to any numerical type is always allowed."
|
"Conversion from a numerical type to any numerical type is always allowed."
|
||||||
"User must be aware of precision loss and value change caused by range difference between two types."
|
"User must be aware of precision loss and value change caused by range difference between two types."
|
||||||
|
@ -476,8 +476,8 @@ ONNX Clip operation
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `input`: memref of any type values or tensor of any type values
|
1. `input`: memref of any type values or tensor of any type values
|
||||||
1. `min`: memref of any type values or tensor of any type values
|
1. `min`: memref of any type values or tensor of any type values or none type
|
||||||
1. `max`: memref of any type values or tensor of any type values
|
1. `max`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -618,8 +618,8 @@ ONNX ConvInteger operation
|
||||||
|
|
||||||
1. `x`: memref of any type values or tensor of any type values
|
1. `x`: memref of any type values or tensor of any type values
|
||||||
1. `w`: memref of any type values or tensor of any type values
|
1. `w`: memref of any type values or tensor of any type values
|
||||||
1. `x_zero_point`: memref of any type values or tensor of any type values
|
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
1. `w_zero_point`: memref of any type values or tensor of any type values
|
1. `w_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -678,7 +678,7 @@ ONNX Conv operation
|
||||||
|
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -720,7 +720,7 @@ ONNX ConvTranspose operation
|
||||||
|
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -884,7 +884,7 @@ ONNX DequantizeLinear operation
|
||||||
|
|
||||||
1. `x`: memref of any type values or tensor of any type values
|
1. `x`: memref of any type values or tensor of any type values
|
||||||
1. `x_scale`: memref of any type values or tensor of any type values
|
1. `x_scale`: memref of any type values or tensor of any type values
|
||||||
1. `x_zero_point`: memref of any type values or tensor of any type values
|
1. `x_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -964,7 +964,7 @@ ONNX Dropout operation
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `output`: memref of any type values or tensor of any type values
|
1. `output`: memref of any type values or tensor of any type values
|
||||||
1. `mask`: memref of any type values or tensor of any type values
|
1. `mask`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
|
### onnx.DynamicQuantizeLinear (ONNXDynamicQuantizeLinearOp)
|
||||||
ONNX DynamicQuantizeLinear operation
|
ONNX DynamicQuantizeLinear operation
|
||||||
|
@ -1297,9 +1297,9 @@ ONNX GRU operation
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `R`: memref of any type values or tensor of any type values
|
1. `R`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
1. `sequence_lens`: memref of any type values or tensor of any type values
|
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
|
||||||
1. `initial_h`: memref of any type values or tensor of any type values
|
1. `initial_h`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -1315,8 +1315,8 @@ ONNX GRU operation
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values or none type
|
||||||
1. `Y_h`: memref of any type values or tensor of any type values
|
1. `Y_h`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.GatherElements (ONNXGatherElementsOp)
|
### onnx.GatherElements (ONNXGatherElementsOp)
|
||||||
ONNX GatherElements operation
|
ONNX GatherElements operation
|
||||||
|
@ -1609,7 +1609,7 @@ ONNX Gemm operation
|
||||||
|
|
||||||
1. `A`: memref of any type values or tensor of any type values
|
1. `A`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values
|
||||||
1. `C`: memref of any type values or tensor of any type values
|
1. `C`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -2013,11 +2013,11 @@ ONNX LSTM operation
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `R`: memref of any type values or tensor of any type values
|
1. `R`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
1. `sequence_lens`: memref of any type values or tensor of any type values
|
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
|
||||||
1. `initial_h`: memref of any type values or tensor of any type values
|
1. `initial_h`: memref of any type values or tensor of any type values or none type
|
||||||
1. `initial_c`: memref of any type values or tensor of any type values
|
1. `initial_c`: memref of any type values or tensor of any type values or none type
|
||||||
1. `P`: memref of any type values or tensor of any type values
|
1. `P`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -2033,9 +2033,9 @@ ONNX LSTM operation
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values or none type
|
||||||
1. `Y_h`: memref of any type values or tensor of any type values
|
1. `Y_h`: memref of any type values or tensor of any type values or none type
|
||||||
1. `Y_c`: memref of any type values or tensor of any type values
|
1. `Y_c`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.LeakyRelu (ONNXLeakyReluOp)
|
### onnx.LeakyRelu (ONNXLeakyReluOp)
|
||||||
ONNX LeakyRelu operation
|
ONNX LeakyRelu operation
|
||||||
|
@ -2160,24 +2160,24 @@ ONNX Loop operation
|
||||||
""
|
""
|
||||||
" Operator inputs defined as (max_trip_count, condition_var)."
|
" Operator inputs defined as (max_trip_count, condition_var)."
|
||||||
""
|
""
|
||||||
" input ("", ""):"
|
" input (\"\", \"\"):"
|
||||||
" for (int i=0; ; ++i) {"
|
" for (int i=0; ; ++i) {"
|
||||||
" cond = ... // Note this value is ignored, but is required in the body"
|
" cond = ... // Note this value is ignored, but is required in the body"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
" input ("", cond) // Note this is analogous to a while loop"
|
" input (\"\", cond) // Note this is analogous to a while loop"
|
||||||
" bool cond = ...;"
|
" bool cond = ...;"
|
||||||
" for (int i=0; cond; ++i) {"
|
" for (int i=0; cond; ++i) {"
|
||||||
" cond = ...;"
|
" cond = ...;"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
" input ("", 1) // Note this is analogous to a do-while loop"
|
" input (\"\", 1) // Note this is analogous to a do-while loop"
|
||||||
" bool cond = true"
|
" bool cond = true"
|
||||||
" for (int i=0; cond; ++i) {"
|
" for (int i=0; cond; ++i) {"
|
||||||
" cond = ...;"
|
" cond = ...;"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
" input (trip_count, "") // Note this is analogous to a for loop"
|
" input (trip_count, \"\") // Note this is analogous to a for loop"
|
||||||
" int trip_count = ..."
|
" int trip_count = ..."
|
||||||
" for (int i=0; i < trip_count; ++i) {"
|
" for (int i=0; i < trip_count; ++i) {"
|
||||||
" cond = ...; // ignored"
|
" cond = ...; // ignored"
|
||||||
|
@ -2203,15 +2203,15 @@ ONNX Loop operation
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
" graph body-net ("
|
" graph body-net ("
|
||||||
" %i[INT32, scalar] // iteration number"
|
" %i[INT32, scalar]"
|
||||||
" %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used"
|
" %keepgoing[BOOL, scalar]"
|
||||||
" %b_in[INT32, scalar] // incoming value of loop-carried-dependency b"
|
" %b[INT32, scalar]"
|
||||||
" ) {"
|
" ) {"
|
||||||
" %my_local = Add(%a, %b_in)"
|
" %my_local = Add(%a, %b)"
|
||||||
" %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b"
|
" %b_out = Sub(%a, %b)"
|
||||||
" %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition"
|
" %keepgoing_out = Greater(%my_local, %b_out)"
|
||||||
" %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated"
|
" %user_defined_vals = Add(%b, %b)"
|
||||||
" return %keepgoing_out, %b_out, %user_defined_val"
|
" return %keepgoing_out, %b_out, %user_defined_vals"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
"*Sample equivalent C code*"
|
"*Sample equivalent C code*"
|
||||||
|
@ -2226,51 +2226,31 @@ ONNX Loop operation
|
||||||
" const int max_trip_count = 10; // Analogous to input M"
|
" const int max_trip_count = 10; // Analogous to input M"
|
||||||
" int user_defined_vals[]; // Imagine this is resizable"
|
" int user_defined_vals[]; // Imagine this is resizable"
|
||||||
" /* End implicitly-defined code */"
|
" /* End implicitly-defined code */"
|
||||||
" /* initialize loop-carried variables and scan-output variables */"
|
" for (int i=0; i < max_trip_count && keepgoing; ++i) {"
|
||||||
" bool keepgoing_out = keepgoing"
|
|
||||||
" int b_out = b"
|
|
||||||
""
|
|
||||||
" for (int i=0; i < max_trip_count && keepgoing_out; ++i) {"
|
|
||||||
" /* Implicitly-defined code: bind actual parameter values"
|
|
||||||
" to formal parameter variables of loop-body */"
|
|
||||||
" bool keepgoing_in = keepgoing_out; "
|
|
||||||
" bool b_in = b_out;"
|
|
||||||
""
|
|
||||||
" /* User-defined code (loop body) */"
|
" /* User-defined code (loop body) */"
|
||||||
" int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine"
|
" int my_local = a + b; // Reading values in the enclosing scope is fine"
|
||||||
" b_out = a - b_in;"
|
" b = a - b; // writes fine if we specify b as a loop-carried dependency"
|
||||||
" keepgoing_out = my_local > b_out; "
|
" keepgoing = my_local > b; // keepgoing is a loop-carried dependency"
|
||||||
" user_defined_val = b_in + b_in; // b_in and b_out are different variables"
|
" user_defined_vals[i] = b + b;"
|
||||||
" /* End user-defined code */"
|
" /* End user-defined code */"
|
||||||
""
|
|
||||||
" /* Implicitly defined-code */"
|
|
||||||
" user_defined_vals[i] = user_defined_val // accumulate scan-output values"
|
|
||||||
" }"
|
" }"
|
||||||
" // int t = my_local; // Can't do this. my_local is not accessible here."
|
" // my_local = 123; // Can't do this. my_local was defined in the the body"
|
||||||
""
|
""
|
||||||
" // The values below are bound to the output variables of the loop and therefore accessible"
|
" // These below values are live-out from the loop and therefore accessible"
|
||||||
" // b_out; user_defined_vals; keepgoing_out;"
|
" b_out; user_defined_vals; keepgoing_out;"
|
||||||
" }"
|
" }"
|
||||||
""
|
""
|
||||||
"There are several things of note in this code snippet:"
|
"There are several things of note in this code snippet:"
|
||||||
""
|
""
|
||||||
"1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can"
|
"1) Values from the enclosing scope (i.e. variable a here) are in scope and can"
|
||||||
" be referenced in the inputs of the loop."
|
" be referenced in the inputs of the loop."
|
||||||
"2) Any values computed in the loop body that needs to be used in a subsequent"
|
"2) Any variables which you wish to make available in the enclosing scope (i.e."
|
||||||
" iteration or after the loop are modelled using a pair of variables in the loop-body,"
|
" the variables b and keepgoing) must be declared as either loop-carried"
|
||||||
" consisting of an input variable (eg., b_in) and an output variable (eg., b_out)."
|
" dependencies (both at the op inputs and output and at the body net input and"
|
||||||
" These are referred to as loop-carried dependences. The loop operation node"
|
" output) or scan_outputs."
|
||||||
" supplies the input value of the input variable for the first iteration, and"
|
"3) Values created in the body cannot be accessed in the enclosing scope."
|
||||||
" returns the output value of the output variable produced by the final"
|
|
||||||
" iteration."
|
|
||||||
"3) Scan_output variables are used to implicitly concatenate values computed across"
|
|
||||||
" all the iterations. In the above example, the value of user_defined_val computed"
|
|
||||||
" over all iterations are concatenated and returned as the value of user_defined_vals"
|
|
||||||
" after the loop."
|
|
||||||
"4) Values created in the body cannot be accessed in the enclosing scope,"
|
|
||||||
" except using the mechanism described above."
|
|
||||||
""
|
""
|
||||||
"Note that the semantics of this op support "diagonal" or "wavefront" execution."
|
"Note that the semantics of this op support \"diagonal\" or \"wavefront\" execution."
|
||||||
"(See Step 3 here for an example:"
|
"(See Step 3 here for an example:"
|
||||||
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
|
"https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/)."
|
||||||
"Frontends should emit multi-layer RNNs as a series of While operators (with"
|
"Frontends should emit multi-layer RNNs as a series of While operators (with"
|
||||||
|
@ -2280,8 +2260,8 @@ ONNX Loop operation
|
||||||
|
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `M`: memref of any type values or tensor of any type values
|
1. `M`: memref of any type values or tensor of any type values or none type
|
||||||
1. `cond`: memref of any type values or tensor of any type values
|
1. `cond`: memref of any type values or tensor of any type values or none type
|
||||||
1. `v_initial`: memref of any type values or tensor of any type values
|
1. `v_initial`: memref of any type values or tensor of any type values
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
@ -2360,8 +2340,8 @@ ONNX MatMulInteger operation
|
||||||
|
|
||||||
1. `A`: memref of any type values or tensor of any type values
|
1. `A`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values
|
||||||
1. `a_zero_point`: memref of any type values or tensor of any type values
|
1. `a_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
1. `b_zero_point`: memref of any type values or tensor of any type values
|
1. `b_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -2444,7 +2424,7 @@ ONNX MaxPool operation
|
||||||
" ```"
|
" ```"
|
||||||
" pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]"
|
" pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]"
|
||||||
" ```"
|
" ```"
|
||||||
" The output of each pooling window is maximum number of elements exclude pad. "
|
" The output of each pooling window is maximum number of elements exclude pad."
|
||||||
" "
|
" "
|
||||||
|
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
@ -2466,7 +2446,7 @@ ONNX MaxPool operation
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values
|
||||||
1. `Indices`: memref of any type values or tensor of any type values
|
1. `Indices`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
|
### onnx.MaxPoolSingleOut (ONNXMaxPoolSingleOutOp)
|
||||||
ONNX MaxPool operation with a single output.
|
ONNX MaxPool operation with a single output.
|
||||||
|
@ -2552,7 +2532,7 @@ ONNX MaxUnpool operation
|
||||||
|
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `I`: memref of any type values or tensor of any type values
|
1. `I`: memref of any type values or tensor of any type values
|
||||||
1. `output_shape`: memref of any type values or tensor of any type values
|
1. `output_shape`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -2752,9 +2732,9 @@ ONNX NonMaxSuppression operation
|
||||||
|
|
||||||
1. `boxes`: memref of any type values or tensor of any type values
|
1. `boxes`: memref of any type values or tensor of any type values
|
||||||
1. `scores`: memref of any type values or tensor of any type values
|
1. `scores`: memref of any type values or tensor of any type values
|
||||||
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values
|
1. `max_output_boxes_per_class`: memref of any type values or tensor of any type values or none type
|
||||||
1. `iou_threshold`: memref of any type values or tensor of any type values
|
1. `iou_threshold`: memref of any type values or tensor of any type values or none type
|
||||||
1. `score_threshold`: memref of any type values or tensor of any type values
|
1. `score_threshold`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3067,7 +3047,7 @@ ONNX Pad operation
|
||||||
|
|
||||||
1. `data`: memref of any type values or tensor of any type values
|
1. `data`: memref of any type values or tensor of any type values
|
||||||
1. `pads`: memref of any type values or tensor of any type values
|
1. `pads`: memref of any type values or tensor of any type values
|
||||||
1. `constant_value`: memref of any type values or tensor of any type values
|
1. `constant_value`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3124,7 +3104,7 @@ ONNX QLinearConv operation
|
||||||
1. `w_zero_point`: memref of any type values or tensor of any type values
|
1. `w_zero_point`: memref of any type values or tensor of any type values
|
||||||
1. `y_scale`: memref of any type values or tensor of any type values
|
1. `y_scale`: memref of any type values or tensor of any type values
|
||||||
1. `y_zero_point`: memref of any type values or tensor of any type values
|
1. `y_zero_point`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3188,7 +3168,7 @@ ONNX QuantizeLinear operation
|
||||||
|
|
||||||
1. `x`: memref of any type values or tensor of any type values
|
1. `x`: memref of any type values or tensor of any type values
|
||||||
1. `y_scale`: memref of any type values or tensor of any type values
|
1. `y_scale`: memref of any type values or tensor of any type values
|
||||||
1. `y_zero_point`: memref of any type values or tensor of any type values
|
1. `y_zero_point`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3270,9 +3250,9 @@ ONNX RNN operation
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `W`: memref of any type values or tensor of any type values
|
1. `W`: memref of any type values or tensor of any type values
|
||||||
1. `R`: memref of any type values or tensor of any type values
|
1. `R`: memref of any type values or tensor of any type values
|
||||||
1. `B`: memref of any type values or tensor of any type values
|
1. `B`: memref of any type values or tensor of any type values or none type
|
||||||
1. `sequence_lens`: memref of any type values or tensor of any type values
|
1. `sequence_lens`: memref of any type values or tensor of any type values or none type
|
||||||
1. `initial_h`: memref of any type values or tensor of any type values
|
1. `initial_h`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -3287,8 +3267,8 @@ ONNX RNN operation
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values or none type
|
||||||
1. `Y_h`: memref of any type values or tensor of any type values
|
1. `Y_h`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
|
### onnx.RandomNormalLike (ONNXRandomNormalLikeOp)
|
||||||
ONNX RandomNormalLike operation
|
ONNX RandomNormalLike operation
|
||||||
|
@ -3813,14 +3793,14 @@ ONNX Resize operation
|
||||||
|
|
||||||
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
|
"Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor."
|
||||||
"Each dimension value of the output tensor is:"
|
"Each dimension value of the output tensor is:"
|
||||||
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified."
|
" output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified."
|
||||||
|
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `X`: memref of any type values or tensor of any type values
|
1. `X`: memref of any type values or tensor of any type values
|
||||||
1. `roi`: memref of any type values or tensor of any type values
|
1. `roi`: memref of any type values or tensor of any type values
|
||||||
1. `scales`: memref of any type values or tensor of any type values
|
1. `scales`: memref of any type values or tensor of any type values
|
||||||
1. `sizes`: memref of any type values or tensor of any type values
|
1. `sizes`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4438,7 +4418,7 @@ ONNX SequenceErase operation
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `input_sequence`: memref of any type values or tensor of any type values
|
1. `input_sequence`: memref of any type values or tensor of any type values
|
||||||
1. `position`: memref of any type values or tensor of any type values
|
1. `position`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4463,7 +4443,7 @@ ONNX SequenceInsert operation
|
||||||
|
|
||||||
1. `input_sequence`: memref of any type values or tensor of any type values
|
1. `input_sequence`: memref of any type values or tensor of any type values
|
||||||
1. `tensor`: memref of any type values or tensor of any type values
|
1. `tensor`: memref of any type values or tensor of any type values
|
||||||
1. `position`: memref of any type values or tensor of any type values
|
1. `position`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4680,8 +4660,8 @@ ONNX Slice operation
|
||||||
1. `data`: memref of any type values or tensor of any type values
|
1. `data`: memref of any type values or tensor of any type values
|
||||||
1. `starts`: memref of any type values or tensor of any type values
|
1. `starts`: memref of any type values or tensor of any type values
|
||||||
1. `ends`: memref of any type values or tensor of any type values
|
1. `ends`: memref of any type values or tensor of any type values
|
||||||
1. `axes`: memref of any type values or tensor of any type values
|
1. `axes`: memref of any type values or tensor of any type values or none type
|
||||||
1. `steps`: memref of any type values or tensor of any type values
|
1. `steps`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4834,7 +4814,7 @@ ONNX SplitToSequence operation
|
||||||
#### Operands:
|
#### Operands:
|
||||||
|
|
||||||
1. `input`: memref of any type values or tensor of any type values
|
1. `input`: memref of any type values or tensor of any type values
|
||||||
1. `split`: memref of any type values or tensor of any type values
|
1. `split`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Attributes:
|
#### Attributes:
|
||||||
|
|
||||||
|
@ -4902,9 +4882,9 @@ ONNX StringNormalizer operation
|
||||||
"StringNormalization performs string operations for basic cleaning."
|
"StringNormalization performs string operations for basic cleaning."
|
||||||
"This operator has only one input (denoted by X) and only one output"
|
"This operator has only one input (denoted by X) and only one output"
|
||||||
"(denoted by Y). This operator first examines the elements in the X,"
|
"(denoted by Y). This operator first examines the elements in the X,"
|
||||||
"and removes elements specified in "stopwords" attribute."
|
"and removes elements specified in \"stopwords\" attribute."
|
||||||
"After removing stop words, the intermediate result can be further lowercased,"
|
"After removing stop words, the intermediate result can be further lowercased,"
|
||||||
"uppercased, or just returned depending the "case_change_action" attribute."
|
"uppercased, or just returned depending the \"case_change_action\" attribute."
|
||||||
"This operator only accepts [C]- and [1, C]-tensor."
|
"This operator only accepts [C]- and [1, C]-tensor."
|
||||||
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
|
"If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]"
|
||||||
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
|
"if input shape is [C] and shape [1, 1] if input shape is [1, C]."
|
||||||
|
@ -5034,8 +5014,8 @@ ONNX TfIdfVectorizer operation
|
||||||
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
|
"respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output."
|
||||||
"Note that we may consider all skips up to S when generating the n-grams."
|
"Note that we may consider all skips up to S when generating the n-grams."
|
||||||
""
|
""
|
||||||
"The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and"
|
"The examples used above are true if mode is \"TF\". If mode is \"IDF\", all the counts larger than 1 would be truncated to 1 and"
|
||||||
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF","
|
"the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is \"TFIDF\","
|
||||||
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
|
"this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute."
|
||||||
""
|
""
|
||||||
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
|
"Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor."
|
||||||
|
@ -5123,9 +5103,9 @@ ONNX TopK operation
|
||||||
" contains the indices of the top k elements (original indices from the input"
|
" contains the indices of the top k elements (original indices from the input"
|
||||||
" tensor)."
|
" tensor)."
|
||||||
""
|
""
|
||||||
"If "largest" is 1 (the default value) then the k largest elements are returned."
|
"If \"largest\" is 1 (the default value) then the k largest elements are returned."
|
||||||
"If "sorted" is 1 (the default value) then the resulting k elements will be sorted."
|
"If \"sorted\" is 1 (the default value) then the resulting k elements will be sorted."
|
||||||
"If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined."
|
"If \"sorted\" is 0, order of returned 'Values' and 'Indices' are undefined."
|
||||||
""
|
""
|
||||||
"Given two equivalent values, this operator uses the indices along the axis as"
|
"Given two equivalent values, this operator uses the indices along the axis as"
|
||||||
" a tiebreaker. That is, the element with the lower index will appear first."
|
" a tiebreaker. That is, the element with the lower index will appear first."
|
||||||
|
@ -5184,7 +5164,7 @@ ONNX Unique operation
|
||||||
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
|
"This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. "
|
||||||
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
|
"The first output tensor 'Y' contains all unique values or subtensors of the input. "
|
||||||
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
|
"The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. "
|
||||||
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". "
|
"The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. \". "
|
||||||
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
|
"The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. "
|
||||||
""
|
""
|
||||||
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
|
"Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. "
|
||||||
|
@ -5268,9 +5248,9 @@ ONNX Unique operation
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
1. `Y`: memref of any type values or tensor of any type values
|
1. `Y`: memref of any type values or tensor of any type values
|
||||||
1. `indices`: memref of any type values or tensor of any type values
|
1. `indices`: memref of any type values or tensor of any type values or none type
|
||||||
1. `inverse_indices`: memref of any type values or tensor of any type values
|
1. `inverse_indices`: memref of any type values or tensor of any type values or none type
|
||||||
1. `counts`: memref of any type values or tensor of any type values
|
1. `counts`: memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
### onnx.Unsqueeze (ONNXUnsqueezeOp)
|
### onnx.Unsqueeze (ONNXUnsqueezeOp)
|
||||||
ONNX Unsqueeze operation
|
ONNX Unsqueeze operation
|
||||||
|
|
1041
doc/gen_doc.py
1041
doc/gen_doc.py
File diff suppressed because it is too large
Load Diff
|
@ -121,6 +121,7 @@ private:
|
||||||
mlir::MLIRContext &context_;
|
mlir::MLIRContext &context_;
|
||||||
mlir::ModuleOp module_;
|
mlir::ModuleOp module_;
|
||||||
mlir::OpBuilder builder_;
|
mlir::OpBuilder builder_;
|
||||||
|
mlir::Value none_;
|
||||||
// mapping between string name and symbol
|
// mapping between string name and symbol
|
||||||
OnnxOnnfSymbolMapping frontend_symbols_;
|
OnnxOnnfSymbolMapping frontend_symbols_;
|
||||||
|
|
||||||
|
@ -287,8 +288,8 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
std::vector<mlir::NamedAttribute>
|
||||||
const onnx::NodeProto &node) {
|
ImportNodeAttributes(const onnx::NodeProto &node) {
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||||
auto attr = node.attribute(i);
|
auto attr = node.attribute(i);
|
||||||
|
@ -317,21 +318,11 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
|
|
||||||
// combined with 'if constexpr' the issue is the type of the output is
|
|
||||||
// different. alternative way to use variadic output for all the op
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Important onnx node which generates only one output
|
|
||||||
* @param node onnx node
|
|
||||||
* @param nIn number of expected inputs
|
|
||||||
* @param nOut number of expected outputs
|
|
||||||
* @param attrs list of desription for attributes with format {name, type,
|
|
||||||
* default}
|
|
||||||
*/
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
|
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
|
||||||
bool variadicIn = false, bool variadicOut = false) {
|
int expectedNumResults = -1) {
|
||||||
|
bool variadicIn = expectedNumOperands == -1;
|
||||||
|
bool variadicOut = expectedNumResults == -1;
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (const auto &item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
@ -339,6 +330,10 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!variadicIn)
|
||||||
|
for (auto i = inputs.size(); i < expectedNumOperands; i++)
|
||||||
|
inputs.emplace_back(none_);
|
||||||
|
|
||||||
std::vector<mlir::Type> outputTypes;
|
std::vector<mlir::Type> outputTypes;
|
||||||
for (auto item : node.output()) {
|
for (auto item : node.output()) {
|
||||||
outputTypes.push_back(
|
outputTypes.push_back(
|
||||||
|
@ -347,49 +342,11 @@ private:
|
||||||
|
|
||||||
auto attributes = ImportNodeAttributes(node);
|
auto attributes = ImportNodeAttributes(node);
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
// TODO: Handle optional inputs.
|
||||||
if ((variadicIn || nIn == inputs.size()) &&
|
auto op = builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||||
(variadicOut || nOut == outputTypes.size())) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
auto op =
|
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
||||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
*(op.getODSResults(i).begin()));
|
||||||
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
|
|
||||||
op.getResult());
|
|
||||||
} else {
|
|
||||||
ImportNodeGeneric(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void ImportNodeMultipleOuts(const onnx::NodeProto &node, int nIn, int nOut,
|
|
||||||
bool variadicIn = false,
|
|
||||||
bool variadicOut = false) {
|
|
||||||
std::vector<mlir::Value> inputs;
|
|
||||||
for (const auto &item : node.input()) {
|
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<mlir::Type> outputTypes;
|
|
||||||
for (auto item : node.output()) {
|
|
||||||
outputTypes.push_back(
|
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto attributes = ImportNodeAttributes(node);
|
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
|
||||||
|
|
||||||
if ((variadicIn || nIn == inputs.size()) &&
|
|
||||||
(variadicOut || nOut == outputTypes.size())) {
|
|
||||||
auto op =
|
|
||||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
|
||||||
for (int i = 0; i < node.output().size(); i++) {
|
|
||||||
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
|
||||||
op.getResult(i));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ImportNodeGeneric(node);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -398,8 +355,7 @@ private:
|
||||||
* c++ does not allow template specialization inside a class scope
|
* c++ does not allow template specialization inside a class scope
|
||||||
* a specialized function is used
|
* a specialized function is used
|
||||||
*/
|
*/
|
||||||
void
|
void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
|
|
||||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||||
// which is determined by the shape of first argument. However, since the
|
// which is determined by the shape of first argument. However, since the
|
||||||
// shape is unknown now, these attributes can be not generated auto
|
// shape is unknown now, these attributes can be not generated auto
|
||||||
|
@ -413,24 +369,20 @@ private:
|
||||||
int nOps = node.input().size();
|
int nOps = node.input().size();
|
||||||
|
|
||||||
if (nOps == 2)
|
if (nOps == 2)
|
||||||
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
|
buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
|
||||||
node, nOps, nOut);
|
|
||||||
else
|
else
|
||||||
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut);
|
buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for MaxPool operations.
|
* Special handle for MaxPool operations.
|
||||||
*/
|
*/
|
||||||
void ImportNodeMaxPool(
|
void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
onnx::NodeProto node, int nIn, int nOut) {
|
|
||||||
int nOuts = node.output().size();
|
int nOuts = node.output().size();
|
||||||
if (nOuts == 1) {
|
if (nOuts == 1) {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
|
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
|
||||||
node, nIn, nOuts);
|
|
||||||
} else {
|
} else {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
|
buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
|
||||||
node, nIn, nOuts);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -441,23 +393,10 @@ private:
|
||||||
int nOuts = node.output().size();
|
int nOuts = node.output().size();
|
||||||
if (nOuts == 1) {
|
if (nOuts == 1) {
|
||||||
// Test mode with one output.
|
// Test mode with one output.
|
||||||
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn,
|
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
|
||||||
nOuts);
|
|
||||||
} else {
|
} else {
|
||||||
// Training mode with four trailing optional outputs. Not handled yet.
|
// Training mode with four trailing optional outputs. Not handled yet.
|
||||||
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
|
buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Special handle for Gemm operations.
|
|
||||||
*/
|
|
||||||
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
|
|
||||||
int nOps = node.input().size();
|
|
||||||
if (nOps == 2) {
|
|
||||||
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
|
|
||||||
} else {
|
|
||||||
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -467,28 +406,14 @@ private:
|
||||||
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
|
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
int nOps = node.input().size();
|
int nOps = node.input().size();
|
||||||
if (nOps == 2) {
|
if (nOps == 2) {
|
||||||
ImportNodeOneOut<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
|
buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
|
||||||
} else {
|
} else {
|
||||||
ImportNodeOneOut<mlir::ONNXPadOp>(node, nIn, nOut);
|
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportNode(const onnx::NodeProto &node) {
|
void ImportNode(const onnx::NodeProto &node) {
|
||||||
std::vector<mlir::Value> inputs;
|
llvm::StringRef opName = node.op_type();
|
||||||
for (const auto &item : node.input()) {
|
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<mlir::Type> outputTypes;
|
|
||||||
for (auto item : node.output()) {
|
|
||||||
outputTypes.push_back(
|
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
|
||||||
|
|
||||||
// the following code is generated by gen_doc.py
|
// the following code is generated by gen_doc.py
|
||||||
// refer to dialect/onnx/onnx.td for details
|
// refer to dialect/onnx/onnx.td for details
|
||||||
|
@ -555,9 +480,11 @@ private:
|
||||||
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
|
ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it));
|
||||||
}
|
}
|
||||||
|
|
||||||
// import nodes in the graph
|
// Create a NoneTyped constant.
|
||||||
auto node = graph.node();
|
none_ =
|
||||||
for (const auto &item : node) {
|
builder_.create<mlir::ConstantOp>(UnknownLoc(), builder_.getUnitAttr());
|
||||||
|
// Import nodes in the graph.
|
||||||
|
for (const auto &item : graph.node()) {
|
||||||
ImportNode(item);
|
ImportNode(item);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,320 +1,319 @@
|
||||||
//********************************************************
|
//********************************************************
|
||||||
// Warning: Do not modify this file directly
|
// This file is generated on UTC-02/24/2020, 06:29:01.
|
||||||
// This file is automatically generated via script
|
// Do not modify this file directly.
|
||||||
// Details can be found in doc/readonnxdefs.md
|
// This file is automatically generated via script.
|
||||||
|
// Details can be found in doc/readonnxdefs.md .
|
||||||
//********************************************************
|
//********************************************************
|
||||||
|
|
||||||
if (OpName == "DUMMY") {
|
if (opName == "Abs")
|
||||||
}else if (OpName == "Abs") {
|
return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1);
|
if (opName == "Acos")
|
||||||
}else if (OpName == "Acos") {
|
return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1);
|
if (opName == "Acosh")
|
||||||
}else if (OpName == "Acosh") {
|
return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1);
|
if (opName == "Add")
|
||||||
}else if (OpName == "Add") {
|
return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1);
|
if (opName == "And")
|
||||||
}else if (OpName == "And") {
|
return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1);
|
if (opName == "ArgMax")
|
||||||
}else if (OpName == "ArgMax") {
|
return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1);
|
if (opName == "ArgMin")
|
||||||
}else if (OpName == "ArgMin") {
|
return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1);
|
if (opName == "Asin")
|
||||||
}else if (OpName == "Asin") {
|
return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1);
|
if (opName == "Asinh")
|
||||||
}else if (OpName == "Asinh") {
|
return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1);
|
if (opName == "Atan")
|
||||||
}else if (OpName == "Atan") {
|
return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1);
|
if (opName == "Atanh")
|
||||||
}else if (OpName == "Atanh") {
|
return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1);
|
if (opName == "AveragePool")
|
||||||
}else if (OpName == "AveragePool") {
|
return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
|
if (opName == "BatchNormalization")
|
||||||
}else if (OpName == "BatchNormalization") {
|
return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
|
||||||
ImportNodeBatchNormalization(node, 5, 5);
|
if (opName == "BitShift")
|
||||||
}else if (OpName == "BitShift") {
|
return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
|
if (opName == "Cast")
|
||||||
}else if (OpName == "Cast") {
|
return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1);
|
if (opName == "Ceil")
|
||||||
}else if (OpName == "Ceil") {
|
return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1);
|
if (opName == "Clip")
|
||||||
}else if (OpName == "Clip") {
|
return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1);
|
if (opName == "Compress")
|
||||||
}else if (OpName == "Compress") {
|
return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
|
if (opName == "Concat")
|
||||||
}else if (OpName == "Concat") {
|
return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, true, false);
|
if (opName == "ConcatFromSequence")
|
||||||
}else if (OpName == "ConcatFromSequence") {
|
return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
|
if (opName == "Constant")
|
||||||
}else if (OpName == "Constant") {
|
return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1);
|
if (opName == "ConstantOfShape")
|
||||||
}else if (OpName == "ConstantOfShape") {
|
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1);
|
if (opName == "Conv")
|
||||||
}else if (OpName == "Conv") {
|
return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeConv(node, 3, 1);
|
if (opName == "ConvInteger")
|
||||||
}else if (OpName == "ConvInteger") {
|
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1);
|
if (opName == "ConvTranspose")
|
||||||
}else if (OpName == "ConvTranspose") {
|
return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1);
|
if (opName == "Cos")
|
||||||
}else if (OpName == "Cos") {
|
return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1);
|
if (opName == "Cosh")
|
||||||
}else if (OpName == "Cosh") {
|
return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1);
|
if (opName == "CumSum")
|
||||||
}else if (OpName == "CumSum") {
|
return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1);
|
if (opName == "DepthToSpace")
|
||||||
}else if (OpName == "DepthToSpace") {
|
return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1);
|
if (opName == "DequantizeLinear")
|
||||||
}else if (OpName == "DequantizeLinear") {
|
return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1);
|
if (opName == "Det")
|
||||||
}else if (OpName == "Det") {
|
return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1);
|
if (opName == "Div")
|
||||||
}else if (OpName == "Div") {
|
return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1);
|
if (opName == "Dropout")
|
||||||
}else if (OpName == "Dropout") {
|
return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2);
|
if (opName == "DynamicQuantizeLinear")
|
||||||
}else if (OpName == "DynamicQuantizeLinear") {
|
return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3);
|
if (opName == "Elu")
|
||||||
}else if (OpName == "Elu") {
|
return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1);
|
if (opName == "Equal")
|
||||||
}else if (OpName == "Equal") {
|
return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1);
|
if (opName == "Erf")
|
||||||
}else if (OpName == "Erf") {
|
return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1);
|
if (opName == "Exp")
|
||||||
}else if (OpName == "Exp") {
|
return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1);
|
if (opName == "Expand")
|
||||||
}else if (OpName == "Expand") {
|
return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1);
|
if (opName == "EyeLike")
|
||||||
}else if (OpName == "EyeLike") {
|
return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1);
|
if (opName == "Flatten")
|
||||||
}else if (OpName == "Flatten") {
|
return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1);
|
if (opName == "Floor")
|
||||||
}else if (OpName == "Floor") {
|
return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1);
|
if (opName == "GRU")
|
||||||
}else if (OpName == "GRU") {
|
return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2);
|
if (opName == "Gather")
|
||||||
}else if (OpName == "Gather") {
|
return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1);
|
if (opName == "GatherElements")
|
||||||
}else if (OpName == "GatherElements") {
|
return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1);
|
if (opName == "GatherND")
|
||||||
}else if (OpName == "GatherND") {
|
return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
|
if (opName == "Gemm")
|
||||||
}else if (OpName == "Gemm") {
|
return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeGemm(node, 3, 1);
|
if (opName == "GlobalAveragePool")
|
||||||
}else if (OpName == "GlobalAveragePool") {
|
return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
|
if (opName == "GlobalLpPool")
|
||||||
}else if (OpName == "GlobalLpPool") {
|
return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1);
|
if (opName == "GlobalMaxPool")
|
||||||
}else if (OpName == "GlobalMaxPool") {
|
return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1);
|
if (opName == "Greater")
|
||||||
}else if (OpName == "Greater") {
|
return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1);
|
if (opName == "HardSigmoid")
|
||||||
}else if (OpName == "HardSigmoid") {
|
return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1);
|
if (opName == "Hardmax")
|
||||||
}else if (OpName == "Hardmax") {
|
return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1);
|
if (opName == "Identity")
|
||||||
}else if (OpName == "Identity") {
|
return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1);
|
if (opName == "If")
|
||||||
}else if (OpName == "If") {
|
return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
||||||
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1);
|
if (opName == "InstanceNormalization")
|
||||||
}else if (OpName == "InstanceNormalization") {
|
return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1);
|
if (opName == "IsInf")
|
||||||
}else if (OpName == "IsInf") {
|
return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1);
|
if (opName == "IsNaN")
|
||||||
}else if (OpName == "IsNaN") {
|
return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1);
|
if (opName == "LRN")
|
||||||
}else if (OpName == "LRN") {
|
return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1);
|
if (opName == "LSTM")
|
||||||
}else if (OpName == "LSTM") {
|
return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3);
|
if (opName == "LeakyRelu")
|
||||||
}else if (OpName == "LeakyRelu") {
|
return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1);
|
if (opName == "Less")
|
||||||
}else if (OpName == "Less") {
|
return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1);
|
if (opName == "Log")
|
||||||
}else if (OpName == "Log") {
|
return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1);
|
if (opName == "LogSoftmax")
|
||||||
}else if (OpName == "LogSoftmax") {
|
return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1);
|
if (opName == "Loop")
|
||||||
}else if (OpName == "Loop") {
|
return buildOperation<mlir::ONNXLoopOp>(node);
|
||||||
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1);
|
if (opName == "LpNormalization")
|
||||||
}else if (OpName == "LpNormalization") {
|
return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1);
|
if (opName == "LpPool")
|
||||||
}else if (OpName == "LpPool") {
|
return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1);
|
if (opName == "MatMul")
|
||||||
}else if (OpName == "MatMul") {
|
return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1);
|
if (opName == "MatMulInteger")
|
||||||
}else if (OpName == "MatMulInteger") {
|
return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
|
if (opName == "Max")
|
||||||
}else if (OpName == "Max") {
|
return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, true, false);
|
if (opName == "MaxPool")
|
||||||
}else if (OpName == "MaxPool") {
|
return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
||||||
ImportNodeMaxPool(node, 1, 2);
|
if (opName == "MaxRoiPool")
|
||||||
}else if (OpName == "MaxRoiPool") {
|
return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1);
|
if (opName == "MaxUnpool")
|
||||||
}else if (OpName == "MaxUnpool") {
|
return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
|
if (opName == "Mean")
|
||||||
}else if (OpName == "Mean") {
|
return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, true, false);
|
if (opName == "MeanVarianceNormalization")
|
||||||
}else if (OpName == "MeanVarianceNormalization") {
|
return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
|
if (opName == "Min")
|
||||||
}else if (OpName == "Min") {
|
return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, true, false);
|
if (opName == "Mod")
|
||||||
}else if (OpName == "Mod") {
|
return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
|
if (opName == "Mul")
|
||||||
}else if (OpName == "Mul") {
|
return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1);
|
if (opName == "Multinomial")
|
||||||
}else if (OpName == "Multinomial") {
|
return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1);
|
if (opName == "Neg")
|
||||||
}else if (OpName == "Neg") {
|
return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1);
|
if (opName == "NonMaxSuppression")
|
||||||
}else if (OpName == "NonMaxSuppression") {
|
return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1);
|
if (opName == "NonZero")
|
||||||
}else if (OpName == "NonZero") {
|
return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1);
|
if (opName == "Not")
|
||||||
}else if (OpName == "Not") {
|
return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1);
|
if (opName == "OneHot")
|
||||||
}else if (OpName == "OneHot") {
|
return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1);
|
if (opName == "Or")
|
||||||
}else if (OpName == "Or") {
|
return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1);
|
if (opName == "PRelu")
|
||||||
}else if (OpName == "PRelu") {
|
return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1);
|
if (opName == "Pad")
|
||||||
}else if (OpName == "Pad") {
|
return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodePad(node, 3, 1);
|
if (opName == "Pow")
|
||||||
}else if (OpName == "Pow") {
|
return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1);
|
if (opName == "QLinearConv")
|
||||||
}else if (OpName == "QLinearConv") {
|
return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1);
|
if (opName == "QLinearMatMul")
|
||||||
}else if (OpName == "QLinearMatMul") {
|
return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1);
|
if (opName == "QuantizeLinear")
|
||||||
}else if (OpName == "QuantizeLinear") {
|
return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1);
|
if (opName == "RNN")
|
||||||
}else if (OpName == "RNN") {
|
return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2);
|
if (opName == "RandomNormal")
|
||||||
}else if (OpName == "RandomNormal") {
|
return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1);
|
if (opName == "RandomNormalLike")
|
||||||
}else if (OpName == "RandomNormalLike") {
|
return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1);
|
if (opName == "RandomUniform")
|
||||||
}else if (OpName == "RandomUniform") {
|
return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1);
|
if (opName == "RandomUniformLike")
|
||||||
}else if (OpName == "RandomUniformLike") {
|
return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1);
|
if (opName == "Range")
|
||||||
}else if (OpName == "Range") {
|
return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1);
|
if (opName == "Reciprocal")
|
||||||
}else if (OpName == "Reciprocal") {
|
return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1);
|
if (opName == "ReduceL1")
|
||||||
}else if (OpName == "ReduceL1") {
|
return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1);
|
if (opName == "ReduceL2")
|
||||||
}else if (OpName == "ReduceL2") {
|
return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1);
|
if (opName == "ReduceLogSum")
|
||||||
}else if (OpName == "ReduceLogSum") {
|
return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1);
|
if (opName == "ReduceLogSumExp")
|
||||||
}else if (OpName == "ReduceLogSumExp") {
|
return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1);
|
if (opName == "ReduceMax")
|
||||||
}else if (OpName == "ReduceMax") {
|
return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1);
|
if (opName == "ReduceMean")
|
||||||
}else if (OpName == "ReduceMean") {
|
return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1);
|
if (opName == "ReduceMin")
|
||||||
}else if (OpName == "ReduceMin") {
|
return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1);
|
if (opName == "ReduceProd")
|
||||||
}else if (OpName == "ReduceProd") {
|
return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1);
|
if (opName == "ReduceSum")
|
||||||
}else if (OpName == "ReduceSum") {
|
return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1);
|
if (opName == "ReduceSumSquare")
|
||||||
}else if (OpName == "ReduceSumSquare") {
|
return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1);
|
if (opName == "Relu")
|
||||||
}else if (OpName == "Relu") {
|
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1);
|
if (opName == "Reshape")
|
||||||
}else if (OpName == "Reshape") {
|
return buildOperation<mlir::ONNXReshapeOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1);
|
if (opName == "Resize")
|
||||||
}else if (OpName == "Resize") {
|
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1);
|
if (opName == "ReverseSequence")
|
||||||
}else if (OpName == "ReverseSequence") {
|
return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1);
|
if (opName == "RoiAlign")
|
||||||
}else if (OpName == "RoiAlign") {
|
return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1);
|
if (opName == "Round")
|
||||||
}else if (OpName == "Round") {
|
return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1);
|
if (opName == "Scan")
|
||||||
}else if (OpName == "Scan") {
|
return buildOperation<mlir::ONNXScanOp>(node);
|
||||||
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1);
|
if (opName == "Scatter")
|
||||||
}else if (OpName == "Scatter") {
|
return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1);
|
if (opName == "ScatterElements")
|
||||||
}else if (OpName == "ScatterElements") {
|
return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1);
|
if (opName == "ScatterND")
|
||||||
}else if (OpName == "ScatterND") {
|
return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1);
|
if (opName == "Selu")
|
||||||
}else if (OpName == "Selu") {
|
return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1);
|
if (opName == "SequenceAt")
|
||||||
}else if (OpName == "SequenceAt") {
|
return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
|
if (opName == "SequenceConstruct")
|
||||||
}else if (OpName == "SequenceConstruct") {
|
return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, true, false);
|
if (opName == "SequenceEmpty")
|
||||||
}else if (OpName == "SequenceEmpty") {
|
return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
|
if (opName == "SequenceErase")
|
||||||
}else if (OpName == "SequenceErase") {
|
return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1);
|
if (opName == "SequenceInsert")
|
||||||
}else if (OpName == "SequenceInsert") {
|
return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1);
|
if (opName == "SequenceLength")
|
||||||
}else if (OpName == "SequenceLength") {
|
return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1);
|
if (opName == "Shape")
|
||||||
}else if (OpName == "Shape") {
|
return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1);
|
if (opName == "Shrink")
|
||||||
}else if (OpName == "Shrink") {
|
return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1);
|
if (opName == "Sigmoid")
|
||||||
}else if (OpName == "Sigmoid") {
|
return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1);
|
if (opName == "Sign")
|
||||||
}else if (OpName == "Sign") {
|
return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1);
|
if (opName == "Sin")
|
||||||
}else if (OpName == "Sin") {
|
return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1);
|
if (opName == "Sinh")
|
||||||
}else if (OpName == "Sinh") {
|
return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1);
|
if (opName == "Size")
|
||||||
}else if (OpName == "Size") {
|
return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1);
|
if (opName == "Slice")
|
||||||
}else if (OpName == "Slice") {
|
return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1);
|
if (opName == "Softmax")
|
||||||
}else if (OpName == "Softmax") {
|
return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1);
|
if (opName == "Softplus")
|
||||||
}else if (OpName == "Softplus") {
|
return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1);
|
if (opName == "Softsign")
|
||||||
}else if (OpName == "Softsign") {
|
return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1);
|
if (opName == "SpaceToDepth")
|
||||||
}else if (OpName == "SpaceToDepth") {
|
return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1);
|
if (opName == "Split")
|
||||||
}else if (OpName == "Split") {
|
return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
||||||
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1);
|
if (opName == "SplitToSequence")
|
||||||
}else if (OpName == "SplitToSequence") {
|
return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1);
|
if (opName == "Sqrt")
|
||||||
}else if (OpName == "Sqrt") {
|
return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1);
|
if (opName == "Squeeze")
|
||||||
}else if (OpName == "Squeeze") {
|
return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1);
|
if (opName == "StringNormalizer")
|
||||||
}else if (OpName == "StringNormalizer") {
|
return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1);
|
if (opName == "Sub")
|
||||||
}else if (OpName == "Sub") {
|
return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
|
if (opName == "Sum")
|
||||||
}else if (OpName == "Sum") {
|
return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, true, false);
|
if (opName == "Tan")
|
||||||
}else if (OpName == "Tan") {
|
return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
|
if (opName == "Tanh")
|
||||||
}else if (OpName == "Tanh") {
|
return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1);
|
if (opName == "TfIdfVectorizer")
|
||||||
}else if (OpName == "TfIdfVectorizer") {
|
return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1);
|
if (opName == "ThresholdedRelu")
|
||||||
}else if (OpName == "ThresholdedRelu") {
|
return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1);
|
if (opName == "Tile")
|
||||||
}else if (OpName == "Tile") {
|
return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1);
|
if (opName == "TopK")
|
||||||
}else if (OpName == "TopK") {
|
return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2);
|
if (opName == "Transpose")
|
||||||
}else if (OpName == "Transpose") {
|
return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1);
|
if (opName == "Unique")
|
||||||
}else if (OpName == "Unique") {
|
return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
|
||||||
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4);
|
if (opName == "Unsqueeze")
|
||||||
}else if (OpName == "Unsqueeze") {
|
return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1);
|
if (opName == "Upsample")
|
||||||
}else if (OpName == "Upsample") {
|
return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1);
|
if (opName == "Where")
|
||||||
}else if (OpName == "Where") {
|
return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1);
|
if (opName == "Xor")
|
||||||
}else if (OpName == "Xor") {
|
return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1);
|
|
||||||
}
|
|
||||||
|
|
|
@ -17,20 +17,24 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto has_bias = (operands.size() == 3);
|
// The first predicate is unnecessary when we remove ONXGemmNoBiasOp.
|
||||||
|
bool hasBias = (operands.size() == 3) &&
|
||||||
|
(!op->getOperand(2).getType().isa<NoneType>());
|
||||||
|
|
||||||
Value A, B, C;
|
Value A, B, C;
|
||||||
A = operands[0];
|
A = operands[0];
|
||||||
B = operands[1];
|
B = operands[1];
|
||||||
if (has_bias)
|
if (hasBias)
|
||||||
C = operands[2];
|
C = operands[2];
|
||||||
|
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
|
auto alphaAttr =
|
||||||
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
FloatAttr::get(memRefType.getElementType(),
|
||||||
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
|
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
||||||
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
auto betaAttr =
|
||||||
|
FloatAttr::get(memRefType.getElementType(),
|
||||||
|
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||||
|
|
||||||
|
@ -68,8 +72,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
// Define loops.
|
// Define loops.
|
||||||
std::vector<Value> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
std::vector<Value> optimizedLoops;
|
std::vector<Value> optimizedLoops;
|
||||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
Block *optimizationBlock =
|
||||||
optimizedLoops, numLoops);
|
defineLoops(rewriter, loc, originalLoops, optimizedLoops, numLoops);
|
||||||
|
|
||||||
// We have two Krnl loops:
|
// We have two Krnl loops:
|
||||||
// - Outer loop iterates over the output matrix dimensions, and
|
// - Outer loop iterates over the output matrix dimensions, and
|
||||||
|
@ -83,8 +87,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
outerLoops.push_back(originalLoops[i]);
|
outerLoops.push_back(originalLoops[i]);
|
||||||
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||||
}
|
}
|
||||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops,
|
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
||||||
optimizedOuterLoops);
|
|
||||||
// Induction variables for the outer loops
|
// Induction variables for the outer loops
|
||||||
for (int i = 0; i < 2; ++i)
|
for (int i = 0; i < 2; ++i)
|
||||||
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
addDimensionToPack(rewriter, loc, outerPack, alloc, i);
|
||||||
|
@ -106,20 +109,19 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
int64_t K_B_Idx = (isTransB) ? 1 : 0;
|
int64_t K_B_Idx = (isTransB) ? 1 : 0;
|
||||||
reductionPack.pushConstantBound(0);
|
reductionPack.pushConstantBound(0);
|
||||||
if (ATy.getShape()[K_A_Idx] != -1)
|
if (ATy.getShape()[K_A_Idx] != -1)
|
||||||
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
|
reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
|
||||||
|
else if (BTy.getShape()[K_B_Idx] != -1)
|
||||||
|
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
|
||||||
else
|
else
|
||||||
if (BTy.getShape()[K_B_Idx] != -1)
|
reductionPack.pushOperandBound(
|
||||||
reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
|
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
|
||||||
else
|
|
||||||
reductionPack.pushOperandBound(
|
|
||||||
rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
|
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
// broadcasting.
|
// broadcasting.
|
||||||
// GemmOp supports unidirectional broadcasting from C to A*B.
|
// GemmOp supports unidirectional broadcasting from C to A*B.
|
||||||
// Hence, it must be enough to get broadcasting information for C only.
|
// Hence, it must be enough to get broadcasting information for C only.
|
||||||
std::map<int, Value> broadcastedDimInfo;
|
std::map<int, Value> broadcastedDimInfo;
|
||||||
if (has_bias) {
|
if (hasBias) {
|
||||||
auto shape = C.getType().cast<MemRefType>().getShape();
|
auto shape = C.getType().cast<MemRefType>().getShape();
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
if (shape[i] < 0) {
|
if (shape[i] < 0) {
|
||||||
|
@ -162,7 +164,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
||||||
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||||
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
||||||
if (has_bias) {
|
if (hasBias) {
|
||||||
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
|
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
|
||||||
broadcastedDimInfo);
|
broadcastedDimInfo);
|
||||||
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||||
|
@ -210,8 +212,8 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateLoweringONNXGemmOpPattern(
|
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
MLIRContext *ctx) {
|
||||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
||||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
|
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,25 +120,19 @@ void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
// Tanh
|
// Tanh
|
||||||
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXTanhOp::inferShapes() {
|
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sinh
|
// Sinh
|
||||||
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXSinhOp::inferShapes() {
|
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Cosh
|
// Cosh
|
||||||
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXCoshOp::inferShapes() {
|
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Cos
|
// Cos
|
||||||
|
@ -178,9 +172,7 @@ void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
// Relu
|
// Relu
|
||||||
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXReluOp::inferShapes() {
|
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LeakyRelu
|
// LeakyRelu
|
||||||
|
@ -194,9 +186,7 @@ void ONNXLeakyReluOp::inferShapes() {
|
||||||
// Selu
|
// Selu
|
||||||
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSeluOp::inferShapes() {
|
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Reciprocal
|
// Reciprocal
|
||||||
|
@ -234,17 +224,13 @@ void ONNXSoftsignOp::inferShapes() {
|
||||||
// Sqrt
|
// Sqrt
|
||||||
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSqrtOp::inferShapes() {
|
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sign
|
// Sign
|
||||||
/// Infer the output shape of the ONNXSignOp. This method is required by
|
/// Infer the output shape of the ONNXSignOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSignOp::inferShapes() {
|
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||||
getResult().setType(getOperand().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
|
@ -423,8 +409,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
||||||
// need to be removed after the multiplication but cannot be removed if all
|
// need to be removed after the multiplication but cannot be removed if all
|
||||||
// sizes are 1.
|
// sizes are 1.
|
||||||
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
|
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
|
||||||
lhsShape[0] != rhsShape[0])
|
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
dims.emplace_back(1);
|
dims.emplace_back(1);
|
||||||
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
||||||
|
@ -541,14 +526,14 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
// Gemm
|
// Gemm
|
||||||
|
|
||||||
void ONNXGemmOp::inferShapes() {
|
void ONNXGemmOp::inferShapes() {
|
||||||
|
bool hasBias = !getOperand(2).getType().isa<NoneType>();
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
!getOperand(1).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(2).getType().isa<RankedTensorType>())
|
(hasBias && !getOperand(2).getType().isa<RankedTensorType>()))
|
||||||
return;
|
return;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
|
|
||||||
|
|
||||||
int64_t M, N, K_A, K_B;
|
int64_t M, N, K_A, K_B;
|
||||||
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
||||||
|
@ -560,15 +545,18 @@ void ONNXGemmOp::inferShapes() {
|
||||||
emitError("Tensor shapes mismatched.");
|
emitError("Tensor shapes mismatched.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether bias is unidirectional broadcasting or not.
|
if (hasBias) {
|
||||||
auto shape = biasTy.getShape();
|
// Check whether bias is unidirectional broadcasting or not.
|
||||||
int rank = shape.size();
|
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
|
||||||
if ((rank > 2) ||
|
auto shape = biasTy.getShape();
|
||||||
(rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] &&
|
int rank = shape.size();
|
||||||
shape[rank - 1] != 1) ||
|
if ((rank > 2) ||
|
||||||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] &&
|
(rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
|
||||||
shape[rank - 2] != 1)) {
|
N != shape[rank - 1] && shape[rank - 1] != 1) ||
|
||||||
emitError("Bias shape mismatched.");
|
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
|
||||||
|
M != shape[rank - 2] && shape[rank - 2] != 1)) {
|
||||||
|
emitError("Bias shape mismatched.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
|
@ -713,7 +701,6 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceMax
|
// ReduceMax
|
||||||
|
@ -801,7 +788,8 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// Required attribute auto_pad defaults to NOTSET.
|
// Required attribute auto_pad defaults to NOTSET.
|
||||||
auto autoPad = auto_pad();
|
auto autoPad = auto_pad();
|
||||||
// Group is a required attribute and should have default value of 1.
|
// Group is a required attribute and should have default value of 1.
|
||||||
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
|
int64_t group =
|
||||||
|
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
|
||||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||||
if (dataShape[1] != (weightShape[1] * group))
|
if (dataShape[1] != (weightShape[1] * group))
|
||||||
emitError("Channel dimension mismatch.");
|
emitError("Channel dimension mismatch.");
|
||||||
|
@ -859,8 +847,10 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
if (dilations.getValue().size() != nDims)
|
if (dilations.getValue().size() != nDims)
|
||||||
emitError("dilations length incompatible with spatial dimensions.");
|
emitError("dilations length incompatible with spatial dimensions.");
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
kernelDims[i] = (kernelDims[i] + 1) *
|
kernelDims[i] =
|
||||||
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1;
|
(kernelDims[i] + 1) *
|
||||||
|
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() -
|
||||||
|
1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subtract kernel dimensions from input data dimensions.
|
// Subtract kernel dimensions from input data dimensions.
|
||||||
|
@ -906,8 +896,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
if (strides.getValue().size() != nDims)
|
if (strides.getValue().size() != nDims)
|
||||||
emitError("strides length incompatible with spatial dimensions.");
|
emitError("strides length incompatible with spatial dimensions.");
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = 0; i < nDims; ++i) {
|
||||||
int64_t stride =
|
int64_t stride = strides.getValue()[i].cast<IntegerAttr>().getInt();
|
||||||
strides.getValue()[i].cast<IntegerAttr>().getInt();
|
|
||||||
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -937,7 +926,8 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
// get kernel sizes from kernel_shape attribute
|
// get kernel sizes from kernel_shape attribute
|
||||||
auto kernelShape = kernel_shape();
|
auto kernelShape = kernel_shape();
|
||||||
if (!kernelShape)
|
if (!kernelShape)
|
||||||
emitError("kernel_shape is a mandatory attribute for which there is no default.");
|
emitError(
|
||||||
|
"kernel_shape is a mandatory attribute for which there is no default.");
|
||||||
auto kernelShapeArray = kernelShape.getValue();
|
auto kernelShapeArray = kernelShape.getValue();
|
||||||
auto kernelRank = kernelShape.size();
|
auto kernelRank = kernelShape.size();
|
||||||
if (kernelRank > xRank)
|
if (kernelRank > xRank)
|
||||||
|
@ -951,9 +941,10 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
SmallVector<int64_t, 4> actualDilations;
|
SmallVector<int64_t, 4> actualDilations;
|
||||||
auto dilationsOpt = dilations();
|
auto dilationsOpt = dilations();
|
||||||
if (dilationsOpt.hasValue()) {
|
if (dilationsOpt.hasValue()) {
|
||||||
auto dilationsArray = dilationsOpt.getValue().getValue(); // opt -> attr -> array
|
auto dilationsArray =
|
||||||
|
dilationsOpt.getValue().getValue(); // opt -> attr -> array
|
||||||
if (dilationsArray.size() != kernelRank)
|
if (dilationsArray.size() != kernelRank)
|
||||||
emitError("dialation rank is not the same as the spatial rank.");
|
emitError("dialation rank is not the same as the spatial rank.");
|
||||||
// fill in the actual values
|
// fill in the actual values
|
||||||
for (int i = 0; i < kernelRank; ++i) {
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
int64_t d = (dilationsArray[i]).cast<IntegerAttr>().getInt();
|
int64_t d = (dilationsArray[i]).cast<IntegerAttr>().getInt();
|
||||||
|
@ -962,7 +953,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
actualDilations.emplace_back(d);
|
actualDilations.emplace_back(d);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for(int i=0; i < kernelRank; ++i) {
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
actualDilations.emplace_back(1);
|
actualDilations.emplace_back(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -975,7 +966,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
if (stridesOpt.hasValue()) {
|
if (stridesOpt.hasValue()) {
|
||||||
auto stridesArray = stridesOpt.getValue().getValue();
|
auto stridesArray = stridesOpt.getValue().getValue();
|
||||||
if (stridesArray.size() != kernelRank)
|
if (stridesArray.size() != kernelRank)
|
||||||
emitError("strides rank is not the same as the spatial rank.");
|
emitError("strides rank is not the same as the spatial rank.");
|
||||||
// fill in the actual values
|
// fill in the actual values
|
||||||
for (int i = 0; i < kernelRank; ++i) {
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
int64_t s = (stridesArray[i]).cast<IntegerAttr>().getInt();
|
int64_t s = (stridesArray[i]).cast<IntegerAttr>().getInt();
|
||||||
|
@ -984,7 +975,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
actualStrides.emplace_back(s);
|
actualStrides.emplace_back(s);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for(int i=0; i < kernelRank; ++i) {
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
actualStrides.emplace_back(1);
|
actualStrides.emplace_back(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1002,7 +993,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
if (padsArray.size() != 2 * kernelRank)
|
if (padsArray.size() != 2 * kernelRank)
|
||||||
emitError("pads rank is not twice the spatial rank.");
|
emitError("pads rank is not twice the spatial rank.");
|
||||||
// fill in the actual values
|
// fill in the actual values
|
||||||
for (int i = 0; i < 2*kernelRank; ++i) {
|
for (int i = 0; i < 2 * kernelRank; ++i) {
|
||||||
int64_t p = (padsArray[i]).cast<IntegerAttr>().getInt();
|
int64_t p = (padsArray[i]).cast<IntegerAttr>().getInt();
|
||||||
if (p < 0)
|
if (p < 0)
|
||||||
emitError("pads value must be nonnegative.");
|
emitError("pads value must be nonnegative.");
|
||||||
|
@ -1016,18 +1007,20 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
defaultPads = true;
|
defaultPads = true;
|
||||||
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
// init pad with zero
|
// init pad with zero
|
||||||
for(int i=0; i<2*kernelRank; ++i) {
|
for (int i = 0; i < 2 * kernelRank; ++i) {
|
||||||
actualPads.emplace_back(0);
|
actualPads.emplace_back(0);
|
||||||
}
|
}
|
||||||
for(int i=0; i<kernelRank; ++i) {
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
auto inputSpatialShape = xShape[kernelOffset + i];
|
auto inputSpatialShape = xShape[kernelOffset + i];
|
||||||
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
|
auto kernelSpatialShape =
|
||||||
|
(kernelShapeArray[i]).cast<IntegerAttr>().getInt();
|
||||||
auto dilations = actualDilations[i];
|
auto dilations = actualDilations[i];
|
||||||
auto strideSpatialShape = actualStrides[i];
|
auto strideSpatialShape = actualStrides[i];
|
||||||
int64_t outputSpatialShape = ceil((1.0 * inputSpatialShape) /
|
int64_t outputSpatialShape =
|
||||||
(1.0 * strideSpatialShape));
|
ceil((1.0 * inputSpatialShape) / (1.0 * strideSpatialShape));
|
||||||
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
|
auto sumOfPad = (outputSpatialShape - 1) * strideSpatialShape +
|
||||||
((kernelSpatialShape - 1) * dilations + 1) - inputSpatialShape;
|
((kernelSpatialShape - 1) * dilations + 1) -
|
||||||
|
inputSpatialShape;
|
||||||
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
|
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
|
||||||
if (sumOfPad % 2 != 0) {
|
if (sumOfPad % 2 != 0) {
|
||||||
if (autoPad == "SAME_UPPER") {
|
if (autoPad == "SAME_UPPER") {
|
||||||
|
@ -1042,7 +1035,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
}
|
}
|
||||||
// handle case where default pad values must be used
|
// handle case where default pad values must be used
|
||||||
if (defaultPads) {
|
if (defaultPads) {
|
||||||
for(int i=0; i<2*kernelRank; ++i) {
|
for (int i = 0; i < 2 * kernelRank; ++i) {
|
||||||
actualPads.emplace_back(0);
|
actualPads.emplace_back(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1050,16 +1043,18 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
// initialize output shape
|
// initialize output shape
|
||||||
SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end());
|
SmallVector<int64_t, 4> yShape(xShape.begin(), xShape.end());
|
||||||
// for all kernel dimensions
|
// for all kernel dimensions
|
||||||
for(int i=0; i<kernelRank; ++i) {
|
for (int i = 0; i < kernelRank; ++i) {
|
||||||
auto inputSpatialShape = xShape[kernelOffset + i];
|
auto inputSpatialShape = xShape[kernelOffset + i];
|
||||||
auto padShape = actualPads[i] + actualPads[kernelRank+i];
|
auto padShape = actualPads[i] + actualPads[kernelRank + i];
|
||||||
auto kernelSpatialShape = (kernelShapeArray[i]).cast<IntegerAttr>().getInt();
|
auto kernelSpatialShape =
|
||||||
|
(kernelShapeArray[i]).cast<IntegerAttr>().getInt();
|
||||||
auto dilations = actualDilations[i];
|
auto dilations = actualDilations[i];
|
||||||
auto strideSpatialShape = actualStrides[i];
|
auto strideSpatialShape = actualStrides[i];
|
||||||
///output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] -
|
/// output_spatial_shape[i] = ceil( (input_spatial_shape[i] + pad_shape[i] -
|
||||||
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
|
// ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) /
|
||||||
|
// strides_spatial_shape[i] + 1)
|
||||||
double numerator = inputSpatialShape + padShape -
|
double numerator = inputSpatialShape + padShape -
|
||||||
((kernelSpatialShape - 1) * dilations + 1);
|
((kernelSpatialShape - 1) * dilations + 1);
|
||||||
double denominator = strideSpatialShape;
|
double denominator = strideSpatialShape;
|
||||||
int64_t res;
|
int64_t res;
|
||||||
if (ceilMode) {
|
if (ceilMode) {
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -127,6 +127,10 @@ int main(int argc, char *argv[]) {
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
// An additional pass of canonicalization is helpful because lowering
|
||||||
|
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
||||||
|
// oppertunities.
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createLowerKrnlPass());
|
pm.addPass(mlir::createLowerKrnlPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,11 @@ void ONNXAddOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
results.insert<MulAddToGemmOptPattern>(context);
|
results.insert<MulAddToGemmOptPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ONNXGemmOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
|
results.insert<FuseGemmFollowedByAddition>(context);
|
||||||
|
}
|
||||||
/// on the ONNXIdentityOp.
|
/// on the ONNXIdentityOp.
|
||||||
void ONNXIdentityOp::getCanonicalizationPatterns(
|
void ONNXIdentityOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList& results, MLIRContext* context) {
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
|
|
|
@ -26,6 +26,7 @@ include "dialect/onnx/onnx.td"
|
||||||
|
|
||||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||||
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
class HasRankOf<int rank> : Constraint<CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>>;
|
||||||
|
def HasNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pattern-Match and Rewrite
|
// Pattern-Match and Rewrite
|
||||||
|
@ -41,6 +42,11 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
||||||
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
|
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
|
||||||
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
|
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2)]>;
|
||||||
|
|
||||||
|
// onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z)
|
||||||
|
def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias),
|
||||||
|
(ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB),
|
||||||
|
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>;
|
||||||
|
|
||||||
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
||||||
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
||||||
(replaceWithValue $arg)>;
|
(replaceWithValue $arg)>;
|
||||||
|
|
|
@ -101,3 +101,14 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>
|
||||||
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
// CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: return %1 : tensor<*xf32>
|
// CHECK-NEXT: return %1 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> {
|
||||||
|
func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
|
||||||
|
return %1 : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32>
|
||||||
|
// return [[GEMM]] : tensor<*xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue