Express some basic features of an Operation in TableGen file (#103)
* change operation definition * change importer * default type inference * file format * generate types for input/output * generate the mapping for operation output type * remove debug message for gen_doc.py * update the dialect doc * add support Complex * format * update document Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									df18efcb48
								
							
						
					
					
						commit
						6099efd91b
					
				|  | @ -35,13 +35,13 @@ ONNX Binarizer operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `Y` | memref of any type values or tensor of any type values | ||||
| `Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| ### `mlonnx.CastMap` (MLONNXCastMapOp) | ||||
| 
 | ||||
|  | @ -160,7 +160,7 @@ ONNX FeatureVectorizer operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -194,13 +194,13 @@ ONNX Imputer operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `Y` | memref of any type values or tensor of any type values | ||||
| `Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| ### `mlonnx.LabelEncoder` (MLONNXLabelEncoderOp) | ||||
| 
 | ||||
|  | @ -271,7 +271,7 @@ ONNX LinearClassifier operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -304,7 +304,7 @@ ONNX LinearRegressor operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -337,7 +337,7 @@ ONNX Normalizer operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -404,7 +404,7 @@ ONNX SVMClassifier operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -436,7 +436,7 @@ ONNX SVMRegressor operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -461,7 +461,7 @@ ONNX Scaler operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -509,7 +509,7 @@ ONNX TreeEnsembleClassifier operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -559,7 +559,7 @@ ONNX TreeEnsembleRegressor operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `X` | memref of any type values or tensor of any type values | ||||
| `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  |  | |||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -197,6 +197,47 @@ private: | |||
|     } | ||||
|   } | ||||
| 
 | ||||
| #define MAX_TYPE 20 | ||||
|   // itblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32',
 | ||||
|   // 'F64', 'Complex<F32>', 'Complex<F64>' )
 | ||||
|   mlir::Type buildTypeFromIndex(int index) { | ||||
|     switch (index) { | ||||
|     case 0: | ||||
|       return builder_.getI1Type(); | ||||
|     case 1: | ||||
|       return builder_.getIntegerType(8); | ||||
|     case 2: | ||||
|       return builder_.getIntegerType(16); | ||||
|     case 3: | ||||
|       return builder_.getIntegerType(32); | ||||
|     case 4: | ||||
|       return builder_.getIntegerType(64); | ||||
|     case 5: | ||||
|       return builder_.getBF16Type(); | ||||
|     case 6: | ||||
|       return builder_.getF16Type(); | ||||
|     case 7: | ||||
|       return builder_.getF32Type(); | ||||
|     case 8: | ||||
|       return builder_.getF64Type(); | ||||
|     case 9: { | ||||
|       std::vector<mlir::Type> typeTuple(2); | ||||
|       typeTuple.push_back(builder_.getF32Type()); | ||||
|       typeTuple.push_back(builder_.getF32Type()); | ||||
|       return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple)); | ||||
|     } | ||||
|     case 10: { | ||||
|       std::vector<mlir::Type> typeTuple(2); | ||||
|       typeTuple.push_back(builder_.getF64Type()); | ||||
|       typeTuple.push_back(builder_.getF64Type()); | ||||
|       return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple)); | ||||
|     } | ||||
|     default: | ||||
|       assert(false && "Unsupported type index encountered."); | ||||
|       return nullptr; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   template <typename T> | ||||
|   void buildOutputAndOperation(const onnx::NodeProto &node, | ||||
|       std::vector<mlir::Value> inputs, int expectedNumOperands, | ||||
|  | @ -217,13 +258,34 @@ private: | |||
|         inputs.emplace_back(none_); | ||||
| 
 | ||||
|     std::vector<mlir::Type> outputTypes; | ||||
|     for (auto item : node.output()) { | ||||
| 
 | ||||
|     // Use the type map to determine the data type of output.
 | ||||
|     std::vector<int> outputMap = T::getTypeMap(); | ||||
|     for (auto i = 0; i < node.output().size(); i++) { | ||||
|       // Optional outputs using empty string.
 | ||||
|       if (item.empty()) | ||||
|       if (node.output()[i].empty()) { | ||||
|         outputTypes.emplace_back(builder_.getNoneType()); | ||||
|       else | ||||
|         outputTypes.push_back( | ||||
|             mlir::UnrankedTensorType::get(builder_.getF32Type())); | ||||
|       } else { | ||||
|         if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) { | ||||
|           // Mapping gives a connection with an input.
 | ||||
|           mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType(); | ||||
|           if (inputType.isa<mlir::TensorType>()) { | ||||
|             auto elementType = | ||||
|                 inputType.cast<mlir::TensorType>().getElementType(); | ||||
|             auto outType = mlir::UnrankedTensorType::get(elementType); | ||||
|             outputTypes.emplace_back(outType); | ||||
|           } else { | ||||
|             outputTypes.push_back(inputType); | ||||
|           } | ||||
|         } else if (i < outputMap.size() && outputMap[i] != -1) { | ||||
|           // Mapping gives a direct type.
 | ||||
|           auto elementType = buildTypeFromIndex(outputMap[i]); | ||||
|           auto outType = mlir::UnrankedTensorType::get(elementType); | ||||
|           outputTypes.emplace_back(outType); | ||||
|         } else { | ||||
|           outputTypes.emplace_back(builder_.getNoneType()); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     // Trailing optional outputs.
 | ||||
|     if (!variadicOut) | ||||
|  | @ -241,9 +303,10 @@ private: | |||
|   } | ||||
| 
 | ||||
|   template <typename T> | ||||
|   void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1, | ||||
|       int expectedNumResults = -1) { | ||||
|   void buildOperation(const onnx::NodeProto &node) { | ||||
|     std::vector<mlir::Value> inputs; | ||||
|     int expectedNumOperands = T::getNumberOfOperands(); | ||||
|     int expectedNumResults = T::getNumberOfResults(); | ||||
|     for (const auto &item : node.input()) | ||||
|       if (initializedTensors.ContainKey(legalize_name(item))) { | ||||
|         inputs.push_back(initializedTensors.EmitInitializerForInputTensor( | ||||
|  | @ -256,7 +319,9 @@ private: | |||
|         node, inputs, expectedNumOperands, expectedNumResults); | ||||
|   } | ||||
| 
 | ||||
|   void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) { | ||||
|   void ImportNodeReshape(onnx::NodeProto node) { | ||||
|     int expectedNumOperands = mlir::ONNXReshapeOp::getNumberOfOperands(); | ||||
|     int expectedNumResults = mlir::ONNXReshapeOp::getNumberOfResults(); | ||||
|     std::vector<mlir::Value> inputs; | ||||
|     std::string item; | ||||
|     for (int i = 0; i < node.input().size(); ++i) { | ||||
|  | @ -270,39 +335,40 @@ private: | |||
|       } | ||||
|     } | ||||
| 
 | ||||
|     buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut); | ||||
|     buildOutputAndOperation<mlir::ONNXReshapeOp>( | ||||
|         node, inputs, expectedNumOperands, expectedNumResults); | ||||
|   } | ||||
| 
 | ||||
|   /*!
 | ||||
|    * Special handle for MaxPool operations. | ||||
|    */ | ||||
|   void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) { | ||||
|   void ImportNodeMaxPool(onnx::NodeProto node) { | ||||
|     int nOuts = node.output().size(); | ||||
|     if (nOuts == 1) { | ||||
|       buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts); | ||||
|       buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node); | ||||
|     } else { | ||||
|       buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts); | ||||
|       buildOperation<mlir::ONNXMaxPoolOp>(node); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /*!
 | ||||
|    * Special handle for BatchNormalization operations. | ||||
|    */ | ||||
|   void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) { | ||||
|   void ImportNodeBatchNormalization(onnx::NodeProto node) { | ||||
|     int nOuts = node.output().size(); | ||||
|     if (nOuts == 1) { | ||||
|       // Test mode with one output.
 | ||||
|       buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts); | ||||
|       buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node); | ||||
|     } else { | ||||
|       // Training mode with four trailing optional outputs. Not handled yet.
 | ||||
|       buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts); | ||||
|       buildOperation<mlir::ONNXBatchNormalizationOp>(node); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /*!
 | ||||
|    * Special handle for Pad operations. | ||||
|    */ | ||||
|   void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) { | ||||
|   void ImportNodePad(onnx::NodeProto node) { | ||||
| 
 | ||||
|     int nOps = node.input().size(); | ||||
|     if (nOps == 2) { | ||||
|  | @ -330,9 +396,11 @@ private: | |||
|         } | ||||
|       inputs.push_back(constantResult); | ||||
| 
 | ||||
|       int nIn = mlir::ONNXPadOp::getNumberOfOperands(); | ||||
|       int nOut = mlir::ONNXPadOp::getNumberOfResults(); | ||||
|       buildOutputAndOperation<mlir::ONNXPadOp>(node, inputs, nIn, nOut); | ||||
|     } else { | ||||
|       buildOperation<mlir::ONNXPadOp>(node, nIn, nOut); | ||||
|       buildOperation<mlir::ONNXPadOp>(node); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  |  | |||
|  | @ -5,38 +5,38 @@ | |||
| //********************************************************
 | ||||
| 
 | ||||
| if (opName == "ArrayFeatureExtractor") | ||||
|   return buildOperation<mlir::MLONNXArrayFeatureExtractorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXArrayFeatureExtractorOp>(node); | ||||
| if (opName == "Binarizer") | ||||
|   return buildOperation<mlir::MLONNXBinarizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXBinarizerOp>(node); | ||||
| if (opName == "CastMap") | ||||
|   return buildOperation<mlir::MLONNXCastMapOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXCastMapOp>(node); | ||||
| if (opName == "CategoryMapper") | ||||
|   return buildOperation<mlir::MLONNXCategoryMapperOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXCategoryMapperOp>(node); | ||||
| if (opName == "DictVectorizer") | ||||
|   return buildOperation<mlir::MLONNXDictVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXDictVectorizerOp>(node); | ||||
| if (opName == "FeatureVectorizer") | ||||
|   return buildOperation<mlir::MLONNXFeatureVectorizerOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXFeatureVectorizerOp>(node); | ||||
| if (opName == "Imputer") | ||||
|   return buildOperation<mlir::MLONNXImputerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXImputerOp>(node); | ||||
| if (opName == "LabelEncoder") | ||||
|   return buildOperation<mlir::MLONNXLabelEncoderOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXLabelEncoderOp>(node); | ||||
| if (opName == "LinearClassifier") | ||||
|   return buildOperation<mlir::MLONNXLinearClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); | ||||
|    buildOperation<mlir::MLONNXLinearClassifierOp>(node); | ||||
| if (opName == "LinearRegressor") | ||||
|   return buildOperation<mlir::MLONNXLinearRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXLinearRegressorOp>(node); | ||||
| if (opName == "Normalizer") | ||||
|   return buildOperation<mlir::MLONNXNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXNormalizerOp>(node); | ||||
| if (opName == "OneHotEncoder") | ||||
|   return buildOperation<mlir::MLONNXOneHotEncoderOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXOneHotEncoderOp>(node); | ||||
| if (opName == "SVMClassifier") | ||||
|   return buildOperation<mlir::MLONNXSVMClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); | ||||
|    buildOperation<mlir::MLONNXSVMClassifierOp>(node); | ||||
| if (opName == "SVMRegressor") | ||||
|   return buildOperation<mlir::MLONNXSVMRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXSVMRegressorOp>(node); | ||||
| if (opName == "Scaler") | ||||
|   return buildOperation<mlir::MLONNXScalerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXScalerOp>(node); | ||||
| if (opName == "TreeEnsembleClassifier") | ||||
|   return buildOperation<mlir::MLONNXTreeEnsembleClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); | ||||
|    buildOperation<mlir::MLONNXTreeEnsembleClassifierOp>(node); | ||||
| if (opName == "TreeEnsembleRegressor") | ||||
|   return buildOperation<mlir::MLONNXTreeEnsembleRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXTreeEnsembleRegressorOp>(node); | ||||
| if (opName == "ZipMap") | ||||
|   return buildOperation<mlir::MLONNXZipMapOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::MLONNXZipMapOp>(node); | ||||
|  |  | |||
|  | @ -5,314 +5,314 @@ | |||
| //********************************************************
 | ||||
| 
 | ||||
| if (opName == "Abs") | ||||
|    buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAbsOp>(node); | ||||
| if (opName == "Acos") | ||||
|    buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAcosOp>(node); | ||||
| if (opName == "Acosh") | ||||
|    buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAcoshOp>(node); | ||||
| if (opName == "Add") | ||||
|    buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAddOp>(node); | ||||
| if (opName == "And") | ||||
|    buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAndOp>(node); | ||||
| if (opName == "ArgMax") | ||||
|    buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXArgMaxOp>(node); | ||||
| if (opName == "ArgMin") | ||||
|    buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXArgMinOp>(node); | ||||
| if (opName == "Asin") | ||||
|    buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAsinOp>(node); | ||||
| if (opName == "Asinh") | ||||
|    buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAsinhOp>(node); | ||||
| if (opName == "Atan") | ||||
|    buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAtanOp>(node); | ||||
| if (opName == "Atanh") | ||||
|    buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAtanhOp>(node); | ||||
| if (opName == "AveragePool") | ||||
|    buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXAveragePoolOp>(node); | ||||
| if (opName == "BatchNormalization") | ||||
|    ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5); | ||||
|    ImportNodeBatchNormalization(node); | ||||
| if (opName == "BitShift") | ||||
|    buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXBitShiftOp>(node); | ||||
| if (opName == "Cast") | ||||
|    buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXCastOp>(node); | ||||
| if (opName == "Ceil") | ||||
|    buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXCeilOp>(node); | ||||
| if (opName == "Clip") | ||||
|    buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXClipOp>(node); | ||||
| if (opName == "Compress") | ||||
|    buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXCompressOp>(node); | ||||
| if (opName == "Concat") | ||||
|    buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXConcatOp>(node); | ||||
| if (opName == "ConcatFromSequence") | ||||
|    buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXConcatFromSequenceOp>(node); | ||||
| if (opName == "Constant") | ||||
|    buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXConstantOp>(node); | ||||
| if (opName == "ConstantOfShape") | ||||
|    buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXConstantOfShapeOp>(node); | ||||
| if (opName == "Conv") | ||||
|    buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXConvOp>(node); | ||||
| if (opName == "ConvInteger") | ||||
|    buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXConvIntegerOp>(node); | ||||
| if (opName == "ConvTranspose") | ||||
|    buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXConvTransposeOp>(node); | ||||
| if (opName == "Cos") | ||||
|    buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXCosOp>(node); | ||||
| if (opName == "Cosh") | ||||
|    buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXCoshOp>(node); | ||||
| if (opName == "CumSum") | ||||
|    buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXCumSumOp>(node); | ||||
| if (opName == "DepthToSpace") | ||||
|    buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXDepthToSpaceOp>(node); | ||||
| if (opName == "DequantizeLinear") | ||||
|    buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXDequantizeLinearOp>(node); | ||||
| if (opName == "Det") | ||||
|    buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXDetOp>(node); | ||||
| if (opName == "Div") | ||||
|    buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXDivOp>(node); | ||||
| if (opName == "Dropout") | ||||
|    buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); | ||||
|    buildOperation<mlir::ONNXDropoutOp>(node); | ||||
| if (opName == "DynamicQuantizeLinear") | ||||
|    buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3); | ||||
|    buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node); | ||||
| if (opName == "Elu") | ||||
|    buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXEluOp>(node); | ||||
| if (opName == "Equal") | ||||
|    buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXEqualOp>(node); | ||||
| if (opName == "Erf") | ||||
|    buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXErfOp>(node); | ||||
| if (opName == "Exp") | ||||
|    buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXExpOp>(node); | ||||
| if (opName == "Expand") | ||||
|    buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXExpandOp>(node); | ||||
| if (opName == "EyeLike") | ||||
|    buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXEyeLikeOp>(node); | ||||
| if (opName == "Flatten") | ||||
|    buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXFlattenOp>(node); | ||||
| if (opName == "Floor") | ||||
|    buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXFloorOp>(node); | ||||
| if (opName == "GRU") | ||||
|    buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); | ||||
|    buildOperation<mlir::ONNXGRUOp>(node); | ||||
| if (opName == "Gather") | ||||
|    buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXGatherOp>(node); | ||||
| if (opName == "GatherElements") | ||||
|    buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXGatherElementsOp>(node); | ||||
| if (opName == "GatherND") | ||||
|    buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXGatherNDOp>(node); | ||||
| if (opName == "Gemm") | ||||
|    buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXGemmOp>(node); | ||||
| if (opName == "GlobalAveragePool") | ||||
|    buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXGlobalAveragePoolOp>(node); | ||||
| if (opName == "GlobalLpPool") | ||||
|    buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXGlobalLpPoolOp>(node); | ||||
| if (opName == "GlobalMaxPool") | ||||
|    buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXGlobalMaxPoolOp>(node); | ||||
| if (opName == "Greater") | ||||
|    buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXGreaterOp>(node); | ||||
| if (opName == "HardSigmoid") | ||||
|    buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXHardSigmoidOp>(node); | ||||
| if (opName == "Hardmax") | ||||
|    buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXHardmaxOp>(node); | ||||
| if (opName == "Identity") | ||||
|    buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXIdentityOp>(node); | ||||
| if (opName == "If") | ||||
|    buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); | ||||
|    buildOperation<mlir::ONNXIfOp>(node); | ||||
| if (opName == "InstanceNormalization") | ||||
|    buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXInstanceNormalizationOp>(node); | ||||
| if (opName == "IsInf") | ||||
|    buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXIsInfOp>(node); | ||||
| if (opName == "IsNaN") | ||||
|    buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXIsNaNOp>(node); | ||||
| if (opName == "LRN") | ||||
|    buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXLRNOp>(node); | ||||
| if (opName == "LSTM") | ||||
|    buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3); | ||||
|    buildOperation<mlir::ONNXLSTMOp>(node); | ||||
| if (opName == "LeakyRelu") | ||||
|    buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXLeakyReluOp>(node); | ||||
| if (opName == "Less") | ||||
|    buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXLessOp>(node); | ||||
| if (opName == "Log") | ||||
|    buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXLogOp>(node); | ||||
| if (opName == "LogSoftmax") | ||||
|    buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXLogSoftmaxOp>(node); | ||||
| if (opName == "Loop") | ||||
|    buildOperation<mlir::ONNXLoopOp>(node); | ||||
| if (opName == "LpNormalization") | ||||
|    buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXLpNormalizationOp>(node); | ||||
| if (opName == "LpPool") | ||||
|    buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXLpPoolOp>(node); | ||||
| if (opName == "MatMul") | ||||
|    buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMatMulOp>(node); | ||||
| if (opName == "MatMulInteger") | ||||
|    buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMatMulIntegerOp>(node); | ||||
| if (opName == "Max") | ||||
|    buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMaxOp>(node); | ||||
| if (opName == "MaxPool") | ||||
|    ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); | ||||
|    ImportNodeMaxPool(node); | ||||
| if (opName == "MaxRoiPool") | ||||
|    buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMaxRoiPoolOp>(node); | ||||
| if (opName == "MaxUnpool") | ||||
|    buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMaxUnpoolOp>(node); | ||||
| if (opName == "Mean") | ||||
|    buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMeanOp>(node); | ||||
| if (opName == "MeanVarianceNormalization") | ||||
|    buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node); | ||||
| if (opName == "Min") | ||||
|    buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMinOp>(node); | ||||
| if (opName == "Mod") | ||||
|    buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXModOp>(node); | ||||
| if (opName == "Mul") | ||||
|    buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMulOp>(node); | ||||
| if (opName == "Multinomial") | ||||
|    buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXMultinomialOp>(node); | ||||
| if (opName == "Neg") | ||||
|    buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXNegOp>(node); | ||||
| if (opName == "NonMaxSuppression") | ||||
|    buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXNonMaxSuppressionOp>(node); | ||||
| if (opName == "NonZero") | ||||
|    buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXNonZeroOp>(node); | ||||
| if (opName == "Not") | ||||
|    buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXNotOp>(node); | ||||
| if (opName == "OneHot") | ||||
|    buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXOneHotOp>(node); | ||||
| if (opName == "Or") | ||||
|    buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXOrOp>(node); | ||||
| if (opName == "PRelu") | ||||
|    buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXPReluOp>(node); | ||||
| if (opName == "Pad") | ||||
|    ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    ImportNodePad(node); | ||||
| if (opName == "Pow") | ||||
|    buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXPowOp>(node); | ||||
| if (opName == "QLinearConv") | ||||
|    buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXQLinearConvOp>(node); | ||||
| if (opName == "QLinearMatMul") | ||||
|    buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXQLinearMatMulOp>(node); | ||||
| if (opName == "QuantizeLinear") | ||||
|    buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXQuantizeLinearOp>(node); | ||||
| if (opName == "RNN") | ||||
|    buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); | ||||
|    buildOperation<mlir::ONNXRNNOp>(node); | ||||
| if (opName == "RandomNormal") | ||||
|    buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXRandomNormalOp>(node); | ||||
| if (opName == "RandomNormalLike") | ||||
|    buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXRandomNormalLikeOp>(node); | ||||
| if (opName == "RandomUniform") | ||||
|    buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXRandomUniformOp>(node); | ||||
| if (opName == "RandomUniformLike") | ||||
|    buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXRandomUniformLikeOp>(node); | ||||
| if (opName == "Range") | ||||
|    buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXRangeOp>(node); | ||||
| if (opName == "Reciprocal") | ||||
|    buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReciprocalOp>(node); | ||||
| if (opName == "ReduceL1") | ||||
|    buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceL1Op>(node); | ||||
| if (opName == "ReduceL2") | ||||
|    buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceL2Op>(node); | ||||
| if (opName == "ReduceLogSum") | ||||
|    buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceLogSumOp>(node); | ||||
| if (opName == "ReduceLogSumExp") | ||||
|    buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceLogSumExpOp>(node); | ||||
| if (opName == "ReduceMax") | ||||
|    buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceMaxOp>(node); | ||||
| if (opName == "ReduceMean") | ||||
|    buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceMeanOp>(node); | ||||
| if (opName == "ReduceMin") | ||||
|    buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceMinOp>(node); | ||||
| if (opName == "ReduceProd") | ||||
|    buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceProdOp>(node); | ||||
| if (opName == "ReduceSum") | ||||
|    buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceSumOp>(node); | ||||
| if (opName == "ReduceSumSquare") | ||||
|    buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReduceSumSquareOp>(node); | ||||
| if (opName == "Relu") | ||||
|    buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReluOp>(node); | ||||
| if (opName == "Reshape") | ||||
|    ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    ImportNodeReshape(node); | ||||
| if (opName == "Resize") | ||||
|    buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXResizeOp>(node); | ||||
| if (opName == "ReverseSequence") | ||||
|    buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXReverseSequenceOp>(node); | ||||
| if (opName == "RoiAlign") | ||||
|    buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXRoiAlignOp>(node); | ||||
| if (opName == "Round") | ||||
|    buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXRoundOp>(node); | ||||
| if (opName == "Scan") | ||||
|    buildOperation<mlir::ONNXScanOp>(node); | ||||
| if (opName == "Scatter") | ||||
|    buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXScatterOp>(node); | ||||
| if (opName == "ScatterElements") | ||||
|    buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXScatterElementsOp>(node); | ||||
| if (opName == "ScatterND") | ||||
|    buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXScatterNDOp>(node); | ||||
| if (opName == "Selu") | ||||
|    buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSeluOp>(node); | ||||
| if (opName == "SequenceAt") | ||||
|    buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSequenceAtOp>(node); | ||||
| if (opName == "SequenceConstruct") | ||||
|    buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSequenceConstructOp>(node); | ||||
| if (opName == "SequenceEmpty") | ||||
|    buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSequenceEmptyOp>(node); | ||||
| if (opName == "SequenceErase") | ||||
|    buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSequenceEraseOp>(node); | ||||
| if (opName == "SequenceInsert") | ||||
|    buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSequenceInsertOp>(node); | ||||
| if (opName == "SequenceLength") | ||||
|    buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSequenceLengthOp>(node); | ||||
| if (opName == "Shape") | ||||
|    buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXShapeOp>(node); | ||||
| if (opName == "Shrink") | ||||
|    buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXShrinkOp>(node); | ||||
| if (opName == "Sigmoid") | ||||
|    buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSigmoidOp>(node); | ||||
| if (opName == "Sign") | ||||
|    buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSignOp>(node); | ||||
| if (opName == "Sin") | ||||
|    buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSinOp>(node); | ||||
| if (opName == "Sinh") | ||||
|    buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSinhOp>(node); | ||||
| if (opName == "Size") | ||||
|    buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSizeOp>(node); | ||||
| if (opName == "Slice") | ||||
|    buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSliceOp>(node); | ||||
| if (opName == "Softmax") | ||||
|    buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSoftmaxOp>(node); | ||||
| if (opName == "Softplus") | ||||
|    buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSoftplusOp>(node); | ||||
| if (opName == "Softsign") | ||||
|    buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSoftsignOp>(node); | ||||
| if (opName == "SpaceToDepth") | ||||
|    buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSpaceToDepthOp>(node); | ||||
| if (opName == "Split") | ||||
|    buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); | ||||
|    buildOperation<mlir::ONNXSplitOp>(node); | ||||
| if (opName == "SplitToSequence") | ||||
|    buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSplitToSequenceOp>(node); | ||||
| if (opName == "Sqrt") | ||||
|    buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSqrtOp>(node); | ||||
| if (opName == "Squeeze") | ||||
|    buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSqueezeOp>(node); | ||||
| if (opName == "StringNormalizer") | ||||
|    buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXStringNormalizerOp>(node); | ||||
| if (opName == "Sub") | ||||
|    buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSubOp>(node); | ||||
| if (opName == "Sum") | ||||
|    buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXSumOp>(node); | ||||
| if (opName == "Tan") | ||||
|    buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXTanOp>(node); | ||||
| if (opName == "Tanh") | ||||
|    buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXTanhOp>(node); | ||||
| if (opName == "TfIdfVectorizer") | ||||
|    buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXTfIdfVectorizerOp>(node); | ||||
| if (opName == "ThresholdedRelu") | ||||
|    buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXThresholdedReluOp>(node); | ||||
| if (opName == "Tile") | ||||
|    buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXTileOp>(node); | ||||
| if (opName == "TopK") | ||||
|    buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2); | ||||
|    buildOperation<mlir::ONNXTopKOp>(node); | ||||
| if (opName == "Transpose") | ||||
|    buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXTransposeOp>(node); | ||||
| if (opName == "Unique") | ||||
|    buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4); | ||||
|    buildOperation<mlir::ONNXUniqueOp>(node); | ||||
| if (opName == "Unsqueeze") | ||||
|    buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXUnsqueezeOp>(node); | ||||
| if (opName == "Upsample") | ||||
|    buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXUpsampleOp>(node); | ||||
| if (opName == "Where") | ||||
|    buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXWhereOp>(node); | ||||
| if (opName == "Xor") | ||||
|    buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); | ||||
|    buildOperation<mlir::ONNXXorOp>(node); | ||||
|  |  | |||
|  | @ -14,6 +14,17 @@ def MLONNXArrayFeatureExtractorOp:MLONNX_Op<"ArrayFeatureExtractor", | |||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|     AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 2; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {20}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXBinarizerOp:MLONNX_Op<"Binarizer", | ||||
|  | @ -22,9 +33,20 @@ def MLONNXBinarizerOp:MLONNX_Op<"Binarizer", | |||
|   let description = [{ | ||||
|   "Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     DefaultValuedAttr<F32Attr, "0.0">:$threshold); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {20}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXCastMapOp:MLONNX_Op<"CastMap", | ||||
|  | @ -40,6 +62,17 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap", | |||
|     DefaultValuedAttr<StrAttr, "DENSE">:$map_form, | ||||
|     DefaultValuedAttr<I64Attr, "1">:$max_map); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper", | ||||
|  | @ -61,6 +94,17 @@ def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper", | |||
|     DefaultValuedAttr<I64Attr, "-1">:$default_int64, | ||||
|     DefaultValuedAttr<StrAttr, "_Unused">:$default_string); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer", | ||||
|  | @ -84,6 +128,17 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer", | |||
|     OptionalAttr<I64ArrayAttr>:$int64_vocabulary, | ||||
|     OptionalAttr<StrArrayAttr>:$string_vocabulary); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer", | ||||
|  | @ -95,9 +150,20 @@ def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer", | |||
|   "    Inputs are copied to the output maintaining the order of the input arguments.<br>" | ||||
|   "    All inputs must be integers or floats, while the output will be all floating point values." | ||||
|   }]; | ||||
|   let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$X, | ||||
|   let arguments = (ins Variadic<AnyTypeOf<[TensorOf<[I32,I64,F32,F64]>, MemRefOf<[I32,I64,F32,F64]>]>>:$X, | ||||
|     OptionalAttr<I64ArrayAttr>:$inputdimensions); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return -1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXImputerOp:MLONNX_Op<"Imputer", | ||||
|  | @ -113,12 +179,23 @@ def MLONNXImputerOp:MLONNX_Op<"Imputer", | |||
|   "    which one depends on whether floats or integers are being processed.<br>" | ||||
|   "    The imputed_value attribute length can be 1 element, or it can have one element per input feature.<br>In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     OptionalAttr<F32ArrayAttr>:$imputed_value_floats, | ||||
|     OptionalAttr<I64ArrayAttr>:$imputed_value_int64s, | ||||
|     DefaultValuedAttr<F32Attr, "0.0">:$replaced_value_float, | ||||
|     DefaultValuedAttr<I64Attr, "0">:$replaced_value_int64); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {20}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder", | ||||
|  | @ -154,6 +231,17 @@ def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder", | |||
|     OptionalAttr<I64ArrayAttr>:$values_int64s, | ||||
|     OptionalAttr<StrArrayAttr>:$values_strings); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier", | ||||
|  | @ -162,7 +250,7 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier", | |||
|   let description = [{ | ||||
|   "Linear classifier" | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     OptionalAttr<I64ArrayAttr>:$classlabels_ints, | ||||
|     OptionalAttr<StrArrayAttr>:$classlabels_strings, | ||||
|     F32ArrayAttr:$coefficients, | ||||
|  | @ -171,6 +259,17 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier", | |||
|     DefaultValuedAttr<StrAttr, "NONE">:$post_transform); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, | ||||
|     AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 2; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1,-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor", | ||||
|  | @ -184,12 +283,23 @@ def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor", | |||
|   "    The coefficients array is of length n, and the coefficients for each target are contiguous." | ||||
|   "    Intercepts are optional but if provided must match the number of targets." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     OptionalAttr<F32ArrayAttr>:$coefficients, | ||||
|     OptionalAttr<F32ArrayAttr>:$intercepts, | ||||
|     DefaultValuedAttr<StrAttr, "NONE">:$post_transform, | ||||
|     DefaultValuedAttr<I64Attr, "1">:$targets); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXNormalizerOp:MLONNX_Op<"Normalizer", | ||||
|  | @ -207,9 +317,20 @@ def MLONNXNormalizerOp:MLONNX_Op<"Normalizer", | |||
|   "    For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row" | ||||
|   "    of the batch is normalized independently." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     DefaultValuedAttr<StrAttr, "MAX">:$norm); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder", | ||||
|  | @ -230,6 +351,17 @@ def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder", | |||
|     OptionalAttr<StrArrayAttr>:$cats_strings, | ||||
|     DefaultValuedAttr<I64Attr, "1">:$zeros); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier", | ||||
|  | @ -238,7 +370,7 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier", | |||
|   let description = [{ | ||||
|   "Support Vector Machine classifier" | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     OptionalAttr<I64ArrayAttr>:$classlabels_ints, | ||||
|     OptionalAttr<StrArrayAttr>:$classlabels_strings, | ||||
|     OptionalAttr<F32ArrayAttr>:$coefficients, | ||||
|  | @ -252,6 +384,17 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier", | |||
|     OptionalAttr<I64ArrayAttr>:$vectors_per_class); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, | ||||
|     AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 2; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1,-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor", | ||||
|  | @ -260,7 +403,7 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor", | |||
|   let description = [{ | ||||
|   "Support Vector Machine regression prediction and one-class SVM anomaly detection." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     OptionalAttr<F32ArrayAttr>:$coefficients, | ||||
|     OptionalAttr<F32ArrayAttr>:$kernel_params, | ||||
|     DefaultValuedAttr<StrAttr, "LINEAR">:$kernel_type, | ||||
|  | @ -270,6 +413,17 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor", | |||
|     OptionalAttr<F32ArrayAttr>:$rho, | ||||
|     OptionalAttr<F32ArrayAttr>:$support_vectors); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXScalerOp:MLONNX_Op<"Scaler", | ||||
|  | @ -278,10 +432,21 @@ def MLONNXScalerOp:MLONNX_Op<"Scaler", | |||
|   let description = [{ | ||||
|   "Rescale input data, for example to standardize features by removing the mean and scaling to unit variance." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     OptionalAttr<F32ArrayAttr>:$offset, | ||||
|     OptionalAttr<F32ArrayAttr>:$scale); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier", | ||||
|  | @ -298,7 +463,7 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier", | |||
|   "    One and only one of classlabels_strings or classlabels_int64s" | ||||
|   "    will be defined. The class_ids are indices into this list." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     OptionalAttr<F32ArrayAttr>:$base_values, | ||||
|     OptionalAttr<I64ArrayAttr>:$class_ids, | ||||
|     OptionalAttr<I64ArrayAttr>:$class_nodeids, | ||||
|  | @ -318,6 +483,17 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier", | |||
|     DefaultValuedAttr<StrAttr, "NONE">:$post_transform); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, | ||||
|     AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 2; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1,-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor", | ||||
|  | @ -335,7 +511,7 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor", | |||
|   "    All trees must have their node ids start at 0 and increment by 1.<br>" | ||||
|   "    Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF" | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X, | ||||
|     DefaultValuedAttr<StrAttr, "SUM">:$aggregate_function, | ||||
|     OptionalAttr<F32ArrayAttr>:$base_values, | ||||
|     OptionalAttr<I64Attr>:$n_targets, | ||||
|  | @ -354,6 +530,17 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor", | |||
|     OptionalAttr<I64ArrayAttr>:$target_treeids, | ||||
|     OptionalAttr<F32ArrayAttr>:$target_weights); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def MLONNXZipMapOp:MLONNX_Op<"ZipMap", | ||||
|  | @ -369,5 +556,16 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap", | |||
|     OptionalAttr<I64ArrayAttr>:$classlabels_int64s, | ||||
|     OptionalAttr<StrArrayAttr>:$classlabels_strings); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {-1}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -38,7 +38,7 @@ def ONNX_Dialect : Dialect { | |||
| //   * The mnemonic for the operation, or the name without the dialect prefix. | ||||
| //   * A list of traits for the operation. | ||||
| class ONNX_Op<string mnemonic, list<OpTrait> traits = []> : | ||||
|     Op<ONNX_Dialect, mnemonic, traits>; | ||||
|     Op<ONNX_Dialect, mnemonic, traits> ; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // ONNX Operations | ||||
|  | @ -112,6 +112,17 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", | |||
|            DefaultValuedAttr<I64Attr, "0">:$storage_order, | ||||
|            OptionalAttr<I64ArrayAttr>:$strides); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {0}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", | ||||
|  | @ -137,6 +148,17 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", | |||
|            DefaultValuedAttr<F32Attr, "1e-05">:$epsilon, | ||||
|            DefaultValuedAttr<F32Attr, "0.9">:$momentum); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 5; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {0}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue", | ||||
|  | @ -154,6 +176,17 @@ def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue", | |||
|            DefaultValuedAttr<F32Attr, "0.0">:$constant_value, | ||||
|            DefaultValuedAttr<StrAttr, "constant">:$mode); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {0}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad", | ||||
|  | @ -168,6 +201,17 @@ def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad", | |||
|            I64ArrayAttr:$pads, | ||||
|            DefaultValuedAttr<StrAttr, "constant">:$mode); | ||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {0}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad", | ||||
|  | @ -186,6 +230,17 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad", | |||
|   let builders = [OpBuilder<"OpBuilder &builder, OperationState &state, " | ||||
|                             "Value data, ArrayAttr pads, " | ||||
|                             "FloatAttr constant_value, StringAttr mode">]; | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|     } | ||||
|     static int getNumberOfResults() { | ||||
|       return 1; | ||||
|     } | ||||
|     static std::vector<int> getTypeMap() { | ||||
|       return {0}; | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| #endif // ONNX_OPS | ||||
|  |  | |||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -144,62 +144,62 @@ func @test_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<* | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { | ||||
|   %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> | ||||
|   "std.return"(%0) : (tensor<*xi32>) -> () | ||||
| func @test_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { | ||||
|   %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> | ||||
|   "std.return"(%0) : (tensor<*xi1>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_and | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 | ||||
|   // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: return [[RES]] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1 | ||||
|   // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: return [[RES]] : memref<10x10xi1> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { | ||||
|   %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> | ||||
|   "std.return"(%0) : (tensor<*xi32>) -> () | ||||
| func @test_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { | ||||
|   %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> | ||||
|   "std.return"(%0) : (tensor<*xi1>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_or | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 | ||||
|   // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: return [[RES]] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1 | ||||
|   // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: return [[RES]] : memref<10x10xi1> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { | ||||
|   %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> | ||||
|   "std.return"(%0) : (tensor<*xi32>) -> () | ||||
| func @test_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { | ||||
|   %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> | ||||
|   "std.return"(%0) : (tensor<*xi1>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_xor | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 | ||||
|   // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: return [[RES]] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1 | ||||
|   // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: return [[RES]] : memref<10x10xi1> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
|  |  | |||
|  | @ -158,24 +158,24 @@ func @test_sub_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { | ||||
|   %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> | ||||
|   %1 = "onnx.And"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> | ||||
|   "std.return"(%1) : (tensor<*xi32>) -> () | ||||
| func @test_and_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { | ||||
|   %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> | ||||
|   %1 = "onnx.And"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1> | ||||
|   "std.return"(%1) : (tensor<*xi1>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_and_and | ||||
|   /// First And | ||||
|   // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> | ||||
|   // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1> | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 | ||||
|   // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1 | ||||
|   // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
| 
 | ||||
|   /// Second And | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
|  | @ -183,38 +183,38 @@ func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens | |||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 | ||||
|   // CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1 | ||||
|   // CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
| 
 | ||||
|   /// Dealloc of first result. | ||||
|   // CHECK: dealloc [[RES]] : memref<10x10xi32> | ||||
|   // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> | ||||
|   // CHECK: dealloc [[RES]] : memref<10x10xi1> | ||||
|   // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1> | ||||
| 
 | ||||
|   // CHECK: return [[RET_RES]] : memref<10x10xi32> | ||||
|   // CHECK: return [[RET_RES]] : memref<10x10xi1> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { | ||||
|   %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> | ||||
|   %1 = "onnx.Or"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> | ||||
|   "std.return"(%1) : (tensor<*xi32>) -> () | ||||
| func @test_or_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { | ||||
|   %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> | ||||
|   %1 = "onnx.Or"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1> | ||||
|   "std.return"(%1) : (tensor<*xi1>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_or_or | ||||
|   /// First Or | ||||
|   // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> | ||||
|   // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1> | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 | ||||
|   // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1 | ||||
|   // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
| 
 | ||||
|   /// Second Or | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
|  | @ -222,38 +222,38 @@ func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor | |||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 | ||||
|   // CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1 | ||||
|   // CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
| 
 | ||||
|   /// Dealloc of first result. | ||||
|   // CHECK: dealloc [[RES]] : memref<10x10xi32> | ||||
|   // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> | ||||
|   // CHECK: dealloc [[RES]] : memref<10x10xi1> | ||||
|   // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1> | ||||
| 
 | ||||
|   // CHECK: return [[RET_RES]] : memref<10x10xi32> | ||||
|   // CHECK: return [[RET_RES]] : memref<10x10xi1> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { | ||||
|   %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> | ||||
|   %1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> | ||||
|   "std.return"(%1) : (tensor<*xi32>) -> () | ||||
| func @test_xor_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> { | ||||
|   %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1> | ||||
|   %1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1> | ||||
|   "std.return"(%1) : (tensor<*xi1>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_xor_xor | ||||
|   /// First Xor | ||||
|   // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> | ||||
|   // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1> | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1> | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 | ||||
|   // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1 | ||||
|   // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
| 
 | ||||
|   /// Second Xor | ||||
|   // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
|  | @ -261,16 +261,16 @@ func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens | |||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
|   // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 | ||||
|   // CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> | ||||
|   // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1> | ||||
|   // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1 | ||||
|   // CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1> | ||||
| 
 | ||||
|   /// Dealloc of first result. | ||||
|   // CHECK: dealloc [[RES]] : memref<10x10xi32> | ||||
|   // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> | ||||
|   // CHECK: dealloc [[RES]] : memref<10x10xi1> | ||||
|   // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1> | ||||
| 
 | ||||
|   // CHECK: return [[RET_RES]] : memref<10x10xi32> | ||||
|   // CHECK: return [[RET_RES]] : memref<10x10xi1> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
|  |  | |||
|  | @ -298,6 +298,16 @@ custom_definition_misc = dict([ ('Constant', | |||
|   )]) | ||||
| 
 | ||||
| 
 | ||||
| onnx_types = ( | ||||
|     'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16', | ||||
|     'float', 'double', 'complex64', 'complex128' | ||||
| ) | ||||
| tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64',  | ||||
|     'Complex<F32>', 'Complex<F64>' | ||||
| ) | ||||
| 
 | ||||
| MAX_NUM_TYPES=20 | ||||
| 
 | ||||
| SNIPPETS = collect_snippets() | ||||
| SAMPLE_IMPLEMENTATIONS = collect_sample_implementations() | ||||
| ONNX_ML = bool(args.domain == "ONNX_ML") | ||||
|  | @ -376,53 +386,55 @@ def tblgen_operand_type_to_cpp_type(op_type): | |||
| 
 | ||||
| 
 | ||||
| def np_type_to_tblgen_attr_type(tstr): | ||||
|     tfrom = np.array([ | ||||
|         'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16', | ||||
|         'float', 'double' | ||||
|     ]) | ||||
|     tto = np.array( | ||||
|         ['I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64']) | ||||
|     index = -1 | ||||
|     for i in range(len(tfrom)): | ||||
|         if tfrom[i] in tstr: | ||||
|     for i in range(len(onnx_types)): | ||||
|         if onnx_types[i] in tstr: | ||||
|             index = i | ||||
|             break | ||||
|     if index == -1: | ||||
|         print("error", tstr) | ||||
|         return '' | ||||
|         return None | ||||
|     else: | ||||
|         return tto[i] | ||||
|         return tblgen_types[i] | ||||
| 
 | ||||
| def get_tblgen_type_index(type_str): | ||||
|     return tblgen_types.index(type_str) | ||||
| 
 | ||||
| #the possible data structures are tensor, map and seq(tensor()) | ||||
| #TOFIX: currently, only tensor structure is supported | ||||
| def get_data_structure_element(allowed_type_str):  | ||||
|     if allowed_type_str.startswith('tensor') : | ||||
|         element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1) | ||||
|         return ('tensor', element) | ||||
|     else : | ||||
|         return (None, None) | ||||
| 
 | ||||
| def get_allowed_elem_types(schema, input): | ||||
|     allowed_types_str = None | ||||
|     return allowed_types_str | ||||
|     #allowed_types_str = None | ||||
|     # return allowed_types_str | ||||
|     # TODO: enable type constraints. | ||||
|     # if input.typeStr : | ||||
|     #     tstr = input.typeStr | ||||
|     # else : | ||||
|     #     return allwedTypeStr | ||||
|     # if schema.type_constraints: | ||||
|     #     for type_constraint in schema.type_constraints: | ||||
|     #         if type_constraint.type_param_str != tstr : | ||||
|     #             continue | ||||
|     #         allowedTypes = type_constraint.allowed_type_strs | ||||
|     #         allowedTypeStr='' | ||||
|     #         if (len(allowedTypes) > 0): | ||||
|     #             t = convert_type(allowedTypes[0]) | ||||
|     #             if t == '' : | ||||
|     #                 return '' | ||||
|     #             allowedTypeStr += t | ||||
|     #         for allowedType in allowedTypes[1:]: | ||||
|     #             t = convert_type(allowedType) | ||||
|     #             if t == '' : | ||||
|     #                 return '' | ||||
|     #             if  not t in allowedTypeStr : | ||||
|     #                 allowedTypeStr += ', '+t | ||||
|     # | ||||
|     #         return allowedTypeStr | ||||
|     # | ||||
|     # return allowedTypeStr | ||||
|     if input.typeStr : | ||||
|          tstr = input.typeStr | ||||
|     else : | ||||
|         return None | ||||
|     if schema.type_constraints: | ||||
|         for type_constraint in schema.type_constraints: | ||||
|             if type_constraint.type_param_str != tstr : | ||||
|                 continue | ||||
|             allowed_type_list=[] | ||||
|             allowedTypes = type_constraint.allowed_type_strs | ||||
|             for allowedType in allowedTypes: | ||||
|                 structure, element = get_data_structure_element(allowedType); | ||||
|                 if structure == None or element == None: | ||||
|                     return None | ||||
|                 t = np_type_to_tblgen_attr_type(element) | ||||
|                 if t == None : | ||||
|                     return None | ||||
|                 if  not t in allowed_type_list : | ||||
|                     allowed_tyoe_list = allowed_type_list.append(t) | ||||
|      | ||||
|             return allowed_type_list | ||||
|      | ||||
|     return None | ||||
| 
 | ||||
| 
 | ||||
| def inc_indent(indent=None): | ||||
|  | @ -436,7 +448,6 @@ def dec_indent(indent): | |||
| def join_args(args): | ||||
|     return ", ".join(args) | ||||
| 
 | ||||
| 
 | ||||
| def get_operands_or_results(schema, is_input): | ||||
|     value_list = schema.inputs if is_input else schema.outputs | ||||
|     if not value_list: | ||||
|  | @ -456,8 +467,9 @@ def get_operands_or_results(schema, is_input): | |||
|         if elem_types is None: | ||||
|             types = ["AnyMemRef", "AnyTensor"] | ||||
|         else: | ||||
|             elem_types_str = ','.join(elem_types) | ||||
|             types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] | ||||
|             types = list(map(lambda x: x.format(elem_types), types)) | ||||
|             types = list(map(lambda x: x.format(elem_types_str), types)) | ||||
| 
 | ||||
|         # If operand is promotable to an attribute, then it must be | ||||
|         # nullable in case it migrates to be an attribute. | ||||
|  | @ -545,6 +557,64 @@ def get_attrs(schema): | |||
|             name_to_type[attr.name] = get_attr_type_optional(attr.type) | ||||
|     return name_to_type | ||||
| 
 | ||||
| def get_numberof_list(mylist): | ||||
|     expected_num = len(mylist) | ||||
|     for element in mylist : | ||||
|         if OpSchema.FormalParameterOption.Variadic == element.option: | ||||
|             expected_num = -1 | ||||
|     return expected_num | ||||
| 
 | ||||
| def get_output_type_mapping(schema): | ||||
|     mapping=[] | ||||
|     for output in schema.outputs : | ||||
|         #if only one type is allowed, just set that | ||||
|         allowed_elem_types = get_allowed_elem_types(schema, output) | ||||
|         if allowed_elem_types != None and len(allowed_elem_types) == 1 : | ||||
|             mapping.append(str(get_tblgen_type_index(allowed_elem_types[0]))) | ||||
|             continue | ||||
| 
 | ||||
|         #map the type string | ||||
|         if output.typeStr : | ||||
|             tstr = output.typeStr | ||||
|             found = False | ||||
|             for i, input in enumerate(schema.inputs): | ||||
|                 if input.typeStr and input.typeStr == tstr: | ||||
|                     mapping.append(str(i+MAX_NUM_TYPES)) | ||||
|                     found = True | ||||
|                     break | ||||
|             if found: | ||||
|                 continue | ||||
| 
 | ||||
|         #unknown output type | ||||
|         mapping.append(str(-1)) | ||||
|          | ||||
|     return mapping | ||||
|      | ||||
| def get_numberof_inout(s, indent, schema): | ||||
|     expected_num_operands = get_numberof_list(schema.inputs) | ||||
|     indent = inc_indent(indent) | ||||
|     s += indent + "static int getNumberOfOperands() {\n" | ||||
|     indent = inc_indent(indent) | ||||
|     s += indent + "return {};\n".format(expected_num_operands) | ||||
|     indent = dec_indent(indent) | ||||
|     s += indent + "}\n" | ||||
| 
 | ||||
|     expected_num_results = get_numberof_list(schema.outputs) | ||||
|     s += indent + "static int getNumberOfResults() {\n" | ||||
|     indent = inc_indent(indent) | ||||
|     s += indent + "return {};\n".format(expected_num_results) | ||||
|     indent = dec_indent(indent) | ||||
|     s += indent + "}\n" | ||||
| 
 | ||||
|     s += indent + "static std::vector<int> getTypeMap() {\n" | ||||
|     mapping = get_output_type_mapping(schema) | ||||
|     indent = inc_indent(indent) | ||||
|     s += indent + "return {" + ",".join(mapping) + "};\n" | ||||
|     indent = dec_indent(indent) | ||||
|     s += indent + "}\n" | ||||
| 
 | ||||
|     return s | ||||
| 
 | ||||
| 
 | ||||
| def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): | ||||
|     cpp_name_to_idx_literal = "{" + ", ".join([ | ||||
|  | @ -552,15 +622,15 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): | |||
|         for name_to_idx in const_operands_name_to_idx | ||||
|     ]) + "}" | ||||
| 
 | ||||
|     s += indent + "let extraClassDeclaration = [{\n" | ||||
|     #s += indent + "let extraClassDeclaration = [{\n" | ||||
|     indent = inc_indent(indent) | ||||
|     s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n" | ||||
|     indent = inc_indent(indent) | ||||
|     s += indent + "return {};\n".format(cpp_name_to_idx_literal) | ||||
|     indent = dec_indent(indent) | ||||
|     s += indent + "}\n" | ||||
|     indent = dec_indent(indent) | ||||
|     s += indent + "}];\n" | ||||
|     #indent = dec_indent(indent) | ||||
|     #s += indent + "}];\n" | ||||
| 
 | ||||
|     return s | ||||
| 
 | ||||
|  | @ -657,10 +727,20 @@ def gen_op_def(schema): | |||
| 
 | ||||
|             s += '\n' + indent + '];\n' | ||||
| 
 | ||||
|     # generate extracClassDeclaration | ||||
|     s += indent + "let extraClassDeclaration = [{\n" | ||||
|     #indent = inc_indent(indent) | ||||
| 
 | ||||
|     # generate input/output number | ||||
|     s = get_numberof_inout(s, indent, schema) | ||||
| 
 | ||||
|     # generate ProtableConst  | ||||
|     if schema.name in OpsWithPromotableConstOperands: | ||||
|         s = get_promotable_const_operands_func( | ||||
|             s, indent, OpsWithPromotableConstOperands[schema.name]) | ||||
| 
 | ||||
|     s += indent + '}];\n' | ||||
| 
 | ||||
|     if ( schema.name in custom_definition_misc) : | ||||
|         s += custom_definition_misc[schema.name] | ||||
| 
 | ||||
|  | @ -700,11 +780,13 @@ def gen_op_importer(schema, file): | |||
|     # Special handlers currently require expected num operands/results to be specified. | ||||
|     # TODO: remove special handlers. | ||||
|     args = ["node"] | ||||
|     """ | ||||
|     if expected_num_operands != -1 or expected_num_results != -1 or "buildOperation" not in handler_func: | ||||
|         args.append( | ||||
|             "/* expected_num_operands = */ {}".format(expected_num_operands)) | ||||
|         args.append( | ||||
|             '/* expected_num_results = */ {}'.format(expected_num_results)) | ||||
|     """ | ||||
|     s += inc_indent(indent) + " {}({});\n".format( | ||||
|         handler_func, ", ".join(args)) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue