Express some basic features of an Operation in TableGen file (#103)

* change operation definition

* change importer

* default type inference

* file format

* generate types for input/output

* generate the mapping for operation output type

* remove debug message for gen_doc.py

* update the dialect doc

* add support Complex

* format

* update document

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
chentong319 2020-05-21 22:03:16 -04:00 committed by GitHub
parent df18efcb48
commit 6099efd91b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 3100 additions and 985 deletions

View File

@ -35,13 +35,13 @@ ONNX Binarizer operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`Y` | memref of any type values or tensor of any type values `Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
### `mlonnx.CastMap` (MLONNXCastMapOp) ### `mlonnx.CastMap` (MLONNXCastMapOp)
@ -160,7 +160,7 @@ ONNX FeatureVectorizer operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values
#### Results: #### Results:
@ -194,13 +194,13 @@ ONNX Imputer operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`Y` | memref of any type values or tensor of any type values `Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
### `mlonnx.LabelEncoder` (MLONNXLabelEncoderOp) ### `mlonnx.LabelEncoder` (MLONNXLabelEncoderOp)
@ -271,7 +271,7 @@ ONNX LinearClassifier operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:
@ -304,7 +304,7 @@ ONNX LinearRegressor operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:
@ -337,7 +337,7 @@ ONNX Normalizer operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:
@ -404,7 +404,7 @@ ONNX SVMClassifier operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:
@ -436,7 +436,7 @@ ONNX SVMRegressor operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:
@ -461,7 +461,7 @@ ONNX Scaler operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:
@ -509,7 +509,7 @@ ONNX TreeEnsembleClassifier operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:
@ -559,7 +559,7 @@ ONNX TreeEnsembleRegressor operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`X` | memref of any type values or tensor of any type values `X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results: #### Results:

File diff suppressed because it is too large Load Diff

View File

@ -197,6 +197,47 @@ private:
} }
} }
#define MAX_TYPE 20
// itblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32',
// 'F64', 'Complex<F32>', 'Complex<F64>' )
mlir::Type buildTypeFromIndex(int index) {
switch (index) {
case 0:
return builder_.getI1Type();
case 1:
return builder_.getIntegerType(8);
case 2:
return builder_.getIntegerType(16);
case 3:
return builder_.getIntegerType(32);
case 4:
return builder_.getIntegerType(64);
case 5:
return builder_.getBF16Type();
case 6:
return builder_.getF16Type();
case 7:
return builder_.getF32Type();
case 8:
return builder_.getF64Type();
case 9: {
std::vector<mlir::Type> typeTuple(2);
typeTuple.push_back(builder_.getF32Type());
typeTuple.push_back(builder_.getF32Type());
return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple));
}
case 10: {
std::vector<mlir::Type> typeTuple(2);
typeTuple.push_back(builder_.getF64Type());
typeTuple.push_back(builder_.getF64Type());
return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple));
}
default:
assert(false && "Unsupported type index encountered.");
return nullptr;
}
}
template <typename T> template <typename T>
void buildOutputAndOperation(const onnx::NodeProto &node, void buildOutputAndOperation(const onnx::NodeProto &node,
std::vector<mlir::Value> inputs, int expectedNumOperands, std::vector<mlir::Value> inputs, int expectedNumOperands,
@ -217,13 +258,34 @@ private:
inputs.emplace_back(none_); inputs.emplace_back(none_);
std::vector<mlir::Type> outputTypes; std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
// Use the type map to determine the data type of output.
std::vector<int> outputMap = T::getTypeMap();
for (auto i = 0; i < node.output().size(); i++) {
// Optional outputs using empty string. // Optional outputs using empty string.
if (item.empty()) if (node.output()[i].empty()) {
outputTypes.emplace_back(builder_.getNoneType()); outputTypes.emplace_back(builder_.getNoneType());
else } else {
outputTypes.push_back( if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) {
mlir::UnrankedTensorType::get(builder_.getF32Type())); // Mapping gives a connection with an input.
mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType();
if (inputType.isa<mlir::TensorType>()) {
auto elementType =
inputType.cast<mlir::TensorType>().getElementType();
auto outType = mlir::UnrankedTensorType::get(elementType);
outputTypes.emplace_back(outType);
} else {
outputTypes.push_back(inputType);
}
} else if (i < outputMap.size() && outputMap[i] != -1) {
// Mapping gives a direct type.
auto elementType = buildTypeFromIndex(outputMap[i]);
auto outType = mlir::UnrankedTensorType::get(elementType);
outputTypes.emplace_back(outType);
} else {
outputTypes.emplace_back(builder_.getNoneType());
}
}
} }
// Trailing optional outputs. // Trailing optional outputs.
if (!variadicOut) if (!variadicOut)
@ -241,9 +303,10 @@ private:
} }
template <typename T> template <typename T>
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1, void buildOperation(const onnx::NodeProto &node) {
int expectedNumResults = -1) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
int expectedNumOperands = T::getNumberOfOperands();
int expectedNumResults = T::getNumberOfResults();
for (const auto &item : node.input()) for (const auto &item : node.input())
if (initializedTensors.ContainKey(legalize_name(item))) { if (initializedTensors.ContainKey(legalize_name(item))) {
inputs.push_back(initializedTensors.EmitInitializerForInputTensor( inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
@ -256,7 +319,9 @@ private:
node, inputs, expectedNumOperands, expectedNumResults); node, inputs, expectedNumOperands, expectedNumResults);
} }
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) { void ImportNodeReshape(onnx::NodeProto node) {
int expectedNumOperands = mlir::ONNXReshapeOp::getNumberOfOperands();
int expectedNumResults = mlir::ONNXReshapeOp::getNumberOfResults();
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
std::string item; std::string item;
for (int i = 0; i < node.input().size(); ++i) { for (int i = 0; i < node.input().size(); ++i) {
@ -270,39 +335,40 @@ private:
} }
} }
buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut); buildOutputAndOperation<mlir::ONNXReshapeOp>(
node, inputs, expectedNumOperands, expectedNumResults);
} }
/*! /*!
* Special handle for MaxPool operations. * Special handle for MaxPool operations.
*/ */
void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) { void ImportNodeMaxPool(onnx::NodeProto node) {
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts); buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node);
} else { } else {
buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts); buildOperation<mlir::ONNXMaxPoolOp>(node);
} }
} }
/*! /*!
* Special handle for BatchNormalization operations. * Special handle for BatchNormalization operations.
*/ */
void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) { void ImportNodeBatchNormalization(onnx::NodeProto node) {
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
// Test mode with one output. // Test mode with one output.
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts); buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node);
} else { } else {
// Training mode with four trailing optional outputs. Not handled yet. // Training mode with four trailing optional outputs. Not handled yet.
buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts); buildOperation<mlir::ONNXBatchNormalizationOp>(node);
} }
} }
/*! /*!
* Special handle for Pad operations. * Special handle for Pad operations.
*/ */
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) { void ImportNodePad(onnx::NodeProto node) {
int nOps = node.input().size(); int nOps = node.input().size();
if (nOps == 2) { if (nOps == 2) {
@ -330,9 +396,11 @@ private:
} }
inputs.push_back(constantResult); inputs.push_back(constantResult);
int nIn = mlir::ONNXPadOp::getNumberOfOperands();
int nOut = mlir::ONNXPadOp::getNumberOfResults();
buildOutputAndOperation<mlir::ONNXPadOp>(node, inputs, nIn, nOut); buildOutputAndOperation<mlir::ONNXPadOp>(node, inputs, nIn, nOut);
} else { } else {
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut); buildOperation<mlir::ONNXPadOp>(node);
} }
} }

View File

@ -5,38 +5,38 @@
//******************************************************** //********************************************************
if (opName == "ArrayFeatureExtractor") if (opName == "ArrayFeatureExtractor")
return buildOperation<mlir::MLONNXArrayFeatureExtractorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXArrayFeatureExtractorOp>(node);
if (opName == "Binarizer") if (opName == "Binarizer")
return buildOperation<mlir::MLONNXBinarizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXBinarizerOp>(node);
if (opName == "CastMap") if (opName == "CastMap")
return buildOperation<mlir::MLONNXCastMapOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXCastMapOp>(node);
if (opName == "CategoryMapper") if (opName == "CategoryMapper")
return buildOperation<mlir::MLONNXCategoryMapperOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXCategoryMapperOp>(node);
if (opName == "DictVectorizer") if (opName == "DictVectorizer")
return buildOperation<mlir::MLONNXDictVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXDictVectorizerOp>(node);
if (opName == "FeatureVectorizer") if (opName == "FeatureVectorizer")
return buildOperation<mlir::MLONNXFeatureVectorizerOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXFeatureVectorizerOp>(node);
if (opName == "Imputer") if (opName == "Imputer")
return buildOperation<mlir::MLONNXImputerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXImputerOp>(node);
if (opName == "LabelEncoder") if (opName == "LabelEncoder")
return buildOperation<mlir::MLONNXLabelEncoderOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXLabelEncoderOp>(node);
if (opName == "LinearClassifier") if (opName == "LinearClassifier")
return buildOperation<mlir::MLONNXLinearClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); buildOperation<mlir::MLONNXLinearClassifierOp>(node);
if (opName == "LinearRegressor") if (opName == "LinearRegressor")
return buildOperation<mlir::MLONNXLinearRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXLinearRegressorOp>(node);
if (opName == "Normalizer") if (opName == "Normalizer")
return buildOperation<mlir::MLONNXNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXNormalizerOp>(node);
if (opName == "OneHotEncoder") if (opName == "OneHotEncoder")
return buildOperation<mlir::MLONNXOneHotEncoderOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXOneHotEncoderOp>(node);
if (opName == "SVMClassifier") if (opName == "SVMClassifier")
return buildOperation<mlir::MLONNXSVMClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); buildOperation<mlir::MLONNXSVMClassifierOp>(node);
if (opName == "SVMRegressor") if (opName == "SVMRegressor")
return buildOperation<mlir::MLONNXSVMRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXSVMRegressorOp>(node);
if (opName == "Scaler") if (opName == "Scaler")
return buildOperation<mlir::MLONNXScalerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXScalerOp>(node);
if (opName == "TreeEnsembleClassifier") if (opName == "TreeEnsembleClassifier")
return buildOperation<mlir::MLONNXTreeEnsembleClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); buildOperation<mlir::MLONNXTreeEnsembleClassifierOp>(node);
if (opName == "TreeEnsembleRegressor") if (opName == "TreeEnsembleRegressor")
return buildOperation<mlir::MLONNXTreeEnsembleRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXTreeEnsembleRegressorOp>(node);
if (opName == "ZipMap") if (opName == "ZipMap")
return buildOperation<mlir::MLONNXZipMapOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::MLONNXZipMapOp>(node);

View File

@ -5,314 +5,314 @@
//******************************************************** //********************************************************
if (opName == "Abs") if (opName == "Abs")
buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAbsOp>(node);
if (opName == "Acos") if (opName == "Acos")
buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAcosOp>(node);
if (opName == "Acosh") if (opName == "Acosh")
buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAcoshOp>(node);
if (opName == "Add") if (opName == "Add")
buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAddOp>(node);
if (opName == "And") if (opName == "And")
buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAndOp>(node);
if (opName == "ArgMax") if (opName == "ArgMax")
buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXArgMaxOp>(node);
if (opName == "ArgMin") if (opName == "ArgMin")
buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXArgMinOp>(node);
if (opName == "Asin") if (opName == "Asin")
buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAsinOp>(node);
if (opName == "Asinh") if (opName == "Asinh")
buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAsinhOp>(node);
if (opName == "Atan") if (opName == "Atan")
buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAtanOp>(node);
if (opName == "Atanh") if (opName == "Atanh")
buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAtanhOp>(node);
if (opName == "AveragePool") if (opName == "AveragePool")
buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAveragePoolOp>(node);
if (opName == "BatchNormalization") if (opName == "BatchNormalization")
ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5); ImportNodeBatchNormalization(node);
if (opName == "BitShift") if (opName == "BitShift")
buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXBitShiftOp>(node);
if (opName == "Cast") if (opName == "Cast")
buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCastOp>(node);
if (opName == "Ceil") if (opName == "Ceil")
buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCeilOp>(node);
if (opName == "Clip") if (opName == "Clip")
buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXClipOp>(node);
if (opName == "Compress") if (opName == "Compress")
buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCompressOp>(node);
if (opName == "Concat") if (opName == "Concat")
buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConcatOp>(node);
if (opName == "ConcatFromSequence") if (opName == "ConcatFromSequence")
buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConcatFromSequenceOp>(node);
if (opName == "Constant") if (opName == "Constant")
buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConstantOp>(node);
if (opName == "ConstantOfShape") if (opName == "ConstantOfShape")
buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConstantOfShapeOp>(node);
if (opName == "Conv") if (opName == "Conv")
buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConvOp>(node);
if (opName == "ConvInteger") if (opName == "ConvInteger")
buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConvIntegerOp>(node);
if (opName == "ConvTranspose") if (opName == "ConvTranspose")
buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConvTransposeOp>(node);
if (opName == "Cos") if (opName == "Cos")
buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCosOp>(node);
if (opName == "Cosh") if (opName == "Cosh")
buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCoshOp>(node);
if (opName == "CumSum") if (opName == "CumSum")
buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCumSumOp>(node);
if (opName == "DepthToSpace") if (opName == "DepthToSpace")
buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXDepthToSpaceOp>(node);
if (opName == "DequantizeLinear") if (opName == "DequantizeLinear")
buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXDequantizeLinearOp>(node);
if (opName == "Det") if (opName == "Det")
buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXDetOp>(node);
if (opName == "Div") if (opName == "Div")
buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXDivOp>(node);
if (opName == "Dropout") if (opName == "Dropout")
buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); buildOperation<mlir::ONNXDropoutOp>(node);
if (opName == "DynamicQuantizeLinear") if (opName == "DynamicQuantizeLinear")
buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3); buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node);
if (opName == "Elu") if (opName == "Elu")
buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXEluOp>(node);
if (opName == "Equal") if (opName == "Equal")
buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXEqualOp>(node);
if (opName == "Erf") if (opName == "Erf")
buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXErfOp>(node);
if (opName == "Exp") if (opName == "Exp")
buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXExpOp>(node);
if (opName == "Expand") if (opName == "Expand")
buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXExpandOp>(node);
if (opName == "EyeLike") if (opName == "EyeLike")
buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXEyeLikeOp>(node);
if (opName == "Flatten") if (opName == "Flatten")
buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXFlattenOp>(node);
if (opName == "Floor") if (opName == "Floor")
buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXFloorOp>(node);
if (opName == "GRU") if (opName == "GRU")
buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); buildOperation<mlir::ONNXGRUOp>(node);
if (opName == "Gather") if (opName == "Gather")
buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGatherOp>(node);
if (opName == "GatherElements") if (opName == "GatherElements")
buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGatherElementsOp>(node);
if (opName == "GatherND") if (opName == "GatherND")
buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGatherNDOp>(node);
if (opName == "Gemm") if (opName == "Gemm")
buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGemmOp>(node);
if (opName == "GlobalAveragePool") if (opName == "GlobalAveragePool")
buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGlobalAveragePoolOp>(node);
if (opName == "GlobalLpPool") if (opName == "GlobalLpPool")
buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGlobalLpPoolOp>(node);
if (opName == "GlobalMaxPool") if (opName == "GlobalMaxPool")
buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGlobalMaxPoolOp>(node);
if (opName == "Greater") if (opName == "Greater")
buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGreaterOp>(node);
if (opName == "HardSigmoid") if (opName == "HardSigmoid")
buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXHardSigmoidOp>(node);
if (opName == "Hardmax") if (opName == "Hardmax")
buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXHardmaxOp>(node);
if (opName == "Identity") if (opName == "Identity")
buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXIdentityOp>(node);
if (opName == "If") if (opName == "If")
buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); buildOperation<mlir::ONNXIfOp>(node);
if (opName == "InstanceNormalization") if (opName == "InstanceNormalization")
buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXInstanceNormalizationOp>(node);
if (opName == "IsInf") if (opName == "IsInf")
buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXIsInfOp>(node);
if (opName == "IsNaN") if (opName == "IsNaN")
buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXIsNaNOp>(node);
if (opName == "LRN") if (opName == "LRN")
buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLRNOp>(node);
if (opName == "LSTM") if (opName == "LSTM")
buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3); buildOperation<mlir::ONNXLSTMOp>(node);
if (opName == "LeakyRelu") if (opName == "LeakyRelu")
buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLeakyReluOp>(node);
if (opName == "Less") if (opName == "Less")
buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLessOp>(node);
if (opName == "Log") if (opName == "Log")
buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLogOp>(node);
if (opName == "LogSoftmax") if (opName == "LogSoftmax")
buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLogSoftmaxOp>(node);
if (opName == "Loop") if (opName == "Loop")
buildOperation<mlir::ONNXLoopOp>(node); buildOperation<mlir::ONNXLoopOp>(node);
if (opName == "LpNormalization") if (opName == "LpNormalization")
buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLpNormalizationOp>(node);
if (opName == "LpPool") if (opName == "LpPool")
buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLpPoolOp>(node);
if (opName == "MatMul") if (opName == "MatMul")
buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMatMulOp>(node);
if (opName == "MatMulInteger") if (opName == "MatMulInteger")
buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMatMulIntegerOp>(node);
if (opName == "Max") if (opName == "Max")
buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMaxOp>(node);
if (opName == "MaxPool") if (opName == "MaxPool")
ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); ImportNodeMaxPool(node);
if (opName == "MaxRoiPool") if (opName == "MaxRoiPool")
buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMaxRoiPoolOp>(node);
if (opName == "MaxUnpool") if (opName == "MaxUnpool")
buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMaxUnpoolOp>(node);
if (opName == "Mean") if (opName == "Mean")
buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMeanOp>(node);
if (opName == "MeanVarianceNormalization") if (opName == "MeanVarianceNormalization")
buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node);
if (opName == "Min") if (opName == "Min")
buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMinOp>(node);
if (opName == "Mod") if (opName == "Mod")
buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXModOp>(node);
if (opName == "Mul") if (opName == "Mul")
buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMulOp>(node);
if (opName == "Multinomial") if (opName == "Multinomial")
buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMultinomialOp>(node);
if (opName == "Neg") if (opName == "Neg")
buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXNegOp>(node);
if (opName == "NonMaxSuppression") if (opName == "NonMaxSuppression")
buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); buildOperation<mlir::ONNXNonMaxSuppressionOp>(node);
if (opName == "NonZero") if (opName == "NonZero")
buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXNonZeroOp>(node);
if (opName == "Not") if (opName == "Not")
buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXNotOp>(node);
if (opName == "OneHot") if (opName == "OneHot")
buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXOneHotOp>(node);
if (opName == "Or") if (opName == "Or")
buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXOrOp>(node);
if (opName == "PRelu") if (opName == "PRelu")
buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXPReluOp>(node);
if (opName == "Pad") if (opName == "Pad")
ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); ImportNodePad(node);
if (opName == "Pow") if (opName == "Pow")
buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXPowOp>(node);
if (opName == "QLinearConv") if (opName == "QLinearConv")
buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1); buildOperation<mlir::ONNXQLinearConvOp>(node);
if (opName == "QLinearMatMul") if (opName == "QLinearMatMul")
buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1); buildOperation<mlir::ONNXQLinearMatMulOp>(node);
if (opName == "QuantizeLinear") if (opName == "QuantizeLinear")
buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXQuantizeLinearOp>(node);
if (opName == "RNN") if (opName == "RNN")
buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); buildOperation<mlir::ONNXRNNOp>(node);
if (opName == "RandomNormal") if (opName == "RandomNormal")
buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRandomNormalOp>(node);
if (opName == "RandomNormalLike") if (opName == "RandomNormalLike")
buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRandomNormalLikeOp>(node);
if (opName == "RandomUniform") if (opName == "RandomUniform")
buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRandomUniformOp>(node);
if (opName == "RandomUniformLike") if (opName == "RandomUniformLike")
buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRandomUniformLikeOp>(node);
if (opName == "Range") if (opName == "Range")
buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRangeOp>(node);
if (opName == "Reciprocal") if (opName == "Reciprocal")
buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReciprocalOp>(node);
if (opName == "ReduceL1") if (opName == "ReduceL1")
buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceL1Op>(node);
if (opName == "ReduceL2") if (opName == "ReduceL2")
buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceL2Op>(node);
if (opName == "ReduceLogSum") if (opName == "ReduceLogSum")
buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceLogSumOp>(node);
if (opName == "ReduceLogSumExp") if (opName == "ReduceLogSumExp")
buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceLogSumExpOp>(node);
if (opName == "ReduceMax") if (opName == "ReduceMax")
buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceMaxOp>(node);
if (opName == "ReduceMean") if (opName == "ReduceMean")
buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceMeanOp>(node);
if (opName == "ReduceMin") if (opName == "ReduceMin")
buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceMinOp>(node);
if (opName == "ReduceProd") if (opName == "ReduceProd")
buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceProdOp>(node);
if (opName == "ReduceSum") if (opName == "ReduceSum")
buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceSumOp>(node);
if (opName == "ReduceSumSquare") if (opName == "ReduceSumSquare")
buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceSumSquareOp>(node);
if (opName == "Relu") if (opName == "Relu")
buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReluOp>(node);
if (opName == "Reshape") if (opName == "Reshape")
ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); ImportNodeReshape(node);
if (opName == "Resize") if (opName == "Resize")
buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); buildOperation<mlir::ONNXResizeOp>(node);
if (opName == "ReverseSequence") if (opName == "ReverseSequence")
buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReverseSequenceOp>(node);
if (opName == "RoiAlign") if (opName == "RoiAlign")
buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRoiAlignOp>(node);
if (opName == "Round") if (opName == "Round")
buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRoundOp>(node);
if (opName == "Scan") if (opName == "Scan")
buildOperation<mlir::ONNXScanOp>(node); buildOperation<mlir::ONNXScanOp>(node);
if (opName == "Scatter") if (opName == "Scatter")
buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXScatterOp>(node);
if (opName == "ScatterElements") if (opName == "ScatterElements")
buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXScatterElementsOp>(node);
if (opName == "ScatterND") if (opName == "ScatterND")
buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXScatterNDOp>(node);
if (opName == "Selu") if (opName == "Selu")
buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSeluOp>(node);
if (opName == "SequenceAt") if (opName == "SequenceAt")
buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceAtOp>(node);
if (opName == "SequenceConstruct") if (opName == "SequenceConstruct")
buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceConstructOp>(node);
if (opName == "SequenceEmpty") if (opName == "SequenceEmpty")
buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceEmptyOp>(node);
if (opName == "SequenceErase") if (opName == "SequenceErase")
buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceEraseOp>(node);
if (opName == "SequenceInsert") if (opName == "SequenceInsert")
buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceInsertOp>(node);
if (opName == "SequenceLength") if (opName == "SequenceLength")
buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceLengthOp>(node);
if (opName == "Shape") if (opName == "Shape")
buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXShapeOp>(node);
if (opName == "Shrink") if (opName == "Shrink")
buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXShrinkOp>(node);
if (opName == "Sigmoid") if (opName == "Sigmoid")
buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSigmoidOp>(node);
if (opName == "Sign") if (opName == "Sign")
buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSignOp>(node);
if (opName == "Sin") if (opName == "Sin")
buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSinOp>(node);
if (opName == "Sinh") if (opName == "Sinh")
buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSinhOp>(node);
if (opName == "Size") if (opName == "Size")
buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSizeOp>(node);
if (opName == "Slice") if (opName == "Slice")
buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSliceOp>(node);
if (opName == "Softmax") if (opName == "Softmax")
buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSoftmaxOp>(node);
if (opName == "Softplus") if (opName == "Softplus")
buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSoftplusOp>(node);
if (opName == "Softsign") if (opName == "Softsign")
buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSoftsignOp>(node);
if (opName == "SpaceToDepth") if (opName == "SpaceToDepth")
buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSpaceToDepthOp>(node);
if (opName == "Split") if (opName == "Split")
buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); buildOperation<mlir::ONNXSplitOp>(node);
if (opName == "SplitToSequence") if (opName == "SplitToSequence")
buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSplitToSequenceOp>(node);
if (opName == "Sqrt") if (opName == "Sqrt")
buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSqrtOp>(node);
if (opName == "Squeeze") if (opName == "Squeeze")
buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSqueezeOp>(node);
if (opName == "StringNormalizer") if (opName == "StringNormalizer")
buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXStringNormalizerOp>(node);
if (opName == "Sub") if (opName == "Sub")
buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSubOp>(node);
if (opName == "Sum") if (opName == "Sum")
buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSumOp>(node);
if (opName == "Tan") if (opName == "Tan")
buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTanOp>(node);
if (opName == "Tanh") if (opName == "Tanh")
buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTanhOp>(node);
if (opName == "TfIdfVectorizer") if (opName == "TfIdfVectorizer")
buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTfIdfVectorizerOp>(node);
if (opName == "ThresholdedRelu") if (opName == "ThresholdedRelu")
buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXThresholdedReluOp>(node);
if (opName == "Tile") if (opName == "Tile")
buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTileOp>(node);
if (opName == "TopK") if (opName == "TopK")
buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2); buildOperation<mlir::ONNXTopKOp>(node);
if (opName == "Transpose") if (opName == "Transpose")
buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTransposeOp>(node);
if (opName == "Unique") if (opName == "Unique")
buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4); buildOperation<mlir::ONNXUniqueOp>(node);
if (opName == "Unsqueeze") if (opName == "Unsqueeze")
buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXUnsqueezeOp>(node);
if (opName == "Upsample") if (opName == "Upsample")
buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXUpsampleOp>(node);
if (opName == "Where") if (opName == "Where")
buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXWhereOp>(node);
if (opName == "Xor") if (opName == "Xor")
buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXXorOp>(node);

View File

@ -14,6 +14,17 @@ def MLONNXArrayFeatureExtractorOp:MLONNX_Op<"ArrayFeatureExtractor",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 2;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
} }
def MLONNXBinarizerOp:MLONNX_Op<"Binarizer", def MLONNXBinarizerOp:MLONNX_Op<"Binarizer",
@ -22,9 +33,20 @@ def MLONNXBinarizerOp:MLONNX_Op<"Binarizer",
let description = [{ let description = [{
"Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value." "Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
DefaultValuedAttr<F32Attr, "0.0">:$threshold); DefaultValuedAttr<F32Attr, "0.0">:$threshold);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
} }
def MLONNXCastMapOp:MLONNX_Op<"CastMap", def MLONNXCastMapOp:MLONNX_Op<"CastMap",
@ -40,6 +62,17 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap",
DefaultValuedAttr<StrAttr, "DENSE">:$map_form, DefaultValuedAttr<StrAttr, "DENSE">:$map_form,
DefaultValuedAttr<I64Attr, "1">:$max_map); DefaultValuedAttr<I64Attr, "1">:$max_map);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper", def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper",
@ -61,6 +94,17 @@ def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper",
DefaultValuedAttr<I64Attr, "-1">:$default_int64, DefaultValuedAttr<I64Attr, "-1">:$default_int64,
DefaultValuedAttr<StrAttr, "_Unused">:$default_string); DefaultValuedAttr<StrAttr, "_Unused">:$default_string);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer", def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
@ -84,6 +128,17 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
OptionalAttr<I64ArrayAttr>:$int64_vocabulary, OptionalAttr<I64ArrayAttr>:$int64_vocabulary,
OptionalAttr<StrArrayAttr>:$string_vocabulary); OptionalAttr<StrArrayAttr>:$string_vocabulary);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer", def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer",
@ -95,9 +150,20 @@ def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer",
" Inputs are copied to the output maintaining the order of the input arguments.<br>" " Inputs are copied to the output maintaining the order of the input arguments.<br>"
" All inputs must be integers or floats, while the output will be all floating point values." " All inputs must be integers or floats, while the output will be all floating point values."
}]; }];
let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$X, let arguments = (ins Variadic<AnyTypeOf<[TensorOf<[I32,I64,F32,F64]>, MemRefOf<[I32,I64,F32,F64]>]>>:$X,
OptionalAttr<I64ArrayAttr>:$inputdimensions); OptionalAttr<I64ArrayAttr>:$inputdimensions);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return -1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXImputerOp:MLONNX_Op<"Imputer", def MLONNXImputerOp:MLONNX_Op<"Imputer",
@ -113,12 +179,23 @@ def MLONNXImputerOp:MLONNX_Op<"Imputer",
" which one depends on whether floats or integers are being processed.<br>" " which one depends on whether floats or integers are being processed.<br>"
" The imputed_value attribute length can be 1 element, or it can have one element per input feature.<br>In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature." " The imputed_value attribute length can be 1 element, or it can have one element per input feature.<br>In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$imputed_value_floats, OptionalAttr<F32ArrayAttr>:$imputed_value_floats,
OptionalAttr<I64ArrayAttr>:$imputed_value_int64s, OptionalAttr<I64ArrayAttr>:$imputed_value_int64s,
DefaultValuedAttr<F32Attr, "0.0">:$replaced_value_float, DefaultValuedAttr<F32Attr, "0.0">:$replaced_value_float,
DefaultValuedAttr<I64Attr, "0">:$replaced_value_int64); DefaultValuedAttr<I64Attr, "0">:$replaced_value_int64);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
} }
def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder", def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder",
@ -154,6 +231,17 @@ def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder",
OptionalAttr<I64ArrayAttr>:$values_int64s, OptionalAttr<I64ArrayAttr>:$values_int64s,
OptionalAttr<StrArrayAttr>:$values_strings); OptionalAttr<StrArrayAttr>:$values_strings);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier", def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
@ -162,7 +250,7 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
let description = [{ let description = [{
"Linear classifier" "Linear classifier"
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<I64ArrayAttr>:$classlabels_ints, OptionalAttr<I64ArrayAttr>:$classlabels_ints,
OptionalAttr<StrArrayAttr>:$classlabels_strings, OptionalAttr<StrArrayAttr>:$classlabels_strings,
F32ArrayAttr:$coefficients, F32ArrayAttr:$coefficients,
@ -171,6 +259,17 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
DefaultValuedAttr<StrAttr, "NONE">:$post_transform); DefaultValuedAttr<StrAttr, "NONE">:$post_transform);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 2;
}
static std::vector<int> getTypeMap() {
return {-1,-1};
}
}];
} }
def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor", def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor",
@ -184,12 +283,23 @@ def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor",
" The coefficients array is of length n, and the coefficients for each target are contiguous." " The coefficients array is of length n, and the coefficients for each target are contiguous."
" Intercepts are optional but if provided must match the number of targets." " Intercepts are optional but if provided must match the number of targets."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$coefficients, OptionalAttr<F32ArrayAttr>:$coefficients,
OptionalAttr<F32ArrayAttr>:$intercepts, OptionalAttr<F32ArrayAttr>:$intercepts,
DefaultValuedAttr<StrAttr, "NONE">:$post_transform, DefaultValuedAttr<StrAttr, "NONE">:$post_transform,
DefaultValuedAttr<I64Attr, "1">:$targets); DefaultValuedAttr<I64Attr, "1">:$targets);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXNormalizerOp:MLONNX_Op<"Normalizer", def MLONNXNormalizerOp:MLONNX_Op<"Normalizer",
@ -207,9 +317,20 @@ def MLONNXNormalizerOp:MLONNX_Op<"Normalizer",
" For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row" " For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row"
" of the batch is normalized independently." " of the batch is normalized independently."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
DefaultValuedAttr<StrAttr, "MAX">:$norm); DefaultValuedAttr<StrAttr, "MAX">:$norm);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder", def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder",
@ -230,6 +351,17 @@ def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder",
OptionalAttr<StrArrayAttr>:$cats_strings, OptionalAttr<StrArrayAttr>:$cats_strings,
DefaultValuedAttr<I64Attr, "1">:$zeros); DefaultValuedAttr<I64Attr, "1">:$zeros);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier", def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
@ -238,7 +370,7 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
let description = [{ let description = [{
"Support Vector Machine classifier" "Support Vector Machine classifier"
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<I64ArrayAttr>:$classlabels_ints, OptionalAttr<I64ArrayAttr>:$classlabels_ints,
OptionalAttr<StrArrayAttr>:$classlabels_strings, OptionalAttr<StrArrayAttr>:$classlabels_strings,
OptionalAttr<F32ArrayAttr>:$coefficients, OptionalAttr<F32ArrayAttr>:$coefficients,
@ -252,6 +384,17 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
OptionalAttr<I64ArrayAttr>:$vectors_per_class); OptionalAttr<I64ArrayAttr>:$vectors_per_class);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 2;
}
static std::vector<int> getTypeMap() {
return {-1,-1};
}
}];
} }
def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor", def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
@ -260,7 +403,7 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
let description = [{ let description = [{
"Support Vector Machine regression prediction and one-class SVM anomaly detection." "Support Vector Machine regression prediction and one-class SVM anomaly detection."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$coefficients, OptionalAttr<F32ArrayAttr>:$coefficients,
OptionalAttr<F32ArrayAttr>:$kernel_params, OptionalAttr<F32ArrayAttr>:$kernel_params,
DefaultValuedAttr<StrAttr, "LINEAR">:$kernel_type, DefaultValuedAttr<StrAttr, "LINEAR">:$kernel_type,
@ -270,6 +413,17 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
OptionalAttr<F32ArrayAttr>:$rho, OptionalAttr<F32ArrayAttr>:$rho,
OptionalAttr<F32ArrayAttr>:$support_vectors); OptionalAttr<F32ArrayAttr>:$support_vectors);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXScalerOp:MLONNX_Op<"Scaler", def MLONNXScalerOp:MLONNX_Op<"Scaler",
@ -278,10 +432,21 @@ def MLONNXScalerOp:MLONNX_Op<"Scaler",
let description = [{ let description = [{
"Rescale input data, for example to standardize features by removing the mean and scaling to unit variance." "Rescale input data, for example to standardize features by removing the mean and scaling to unit variance."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$offset, OptionalAttr<F32ArrayAttr>:$offset,
OptionalAttr<F32ArrayAttr>:$scale); OptionalAttr<F32ArrayAttr>:$scale);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier", def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
@ -298,7 +463,7 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
" One and only one of classlabels_strings or classlabels_int64s" " One and only one of classlabels_strings or classlabels_int64s"
" will be defined. The class_ids are indices into this list." " will be defined. The class_ids are indices into this list."
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$base_values, OptionalAttr<F32ArrayAttr>:$base_values,
OptionalAttr<I64ArrayAttr>:$class_ids, OptionalAttr<I64ArrayAttr>:$class_ids,
OptionalAttr<I64ArrayAttr>:$class_nodeids, OptionalAttr<I64ArrayAttr>:$class_nodeids,
@ -318,6 +483,17 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
DefaultValuedAttr<StrAttr, "NONE">:$post_transform); DefaultValuedAttr<StrAttr, "NONE">:$post_transform);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y, let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 2;
}
static std::vector<int> getTypeMap() {
return {-1,-1};
}
}];
} }
def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor", def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
@ -335,7 +511,7 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
" All trees must have their node ids start at 0 and increment by 1.<br>" " All trees must have their node ids start at 0 and increment by 1.<br>"
" Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF" " Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF"
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
DefaultValuedAttr<StrAttr, "SUM">:$aggregate_function, DefaultValuedAttr<StrAttr, "SUM">:$aggregate_function,
OptionalAttr<F32ArrayAttr>:$base_values, OptionalAttr<F32ArrayAttr>:$base_values,
OptionalAttr<I64Attr>:$n_targets, OptionalAttr<I64Attr>:$n_targets,
@ -354,6 +530,17 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
OptionalAttr<I64ArrayAttr>:$target_treeids, OptionalAttr<I64ArrayAttr>:$target_treeids,
OptionalAttr<F32ArrayAttr>:$target_weights); OptionalAttr<F32ArrayAttr>:$target_weights);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }
def MLONNXZipMapOp:MLONNX_Op<"ZipMap", def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
@ -369,5 +556,16 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
OptionalAttr<I64ArrayAttr>:$classlabels_int64s, OptionalAttr<I64ArrayAttr>:$classlabels_int64s,
OptionalAttr<StrArrayAttr>:$classlabels_strings); OptionalAttr<StrArrayAttr>:$classlabels_strings);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
} }

View File

@ -38,7 +38,7 @@ def ONNX_Dialect : Dialect {
// * The mnemonic for the operation, or the name without the dialect prefix. // * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation. // * A list of traits for the operation.
class ONNX_Op<string mnemonic, list<OpTrait> traits = []> : class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
Op<ONNX_Dialect, mnemonic, traits>; Op<ONNX_Dialect, mnemonic, traits> ;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ONNX Operations // ONNX Operations
@ -112,6 +112,17 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
DefaultValuedAttr<I64Attr, "0">:$storage_order, DefaultValuedAttr<I64Attr, "0">:$storage_order,
OptionalAttr<I64ArrayAttr>:$strides); OptionalAttr<I64ArrayAttr>:$strides);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
} }
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
@ -137,6 +148,17 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon, DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
DefaultValuedAttr<F32Attr, "0.9">:$momentum); DefaultValuedAttr<F32Attr, "0.9">:$momentum);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 5;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
} }
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue", def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
@ -154,6 +176,17 @@ def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
DefaultValuedAttr<F32Attr, "0.0">:$constant_value, DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
DefaultValuedAttr<StrAttr, "constant">:$mode); DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
} }
def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad", def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad",
@ -168,6 +201,17 @@ def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad",
I64ArrayAttr:$pads, I64ArrayAttr:$pads,
DefaultValuedAttr<StrAttr, "constant">:$mode); DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
} }
def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad", def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
@ -186,6 +230,17 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
let builders = [OpBuilder<"OpBuilder &builder, OperationState &state, " let builders = [OpBuilder<"OpBuilder &builder, OperationState &state, "
"Value data, ArrayAttr pads, " "Value data, ArrayAttr pads, "
"FloatAttr constant_value, StringAttr mode">]; "FloatAttr constant_value, StringAttr mode">];
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
} }
#endif // ONNX_OPS #endif // ONNX_OPS

File diff suppressed because it is too large Load Diff

View File

@ -144,62 +144,62 @@ func @test_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*
// ----- // -----
func @test_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { func @test_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%0) : (tensor<*xi32>) -> () "std.return"(%0) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_and // CHECK-LABEL: test_and
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: return [[RES]] : memref<10x10xi32> // CHECK: return [[RES]] : memref<10x10xi1>
} }
// ----- // -----
func @test_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { func @test_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%0) : (tensor<*xi32>) -> () "std.return"(%0) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_or // CHECK-LABEL: test_or
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: return [[RES]] : memref<10x10xi32> // CHECK: return [[RES]] : memref<10x10xi1>
} }
// ----- // -----
func @test_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { func @test_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%0) : (tensor<*xi32>) -> () "std.return"(%0) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_xor // CHECK-LABEL: test_xor
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: return [[RES]] : memref<10x10xi32> // CHECK: return [[RES]] : memref<10x10xi1>
} }
// ----- // -----

View File

@ -158,24 +158,24 @@ func @test_sub_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens
// ----- // -----
func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { func @test_and_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> %0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
%1 = "onnx.And"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> %1 = "onnx.And"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%1) : (tensor<*xi32>) -> () "std.return"(%1) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_and_and // CHECK-LABEL: test_and_and
/// First And /// First And
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
/// Second And /// Second And
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
@ -183,38 +183,38 @@ func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32 // CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1>
/// Dealloc of first result. /// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<10x10xi32> // CHECK: dealloc [[RES]] : memref<10x10xi1>
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1>
// CHECK: return [[RET_RES]] : memref<10x10xi32> // CHECK: return [[RET_RES]] : memref<10x10xi1>
} }
// ----- // -----
func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { func @test_or_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> %0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
%1 = "onnx.Or"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> %1 = "onnx.Or"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%1) : (tensor<*xi32>) -> () "std.return"(%1) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_or_or // CHECK-LABEL: test_or_or
/// First Or /// First Or
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
/// Second Or /// Second Or
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
@ -222,38 +222,38 @@ func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32 // CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1>
/// Dealloc of first result. /// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<10x10xi32> // CHECK: dealloc [[RES]] : memref<10x10xi1>
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1>
// CHECK: return [[RET_RES]] : memref<10x10xi32> // CHECK: return [[RET_RES]] : memref<10x10xi1>
} }
// ----- // -----
func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> { func @test_xor_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32> %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
%1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32> %1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%1) : (tensor<*xi32>) -> () "std.return"(%1) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_xor_xor // CHECK-LABEL: test_xor_xor
/// First Xor /// First Xor
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32> // CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
/// Second Xor /// Second Xor
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
@ -261,16 +261,16 @@ func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop) // CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32> // CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32 // CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32> // CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1>
/// Dealloc of first result. /// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<10x10xi32> // CHECK: dealloc [[RES]] : memref<10x10xi1>
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32> // CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1>
// CHECK: return [[RET_RES]] : memref<10x10xi32> // CHECK: return [[RET_RES]] : memref<10x10xi1>
} }
// ----- // -----

View File

@ -298,6 +298,16 @@ custom_definition_misc = dict([ ('Constant',
)]) )])
onnx_types = (
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
'float', 'double', 'complex64', 'complex128'
)
tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64',
'Complex<F32>', 'Complex<F64>'
)
MAX_NUM_TYPES=20
SNIPPETS = collect_snippets() SNIPPETS = collect_snippets()
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations() SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
ONNX_ML = bool(args.domain == "ONNX_ML") ONNX_ML = bool(args.domain == "ONNX_ML")
@ -376,53 +386,55 @@ def tblgen_operand_type_to_cpp_type(op_type):
def np_type_to_tblgen_attr_type(tstr): def np_type_to_tblgen_attr_type(tstr):
tfrom = np.array([
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
'float', 'double'
])
tto = np.array(
['I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64'])
index = -1 index = -1
for i in range(len(tfrom)): for i in range(len(onnx_types)):
if tfrom[i] in tstr: if onnx_types[i] in tstr:
index = i index = i
break break
if index == -1: if index == -1:
print("error", tstr) return None
return ''
else: else:
return tto[i] return tblgen_types[i]
def get_tblgen_type_index(type_str):
return tblgen_types.index(type_str)
#the possible data structures are tensor, map and seq(tensor())
#TOFIX: currently, only tensor structure is supported
def get_data_structure_element(allowed_type_str):
if allowed_type_str.startswith('tensor') :
element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1)
return ('tensor', element)
else :
return (None, None)
def get_allowed_elem_types(schema, input): def get_allowed_elem_types(schema, input):
allowed_types_str = None #allowed_types_str = None
return allowed_types_str # return allowed_types_str
# TODO: enable type constraints. # TODO: enable type constraints.
# if input.typeStr : if input.typeStr :
# tstr = input.typeStr tstr = input.typeStr
# else : else :
# return allwedTypeStr return None
# if schema.type_constraints: if schema.type_constraints:
# for type_constraint in schema.type_constraints: for type_constraint in schema.type_constraints:
# if type_constraint.type_param_str != tstr : if type_constraint.type_param_str != tstr :
# continue continue
# allowedTypes = type_constraint.allowed_type_strs allowed_type_list=[]
# allowedTypeStr='' allowedTypes = type_constraint.allowed_type_strs
# if (len(allowedTypes) > 0): for allowedType in allowedTypes:
# t = convert_type(allowedTypes[0]) structure, element = get_data_structure_element(allowedType);
# if t == '' : if structure == None or element == None:
# return '' return None
# allowedTypeStr += t t = np_type_to_tblgen_attr_type(element)
# for allowedType in allowedTypes[1:]: if t == None :
# t = convert_type(allowedType) return None
# if t == '' : if not t in allowed_type_list :
# return '' allowed_tyoe_list = allowed_type_list.append(t)
# if not t in allowedTypeStr :
# allowedTypeStr += ', '+t return allowed_type_list
#
# return allowedTypeStr return None
#
# return allowedTypeStr
def inc_indent(indent=None): def inc_indent(indent=None):
@ -436,7 +448,6 @@ def dec_indent(indent):
def join_args(args): def join_args(args):
return ", ".join(args) return ", ".join(args)
def get_operands_or_results(schema, is_input): def get_operands_or_results(schema, is_input):
value_list = schema.inputs if is_input else schema.outputs value_list = schema.inputs if is_input else schema.outputs
if not value_list: if not value_list:
@ -456,8 +467,9 @@ def get_operands_or_results(schema, is_input):
if elem_types is None: if elem_types is None:
types = ["AnyMemRef", "AnyTensor"] types = ["AnyMemRef", "AnyTensor"]
else: else:
elem_types_str = ','.join(elem_types)
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
types = list(map(lambda x: x.format(elem_types), types)) types = list(map(lambda x: x.format(elem_types_str), types))
# If operand is promotable to an attribute, then it must be # If operand is promotable to an attribute, then it must be
# nullable in case it migrates to be an attribute. # nullable in case it migrates to be an attribute.
@ -545,6 +557,64 @@ def get_attrs(schema):
name_to_type[attr.name] = get_attr_type_optional(attr.type) name_to_type[attr.name] = get_attr_type_optional(attr.type)
return name_to_type return name_to_type
def get_numberof_list(mylist):
expected_num = len(mylist)
for element in mylist :
if OpSchema.FormalParameterOption.Variadic == element.option:
expected_num = -1
return expected_num
def get_output_type_mapping(schema):
mapping=[]
for output in schema.outputs :
#if only one type is allowed, just set that
allowed_elem_types = get_allowed_elem_types(schema, output)
if allowed_elem_types != None and len(allowed_elem_types) == 1 :
mapping.append(str(get_tblgen_type_index(allowed_elem_types[0])))
continue
#map the type string
if output.typeStr :
tstr = output.typeStr
found = False
for i, input in enumerate(schema.inputs):
if input.typeStr and input.typeStr == tstr:
mapping.append(str(i+MAX_NUM_TYPES))
found = True
break
if found:
continue
#unknown output type
mapping.append(str(-1))
return mapping
def get_numberof_inout(s, indent, schema):
expected_num_operands = get_numberof_list(schema.inputs)
indent = inc_indent(indent)
s += indent + "static int getNumberOfOperands() {\n"
indent = inc_indent(indent)
s += indent + "return {};\n".format(expected_num_operands)
indent = dec_indent(indent)
s += indent + "}\n"
expected_num_results = get_numberof_list(schema.outputs)
s += indent + "static int getNumberOfResults() {\n"
indent = inc_indent(indent)
s += indent + "return {};\n".format(expected_num_results)
indent = dec_indent(indent)
s += indent + "}\n"
s += indent + "static std::vector<int> getTypeMap() {\n"
mapping = get_output_type_mapping(schema)
indent = inc_indent(indent)
s += indent + "return {" + ",".join(mapping) + "};\n"
indent = dec_indent(indent)
s += indent + "}\n"
return s
def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
cpp_name_to_idx_literal = "{" + ", ".join([ cpp_name_to_idx_literal = "{" + ", ".join([
@ -552,15 +622,15 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
for name_to_idx in const_operands_name_to_idx for name_to_idx in const_operands_name_to_idx
]) + "}" ]) + "}"
s += indent + "let extraClassDeclaration = [{\n" #s += indent + "let extraClassDeclaration = [{\n"
indent = inc_indent(indent) indent = inc_indent(indent)
s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n" s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n"
indent = inc_indent(indent) indent = inc_indent(indent)
s += indent + "return {};\n".format(cpp_name_to_idx_literal) s += indent + "return {};\n".format(cpp_name_to_idx_literal)
indent = dec_indent(indent) indent = dec_indent(indent)
s += indent + "}\n" s += indent + "}\n"
indent = dec_indent(indent) #indent = dec_indent(indent)
s += indent + "}];\n" #s += indent + "}];\n"
return s return s
@ -657,10 +727,20 @@ def gen_op_def(schema):
s += '\n' + indent + '];\n' s += '\n' + indent + '];\n'
# generate extracClassDeclaration
s += indent + "let extraClassDeclaration = [{\n"
#indent = inc_indent(indent)
# generate input/output number
s = get_numberof_inout(s, indent, schema)
# generate ProtableConst
if schema.name in OpsWithPromotableConstOperands: if schema.name in OpsWithPromotableConstOperands:
s = get_promotable_const_operands_func( s = get_promotable_const_operands_func(
s, indent, OpsWithPromotableConstOperands[schema.name]) s, indent, OpsWithPromotableConstOperands[schema.name])
s += indent + '}];\n'
if ( schema.name in custom_definition_misc) : if ( schema.name in custom_definition_misc) :
s += custom_definition_misc[schema.name] s += custom_definition_misc[schema.name]
@ -700,11 +780,13 @@ def gen_op_importer(schema, file):
# Special handlers currently require expected num operands/results to be specified. # Special handlers currently require expected num operands/results to be specified.
# TODO: remove special handlers. # TODO: remove special handlers.
args = ["node"] args = ["node"]
"""
if expected_num_operands != -1 or expected_num_results != -1 or "buildOperation" not in handler_func: if expected_num_operands != -1 or expected_num_results != -1 or "buildOperation" not in handler_func:
args.append( args.append(
"/* expected_num_operands = */ {}".format(expected_num_operands)) "/* expected_num_operands = */ {}".format(expected_num_operands))
args.append( args.append(
'/* expected_num_results = */ {}'.format(expected_num_results)) '/* expected_num_results = */ {}'.format(expected_num_results))
"""
s += inc_indent(indent) + " {}({});\n".format( s += inc_indent(indent) + " {}({});\n".format(
handler_func, ", ".join(args)) handler_func, ", ".join(args))