diff --git a/docs/Dialects/mlonnx.md b/docs/Dialects/mlonnx.md index d7ab1e0..485fc42 100644 --- a/docs/Dialects/mlonnx.md +++ b/docs/Dialects/mlonnx.md @@ -35,13 +35,13 @@ ONNX Binarizer operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values ### `mlonnx.CastMap` (MLONNXCastMapOp) @@ -160,7 +160,7 @@ ONNX FeatureVectorizer operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values #### Results: @@ -194,13 +194,13 @@ ONNX Imputer operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values ### `mlonnx.LabelEncoder` (MLONNXLabelEncoderOp) @@ -271,7 +271,7 @@ ONNX LinearClassifier operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: @@ -304,7 +304,7 @@ ONNX LinearRegressor operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: @@ -337,7 +337,7 @@ ONNX Normalizer operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: @@ -404,7 +404,7 @@ ONNX SVMClassifier operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: @@ -436,7 +436,7 @@ ONNX SVMRegressor operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: @@ -461,7 +461,7 @@ ONNX Scaler operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: @@ -509,7 +509,7 @@ ONNX TreeEnsembleClassifier operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: @@ -559,7 +559,7 @@ ONNX TreeEnsembleRegressor operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values #### Results: diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index c742c69..a447655 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -11,13 +11,13 @@ ONNX Abs operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Acos` (ONNXAcosOp) @@ -29,13 +29,13 @@ ONNX Acos operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Acosh` (ONNXAcoshOp) @@ -47,13 +47,13 @@ ONNX Acosh operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Add` (ONNXAddOp) @@ -67,14 +67,14 @@ ONNX Add operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.And` (ONNXAndOp) @@ -89,14 +89,14 @@ ONNX And operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values +`B` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values ### `onnx.ArgMax` (ONNXArgMaxOp) @@ -118,7 +118,7 @@ ONNX ArgMax operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: @@ -146,7 +146,7 @@ ONNX ArgMin operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: @@ -164,13 +164,13 @@ ONNX Asin operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Asinh` (ONNXAsinhOp) @@ -182,13 +182,13 @@ ONNX Asinh operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Atan` (ONNXAtanOp) @@ -200,13 +200,13 @@ ONNX Atan operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Atanh` (ONNXAtanhOp) @@ -218,13 +218,13 @@ ONNX Atanh operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.AveragePool` (ONNXAveragePoolOp) @@ -275,13 +275,13 @@ ONNX AveragePool operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.BatchNormalization` (ONNXBatchNormalizationOp) @@ -309,21 +309,21 @@ ONNX BatchNormalization operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`scale` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values -`mean` | memref of any type values or tensor of any type values -`var` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`scale` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`mean` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`var` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values -`out_mean` | memref of any type values or tensor of any type values or none type -`out_var` | memref of any type values or tensor of any type values or none type -`saved_mean` | memref of any type values or tensor of any type values or none type -`saved_var` | memref of any type values or tensor of any type values or none type +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`out_mean` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`out_var` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`saved_mean` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`saved_var` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type ### `onnx.BatchNormalizationTestMode` (ONNXBatchNormalizationTestModeOp) @@ -390,14 +390,14 @@ ONNX BitShift operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`Y` | memref of any type values or tensor of any type values +`X` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values +`Y` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`Z` | memref of any type values or tensor of any type values +`Z` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values ### `onnx.Cast` (ONNXCastOp) @@ -453,13 +453,13 @@ ONNX Ceil operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Clip` (ONNXClipOp) @@ -473,15 +473,15 @@ ONNX Clip operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values -`min` | memref of any type values or tensor of any type values or none type -`max` | memref of any type values or tensor of any type values or none type +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`min` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`max` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Compress` (ONNXCompressOp) @@ -503,7 +503,7 @@ ONNX Compress operation | Operand | Description | | :-----: | ----------- | `input` | memref of any type values or tensor of any type values -`condition` | memref of any type values or tensor of any type values +`condition` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values #### Results: @@ -579,13 +579,13 @@ ONNX ConstantOfShape operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values ### `onnx.Constant` (ONNXConstantOp) @@ -629,16 +629,16 @@ ONNX ConvInteger operation | Operand | Description | | :-----: | ----------- | -`x` | memref of any type values or tensor of any type values -`w` | memref of any type values or tensor of any type values -`x_zero_point` | memref of any type values or tensor of any type values or none type -`w_zero_point` | memref of any type values or tensor of any type values or none type +`x` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values +`w` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values +`x_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values or none type +`w_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values or none type #### Results: | Result | Description | | :----: | ----------- | -`y` | memref of any type values or tensor of any type values +`y` | tensor of 32-bit signless integer values or memref of 32-bit signless integer values ### `onnx.Conv` (ONNXConvOp) @@ -662,15 +662,15 @@ ONNX Conv operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`W` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values or none type +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`W` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.ConvTranspose` (ONNXConvTransposeOp) @@ -708,15 +708,15 @@ ONNX ConvTranspose operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`W` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values or none type +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`W` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Cos` (ONNXCosOp) @@ -728,13 +728,13 @@ ONNX Cos operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Cosh` (ONNXCoshOp) @@ -746,13 +746,13 @@ ONNX Cosh operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.CumSum` (ONNXCumSumOp) @@ -790,14 +790,14 @@ ONNX CumSum operation | Operand | Description | | :-----: | ----------- | -`x` | memref of any type values or tensor of any type values -`axis` | memref of any type values or tensor of any type values +`x` | tensor of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values +`axis` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`y` | memref of any type values or tensor of any type values +`y` | tensor of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values ### `onnx.DepthToSpace` (ONNXDepthToSpaceOp) @@ -863,9 +863,9 @@ ONNX DequantizeLinear operation | Operand | Description | | :-----: | ----------- | -`x` | memref of any type values or tensor of any type values +`x` | tensor of 8-bit signless integer or 32-bit signless integer values or memref of 8-bit signless integer or 32-bit signless integer values `x_scale` | memref of any type values or tensor of any type values -`x_zero_point` | memref of any type values or tensor of any type values or none type +`x_zero_point` | tensor of 8-bit signless integer or 32-bit signless integer values or memref of 8-bit signless integer or 32-bit signless integer values or none type #### Results: @@ -887,13 +887,13 @@ ONNX Det operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Div` (ONNXDivOp) @@ -907,14 +907,14 @@ ONNX Div operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Dropout` (ONNXDropoutOp) @@ -937,14 +937,14 @@ ONNX Dropout operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values -`mask` | memref of any type values or tensor of any type values or none type +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`mask` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values or none type ### `onnx.DynamicQuantizeLinear` (ONNXDynamicQuantizeLinearOp) @@ -977,15 +977,15 @@ ONNX DynamicQuantizeLinear operation | Operand | Description | | :-----: | ----------- | -`x` | memref of any type values or tensor of any type values +`x` | tensor of 32-bit float values or memref of 32-bit float values #### Results: | Result | Description | | :----: | ----------- | -`y` | memref of any type values or tensor of any type values +`y` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values `y_scale` | memref of any type values or tensor of any type values -`y_zero_point` | memref of any type values or tensor of any type values +`y_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values ### `onnx.Elu` (ONNXEluOp) @@ -1006,13 +1006,13 @@ ONNX Elu operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.EntryPoint` (ONNXEntryPointOp) @@ -1033,14 +1033,14 @@ ONNX Equal operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values ### `onnx.Erf` (ONNXErfOp) @@ -1052,13 +1052,13 @@ ONNX Erf operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Exp` (ONNXExpOp) @@ -1070,13 +1070,13 @@ ONNX Exp operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Expand` (ONNXExpandOp) @@ -1127,13 +1127,13 @@ ONNX EyeLike operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values ### `onnx.Flatten` (ONNXFlattenOp) @@ -1173,13 +1173,13 @@ ONNX Floor operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.GRU` (ONNXGRUOp) @@ -1275,19 +1275,19 @@ ONNX GRU operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`W` | memref of any type values or tensor of any type values -`R` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values or none type -`sequence_lens` | memref of any type values or tensor of any type values or none type -`initial_h` | memref of any type values or tensor of any type values or none type +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`W` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`R` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`sequence_lens` | tensor of 32-bit signless integer values or memref of 32-bit signless integer values or none type +`initial_h` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values or none type -`Y_h` | memref of any type values or tensor of any type values or none type +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`Y_h` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type ### `onnx.GatherElements` (ONNXGatherElementsOp) @@ -1360,7 +1360,7 @@ ONNX GatherElements operation | Operand | Description | | :-----: | ----------- | `data` | memref of any type values or tensor of any type values -`indices` | memref of any type values or tensor of any type values +`indices` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values #### Results: @@ -1524,7 +1524,7 @@ ONNX Gather operation | Operand | Description | | :-----: | ----------- | `data` | memref of any type values or tensor of any type values -`indices` | memref of any type values or tensor of any type values +`indices` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values #### Results: @@ -1563,15 +1563,15 @@ ONNX Gemm operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values -`C` | memref of any type values or tensor of any type values or none type +`A` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values +`B` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values +`C` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or none type #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values ### `onnx.GlobalAveragePool` (ONNXGlobalAveragePoolOp) @@ -1585,13 +1585,13 @@ ONNX GlobalAveragePool operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.GlobalLpPool` (ONNXGlobalLpPoolOp) @@ -1611,13 +1611,13 @@ ONNX GlobalLpPool operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.GlobalMaxPool` (ONNXGlobalMaxPoolOp) @@ -1631,13 +1631,13 @@ ONNX GlobalMaxPool operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Greater` (ONNXGreaterOp) @@ -1652,14 +1652,14 @@ ONNX Greater operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values ### `onnx.HardSigmoid` (ONNXHardSigmoidOp) @@ -1680,13 +1680,13 @@ ONNX HardSigmoid operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Hardmax` (ONNXHardmaxOp) @@ -1717,13 +1717,13 @@ ONNX Hardmax operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Identity` (ONNXIdentityOp) @@ -1760,7 +1760,7 @@ ONNX If operation | Operand | Description | | :-----: | ----------- | -`cond` | memref of any type values or tensor of any type values +`cond` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values #### Results: @@ -1789,15 +1789,15 @@ ONNX InstanceNormalization operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values -`scale` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`scale` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.IsInf` (ONNXIsInfOp) @@ -1816,13 +1816,13 @@ ONNX IsInf operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 64-bit float values or memref of 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values ### `onnx.IsNaN` (ONNXIsNaNOp) @@ -1834,13 +1834,13 @@ ONNX IsNaN operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values ### `onnx.LRN` (ONNXLRNOp) @@ -1870,13 +1870,13 @@ ONNX LRN operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.LSTM` (ONNXLSTMOp) @@ -1980,22 +1980,22 @@ ONNX LSTM operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`W` | memref of any type values or tensor of any type values -`R` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values or none type -`sequence_lens` | memref of any type values or tensor of any type values or none type -`initial_h` | memref of any type values or tensor of any type values or none type -`initial_c` | memref of any type values or tensor of any type values or none type -`P` | memref of any type values or tensor of any type values or none type +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`W` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`R` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`sequence_lens` | tensor of 32-bit signless integer values or memref of 32-bit signless integer values or none type +`initial_h` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`initial_c` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`P` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values or none type -`Y_h` | memref of any type values or tensor of any type values or none type -`Y_c` | memref of any type values or tensor of any type values or none type +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`Y_h` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`Y_c` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type ### `onnx.LeakyRelu` (ONNXLeakyReluOp) @@ -2015,13 +2015,13 @@ ONNX LeakyRelu operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Less` (ONNXLessOp) @@ -2036,14 +2036,14 @@ ONNX Less operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values ### `onnx.Log` (ONNXLogOp) @@ -2055,13 +2055,13 @@ ONNX Log operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.LogSoftmax` (ONNXLogSoftmaxOp) @@ -2092,13 +2092,13 @@ ONNX LogSoftmax operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Loop` (ONNXLoopOp) @@ -2228,8 +2228,8 @@ ONNX Loop operation | Operand | Description | | :-----: | ----------- | -`M` | memref of any type values or tensor of any type values or none type -`cond` | memref of any type values or tensor of any type values or none type +`M` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values or none type +`cond` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values or none type `v_initial` | memref of any type values or tensor of any type values #### Results: @@ -2255,13 +2255,13 @@ ONNX LpNormalization operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.LpPool` (ONNXLpPoolOp) @@ -2287,13 +2287,13 @@ ONNX LpPool operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.MatMulInteger` (ONNXMatMulIntegerOp) @@ -2306,16 +2306,16 @@ ONNX MatMulInteger operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values -`a_zero_point` | memref of any type values or tensor of any type values or none type -`b_zero_point` | memref of any type values or tensor of any type values or none type +`A` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values +`B` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values +`a_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values or none type +`b_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values or none type #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 32-bit signless integer values or memref of 32-bit signless integer values ### `onnx.MatMul` (ONNXMatMulOp) @@ -2327,14 +2327,14 @@ ONNX MatMul operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values +`B` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values ### `onnx.Max` (ONNXMaxOp) @@ -2348,13 +2348,13 @@ ONNX Max operation | Operand | Description | | :-----: | ----------- | -`data_0` | memref of any type values or tensor of any type values +`data_0` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`max` | memref of any type values or tensor of any type values +`max` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.MaxPool` (ONNXMaxPoolOp) @@ -2406,14 +2406,14 @@ ONNX MaxPool operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values -`Indices` | memref of any type values or tensor of any type values or none type +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`Indices` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values or none type ### `onnx.MaxPoolSingleOut` (ONNXMaxPoolSingleOutOp) @@ -2465,14 +2465,14 @@ ONNX MaxRoiPool operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`rois` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`rois` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.MaxUnpool` (ONNXMaxUnpoolOp) @@ -2509,15 +2509,15 @@ ONNX MaxUnpool operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`I` | memref of any type values or tensor of any type values -`output_shape` | memref of any type values or tensor of any type values or none type +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`I` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values +`output_shape` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values or none type #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Mean` (ONNXMeanOp) @@ -2531,13 +2531,13 @@ ONNX Mean operation | Operand | Description | | :-----: | ----------- | -`data_0` | memref of any type values or tensor of any type values +`data_0` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`mean` | memref of any type values or tensor of any type values +`mean` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.MeanVarianceNormalization` (ONNXMeanVarianceNormalizationOp) @@ -2556,13 +2556,13 @@ ONNX MeanVarianceNormalization operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Min` (ONNXMinOp) @@ -2576,13 +2576,13 @@ ONNX Min operation | Operand | Description | | :-----: | ----------- | -`data_0` | memref of any type values or tensor of any type values +`data_0` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`min` | memref of any type values or tensor of any type values +`min` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Mod` (ONNXModOp) @@ -2612,14 +2612,14 @@ ONNX Mod operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Mul` (ONNXMulOp) @@ -2633,14 +2633,14 @@ ONNX Mul operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Multinomial` (ONNXMultinomialOp) @@ -2661,13 +2661,13 @@ ONNX Multinomial operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values ### `onnx.Neg` (ONNXNegOp) @@ -2681,13 +2681,13 @@ ONNX Neg operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 32-bit float or 32-bit signless integer or 8-bit signless integer or 16-bit signless integer or 64-bit signless integer or 16-bit float or 64-bit float values or memref of 32-bit float or 32-bit signless integer or 8-bit signless integer or 16-bit signless integer or 64-bit signless integer or 16-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 32-bit float or 32-bit signless integer or 8-bit signless integer or 16-bit signless integer or 64-bit signless integer or 16-bit float or 64-bit float values or memref of 32-bit float or 32-bit signless integer or 8-bit signless integer or 16-bit signless integer or 64-bit signless integer or 16-bit float or 64-bit float values ### `onnx.NonMaxSuppression` (ONNXNonMaxSuppressionOp) @@ -2754,13 +2754,13 @@ ONNX Not operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values ### `onnx.OneHot` (ONNXOneHotOp) @@ -2796,8 +2796,8 @@ ONNX OneHot operation | Operand | Description | | :-----: | ----------- | -`indices` | memref of any type values or tensor of any type values -`depth` | memref of any type values or tensor of any type values +`indices` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`depth` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values `values` | memref of any type values or tensor of any type values #### Results: @@ -2819,14 +2819,14 @@ ONNX Or operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values +`B` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values ### `onnx.PRelu` (ONNXPReluOp) @@ -2841,14 +2841,14 @@ ONNX PRelu operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`slope` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values +`slope` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values or memref of 16-bit float or 32-bit float or 64-bit float or 32-bit signless integer or 64-bit signless integer values ### `onnx.PadConstantPad` (ONNXPadConstantPadOp) @@ -3032,15 +3032,15 @@ ONNX Pad operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values `pads` | memref of any type values or tensor of any type values or none type -`constant_value` | memref of any type values or tensor of any type values or none type +`constant_value` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Pow` (ONNXPowOp) @@ -3055,14 +3055,14 @@ ONNX Pow operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`Y` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Z` | memref of any type values or tensor of any type values +`Z` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.QLinearConv` (ONNXQLinearConvOp) @@ -3089,21 +3089,21 @@ ONNX QLinearConv operation | Operand | Description | | :-----: | ----------- | -`x` | memref of any type values or tensor of any type values +`x` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values `x_scale` | memref of any type values or tensor of any type values -`x_zero_point` | memref of any type values or tensor of any type values -`w` | memref of any type values or tensor of any type values +`x_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values +`w` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values `w_scale` | memref of any type values or tensor of any type values -`w_zero_point` | memref of any type values or tensor of any type values +`w_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values `y_scale` | memref of any type values or tensor of any type values -`y_zero_point` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values or none type +`y_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values +`B` | tensor of 32-bit signless integer values or memref of 32-bit signless integer values or none type #### Results: | Result | Description | | :----: | ----------- | -`y` | memref of any type values or tensor of any type values +`y` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values ### `onnx.QLinearMatMul` (ONNXQLinearMatMulOp) @@ -3122,20 +3122,20 @@ ONNX QLinearMatMul operation | Operand | Description | | :-----: | ----------- | -`a` | memref of any type values or tensor of any type values +`a` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values `a_scale` | memref of any type values or tensor of any type values -`a_zero_point` | memref of any type values or tensor of any type values -`b` | memref of any type values or tensor of any type values +`a_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values +`b` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values `b_scale` | memref of any type values or tensor of any type values -`b_zero_point` | memref of any type values or tensor of any type values +`b_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values `y_scale` | memref of any type values or tensor of any type values -`y_zero_point` | memref of any type values or tensor of any type values +`y_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`y` | memref of any type values or tensor of any type values +`y` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values ### `onnx.QuantizeLinear` (ONNXQuantizeLinearOp) @@ -3149,15 +3149,15 @@ ONNX QuantizeLinear operation | Operand | Description | | :-----: | ----------- | -`x` | memref of any type values or tensor of any type values +`x` | tensor of 32-bit float or 32-bit signless integer values or memref of 32-bit float or 32-bit signless integer values `y_scale` | memref of any type values or tensor of any type values -`y_zero_point` | memref of any type values or tensor of any type values or none type +`y_zero_point` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values or none type #### Results: | Result | Description | | :----: | ----------- | -`y` | memref of any type values or tensor of any type values +`y` | tensor of 8-bit signless integer values or memref of 8-bit signless integer values ### `onnx.RNN` (ONNXRNNOp) @@ -3240,19 +3240,19 @@ ONNX RNN operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`W` | memref of any type values or tensor of any type values -`R` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values or none type -`sequence_lens` | memref of any type values or tensor of any type values or none type -`initial_h` | memref of any type values or tensor of any type values or none type +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`W` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`R` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`sequence_lens` | tensor of 32-bit signless integer values or memref of 32-bit signless integer values or none type +`initial_h` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values or none type -`Y_h` | memref of any type values or tensor of any type values or none type +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type +`Y_h` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values or none type ### `onnx.RandomNormalLike` (ONNXRandomNormalLikeOp) @@ -3285,7 +3285,7 @@ ONNX RandomNormalLike operation | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.RandomNormal` (ONNXRandomNormalOp) @@ -3313,7 +3313,7 @@ ONNX RandomNormal operation | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.RandomUniformLike` (ONNXRandomUniformLikeOp) @@ -3346,7 +3346,7 @@ ONNX RandomUniformLike operation | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.RandomUniform` (ONNXRandomUniformOp) @@ -3373,7 +3373,7 @@ ONNX RandomUniform operation | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Range` (ONNXRangeOp) @@ -3409,15 +3409,15 @@ ONNX Range operation | Operand | Description | | :-----: | ----------- | -`start` | memref of any type values or tensor of any type values -`limit` | memref of any type values or tensor of any type values -`delta` | memref of any type values or tensor of any type values +`start` | tensor of 32-bit float or 64-bit float or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or memref of 32-bit float or 64-bit float or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values +`limit` | tensor of 32-bit float or 64-bit float or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or memref of 32-bit float or 64-bit float or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values +`delta` | tensor of 32-bit float or 64-bit float or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or memref of 32-bit float or 64-bit float or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 32-bit float or 64-bit float or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values or memref of 32-bit float or 64-bit float or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer values ### `onnx.Reciprocal` (ONNXReciprocalOp) @@ -3431,13 +3431,13 @@ ONNX Reciprocal operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceL1` (ONNXReduceL1Op) @@ -3461,13 +3461,13 @@ ONNX ReduceL1 operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceL2` (ONNXReduceL2Op) @@ -3491,13 +3491,13 @@ ONNX ReduceL2 operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceLogSumExp` (ONNXReduceLogSumExpOp) @@ -3521,13 +3521,13 @@ ONNX ReduceLogSumExp operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceLogSum` (ONNXReduceLogSumOp) @@ -3551,13 +3551,13 @@ ONNX ReduceLogSum operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceMax` (ONNXReduceMaxOp) @@ -3581,13 +3581,13 @@ ONNX ReduceMax operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceMean` (ONNXReduceMeanOp) @@ -3611,13 +3611,13 @@ ONNX ReduceMean operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceMin` (ONNXReduceMinOp) @@ -3641,13 +3641,13 @@ ONNX ReduceMin operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceProd` (ONNXReduceProdOp) @@ -3671,13 +3671,13 @@ ONNX ReduceProd operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceSum` (ONNXReduceSumOp) @@ -3701,13 +3701,13 @@ ONNX ReduceSum operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.ReduceSumSquare` (ONNXReduceSumSquareOp) @@ -3731,13 +3731,13 @@ ONNX ReduceSumSquare operation | Operand | Description | | :-----: | ----------- | -`data` | memref of any type values or tensor of any type values +`data` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`reduced` | memref of any type values or tensor of any type values +`reduced` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Relu` (ONNXReluOp) @@ -3751,13 +3751,13 @@ ONNX Relu operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Reshape` (ONNXReshapeOp) @@ -3807,7 +3807,7 @@ ONNX Resize operation | Operand | Description | | :-----: | ----------- | `X` | memref of any type values or tensor of any type values -`roi` | memref of any type values or tensor of any type values +`roi` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values `scales` | memref of any type values or tensor of any type values `sizes` | memref of any type values or tensor of any type values or none type @@ -3905,15 +3905,15 @@ ONNX RoiAlign operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values -`rois` | memref of any type values or tensor of any type values -`batch_indices` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`rois` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values +`batch_indices` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Round` (ONNXRoundOp) @@ -3937,13 +3937,13 @@ ONNX Round operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Scan` (ONNXScanOp) @@ -4161,7 +4161,7 @@ ONNX ScatterElements operation | Operand | Description | | :-----: | ----------- | `data` | memref of any type values or tensor of any type values -`indices` | memref of any type values or tensor of any type values +`indices` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values `updates` | memref of any type values or tensor of any type values #### Results: @@ -4314,7 +4314,7 @@ ONNX Scatter operation | Operand | Description | | :-----: | ----------- | `data` | memref of any type values or tensor of any type values -`indices` | memref of any type values or tensor of any type values +`indices` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values `updates` | memref of any type values or tensor of any type values #### Results: @@ -4343,13 +4343,13 @@ ONNX Selu operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.SequenceAt` (ONNXSequenceAtOp) @@ -4364,7 +4364,7 @@ ONNX SequenceAt operation | Operand | Description | | :-----: | ----------- | `input_sequence` | memref of any type values or tensor of any type values -`position` | memref of any type values or tensor of any type values +`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values #### Results: @@ -4423,7 +4423,7 @@ ONNX SequenceErase operation | Operand | Description | | :-----: | ----------- | `input_sequence` | memref of any type values or tensor of any type values -`position` | memref of any type values or tensor of any type values or none type +`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type #### Results: @@ -4447,7 +4447,7 @@ ONNX SequenceInsert operation | :-----: | ----------- | `input_sequence` | memref of any type values or tensor of any type values `tensor` | memref of any type values or tensor of any type values -`position` | memref of any type values or tensor of any type values or none type +`position` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type #### Results: @@ -4471,7 +4471,7 @@ ONNX SequenceLength operation | Result | Description | | :----: | ----------- | -`length` | memref of any type values or tensor of any type values +`length` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values ### `onnx.Shape` (ONNXShapeOp) @@ -4489,7 +4489,7 @@ ONNX Shape operation | Result | Description | | :----: | ----------- | -`shape` | memref of any type values or tensor of any type values +`shape` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values ### `onnx.Shrink` (ONNXShrinkOp) @@ -4511,13 +4511,13 @@ ONNX Shrink operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Sigmoid` (ONNXSigmoidOp) @@ -4531,13 +4531,13 @@ ONNX Sigmoid operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Sign` (ONNXSignOp) @@ -4550,13 +4550,13 @@ ONNX Sign operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Sin` (ONNXSinOp) @@ -4568,13 +4568,13 @@ ONNX Sin operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Sinh` (ONNXSinhOp) @@ -4586,13 +4586,13 @@ ONNX Sinh operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Size` (ONNXSizeOp) @@ -4610,7 +4610,7 @@ ONNX Size operation | Result | Description | | :----: | ----------- | -`size` | memref of any type values or tensor of any type values +`size` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values ### `onnx.Slice` (ONNXSliceOp) @@ -4656,10 +4656,10 @@ ONNX Slice operation | Operand | Description | | :-----: | ----------- | `data` | memref of any type values or tensor of any type values -`starts` | memref of any type values or tensor of any type values -`ends` | memref of any type values or tensor of any type values -`axes` | memref of any type values or tensor of any type values or none type -`steps` | memref of any type values or tensor of any type values or none type +`starts` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values +`ends` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values +`axes` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type +`steps` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type #### Results: @@ -4696,13 +4696,13 @@ ONNX Softmax operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Softplus` (ONNXSoftplusOp) @@ -4716,13 +4716,13 @@ ONNX Softplus operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Softsign` (ONNXSoftsignOp) @@ -4734,13 +4734,13 @@ ONNX Softsign operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.SpaceToDepth` (ONNXSpaceToDepthOp) @@ -4822,7 +4822,7 @@ ONNX SplitToSequence operation | Operand | Description | | :-----: | ----------- | `input` | memref of any type values or tensor of any type values -`split` | memref of any type values or tensor of any type values or none type +`split` | tensor of 32-bit signless integer or 64-bit signless integer values or memref of 32-bit signless integer or 64-bit signless integer values or none type #### Results: @@ -4842,13 +4842,13 @@ ONNX Sqrt operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Squeeze` (ONNXSqueezeOp) @@ -4924,14 +4924,14 @@ ONNX Sub operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`B` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values ### `onnx.Sum` (ONNXSumOp) @@ -4945,13 +4945,13 @@ ONNX Sum operation | Operand | Description | | :-----: | ----------- | -`data_0` | memref of any type values or tensor of any type values +`data_0` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`sum` | memref of any type values or tensor of any type values +`sum` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Tan` (ONNXTanOp) @@ -4963,13 +4963,13 @@ ONNX Tan operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Tanh` (ONNXTanhOp) @@ -4981,13 +4981,13 @@ ONNX Tanh operation | Operand | Description | | :-----: | ----------- | -`input` | memref of any type values or tensor of any type values +`input` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`output` | memref of any type values or tensor of any type values +`output` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.TfIdfVectorizer` (ONNXTfIdfVectorizerOp) @@ -5045,7 +5045,7 @@ ONNX TfIdfVectorizer operation | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 32-bit float values or memref of 32-bit float values ### `onnx.ThresholdedRelu` (ONNXThresholdedReluOp) @@ -5065,13 +5065,13 @@ ONNX ThresholdedRelu operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -`Y` | memref of any type values or tensor of any type values +`Y` | tensor of 16-bit float or 32-bit float or 64-bit float values or memref of 16-bit float or 32-bit float or 64-bit float values ### `onnx.Tile` (ONNXTileOp) @@ -5086,7 +5086,7 @@ ONNX Tile operation | Operand | Description | | :-----: | ----------- | `input` | memref of any type values or tensor of any type values -`repeats` | memref of any type values or tensor of any type values +`repeats` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values #### Results: @@ -5125,15 +5125,15 @@ ONNX TopK operation | Operand | Description | | :-----: | ----------- | -`X` | memref of any type values or tensor of any type values +`X` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values `K` | memref of any type values or tensor of any type values #### Results: | Result | Description | | :----: | ----------- | -`Values` | memref of any type values or tensor of any type values -`Indices` | memref of any type values or tensor of any type values +`Values` | tensor of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values or memref of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 16-bit float or 32-bit float or 64-bit float values +`Indices` | tensor of 64-bit signless integer values or memref of 64-bit signless integer values ### `onnx.Transpose` (ONNXTransposeOp) @@ -5338,7 +5338,7 @@ ONNX Where operation | Operand | Description | | :-----: | ----------- | -`condition` | memref of any type values or tensor of any type values +`condition` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values `X` | memref of any type values or tensor of any type values `Y` | memref of any type values or tensor of any type values @@ -5361,12 +5361,12 @@ ONNX Xor operation | Operand | Description | | :-----: | ----------- | -`A` | memref of any type values or tensor of any type values -`B` | memref of any type values or tensor of any type values +`A` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values +`B` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -`C` | memref of any type values or tensor of any type values +`C` | tensor of 1-bit signless integer values or memref of 1-bit signless integer values diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index aecbbdb..932d59d 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -197,6 +197,47 @@ private: } } +#define MAX_TYPE 20 + // itblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', + // 'F64', 'Complex', 'Complex' ) + mlir::Type buildTypeFromIndex(int index) { + switch (index) { + case 0: + return builder_.getI1Type(); + case 1: + return builder_.getIntegerType(8); + case 2: + return builder_.getIntegerType(16); + case 3: + return builder_.getIntegerType(32); + case 4: + return builder_.getIntegerType(64); + case 5: + return builder_.getBF16Type(); + case 6: + return builder_.getF16Type(); + case 7: + return builder_.getF32Type(); + case 8: + return builder_.getF64Type(); + case 9: { + std::vector typeTuple(2); + typeTuple.push_back(builder_.getF32Type()); + typeTuple.push_back(builder_.getF32Type()); + return builder_.getTupleType(llvm::ArrayRef(typeTuple)); + } + case 10: { + std::vector typeTuple(2); + typeTuple.push_back(builder_.getF64Type()); + typeTuple.push_back(builder_.getF64Type()); + return builder_.getTupleType(llvm::ArrayRef(typeTuple)); + } + default: + assert(false && "Unsupported type index encountered."); + return nullptr; + } + } + template void buildOutputAndOperation(const onnx::NodeProto &node, std::vector inputs, int expectedNumOperands, @@ -217,13 +258,34 @@ private: inputs.emplace_back(none_); std::vector outputTypes; - for (auto item : node.output()) { + + // Use the type map to determine the data type of output. + std::vector outputMap = T::getTypeMap(); + for (auto i = 0; i < node.output().size(); i++) { // Optional outputs using empty string. - if (item.empty()) + if (node.output()[i].empty()) { outputTypes.emplace_back(builder_.getNoneType()); - else - outputTypes.push_back( - mlir::UnrankedTensorType::get(builder_.getF32Type())); + } else { + if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) { + // Mapping gives a connection with an input. + mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType(); + if (inputType.isa()) { + auto elementType = + inputType.cast().getElementType(); + auto outType = mlir::UnrankedTensorType::get(elementType); + outputTypes.emplace_back(outType); + } else { + outputTypes.push_back(inputType); + } + } else if (i < outputMap.size() && outputMap[i] != -1) { + // Mapping gives a direct type. + auto elementType = buildTypeFromIndex(outputMap[i]); + auto outType = mlir::UnrankedTensorType::get(elementType); + outputTypes.emplace_back(outType); + } else { + outputTypes.emplace_back(builder_.getNoneType()); + } + } } // Trailing optional outputs. if (!variadicOut) @@ -241,9 +303,10 @@ private: } template - void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1, - int expectedNumResults = -1) { + void buildOperation(const onnx::NodeProto &node) { std::vector inputs; + int expectedNumOperands = T::getNumberOfOperands(); + int expectedNumResults = T::getNumberOfResults(); for (const auto &item : node.input()) if (initializedTensors.ContainKey(legalize_name(item))) { inputs.push_back(initializedTensors.EmitInitializerForInputTensor( @@ -256,7 +319,9 @@ private: node, inputs, expectedNumOperands, expectedNumResults); } - void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) { + void ImportNodeReshape(onnx::NodeProto node) { + int expectedNumOperands = mlir::ONNXReshapeOp::getNumberOfOperands(); + int expectedNumResults = mlir::ONNXReshapeOp::getNumberOfResults(); std::vector inputs; std::string item; for (int i = 0; i < node.input().size(); ++i) { @@ -270,39 +335,40 @@ private: } } - buildOutputAndOperation(node, inputs, nIn, nOut); + buildOutputAndOperation( + node, inputs, expectedNumOperands, expectedNumResults); } /*! * Special handle for MaxPool operations. */ - void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) { + void ImportNodeMaxPool(onnx::NodeProto node) { int nOuts = node.output().size(); if (nOuts == 1) { - buildOperation(node, nIn, nOuts); + buildOperation(node); } else { - buildOperation(node, nIn, nOuts); + buildOperation(node); } } /*! * Special handle for BatchNormalization operations. */ - void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) { + void ImportNodeBatchNormalization(onnx::NodeProto node) { int nOuts = node.output().size(); if (nOuts == 1) { // Test mode with one output. - buildOperation(node, nIn, nOuts); + buildOperation(node); } else { // Training mode with four trailing optional outputs. Not handled yet. - buildOperation(node, nIn, nOuts); + buildOperation(node); } } /*! * Special handle for Pad operations. */ - void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) { + void ImportNodePad(onnx::NodeProto node) { int nOps = node.input().size(); if (nOps == 2) { @@ -330,9 +396,11 @@ private: } inputs.push_back(constantResult); + int nIn = mlir::ONNXPadOp::getNumberOfOperands(); + int nOut = mlir::ONNXPadOp::getNumberOfResults(); buildOutputAndOperation(node, inputs, nIn, nOut); } else { - buildOperation(node, nIn, nOut); + buildOperation(node); } } diff --git a/src/Builder/MLOpBuildTable.inc b/src/Builder/MLOpBuildTable.inc index 582497b..58d3d3b 100644 --- a/src/Builder/MLOpBuildTable.inc +++ b/src/Builder/MLOpBuildTable.inc @@ -5,38 +5,38 @@ //******************************************************** if (opName == "ArrayFeatureExtractor") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Binarizer") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "CastMap") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "CategoryMapper") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "DictVectorizer") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "FeatureVectorizer") - return buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Imputer") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "LabelEncoder") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "LinearClassifier") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); + buildOperation(node); if (opName == "LinearRegressor") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Normalizer") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "OneHotEncoder") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "SVMClassifier") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); + buildOperation(node); if (opName == "SVMRegressor") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Scaler") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "TreeEnsembleClassifier") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); + buildOperation(node); if (opName == "TreeEnsembleRegressor") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ZipMap") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 94e8b54..2b3f7bf 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -5,314 +5,314 @@ //******************************************************** if (opName == "Abs") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Acos") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Acosh") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Add") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "And") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ArgMax") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ArgMin") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Asin") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Asinh") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Atan") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Atanh") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "AveragePool") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "BatchNormalization") - ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5); + ImportNodeBatchNormalization(node); if (opName == "BitShift") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Cast") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Ceil") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Clip") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Compress") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Concat") - buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ConcatFromSequence") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Constant") - buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ConstantOfShape") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Conv") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ConvInteger") - buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ConvTranspose") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Cos") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Cosh") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "CumSum") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "DepthToSpace") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "DequantizeLinear") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Det") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Div") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Dropout") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); + buildOperation(node); if (opName == "DynamicQuantizeLinear") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3); + buildOperation(node); if (opName == "Elu") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Equal") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Erf") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Exp") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Expand") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "EyeLike") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Flatten") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Floor") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "GRU") - buildOperation(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); + buildOperation(node); if (opName == "Gather") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "GatherElements") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "GatherND") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Gemm") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "GlobalAveragePool") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "GlobalLpPool") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "GlobalMaxPool") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Greater") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "HardSigmoid") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Hardmax") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Identity") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "If") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); + buildOperation(node); if (opName == "InstanceNormalization") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "IsInf") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "IsNaN") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "LRN") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "LSTM") - buildOperation(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3); + buildOperation(node); if (opName == "LeakyRelu") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Less") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Log") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "LogSoftmax") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Loop") buildOperation(node); if (opName == "LpNormalization") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "LpPool") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "MatMul") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "MatMulInteger") - buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Max") - buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "MaxPool") - ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); + ImportNodeMaxPool(node); if (opName == "MaxRoiPool") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "MaxUnpool") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Mean") - buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "MeanVarianceNormalization") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Min") - buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Mod") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Mul") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Multinomial") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Neg") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "NonMaxSuppression") - buildOperation(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "NonZero") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Not") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "OneHot") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Or") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "PRelu") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Pad") - ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + ImportNodePad(node); if (opName == "Pow") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "QLinearConv") - buildOperation(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "QLinearMatMul") - buildOperation(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "QuantizeLinear") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "RNN") - buildOperation(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); + buildOperation(node); if (opName == "RandomNormal") - buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "RandomNormalLike") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "RandomUniform") - buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "RandomUniformLike") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Range") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Reciprocal") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceL1") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceL2") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceLogSum") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceLogSumExp") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceMax") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceMean") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceMin") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceProd") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceSum") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReduceSumSquare") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Relu") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Reshape") - ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + ImportNodeReshape(node); if (opName == "Resize") - buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ReverseSequence") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "RoiAlign") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Round") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Scan") buildOperation(node); if (opName == "Scatter") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ScatterElements") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ScatterND") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Selu") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "SequenceAt") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "SequenceConstruct") - buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "SequenceEmpty") - buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "SequenceErase") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "SequenceInsert") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "SequenceLength") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Shape") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Shrink") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Sigmoid") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Sign") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Sin") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Sinh") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Size") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Slice") - buildOperation(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Softmax") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Softplus") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Softsign") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "SpaceToDepth") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Split") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); + buildOperation(node); if (opName == "SplitToSequence") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Sqrt") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Squeeze") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "StringNormalizer") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Sub") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Sum") - buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Tan") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Tanh") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "TfIdfVectorizer") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "ThresholdedRelu") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Tile") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "TopK") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2); + buildOperation(node); if (opName == "Transpose") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Unique") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4); + buildOperation(node); if (opName == "Unsqueeze") - buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Upsample") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Where") - buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node); if (opName == "Xor") - buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node); diff --git a/src/Dialect/MLONNX/MLONNXOps.td.inc b/src/Dialect/MLONNX/MLONNXOps.td.inc index 88adf54..f56efa6 100644 --- a/src/Dialect/MLONNX/MLONNXOps.td.inc +++ b/src/Dialect/MLONNX/MLONNXOps.td.inc @@ -14,6 +14,17 @@ def MLONNXArrayFeatureExtractorOp:MLONNX_Op<"ArrayFeatureExtractor", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def MLONNXBinarizerOp:MLONNX_Op<"Binarizer", @@ -22,9 +33,20 @@ def MLONNXBinarizerOp:MLONNX_Op<"Binarizer", let description = [{ "Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, DefaultValuedAttr:$threshold); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def MLONNXCastMapOp:MLONNX_Op<"CastMap", @@ -40,6 +62,17 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap", DefaultValuedAttr:$map_form, DefaultValuedAttr:$max_map); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper", @@ -61,6 +94,17 @@ def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper", DefaultValuedAttr:$default_int64, DefaultValuedAttr:$default_string); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer", @@ -84,6 +128,17 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer", OptionalAttr:$int64_vocabulary, OptionalAttr:$string_vocabulary); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer", @@ -95,9 +150,20 @@ def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer", " Inputs are copied to the output maintaining the order of the input arguments.
" " All inputs must be integers or floats, while the output will be all floating point values." }]; - let arguments = (ins Variadic>:$X, + let arguments = (ins Variadic, MemRefOf<[I32,I64,F32,F64]>]>>:$X, OptionalAttr:$inputdimensions); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXImputerOp:MLONNX_Op<"Imputer", @@ -113,12 +179,23 @@ def MLONNXImputerOp:MLONNX_Op<"Imputer", " which one depends on whether floats or integers are being processed.
" " The imputed_value attribute length can be 1 element, or it can have one element per input feature.
In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, OptionalAttr:$imputed_value_floats, OptionalAttr:$imputed_value_int64s, DefaultValuedAttr:$replaced_value_float, DefaultValuedAttr:$replaced_value_int64); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder", @@ -154,6 +231,17 @@ def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder", OptionalAttr:$values_int64s, OptionalAttr:$values_strings); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier", @@ -162,7 +250,7 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier", let description = [{ "Linear classifier" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, OptionalAttr:$classlabels_ints, OptionalAttr:$classlabels_strings, F32ArrayAttr:$coefficients, @@ -171,6 +259,17 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier", DefaultValuedAttr:$post_transform); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 2; + } + static std::vector getTypeMap() { + return {-1,-1}; + } + }]; } def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor", @@ -184,12 +283,23 @@ def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor", " The coefficients array is of length n, and the coefficients for each target are contiguous." " Intercepts are optional but if provided must match the number of targets." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, OptionalAttr:$coefficients, OptionalAttr:$intercepts, DefaultValuedAttr:$post_transform, DefaultValuedAttr:$targets); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXNormalizerOp:MLONNX_Op<"Normalizer", @@ -207,9 +317,20 @@ def MLONNXNormalizerOp:MLONNX_Op<"Normalizer", " For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row" " of the batch is normalized independently." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, DefaultValuedAttr:$norm); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder", @@ -230,6 +351,17 @@ def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder", OptionalAttr:$cats_strings, DefaultValuedAttr:$zeros); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier", @@ -238,7 +370,7 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier", let description = [{ "Support Vector Machine classifier" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, OptionalAttr:$classlabels_ints, OptionalAttr:$classlabels_strings, OptionalAttr:$coefficients, @@ -252,6 +384,17 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier", OptionalAttr:$vectors_per_class); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 2; + } + static std::vector getTypeMap() { + return {-1,-1}; + } + }]; } def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor", @@ -260,7 +403,7 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor", let description = [{ "Support Vector Machine regression prediction and one-class SVM anomaly detection." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, OptionalAttr:$coefficients, OptionalAttr:$kernel_params, DefaultValuedAttr:$kernel_type, @@ -270,6 +413,17 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor", OptionalAttr:$rho, OptionalAttr:$support_vectors); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXScalerOp:MLONNX_Op<"Scaler", @@ -278,10 +432,21 @@ def MLONNXScalerOp:MLONNX_Op<"Scaler", let description = [{ "Rescale input data, for example to standardize features by removing the mean and scaling to unit variance." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, OptionalAttr:$offset, OptionalAttr:$scale); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier", @@ -298,7 +463,7 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier", " One and only one of classlabels_strings or classlabels_int64s" " will be defined. The class_ids are indices into this list." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, OptionalAttr:$base_values, OptionalAttr:$class_ids, OptionalAttr:$class_nodeids, @@ -318,6 +483,17 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier", DefaultValuedAttr:$post_transform); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 2; + } + static std::vector getTypeMap() { + return {-1,-1}; + } + }]; } def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor", @@ -335,7 +511,7 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor", " All trees must have their node ids start at 0 and increment by 1.
" " Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, DefaultValuedAttr:$aggregate_function, OptionalAttr:$base_values, OptionalAttr:$n_targets, @@ -354,6 +530,17 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor", OptionalAttr:$target_treeids, OptionalAttr:$target_weights); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def MLONNXZipMapOp:MLONNX_Op<"ZipMap", @@ -369,5 +556,16 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap", OptionalAttr:$classlabels_int64s, OptionalAttr:$classlabels_strings); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } diff --git a/src/Dialect/ONNX/ONNXOps.td b/src/Dialect/ONNX/ONNXOps.td index 1ae7791..8187ddb 100644 --- a/src/Dialect/ONNX/ONNXOps.td +++ b/src/Dialect/ONNX/ONNXOps.td @@ -38,7 +38,7 @@ def ONNX_Dialect : Dialect { // * The mnemonic for the operation, or the name without the dialect prefix. // * A list of traits for the operation. class ONNX_Op traits = []> : - Op; + Op ; //===----------------------------------------------------------------------===// // ONNX Operations @@ -112,6 +112,17 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", DefaultValuedAttr:$storage_order, OptionalAttr:$strides); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", @@ -137,6 +148,17 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", DefaultValuedAttr:$epsilon, DefaultValuedAttr:$momentum); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 5; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue", @@ -154,6 +176,17 @@ def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue", DefaultValuedAttr:$constant_value, DefaultValuedAttr:$mode); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad", @@ -168,6 +201,17 @@ def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad", I64ArrayAttr:$pads, DefaultValuedAttr:$mode); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad", @@ -186,6 +230,17 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad", let builders = [OpBuilder<"OpBuilder &builder, OperationState &state, " "Value data, ArrayAttr pads, " "FloatAttr constant_value, StringAttr mode">]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } #endif // ONNX_OPS diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 7993e89..e49955c 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -12,8 +12,8 @@ def ONNXAbsOp:ONNX_Op<"Abs", "(Tensor) where the absolute is, y = abs(x), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$Y); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value X", [{ auto elementType = X.getType().cast().getElementType(); @@ -26,6 +26,17 @@ def ONNXAbsOp:ONNX_Op<"Abs", build(builder, state, outputTypes, operands, attributes); }]> ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXAcosOp:ONNX_Op<"Acos", @@ -34,8 +45,19 @@ def ONNXAcosOp:ONNX_Op<"Acos", let description = [{ "Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXAcoshOp:ONNX_Op<"Acosh", @@ -44,8 +66,19 @@ def ONNXAcoshOp:ONNX_Op<"Acosh", let description = [{ "Calculates the hyperbolic arccosine of the given input tensor element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXAddOp:ONNX_Op<"Add", @@ -57,9 +90,20 @@ def ONNXAddOp:ONNX_Op<"Add", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, + AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXAndOp:ONNX_Op<"And", @@ -71,9 +115,20 @@ def ONNXAndOp:ONNX_Op<"And", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A, + AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXArgMaxOp:ONNX_Op<"ArgMax", @@ -85,10 +140,21 @@ def ONNXArgMaxOp:ONNX_Op<"ArgMax", "If keepdims equal 0, then the resulted tensor have the reduced dimension pruned. " "The type of the output tensor is integer." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$data, DefaultValuedAttr:$axis, DefaultValuedAttr:$keepdims); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXArgMinOp:ONNX_Op<"ArgMin", @@ -100,10 +166,21 @@ def ONNXArgMinOp:ONNX_Op<"ArgMin", "If keepdims equal 0, then the resulted tensor have the reduced dimension pruned. " "The type of the output tensor is integer." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$data, DefaultValuedAttr:$axis, DefaultValuedAttr:$keepdims); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXAsinOp:ONNX_Op<"Asin", @@ -112,8 +189,19 @@ def ONNXAsinOp:ONNX_Op<"Asin", let description = [{ "Calculates the arcsine (inverse of sine) of the given input tensor, element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXAsinhOp:ONNX_Op<"Asinh", @@ -122,8 +210,19 @@ def ONNXAsinhOp:ONNX_Op<"Asinh", let description = [{ "Calculates the hyperbolic arcsine of the given input tensor element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXAtanOp:ONNX_Op<"Atan", @@ -132,8 +231,19 @@ def ONNXAtanOp:ONNX_Op<"Atan", let description = [{ "Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXAtanhOp:ONNX_Op<"Atanh", @@ -142,8 +252,19 @@ def ONNXAtanhOp:ONNX_Op<"Atanh", let description = [{ "Calculates the hyperbolic arctangent of the given input tensor element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXAveragePoolOp:ONNX_Op<"AveragePool", @@ -180,14 +301,25 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool", " The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero)." " " }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$auto_pad, DefaultValuedAttr:$ceil_mode, DefaultValuedAttr:$count_include_pad, I64ArrayAttr:$kernel_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization", @@ -205,18 +337,29 @@ def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization", "to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op." "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$mean, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$var, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$scale, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$B, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$mean, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$var, DefaultValuedAttr:$epsilon, DefaultValuedAttr:$momentum); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$out_mean, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$out_var, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$saved_mean, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$saved_var); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$out_mean, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$out_var, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$saved_mean, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$saved_var); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 5; + } + static int getNumberOfResults() { + return 5; + } + static std::vector getTypeMap() { + return {20,20,20,20,20}; + } + }]; } def ONNXBitShiftOp:ONNX_Op<"BitShift", @@ -236,10 +379,21 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift", " not necessarily identical." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64]>, MemRefOf<[I8,I16,I32,I64]>]>:$X, + AnyTypeOf<[TensorOf<[I8,I16,I32,I64]>, MemRefOf<[I8,I16,I32,I64]>]>:$Y, StrAttr:$direction); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); + let results = (outs AnyTypeOf<[TensorOf<[I8,I16,I32,I64]>, MemRefOf<[I8,I16,I32,I64]>]>:$Z); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXCastOp:ONNX_Op<"Cast", @@ -269,6 +423,17 @@ def ONNXCastOp:ONNX_Op<"Cast", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, I64Attr:$to); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXCeilOp:ONNX_Op<"Ceil", @@ -279,8 +444,19 @@ def ONNXCeilOp:ONNX_Op<"Ceil", "(Tensor) where the ceil is, y = ceil(x), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXClipOp:ONNX_Op<"Clip", @@ -291,10 +467,21 @@ def ONNXClipOp:ONNX_Op<"Clip", "specified by the inputs 'min' and 'max'. They default to" "numeric_limits::lowest() and numeric_limits::max(), respectively." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$min, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$max); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$min, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$max); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXCompressOp:ONNX_Op<"Compress", @@ -307,9 +494,20 @@ def ONNXCompressOp:ONNX_Op<"Compress", " " }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$condition, + AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$condition, OptionalAttr:$axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXConcatOp:ONNX_Op<"Concat", @@ -321,6 +519,17 @@ def ONNXConcatOp:ONNX_Op<"Concat", let arguments = (ins Variadic>:$inputs, I64Attr:$axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$concat_result); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", @@ -336,6 +545,17 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", I64Attr:$axis, DefaultValuedAttr:$new_axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$concat_result); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXConstantOp:ONNX_Op<"Constant", @@ -348,6 +568,17 @@ def ONNXConstantOp:ONNX_Op<"Constant", let arguments = (ins OptionalAttr:$sparse_value, OptionalAttr:$value); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 0; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ if (value) { @@ -366,9 +597,20 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", let description = [{ "Generate a tensor with given value and shape." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$input, OptionalAttr:$value); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64,I8,I16,I32,I64,I1]>, MemRefOf<[F16,F32,F64,I8,I16,I32,I64,I1]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXConvOp:ONNX_Op<"Conv", @@ -379,16 +621,27 @@ def ONNXConvOp:ONNX_Op<"Conv", "The convolution operator consumes an input tensor and a filter, and" "computes the output." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$W, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$B, DefaultValuedAttr:$auto_pad, OptionalAttr:$dilations, DefaultValuedAttr:$group, OptionalAttr:$kernel_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", @@ -398,17 +651,28 @@ def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", "The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point," "and computes the output. The production MUST never overflow. The accumulation may overflow if and only if in 32 bits." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$w, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$x_zero_point, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$w_zero_point, + let arguments = (ins AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$x, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$w, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>, NoneType]>:$x_zero_point, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>, NoneType]>:$w_zero_point, DefaultValuedAttr:$auto_pad, OptionalAttr:$dilations, DefaultValuedAttr:$group, OptionalAttr:$kernel_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y); + let results = (outs AnyTypeOf<[TensorOf<[I32]>, MemRefOf<[I32]>]>:$y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 4; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {3}; + } + }]; } def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", @@ -430,9 +694,9 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", "" " " }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$W, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$B, DefaultValuedAttr:$auto_pad, OptionalAttr:$dilations, DefaultValuedAttr:$group, @@ -441,7 +705,18 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", OptionalAttr:$output_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXCosOp:ONNX_Op<"Cos", @@ -450,8 +725,19 @@ def ONNXCosOp:ONNX_Op<"Cos", let description = [{ "Calculates the cosine of the given input tensor, element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXCoshOp:ONNX_Op<"Cosh", @@ -460,8 +746,19 @@ def ONNXCoshOp:ONNX_Op<"Cosh", let description = [{ "Calculates the hyperbolic cosine of the given input tensor element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXCumSumOp:ONNX_Op<"CumSum", @@ -489,11 +786,22 @@ def ONNXCumSumOp:ONNX_Op<"CumSum", "```" " " }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$axis, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F32,F64]>, MemRefOf<[I32,I64,F32,F64]>]>:$x, + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$axis, DefaultValuedAttr:$exclusive, DefaultValuedAttr:$reverse); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F32,F64]>, MemRefOf<[I32,I64,F32,F64]>]>:$y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace", @@ -532,6 +840,17 @@ def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace", I64Attr:$blocksize, DefaultValuedAttr:$mode); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", @@ -543,10 +862,21 @@ def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", "'x_zero_point' and 'x' must have same type. 'x' and 'y' must have same shape. In the case of dequantizing int32," "there's no zero point (zero point is supposed to be 0)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I32]>, MemRefOf<[I8,I32]>]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_scale, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$x_zero_point); + AnyTypeOf<[TensorOf<[I8,I32]>, MemRefOf<[I8,I32]>, NoneType]>:$x_zero_point); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {21}; + } + }]; } def ONNXDetOp:ONNX_Op<"Det", @@ -559,8 +889,19 @@ def ONNXDetOp:ONNX_Op<"Det", "The output is a tensor of shape `[*]`, containing the determinants of all input submatrices." "e.g., When the input is 2-D, the output is a scalar(shape is empty: `[]`)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXDivOp:ONNX_Op<"Div", @@ -571,9 +912,20 @@ def ONNXDivOp:ONNX_Op<"Div", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, + AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXDropoutOp:ONNX_Op<"Dropout", @@ -587,10 +939,21 @@ def ONNXDropoutOp:ONNX_Op<"Dropout", "the training phase, so during testing nothing needs to be done." "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$data, DefaultValuedAttr:$ratio); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$mask); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output, + AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>, NoneType]>:$mask); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 2; + } + static std::vector getTypeMap() { + return {20,0}; + } + }]; } def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", @@ -620,10 +983,21 @@ def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", "* rounding to nearest ties to even." "```" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y, + let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, MemRefOf<[F32]>]>:$x); + let results = (outs AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$y, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point); + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$y_zero_point); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 3; + } + static std::vector getTypeMap() { + return {1,-1,1}; + } + }]; } def ONNXEluOp:ONNX_Op<"Elu", @@ -635,9 +1009,20 @@ def ONNXEluOp:ONNX_Op<"Elu", "0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise." "" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$alpha); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXEqualOp:ONNX_Op<"Equal", @@ -649,9 +1034,20 @@ def ONNXEqualOp:ONNX_Op<"Equal", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$A, + AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXErfOp:ONNX_Op<"Erf", @@ -660,8 +1056,19 @@ def ONNXErfOp:ONNX_Op<"Erf", let description = [{ "Computes the error function of the given input tensor element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXExpOp:ONNX_Op<"Exp", @@ -670,8 +1077,8 @@ def ONNXExpOp:ONNX_Op<"Exp", let description = [{ "Calculates the exponential of the given input tensor, element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value input", [{ auto elementType = input.getType().cast().getElementType(); @@ -684,6 +1091,17 @@ def ONNXExpOp:ONNX_Op<"Exp", build(builder, state, outputTypes, operands, attributes); }]> ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXExpandOp:ONNX_Op<"Expand", @@ -702,6 +1120,17 @@ def ONNXExpandOp:ONNX_Op<"Expand", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXEyeLikeOp:ONNX_Op<"EyeLike", @@ -716,10 +1145,21 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike", "The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the" "TensorProto message and be valid as an output type." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64,I8,I16,I32,I64,I1]>, MemRefOf<[F16,F32,F64,I8,I16,I32,I64,I1]>]>:$input, OptionalAttr:$dtype, DefaultValuedAttr:$k); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64,I8,I16,I32,I64,I1]>, MemRefOf<[F16,F32,F64,I8,I16,I32,I64,I1]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXFlattenOp:ONNX_Op<"Flatten", @@ -733,6 +1173,17 @@ def ONNXFlattenOp:ONNX_Op<"Flatten", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, DefaultValuedAttr:$axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXFloorOp:ONNX_Op<"Floor", @@ -743,8 +1194,19 @@ def ONNXFloorOp:ONNX_Op<"Floor", "(Tensor) where the floor is, y = floor(x), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXGRUOp:ONNX_Op<"GRU", @@ -825,12 +1287,12 @@ def ONNXGRUOp:ONNX_Op<"GRU", " - Ht = (1 - zt) (.) ht + zt (.) Ht-1" "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$R, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sequence_lens, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_h, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$W, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$R, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$B, + AnyTypeOf<[TensorOf<[I32]>, MemRefOf<[I32]>, NoneType]>:$sequence_lens, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$initial_h, OptionalAttr:$activation_alpha, OptionalAttr:$activation_beta, OptionalAttr:$activations, @@ -838,8 +1300,19 @@ def ONNXGRUOp:ONNX_Op<"GRU", DefaultValuedAttr:$direction, OptionalAttr:$hidden_size, DefaultValuedAttr:$linear_before_reset); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_h); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$Y, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$Y_h); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 6; + } + static int getNumberOfResults() { + return 2; + } + static std::vector getTypeMap() { + return {20,20}; + } + }]; } def ONNXGatherOp:ONNX_Op<"Gather", @@ -905,9 +1378,20 @@ def ONNXGatherOp:ONNX_Op<"Gather", "```" }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$indices, DefaultValuedAttr:$axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXGatherElementsOp:ONNX_Op<"GatherElements", @@ -971,9 +1455,20 @@ def ONNXGatherElementsOp:ONNX_Op<"GatherElements", "```" }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$indices, DefaultValuedAttr:$axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXGatherNDOp:ONNX_Op<"GatherND", @@ -1049,6 +1544,17 @@ def ONNXGatherNDOp:ONNX_Op<"GatherND", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXGemmOp:ONNX_Op<"Gemm", @@ -1070,14 +1576,25 @@ def ONNXGemmOp:ONNX_Op<"Gemm", "This operator supports **unidirectional broadcasting** (tensor C should be unidirectional broadcastable to tensor A * B); for more details please check [the doc](Broadcasting.md)." "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$C, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>]>:$A, + AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>]>:$B, + AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>, NoneType]>:$C, DefaultValuedAttr:$alpha, DefaultValuedAttr:$beta, DefaultValuedAttr:$transA, DefaultValuedAttr:$transB); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXGlobalAveragePoolOp:ONNX_Op<"GlobalAveragePool", @@ -1088,8 +1605,19 @@ def ONNXGlobalAveragePoolOp:ONNX_Op<"GlobalAveragePool", " the values in the same channel. This is equivalent to AveragePool with kernel size" " equal to the spatial dimension of input tensor." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXGlobalLpPoolOp:ONNX_Op<"GlobalLpPool", @@ -1100,9 +1628,20 @@ def ONNXGlobalLpPoolOp:ONNX_Op<"GlobalLpPool", " the values in the same channel. This is equivalent to LpPool with kernel size" " equal to the spatial dimension of input tensor." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$p); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXGlobalMaxPoolOp:ONNX_Op<"GlobalMaxPool", @@ -1113,8 +1652,19 @@ def ONNXGlobalMaxPoolOp:ONNX_Op<"GlobalMaxPool", " the values in the same channel. This is equivalent to MaxPool with kernel size" " equal to the spatial dimension of input tensor." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXGreaterOp:ONNX_Op<"Greater", @@ -1126,9 +1676,20 @@ def ONNXGreaterOp:ONNX_Op<"Greater", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A, + AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", @@ -1139,10 +1700,21 @@ def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", "(Tensor) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta))," "is applied to the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$alpha, DefaultValuedAttr:$beta); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXHardmaxOp:ONNX_Op<"Hardmax", @@ -1164,9 +1736,20 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax", "will throw errors. The output tensor has the same shape" "and contains the hardmax values of the corresponding input." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input, DefaultValuedAttr:$axis); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXIdentityOp:ONNX_Op<"Identity", @@ -1178,6 +1761,17 @@ def ONNXIdentityOp:ONNX_Op<"Identity", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXIfOp:ONNX_Op<"If", @@ -1186,10 +1780,21 @@ def ONNXIfOp:ONNX_Op<"If", let description = [{ "If conditional" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$cond, + let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$cond, AnyAttr:$else_branch, AnyAttr:$then_branch); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$outputs); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return -1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization", @@ -1203,11 +1808,22 @@ def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization", "where mean and variance are computed per instance per channel." "" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$scale, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$B, DefaultValuedAttr:$epsilon); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXIsInfOp:ONNX_Op<"IsInf", @@ -1216,10 +1832,21 @@ def ONNXIsInfOp:ONNX_Op<"IsInf", let description = [{ "Map infinity to true and other values to false." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64]>, MemRefOf<[F32,F64]>]>:$X, DefaultValuedAttr:$detect_negative, DefaultValuedAttr:$detect_positive); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXIsNaNOp:ONNX_Op<"IsNaN", @@ -1228,8 +1855,19 @@ def ONNXIsNaNOp:ONNX_Op<"IsNaN", let description = [{ "Returns which elements of the input are NaN." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXLRNOp:ONNX_Op<"LRN", @@ -1247,12 +1885,23 @@ def ONNXLRNOp:ONNX_Op<"LRN", "" "Y[n, c, d1, ..., dk] = X[n, c, d1, ..., dk] / (bias + alpha / size * square_sum[n, c, d1, ..., dk] ) ^ beta" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$alpha, DefaultValuedAttr:$beta, DefaultValuedAttr:$bias, I64Attr:$size); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXLSTMOp:ONNX_Op<"LSTM", @@ -1341,14 +1990,14 @@ def ONNXLSTMOp:ONNX_Op<"LSTM", " - Ht = ot (.) h(Ct)" "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$R, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sequence_lens, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_h, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_c, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$P, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$W, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$R, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$B, + AnyTypeOf<[TensorOf<[I32]>, MemRefOf<[I32]>, NoneType]>:$sequence_lens, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$initial_h, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$initial_c, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$P, OptionalAttr:$activation_alpha, OptionalAttr:$activation_beta, OptionalAttr:$activations, @@ -1356,9 +2005,20 @@ def ONNXLSTMOp:ONNX_Op<"LSTM", DefaultValuedAttr:$direction, OptionalAttr:$hidden_size, DefaultValuedAttr:$input_forget); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_h, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_c); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$Y, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$Y_h, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$Y_c); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 8; + } + static int getNumberOfResults() { + return 3; + } + static std::vector getTypeMap() { + return {20,20,20}; + } + }]; } def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", @@ -1369,9 +2029,20 @@ def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", "output data (Tensor) where the function `f(x) = alpha * x for x < 0`," "`f(x) = x for x >= 0`, is applied to the data tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$alpha); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXLessOp:ONNX_Op<"Less", @@ -1383,9 +2054,20 @@ def ONNXLessOp:ONNX_Op<"Less", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A, + AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXLogOp:ONNX_Op<"Log", @@ -1394,8 +2076,19 @@ def ONNXLogOp:ONNX_Op<"Log", let description = [{ "Calculates the natural log of the given input tensor, element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXLogSoftmaxOp:ONNX_Op<"LogSoftmax", @@ -1417,9 +2110,20 @@ def ONNXLogSoftmaxOp:ONNX_Op<"LogSoftmax", "will throw errors. The output tensor has the same shape" "and contains the logsoftmax values of the corresponding input." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input, DefaultValuedAttr:$axis); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXLoopOp:ONNX_Op<"Loop", @@ -1540,11 +2244,22 @@ def ONNXLoopOp:ONNX_Op<"Loop", "the scan_outputs from the previous layer, possibly going through several" "point-wise operators (e.g. dropout, residual connections, linear layer)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$M, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$cond, + let arguments = (ins AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>, NoneType]>:$M, + AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>, NoneType]>:$cond, AnyTypeOf<[AnyMemRef, AnyTensor]>:$v_initial, AnyAttr:$body); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$v_final_and_scan_outputs); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return -1; + } + static std::vector getTypeMap() { + return {22}; + } + }]; } def ONNXLpNormalizationOp:ONNX_Op<"LpNormalization", @@ -1553,10 +2268,21 @@ def ONNXLpNormalizationOp:ONNX_Op<"LpNormalization", let description = [{ "Given a matrix, apply Lp-normalization along the provided axis." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input, DefaultValuedAttr:$axis, DefaultValuedAttr:$p); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXLpPoolOp:ONNX_Op<"LpPool", @@ -1569,13 +2295,24 @@ def ONNXLpPoolOp:ONNX_Op<"LpPool", " of the input tensor according to the kernel size and downsampling the" " data into the output tensor Y for further processing." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$auto_pad, I64ArrayAttr:$kernel_shape, DefaultValuedAttr:$p, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXMatMulOp:ONNX_Op<"MatMul", @@ -1584,9 +2321,20 @@ def ONNXMatMulOp:ONNX_Op<"MatMul", let description = [{ "Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>]>:$A, + AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger", @@ -1596,11 +2344,22 @@ def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger", "Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html." "The production MUST never overflow. The accumulation may overflow if and only if in 32 bits." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$a_zero_point, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$b_zero_point); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$A, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$B, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>, NoneType]>:$a_zero_point, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>, NoneType]>:$b_zero_point); + let results = (outs AnyTypeOf<[TensorOf<[I32]>, MemRefOf<[I32]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 4; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {3}; + } + }]; } def ONNXMaxOp:ONNX_Op<"Max", @@ -1611,8 +2370,19 @@ def ONNXMaxOp:ONNX_Op<"Max", "All inputs and outputs must have the same data type." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins Variadic>:$data_0); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$max); + let arguments = (ins Variadic, MemRefOf<[F16,F32,F64]>]>>:$data_0); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$max); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXMaxPoolOp:ONNX_Op<"MaxPool", @@ -1649,7 +2419,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool", " The output of each pooling window is maximum number of elements exclude pad." " " }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$auto_pad, DefaultValuedAttr:$ceil_mode, OptionalAttr:$dilations, @@ -1657,8 +2427,19 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool", OptionalAttr:$pads, DefaultValuedAttr:$storage_order, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Indices); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y, + AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>, NoneType]>:$Indices); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 2; + } + static std::vector getTypeMap() { + return {20,4}; + } + }]; } def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool", @@ -1669,11 +2450,22 @@ def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool", " apply max pooling across each RoI, to produce output 4-D tensor of shape" " (num_rois, channels, pooled_shape[0], pooled_shape[1])." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$rois, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$rois, I64ArrayAttr:$pooled_shape, DefaultValuedAttr:$spatial_scale); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool", @@ -1699,13 +2491,24 @@ def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool", " which define the exact unpooling op. The attributes typically have the same values as the corrsponding" " pooling op that the unpooling op is trying to invert." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$I, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$output_shape, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$I, + AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>, NoneType]>:$output_shape, I64ArrayAttr:$kernel_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXMeanOp:ONNX_Op<"Mean", @@ -1716,8 +2519,19 @@ def ONNXMeanOp:ONNX_Op<"Mean", "All inputs and outputs must have the same data type." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins Variadic>:$data_0); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$mean); + let arguments = (ins Variadic, MemRefOf<[F16,F32,F64]>]>>:$data_0); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$mean); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization", @@ -1727,9 +2541,20 @@ def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization", "A MeanVarianceNormalization Function: Perform mean variance normalization" " on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ```" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$axes); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXMinOp:ONNX_Op<"Min", @@ -1740,8 +2565,19 @@ def ONNXMinOp:ONNX_Op<"Min", "All inputs and outputs must have the same data type." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins Variadic>:$data_0); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$min); + let arguments = (ins Variadic, MemRefOf<[F16,F32,F64]>]>>:$data_0); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$min); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXModOp:ONNX_Op<"Mod", @@ -1762,10 +2598,21 @@ def ONNXModOp:ONNX_Op<"Mod", "" " This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A, + AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B, DefaultValuedAttr:$fmod); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let results = (outs AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXMulOp:ONNX_Op<"Mul", @@ -1776,9 +2623,9 @@ def ONNXMulOp:ONNX_Op<"Mul", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, + AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ auto elementType = A.getType().cast().getElementType(); @@ -1791,6 +2638,17 @@ def ONNXMulOp:ONNX_Op<"Mul", build(builder, state, outputTypes, operands, attributes); }]> ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXMultinomialOp:ONNX_Op<"Multinomial", @@ -1800,11 +2658,22 @@ def ONNXMultinomialOp:ONNX_Op<"Multinomial", "Generate a tensor of samples from a multinomial distribution according to the probabilities" "of each of the possible outcomes." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input, DefaultValuedAttr:$dtype, DefaultValuedAttr:$sample_size, OptionalAttr:$seed); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXNegOp:ONNX_Op<"Neg", @@ -1815,8 +2684,19 @@ def ONNXNegOp:ONNX_Op<"Neg", "(Tensor) where each element flipped sign, y = -x, is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F32,I32,I8,I16,I64,F16,F64]>, MemRefOf<[F32,I32,I8,I16,I64,F16,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F32,I32,I8,I16,I64,F16,F64]>, MemRefOf<[F32,I32,I8,I16,I64,F16,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", @@ -1838,6 +2718,17 @@ def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$score_threshold, DefaultValuedAttr:$center_point_box); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$selected_indices); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 5; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {22}; + } + }]; } def ONNXNonZeroOp:ONNX_Op<"NonZero", @@ -1851,6 +2742,17 @@ def ONNXNonZeroOp:ONNX_Op<"NonZero", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXNotOp:ONNX_Op<"Not", @@ -1859,8 +2761,19 @@ def ONNXNotOp:ONNX_Op<"Not", let description = [{ "Returns the negation of the input tensor element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXOneHotOp:ONNX_Op<"OneHot", @@ -1887,11 +2800,22 @@ def ONNXOneHotOp:ONNX_Op<"OneHot", " output[i, j, k, input[i, j, k]] = 1 for all i, j, k and 0 otherwise." "" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$depth, + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$indices, + AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$depth, AnyTypeOf<[AnyMemRef, AnyTensor]>:$values, DefaultValuedAttr:$axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {22}; + } + }]; } def ONNXOrOp:ONNX_Op<"Or", @@ -1903,9 +2827,20 @@ def ONNXOrOp:ONNX_Op<"Or", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A, + AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXPReluOp:ONNX_Op<"PRelu", @@ -1917,9 +2852,20 @@ def ONNXPReluOp:ONNX_Op<"PRelu", "`f(x) = x for x >= 0`., is applied to the data tensor elementwise." "This operator supports **unidirectional broadcasting** (tensor slope should be unidirectional broadcastable to input tensor X); for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$slope); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>]>:$slope); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64,I32,I64]>, MemRefOf<[F16,F32,F64,I32,I64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXPadOp:ONNX_Op<"Pad", @@ -2008,11 +2954,11 @@ def ONNXPadOp:ONNX_Op<"Pad", " ]" "" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$pads, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$constant_value, + AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>, NoneType]>:$constant_value, DefaultValuedAttr:$mode); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$output); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value data, Value pads, Value constant_value, StringAttr mode", [{ auto elementType = data.getType().cast().getElementType(); @@ -2026,6 +2972,15 @@ def ONNXPadOp:ONNX_Op<"Pad", }]> ]; let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } std::map promotableConstOperands() { return {{"pads", 1}, {"constant_value", 2}}; } @@ -2041,9 +2996,20 @@ def ONNXPowOp:ONNX_Op<"Pow", "is applied to the data tensor elementwise." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Z); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", @@ -2056,22 +3022,33 @@ def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", "It means they must be either scalars (per tensor) or 1-D tensors (per output channel)." "Each input or output and its related zero point must have same type." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, + let arguments = (ins AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_scale, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$w, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$x_zero_point, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$w, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_scale, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_zero_point, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$w_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$y_zero_point, + AnyTypeOf<[TensorOf<[I32]>, MemRefOf<[I32]>, NoneType]>:$B, DefaultValuedAttr:$auto_pad, OptionalAttr:$dilations, DefaultValuedAttr:$group, OptionalAttr:$kernel_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y); + let results = (outs AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 9; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {1}; + } + }]; } def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul", @@ -2087,15 +3064,26 @@ def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul", "and the number of elements of scale and zero point tensor of input 'b' should be equal to the number of columns of input 'b'." "Production must never overflow, and accumulation may overflow if and only if in 32 bits." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$a, + let arguments = (ins AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$a, AnyTypeOf<[AnyMemRef, AnyTensor]>:$a_scale, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$a_zero_point, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$b, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$a_zero_point, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$b, AnyTypeOf<[AnyMemRef, AnyTensor]>:$b_scale, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$b_zero_point, + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$b_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y); + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$y_zero_point); + let results = (outs AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 8; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {1}; + } + }]; } def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", @@ -2106,10 +3094,21 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", "The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8." "For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. 'y_zero_point' and 'y' must have same type." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, + let arguments = (ins AnyTypeOf<[TensorOf<[F32,I32]>, MemRefOf<[F32,I32]>]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$y_zero_point); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$y); + AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>, NoneType]>:$y_zero_point); + let results = (outs AnyTypeOf<[TensorOf<[I8]>, MemRefOf<[I8]>]>:$y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {1}; + } + }]; } def ONNXRNNOp:ONNX_Op<"RNN", @@ -2178,20 +3177,31 @@ def ONNXRNNOp:ONNX_Op<"RNN", " - Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)" "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$R, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$B, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sequence_lens, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$initial_h, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$W, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$R, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$B, + AnyTypeOf<[TensorOf<[I32]>, MemRefOf<[I32]>, NoneType]>:$sequence_lens, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$initial_h, OptionalAttr:$activation_alpha, OptionalAttr:$activation_beta, DefaultValuedAttr:$activations, OptionalAttr:$clip, DefaultValuedAttr:$direction, OptionalAttr:$hidden_size); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$Y_h); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$Y, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>, NoneType]>:$Y_h); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 6; + } + static int getNumberOfResults() { + return 2; + } + static std::vector getTypeMap() { + return {20,20}; + } + }]; } def ONNXRandomNormalOp:ONNX_Op<"RandomNormal", @@ -2211,7 +3221,18 @@ def ONNXRandomNormalOp:ONNX_Op<"RandomNormal", DefaultValuedAttr:$scale, OptionalAttr:$seed, I64ArrayAttr:$shape); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 0; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike", @@ -2231,7 +3252,18 @@ def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike", DefaultValuedAttr:$mean, DefaultValuedAttr:$scale, OptionalAttr:$seed); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXRandomUniformOp:ONNX_Op<"RandomUniform", @@ -2250,7 +3282,18 @@ def ONNXRandomUniformOp:ONNX_Op<"RandomUniform", DefaultValuedAttr:$low, OptionalAttr:$seed, I64ArrayAttr:$shape); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 0; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike", @@ -2270,7 +3313,18 @@ def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike", DefaultValuedAttr:$high, DefaultValuedAttr:$low, OptionalAttr:$seed); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXRangeOp:ONNX_Op<"Range", @@ -2303,10 +3357,21 @@ def ONNXRangeOp:ONNX_Op<"Range", "Output: [10, 8, 6]" "" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$start, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$limit, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$delta); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I16,I32,I64]>, MemRefOf<[F32,F64,I16,I32,I64]>]>:$start, + AnyTypeOf<[TensorOf<[F32,F64,I16,I32,I64]>, MemRefOf<[F32,F64,I16,I32,I64]>]>:$limit, + AnyTypeOf<[TensorOf<[F32,F64,I16,I32,I64]>, MemRefOf<[F32,F64,I16,I32,I64]>]>:$delta); + let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I16,I32,I64]>, MemRefOf<[F32,F64,I16,I32,I64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReciprocalOp:ONNX_Op<"Reciprocal", @@ -2317,8 +3382,19 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal", "(Tensor) where the reciprocal is, y = 1/x, is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceL1Op:ONNX_Op<"ReduceL1", @@ -2332,10 +3408,21 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceL2Op:ONNX_Op<"ReduceL2", @@ -2349,10 +3436,21 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", @@ -2366,10 +3464,21 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", @@ -2383,10 +3492,21 @@ def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", @@ -2400,10 +3520,21 @@ def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", @@ -2417,10 +3548,21 @@ def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceMinOp:ONNX_Op<"ReduceMin", @@ -2434,10 +3576,21 @@ def ONNXReduceMinOp:ONNX_Op<"ReduceMin", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceProdOp:ONNX_Op<"ReduceProd", @@ -2451,10 +3604,21 @@ def ONNXReduceProdOp:ONNX_Op<"ReduceProd", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceSumOp:ONNX_Op<"ReduceSum", @@ -2468,10 +3632,10 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{ auto elementType = data.getType().cast().getElementType(); @@ -2484,6 +3648,17 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum", build(builder, state, outputTypes, operands, attributes); }]> ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", @@ -2497,10 +3672,10 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$data, OptionalAttr:$axes, DefaultValuedAttr:$keepdims); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reduced); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$reduced); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value data, ArrayAttr axes, IntegerAttr keepdims", [{ auto elementType = data.getType().cast().getElementType(); @@ -2513,6 +3688,17 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", build(builder, state, outputTypes, operands, attributes); }]> ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReluOp:ONNX_Op<"Relu", @@ -2523,8 +3709,19 @@ def ONNXReluOp:ONNX_Op<"Relu", "(Tensor) where the rectified linear function, y = max(0, x), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReshapeOp:ONNX_Op<"Reshape", @@ -2542,6 +3739,15 @@ def ONNXReshapeOp:ONNX_Op<"Reshape", AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped); let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } std::map promotableConstOperands() { return {{"shape", 1}}; } @@ -2557,7 +3763,7 @@ def ONNXResizeOp:ONNX_Op<"Resize", " output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$roi, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$roi, AnyTypeOf<[AnyMemRef, AnyTensor]>:$scales, AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$sizes, DefaultValuedAttr:$coordinate_transformation_mode, @@ -2567,6 +3773,17 @@ def ONNXResizeOp:ONNX_Op<"Resize", DefaultValuedAttr:$mode, DefaultValuedAttr:$nearest_mode); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 4; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXReverseSequenceOp:ONNX_Op<"ReverseSequence", @@ -2612,6 +3829,17 @@ def ONNXReverseSequenceOp:ONNX_Op<"ReverseSequence", DefaultValuedAttr:$batch_axis, DefaultValuedAttr:$time_axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", @@ -2630,15 +3858,26 @@ def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", "the value of the sampled locations are computed directly" "through bilinear interpolation." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$rois, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$batch_indices, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$rois, + AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$batch_indices, DefaultValuedAttr:$mode, DefaultValuedAttr:$output_height, DefaultValuedAttr:$output_width, DefaultValuedAttr:$sampling_ratio, DefaultValuedAttr:$spatial_scale); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXRoundOp:ONNX_Op<"Round", @@ -2659,8 +3898,19 @@ def ONNXRoundOp:ONNX_Op<"Round", "round([-4.5]) = [-4.0]" "```" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXScanOp:ONNX_Op<"Scan", @@ -2797,6 +4047,17 @@ def ONNXScanOp:ONNX_Op<"Scan", OptionalAttr:$scan_output_axes, OptionalAttr:$scan_output_directions); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$final_state_and_scan_outputs); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return -1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXScatterOp:ONNX_Op<"Scatter", @@ -2858,10 +4119,21 @@ def ONNXScatterOp:ONNX_Op<"Scatter", "```" }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$indices, AnyTypeOf<[AnyMemRef, AnyTensor]>:$updates, DefaultValuedAttr:$axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXScatterElementsOp:ONNX_Op<"ScatterElements", @@ -2921,10 +4193,21 @@ def ONNXScatterElementsOp:ONNX_Op<"ScatterElements", "```" }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$indices, AnyTypeOf<[AnyMemRef, AnyTensor]>:$updates, DefaultValuedAttr:$axis); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXScatterNDOp:ONNX_Op<"ScatterND", @@ -2992,6 +4275,17 @@ def ONNXScatterNDOp:ONNX_Op<"ScatterND", AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, AnyTypeOf<[AnyMemRef, AnyTensor]>:$updates); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSeluOp:ONNX_Op<"Selu", @@ -3003,10 +4297,21 @@ def ONNXSeluOp:ONNX_Op<"Selu", "`y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`," "is applied to the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$alpha, DefaultValuedAttr:$gamma); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", @@ -3018,8 +4323,19 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", "Negative value means counting positions from the back." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$position); + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$position); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", @@ -3031,6 +4347,17 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", }]; let arguments = (ins Variadic>:$inputs); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", @@ -3041,6 +4368,17 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", }]; let arguments = (ins OptionalAttr:$dtype); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 0; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", @@ -3053,8 +4391,19 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", "'position' is optional, by default it erases the last tensor from 'input_sequence'." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$position); + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", @@ -3069,8 +4418,19 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$position); + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$position); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", @@ -3080,7 +4440,18 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", "Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$length); + let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$length); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {4}; + } + }]; } def ONNXShapeOp:ONNX_Op<"Shape", @@ -3090,7 +4461,18 @@ def ONNXShapeOp:ONNX_Op<"Shape", "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); + let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$shape); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {4}; + } + }]; } def ONNXShrinkOp:ONNX_Op<"Shrink", @@ -3102,10 +4484,21 @@ def ONNXShrinkOp:ONNX_Op<"Shrink", "bias. The formula of this operator is: If x < -lambd, y = x + bias;" "If x > lambd, y = x - bias; Otherwise, y = 0." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$input, DefaultValuedAttr:$bias, DefaultValuedAttr:$lambd); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSigmoidOp:ONNX_Op<"Sigmoid", @@ -3116,8 +4509,19 @@ def ONNXSigmoidOp:ONNX_Op<"Sigmoid", "(Tensor) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the" "tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSignOp:ONNX_Op<"Sign", @@ -3127,8 +4531,19 @@ def ONNXSignOp:ONNX_Op<"Sign", "Calculate the sign of the given input tensor element-wise." "If input > 0, output 1. if input < 0, output -1. if input == 0, output 0." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSinOp:ONNX_Op<"Sin", @@ -3137,8 +4552,19 @@ def ONNXSinOp:ONNX_Op<"Sin", let description = [{ "Calculates the sine of the given input tensor, element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSinhOp:ONNX_Op<"Sinh", @@ -3147,8 +4573,19 @@ def ONNXSinhOp:ONNX_Op<"Sinh", let description = [{ "Calculates the hyperbolic sine of the given input tensor element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSizeOp:ONNX_Op<"Size", @@ -3158,7 +4595,18 @@ def ONNXSizeOp:ONNX_Op<"Size", "Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$size); + let results = (outs AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$size); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {4}; + } + }]; } def ONNXSliceOp:ONNX_Op<"Slice", @@ -3201,11 +4649,22 @@ def ONNXSliceOp:ONNX_Op<"Slice", " ]" }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$starts, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$ends, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$axes, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$steps); + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$starts, + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>]>:$ends, + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$axes, + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$steps); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 5; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSoftmaxOp:ONNX_Op<"Softmax", @@ -3227,9 +4686,20 @@ def ONNXSoftmaxOp:ONNX_Op<"Softmax", "will throw errors. The output tensor has the same shape" "and contains the softmax values of the corresponding input." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input, DefaultValuedAttr:$axis); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSoftplusOp:ONNX_Op<"Softplus", @@ -3240,8 +4710,19 @@ def ONNXSoftplusOp:ONNX_Op<"Softplus", "(Tensor) where the softplus function, y = ln(exp(x) + 1), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSoftsignOp:ONNX_Op<"Softsign", @@ -3250,8 +4731,19 @@ def ONNXSoftsignOp:ONNX_Op<"Softsign", let description = [{ "Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", @@ -3265,6 +4757,17 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, I64Attr:$blocksize); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSplitOp:ONNX_Op<"Split", @@ -3279,6 +4782,17 @@ def ONNXSplitOp:ONNX_Op<"Split", DefaultValuedAttr:$axis, OptionalAttr:$split); let results = (outs Variadic>:$outputs); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return -1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", @@ -3297,10 +4811,21 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", "dimension size of input tensor on 'axis'." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$split, + AnyTypeOf<[TensorOf<[I32,I64]>, MemRefOf<[I32,I64]>, NoneType]>:$split, DefaultValuedAttr:$axis, DefaultValuedAttr:$keepdims); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_sequence); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {-1}; + } + }]; } def ONNXSqrtOp:ONNX_Op<"Sqrt", @@ -3311,8 +4836,19 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt", "(Tensor) where the square root is, y = x^0.5, is applied to" "the tensor elementwise. If x is negative, then it will return NaN." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSqueezeOp:ONNX_Op<"Squeeze", @@ -3327,6 +4863,17 @@ def ONNXSqueezeOp:ONNX_Op<"Squeeze", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, OptionalAttr:$axes); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$squeezed); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer", @@ -3349,6 +4896,17 @@ def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer", OptionalAttr:$locale, OptionalAttr:$stopwords); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSubOp:ONNX_Op<"Sub", @@ -3359,9 +4917,20 @@ def ONNXSubOp:ONNX_Op<"Sub", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, + AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSumOp:ONNX_Op<"Sum", @@ -3372,8 +4941,19 @@ def ONNXSumOp:ONNX_Op<"Sum", "All inputs and outputs must have the same data type." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins Variadic>:$data_0); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$sum); + let arguments = (ins Variadic, MemRefOf<[F16,F32,F64]>]>>:$data_0); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$sum); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXTanOp:ONNX_Op<"Tan", @@ -3382,8 +4962,19 @@ def ONNXTanOp:ONNX_Op<"Tan", let description = [{ "Calculates the tangent of the given input tensor, element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXTanhOp:ONNX_Op<"Tanh", @@ -3392,8 +4983,19 @@ def ONNXTanhOp:ONNX_Op<"Tanh", let description = [{ "Calculates the hyperbolic tangent of the given input tensor element-wise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer", @@ -3438,7 +5040,18 @@ def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer", OptionalAttr:$pool_int64s, OptionalAttr:$pool_strings, OptionalAttr:$weights); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F32]>, MemRefOf<[F32]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {7}; + } + }]; } def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", @@ -3449,9 +5062,20 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", "(Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise," "is applied to the tensor elementwise." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, DefaultValuedAttr:$alpha); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXTileOp:ONNX_Op<"Tile", @@ -3463,8 +5087,19 @@ def ONNXTileOp:ONNX_Op<"Tile", "For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]]" }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$repeats); + AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$repeats); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXTopKOp:ONNX_Op<"TopK", @@ -3486,13 +5121,24 @@ def ONNXTopKOp:ONNX_Op<"TopK", "Given two equivalent values, this operator uses the indices along the axis as" " a tiebreaker. That is, the element with the lower index will appear first." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$K, DefaultValuedAttr:$axis, DefaultValuedAttr:$largest, DefaultValuedAttr:$sorted); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Values, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$Indices); + let results = (outs AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$Values, + AnyTypeOf<[TensorOf<[I64]>, MemRefOf<[I64]>]>:$Indices); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 2; + } + static std::vector getTypeMap() { + return {20,4}; + } + }]; } def ONNXTransposeOp:ONNX_Op<"Transpose", @@ -3506,6 +5152,17 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, OptionalAttr:$perm); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$transposed); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXUniqueOp:ONNX_Op<"Unique", @@ -3595,6 +5252,17 @@ def ONNXUniqueOp:ONNX_Op<"Unique", AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$indices, AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$inverse_indices, AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$counts); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 4; + } + static std::vector getTypeMap() { + return {20,-1,-1,-1}; + } + }]; } def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", @@ -3617,6 +5285,17 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, I64ArrayAttr:$axes); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$expanded); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 1; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXUpsampleOp:ONNX_Op<"Upsample", @@ -3631,6 +5310,17 @@ def ONNXUpsampleOp:ONNX_Op<"Upsample", AnyTypeOf<[AnyMemRef, AnyTensor]>:$scales, DefaultValuedAttr:$mode); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXWhereOp:ONNX_Op<"Where", @@ -3642,10 +5332,21 @@ def ONNXWhereOp:ONNX_Op<"Where", " Where behaves like numpy.where with three parameters:" " https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$condition, + let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$condition, AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {21}; + } + }]; } def ONNXXorOp:ONNX_Op<"Xor", @@ -3657,8 +5358,19 @@ def ONNXXorOp:ONNX_Op<"Xor", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A, + AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B); + let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 10ff7cd..340f9a5 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -144,62 +144,62 @@ func @test_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<* // ----- -func @test_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { - %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> - "std.return"(%0) : (tensor<*xi32>) -> () +func @test_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { + %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> + "std.return"(%0) : (tensor<*xi1>) -> () // CHECK-LABEL: test_and - // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32> - // CHECK: return [[RES]] : memref<10x10xi32> + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1 + // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1> + // CHECK: return [[RES]] : memref<10x10xi1> } // ----- -func @test_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { - %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> - "std.return"(%0) : (tensor<*xi32>) -> () +func @test_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { + %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> + "std.return"(%0) : (tensor<*xi1>) -> () // CHECK-LABEL: test_or - // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> - // CHECK: return [[RES]] : memref<10x10xi32> + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1 + // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1> + // CHECK: return [[RES]] : memref<10x10xi1> } // ----- -func @test_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { - %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> - "std.return"(%0) : (tensor<*xi32>) -> () +func @test_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { + %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> + "std.return"(%0) : (tensor<*xi1>) -> () // CHECK-LABEL: test_xor - // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> - // CHECK: return [[RES]] : memref<10x10xi32> + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1 + // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1> + // CHECK: return [[RES]] : memref<10x10xi1> } // ----- diff --git a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir index 831cac9..4c86da8 100644 --- a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir +++ b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir @@ -158,24 +158,24 @@ func @test_sub_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens // ----- -func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { - %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> - %1 = "onnx.And"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> - "std.return"(%1) : (tensor<*xi32>) -> () +func @test_and_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { + %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> + %1 = "onnx.And"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1> + "std.return"(%1) : (tensor<*xi1>) -> () // CHECK-LABEL: test_and_and /// First And - // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> - // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1 + // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1> /// Second And // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 @@ -183,38 +183,38 @@ func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1 + // CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref<10x10xi32> - // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> + // CHECK: dealloc [[RES]] : memref<10x10xi1> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1> - // CHECK: return [[RET_RES]] : memref<10x10xi32> + // CHECK: return [[RET_RES]] : memref<10x10xi1> } // ----- -func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { - %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> - %1 = "onnx.Or"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> - "std.return"(%1) : (tensor<*xi32>) -> () +func @test_or_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { + %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> + %1 = "onnx.Or"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1> + "std.return"(%1) : (tensor<*xi1>) -> () // CHECK-LABEL: test_or_or /// First Or - // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> - // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1 + // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1> /// Second Or // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 @@ -222,38 +222,38 @@ func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1 + // CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref<10x10xi32> - // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> + // CHECK: dealloc [[RES]] : memref<10x10xi1> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1> - // CHECK: return [[RET_RES]] : memref<10x10xi32> + // CHECK: return [[RET_RES]] : memref<10x10xi1> } // ----- -func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { - %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> - %1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> - "std.return"(%1) : (tensor<*xi32>) -> () +func @test_xor_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { + %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> + %1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1> + "std.return"(%1) : (tensor<*xi1>) -> () // CHECK-LABEL: test_xor_xor /// First Xor - // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> - // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> + // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1> + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1 + // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1> /// Second Xor // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 @@ -261,16 +261,16 @@ func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> - // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 - // CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> + // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> + // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1 + // CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1> /// Dealloc of first result. - // CHECK: dealloc [[RES]] : memref<10x10xi32> - // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> + // CHECK: dealloc [[RES]] : memref<10x10xi1> + // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1> - // CHECK: return [[RET_RES]] : memref<10x10xi32> + // CHECK: return [[RET_RES]] : memref<10x10xi1> } // ----- diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 96690fb..76b1cbf 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -298,6 +298,16 @@ custom_definition_misc = dict([ ('Constant', )]) +onnx_types = ( + 'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16', + 'float', 'double', 'complex64', 'complex128' +) +tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64', + 'Complex', 'Complex' +) + +MAX_NUM_TYPES=20 + SNIPPETS = collect_snippets() SAMPLE_IMPLEMENTATIONS = collect_sample_implementations() ONNX_ML = bool(args.domain == "ONNX_ML") @@ -376,53 +386,55 @@ def tblgen_operand_type_to_cpp_type(op_type): def np_type_to_tblgen_attr_type(tstr): - tfrom = np.array([ - 'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16', - 'float', 'double' - ]) - tto = np.array( - ['I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64']) index = -1 - for i in range(len(tfrom)): - if tfrom[i] in tstr: + for i in range(len(onnx_types)): + if onnx_types[i] in tstr: index = i break if index == -1: - print("error", tstr) - return '' + return None else: - return tto[i] + return tblgen_types[i] +def get_tblgen_type_index(type_str): + return tblgen_types.index(type_str) + +#the possible data structures are tensor, map and seq(tensor()) +#TOFIX: currently, only tensor structure is supported +def get_data_structure_element(allowed_type_str): + if allowed_type_str.startswith('tensor') : + element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1) + return ('tensor', element) + else : + return (None, None) def get_allowed_elem_types(schema, input): - allowed_types_str = None - return allowed_types_str + #allowed_types_str = None + # return allowed_types_str # TODO: enable type constraints. - # if input.typeStr : - # tstr = input.typeStr - # else : - # return allwedTypeStr - # if schema.type_constraints: - # for type_constraint in schema.type_constraints: - # if type_constraint.type_param_str != tstr : - # continue - # allowedTypes = type_constraint.allowed_type_strs - # allowedTypeStr='' - # if (len(allowedTypes) > 0): - # t = convert_type(allowedTypes[0]) - # if t == '' : - # return '' - # allowedTypeStr += t - # for allowedType in allowedTypes[1:]: - # t = convert_type(allowedType) - # if t == '' : - # return '' - # if not t in allowedTypeStr : - # allowedTypeStr += ', '+t - # - # return allowedTypeStr - # - # return allowedTypeStr + if input.typeStr : + tstr = input.typeStr + else : + return None + if schema.type_constraints: + for type_constraint in schema.type_constraints: + if type_constraint.type_param_str != tstr : + continue + allowed_type_list=[] + allowedTypes = type_constraint.allowed_type_strs + for allowedType in allowedTypes: + structure, element = get_data_structure_element(allowedType); + if structure == None or element == None: + return None + t = np_type_to_tblgen_attr_type(element) + if t == None : + return None + if not t in allowed_type_list : + allowed_tyoe_list = allowed_type_list.append(t) + + return allowed_type_list + + return None def inc_indent(indent=None): @@ -436,7 +448,6 @@ def dec_indent(indent): def join_args(args): return ", ".join(args) - def get_operands_or_results(schema, is_input): value_list = schema.inputs if is_input else schema.outputs if not value_list: @@ -456,8 +467,9 @@ def get_operands_or_results(schema, is_input): if elem_types is None: types = ["AnyMemRef", "AnyTensor"] else: + elem_types_str = ','.join(elem_types) types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] - types = list(map(lambda x: x.format(elem_types), types)) + types = list(map(lambda x: x.format(elem_types_str), types)) # If operand is promotable to an attribute, then it must be # nullable in case it migrates to be an attribute. @@ -545,6 +557,64 @@ def get_attrs(schema): name_to_type[attr.name] = get_attr_type_optional(attr.type) return name_to_type +def get_numberof_list(mylist): + expected_num = len(mylist) + for element in mylist : + if OpSchema.FormalParameterOption.Variadic == element.option: + expected_num = -1 + return expected_num + +def get_output_type_mapping(schema): + mapping=[] + for output in schema.outputs : + #if only one type is allowed, just set that + allowed_elem_types = get_allowed_elem_types(schema, output) + if allowed_elem_types != None and len(allowed_elem_types) == 1 : + mapping.append(str(get_tblgen_type_index(allowed_elem_types[0]))) + continue + + #map the type string + if output.typeStr : + tstr = output.typeStr + found = False + for i, input in enumerate(schema.inputs): + if input.typeStr and input.typeStr == tstr: + mapping.append(str(i+MAX_NUM_TYPES)) + found = True + break + if found: + continue + + #unknown output type + mapping.append(str(-1)) + + return mapping + +def get_numberof_inout(s, indent, schema): + expected_num_operands = get_numberof_list(schema.inputs) + indent = inc_indent(indent) + s += indent + "static int getNumberOfOperands() {\n" + indent = inc_indent(indent) + s += indent + "return {};\n".format(expected_num_operands) + indent = dec_indent(indent) + s += indent + "}\n" + + expected_num_results = get_numberof_list(schema.outputs) + s += indent + "static int getNumberOfResults() {\n" + indent = inc_indent(indent) + s += indent + "return {};\n".format(expected_num_results) + indent = dec_indent(indent) + s += indent + "}\n" + + s += indent + "static std::vector getTypeMap() {\n" + mapping = get_output_type_mapping(schema) + indent = inc_indent(indent) + s += indent + "return {" + ",".join(mapping) + "};\n" + indent = dec_indent(indent) + s += indent + "}\n" + + return s + def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): cpp_name_to_idx_literal = "{" + ", ".join([ @@ -552,15 +622,15 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): for name_to_idx in const_operands_name_to_idx ]) + "}" - s += indent + "let extraClassDeclaration = [{\n" + #s += indent + "let extraClassDeclaration = [{\n" indent = inc_indent(indent) s += indent + "std::map promotableConstOperands() {\n" indent = inc_indent(indent) s += indent + "return {};\n".format(cpp_name_to_idx_literal) indent = dec_indent(indent) s += indent + "}\n" - indent = dec_indent(indent) - s += indent + "}];\n" + #indent = dec_indent(indent) + #s += indent + "}];\n" return s @@ -657,10 +727,20 @@ def gen_op_def(schema): s += '\n' + indent + '];\n' + # generate extracClassDeclaration + s += indent + "let extraClassDeclaration = [{\n" + #indent = inc_indent(indent) + + # generate input/output number + s = get_numberof_inout(s, indent, schema) + + # generate ProtableConst if schema.name in OpsWithPromotableConstOperands: s = get_promotable_const_operands_func( s, indent, OpsWithPromotableConstOperands[schema.name]) + s += indent + '}];\n' + if ( schema.name in custom_definition_misc) : s += custom_definition_misc[schema.name] @@ -700,11 +780,13 @@ def gen_op_importer(schema, file): # Special handlers currently require expected num operands/results to be specified. # TODO: remove special handlers. args = ["node"] + """ if expected_num_operands != -1 or expected_num_results != -1 or "buildOperation" not in handler_func: args.append( "/* expected_num_operands = */ {}".format(expected_num_operands)) args.append( '/* expected_num_results = */ {}'.format(expected_num_results)) + """ s += inc_indent(indent) + " {}({});\n".format( handler_func, ", ".join(args))