Express some basic features of an Operation in TableGen file (#103)

* change operation definition

* change importer

* default type inference

* file format

* generate types for input/output

* generate the mapping for operation output type

* remove debug message for gen_doc.py

* update the dialect doc

* add support Complex

* format

* update document

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
chentong319 2020-05-21 22:03:16 -04:00 committed by GitHub
parent df18efcb48
commit 6099efd91b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 3100 additions and 985 deletions

View File

@ -35,13 +35,13 @@ ONNX Binarizer operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:
| Result | Description |
| :----: | ----------- |
`Y` | memref of any type values or tensor of any type values
`Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
### `mlonnx.CastMap` (MLONNXCastMapOp)
@ -160,7 +160,7 @@ ONNX FeatureVectorizer operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values
#### Results:
@ -194,13 +194,13 @@ ONNX Imputer operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:
| Result | Description |
| :----: | ----------- |
`Y` | memref of any type values or tensor of any type values
`Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
### `mlonnx.LabelEncoder` (MLONNXLabelEncoderOp)
@ -271,7 +271,7 @@ ONNX LinearClassifier operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:
@ -304,7 +304,7 @@ ONNX LinearRegressor operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:
@ -337,7 +337,7 @@ ONNX Normalizer operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:
@ -404,7 +404,7 @@ ONNX SVMClassifier operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:
@ -436,7 +436,7 @@ ONNX SVMRegressor operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:
@ -461,7 +461,7 @@ ONNX Scaler operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:
@ -509,7 +509,7 @@ ONNX TreeEnsembleClassifier operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:
@ -559,7 +559,7 @@ ONNX TreeEnsembleRegressor operation
| Operand | Description |
| :-----: | ----------- |
`X` | memref of any type values or tensor of any type values
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
#### Results:

File diff suppressed because it is too large Load Diff

View File

@ -197,6 +197,47 @@ private:
}
}
#define MAX_TYPE 20
// itblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32',
// 'F64', 'Complex<F32>', 'Complex<F64>' )
mlir::Type buildTypeFromIndex(int index) {
switch (index) {
case 0:
return builder_.getI1Type();
case 1:
return builder_.getIntegerType(8);
case 2:
return builder_.getIntegerType(16);
case 3:
return builder_.getIntegerType(32);
case 4:
return builder_.getIntegerType(64);
case 5:
return builder_.getBF16Type();
case 6:
return builder_.getF16Type();
case 7:
return builder_.getF32Type();
case 8:
return builder_.getF64Type();
case 9: {
std::vector<mlir::Type> typeTuple(2);
typeTuple.push_back(builder_.getF32Type());
typeTuple.push_back(builder_.getF32Type());
return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple));
}
case 10: {
std::vector<mlir::Type> typeTuple(2);
typeTuple.push_back(builder_.getF64Type());
typeTuple.push_back(builder_.getF64Type());
return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple));
}
default:
assert(false && "Unsupported type index encountered.");
return nullptr;
}
}
template <typename T>
void buildOutputAndOperation(const onnx::NodeProto &node,
std::vector<mlir::Value> inputs, int expectedNumOperands,
@ -217,13 +258,34 @@ private:
inputs.emplace_back(none_);
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
// Use the type map to determine the data type of output.
std::vector<int> outputMap = T::getTypeMap();
for (auto i = 0; i < node.output().size(); i++) {
// Optional outputs using empty string.
if (item.empty())
if (node.output()[i].empty()) {
outputTypes.emplace_back(builder_.getNoneType());
else
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
} else {
if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) {
// Mapping gives a connection with an input.
mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType();
if (inputType.isa<mlir::TensorType>()) {
auto elementType =
inputType.cast<mlir::TensorType>().getElementType();
auto outType = mlir::UnrankedTensorType::get(elementType);
outputTypes.emplace_back(outType);
} else {
outputTypes.push_back(inputType);
}
} else if (i < outputMap.size() && outputMap[i] != -1) {
// Mapping gives a direct type.
auto elementType = buildTypeFromIndex(outputMap[i]);
auto outType = mlir::UnrankedTensorType::get(elementType);
outputTypes.emplace_back(outType);
} else {
outputTypes.emplace_back(builder_.getNoneType());
}
}
}
// Trailing optional outputs.
if (!variadicOut)
@ -241,9 +303,10 @@ private:
}
template <typename T>
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
int expectedNumResults = -1) {
void buildOperation(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs;
int expectedNumOperands = T::getNumberOfOperands();
int expectedNumResults = T::getNumberOfResults();
for (const auto &item : node.input())
if (initializedTensors.ContainKey(legalize_name(item))) {
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
@ -256,7 +319,9 @@ private:
node, inputs, expectedNumOperands, expectedNumResults);
}
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
void ImportNodeReshape(onnx::NodeProto node) {
int expectedNumOperands = mlir::ONNXReshapeOp::getNumberOfOperands();
int expectedNumResults = mlir::ONNXReshapeOp::getNumberOfResults();
std::vector<mlir::Value> inputs;
std::string item;
for (int i = 0; i < node.input().size(); ++i) {
@ -270,39 +335,40 @@ private:
}
}
buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut);
buildOutputAndOperation<mlir::ONNXReshapeOp>(
node, inputs, expectedNumOperands, expectedNumResults);
}
/*!
* Special handle for MaxPool operations.
*/
void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
void ImportNodeMaxPool(onnx::NodeProto node) {
int nOuts = node.output().size();
if (nOuts == 1) {
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node);
} else {
buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
buildOperation<mlir::ONNXMaxPoolOp>(node);
}
}
/*!
* Special handle for BatchNormalization operations.
*/
void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) {
void ImportNodeBatchNormalization(onnx::NodeProto node) {
int nOuts = node.output().size();
if (nOuts == 1) {
// Test mode with one output.
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node);
} else {
// Training mode with four trailing optional outputs. Not handled yet.
buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
buildOperation<mlir::ONNXBatchNormalizationOp>(node);
}
}
/*!
* Special handle for Pad operations.
*/
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
void ImportNodePad(onnx::NodeProto node) {
int nOps = node.input().size();
if (nOps == 2) {
@ -330,9 +396,11 @@ private:
}
inputs.push_back(constantResult);
int nIn = mlir::ONNXPadOp::getNumberOfOperands();
int nOut = mlir::ONNXPadOp::getNumberOfResults();
buildOutputAndOperation<mlir::ONNXPadOp>(node, inputs, nIn, nOut);
} else {
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
buildOperation<mlir::ONNXPadOp>(node);
}
}

View File

@ -5,38 +5,38 @@
//********************************************************
if (opName == "ArrayFeatureExtractor")
return buildOperation<mlir::MLONNXArrayFeatureExtractorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXArrayFeatureExtractorOp>(node);
if (opName == "Binarizer")
return buildOperation<mlir::MLONNXBinarizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXBinarizerOp>(node);
if (opName == "CastMap")
return buildOperation<mlir::MLONNXCastMapOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXCastMapOp>(node);
if (opName == "CategoryMapper")
return buildOperation<mlir::MLONNXCategoryMapperOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXCategoryMapperOp>(node);
if (opName == "DictVectorizer")
return buildOperation<mlir::MLONNXDictVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXDictVectorizerOp>(node);
if (opName == "FeatureVectorizer")
return buildOperation<mlir::MLONNXFeatureVectorizerOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXFeatureVectorizerOp>(node);
if (opName == "Imputer")
return buildOperation<mlir::MLONNXImputerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXImputerOp>(node);
if (opName == "LabelEncoder")
return buildOperation<mlir::MLONNXLabelEncoderOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXLabelEncoderOp>(node);
if (opName == "LinearClassifier")
return buildOperation<mlir::MLONNXLinearClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
buildOperation<mlir::MLONNXLinearClassifierOp>(node);
if (opName == "LinearRegressor")
return buildOperation<mlir::MLONNXLinearRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXLinearRegressorOp>(node);
if (opName == "Normalizer")
return buildOperation<mlir::MLONNXNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXNormalizerOp>(node);
if (opName == "OneHotEncoder")
return buildOperation<mlir::MLONNXOneHotEncoderOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXOneHotEncoderOp>(node);
if (opName == "SVMClassifier")
return buildOperation<mlir::MLONNXSVMClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
buildOperation<mlir::MLONNXSVMClassifierOp>(node);
if (opName == "SVMRegressor")
return buildOperation<mlir::MLONNXSVMRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXSVMRegressorOp>(node);
if (opName == "Scaler")
return buildOperation<mlir::MLONNXScalerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXScalerOp>(node);
if (opName == "TreeEnsembleClassifier")
return buildOperation<mlir::MLONNXTreeEnsembleClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
buildOperation<mlir::MLONNXTreeEnsembleClassifierOp>(node);
if (opName == "TreeEnsembleRegressor")
return buildOperation<mlir::MLONNXTreeEnsembleRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXTreeEnsembleRegressorOp>(node);
if (opName == "ZipMap")
return buildOperation<mlir::MLONNXZipMapOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::MLONNXZipMapOp>(node);

View File

@ -5,314 +5,314 @@
//********************************************************
if (opName == "Abs")
buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAbsOp>(node);
if (opName == "Acos")
buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAcosOp>(node);
if (opName == "Acosh")
buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAcoshOp>(node);
if (opName == "Add")
buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAddOp>(node);
if (opName == "And")
buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAndOp>(node);
if (opName == "ArgMax")
buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXArgMaxOp>(node);
if (opName == "ArgMin")
buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXArgMinOp>(node);
if (opName == "Asin")
buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAsinOp>(node);
if (opName == "Asinh")
buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAsinhOp>(node);
if (opName == "Atan")
buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAtanOp>(node);
if (opName == "Atanh")
buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAtanhOp>(node);
if (opName == "AveragePool")
buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXAveragePoolOp>(node);
if (opName == "BatchNormalization")
ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
ImportNodeBatchNormalization(node);
if (opName == "BitShift")
buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXBitShiftOp>(node);
if (opName == "Cast")
buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXCastOp>(node);
if (opName == "Ceil")
buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXCeilOp>(node);
if (opName == "Clip")
buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXClipOp>(node);
if (opName == "Compress")
buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXCompressOp>(node);
if (opName == "Concat")
buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXConcatOp>(node);
if (opName == "ConcatFromSequence")
buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXConcatFromSequenceOp>(node);
if (opName == "Constant")
buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXConstantOp>(node);
if (opName == "ConstantOfShape")
buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXConstantOfShapeOp>(node);
if (opName == "Conv")
buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXConvOp>(node);
if (opName == "ConvInteger")
buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXConvIntegerOp>(node);
if (opName == "ConvTranspose")
buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXConvTransposeOp>(node);
if (opName == "Cos")
buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXCosOp>(node);
if (opName == "Cosh")
buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXCoshOp>(node);
if (opName == "CumSum")
buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXCumSumOp>(node);
if (opName == "DepthToSpace")
buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXDepthToSpaceOp>(node);
if (opName == "DequantizeLinear")
buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXDequantizeLinearOp>(node);
if (opName == "Det")
buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXDetOp>(node);
if (opName == "Div")
buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXDivOp>(node);
if (opName == "Dropout")
buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
buildOperation<mlir::ONNXDropoutOp>(node);
if (opName == "DynamicQuantizeLinear")
buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node);
if (opName == "Elu")
buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXEluOp>(node);
if (opName == "Equal")
buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXEqualOp>(node);
if (opName == "Erf")
buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXErfOp>(node);
if (opName == "Exp")
buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXExpOp>(node);
if (opName == "Expand")
buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXExpandOp>(node);
if (opName == "EyeLike")
buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXEyeLikeOp>(node);
if (opName == "Flatten")
buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXFlattenOp>(node);
if (opName == "Floor")
buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXFloorOp>(node);
if (opName == "GRU")
buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
buildOperation<mlir::ONNXGRUOp>(node);
if (opName == "Gather")
buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXGatherOp>(node);
if (opName == "GatherElements")
buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXGatherElementsOp>(node);
if (opName == "GatherND")
buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXGatherNDOp>(node);
if (opName == "Gemm")
buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXGemmOp>(node);
if (opName == "GlobalAveragePool")
buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXGlobalAveragePoolOp>(node);
if (opName == "GlobalLpPool")
buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXGlobalLpPoolOp>(node);
if (opName == "GlobalMaxPool")
buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXGlobalMaxPoolOp>(node);
if (opName == "Greater")
buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXGreaterOp>(node);
if (opName == "HardSigmoid")
buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXHardSigmoidOp>(node);
if (opName == "Hardmax")
buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXHardmaxOp>(node);
if (opName == "Identity")
buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXIdentityOp>(node);
if (opName == "If")
buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
buildOperation<mlir::ONNXIfOp>(node);
if (opName == "InstanceNormalization")
buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXInstanceNormalizationOp>(node);
if (opName == "IsInf")
buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXIsInfOp>(node);
if (opName == "IsNaN")
buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXIsNaNOp>(node);
if (opName == "LRN")
buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXLRNOp>(node);
if (opName == "LSTM")
buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
buildOperation<mlir::ONNXLSTMOp>(node);
if (opName == "LeakyRelu")
buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXLeakyReluOp>(node);
if (opName == "Less")
buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXLessOp>(node);
if (opName == "Log")
buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXLogOp>(node);
if (opName == "LogSoftmax")
buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXLogSoftmaxOp>(node);
if (opName == "Loop")
buildOperation<mlir::ONNXLoopOp>(node);
if (opName == "LpNormalization")
buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXLpNormalizationOp>(node);
if (opName == "LpPool")
buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXLpPoolOp>(node);
if (opName == "MatMul")
buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMatMulOp>(node);
if (opName == "MatMulInteger")
buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMatMulIntegerOp>(node);
if (opName == "Max")
buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMaxOp>(node);
if (opName == "MaxPool")
ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
ImportNodeMaxPool(node);
if (opName == "MaxRoiPool")
buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMaxRoiPoolOp>(node);
if (opName == "MaxUnpool")
buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMaxUnpoolOp>(node);
if (opName == "Mean")
buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMeanOp>(node);
if (opName == "MeanVarianceNormalization")
buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node);
if (opName == "Min")
buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMinOp>(node);
if (opName == "Mod")
buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXModOp>(node);
if (opName == "Mul")
buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMulOp>(node);
if (opName == "Multinomial")
buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXMultinomialOp>(node);
if (opName == "Neg")
buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXNegOp>(node);
if (opName == "NonMaxSuppression")
buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXNonMaxSuppressionOp>(node);
if (opName == "NonZero")
buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXNonZeroOp>(node);
if (opName == "Not")
buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXNotOp>(node);
if (opName == "OneHot")
buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXOneHotOp>(node);
if (opName == "Or")
buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXOrOp>(node);
if (opName == "PRelu")
buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXPReluOp>(node);
if (opName == "Pad")
ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
ImportNodePad(node);
if (opName == "Pow")
buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXPowOp>(node);
if (opName == "QLinearConv")
buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXQLinearConvOp>(node);
if (opName == "QLinearMatMul")
buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXQLinearMatMulOp>(node);
if (opName == "QuantizeLinear")
buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXQuantizeLinearOp>(node);
if (opName == "RNN")
buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
buildOperation<mlir::ONNXRNNOp>(node);
if (opName == "RandomNormal")
buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXRandomNormalOp>(node);
if (opName == "RandomNormalLike")
buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXRandomNormalLikeOp>(node);
if (opName == "RandomUniform")
buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXRandomUniformOp>(node);
if (opName == "RandomUniformLike")
buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXRandomUniformLikeOp>(node);
if (opName == "Range")
buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXRangeOp>(node);
if (opName == "Reciprocal")
buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReciprocalOp>(node);
if (opName == "ReduceL1")
buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceL1Op>(node);
if (opName == "ReduceL2")
buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceL2Op>(node);
if (opName == "ReduceLogSum")
buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceLogSumOp>(node);
if (opName == "ReduceLogSumExp")
buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceLogSumExpOp>(node);
if (opName == "ReduceMax")
buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceMaxOp>(node);
if (opName == "ReduceMean")
buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceMeanOp>(node);
if (opName == "ReduceMin")
buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceMinOp>(node);
if (opName == "ReduceProd")
buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceProdOp>(node);
if (opName == "ReduceSum")
buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceSumOp>(node);
if (opName == "ReduceSumSquare")
buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReduceSumSquareOp>(node);
if (opName == "Relu")
buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReluOp>(node);
if (opName == "Reshape")
ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
ImportNodeReshape(node);
if (opName == "Resize")
buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXResizeOp>(node);
if (opName == "ReverseSequence")
buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXReverseSequenceOp>(node);
if (opName == "RoiAlign")
buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXRoiAlignOp>(node);
if (opName == "Round")
buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXRoundOp>(node);
if (opName == "Scan")
buildOperation<mlir::ONNXScanOp>(node);
if (opName == "Scatter")
buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXScatterOp>(node);
if (opName == "ScatterElements")
buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXScatterElementsOp>(node);
if (opName == "ScatterND")
buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXScatterNDOp>(node);
if (opName == "Selu")
buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSeluOp>(node);
if (opName == "SequenceAt")
buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSequenceAtOp>(node);
if (opName == "SequenceConstruct")
buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSequenceConstructOp>(node);
if (opName == "SequenceEmpty")
buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSequenceEmptyOp>(node);
if (opName == "SequenceErase")
buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSequenceEraseOp>(node);
if (opName == "SequenceInsert")
buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSequenceInsertOp>(node);
if (opName == "SequenceLength")
buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSequenceLengthOp>(node);
if (opName == "Shape")
buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXShapeOp>(node);
if (opName == "Shrink")
buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXShrinkOp>(node);
if (opName == "Sigmoid")
buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSigmoidOp>(node);
if (opName == "Sign")
buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSignOp>(node);
if (opName == "Sin")
buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSinOp>(node);
if (opName == "Sinh")
buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSinhOp>(node);
if (opName == "Size")
buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSizeOp>(node);
if (opName == "Slice")
buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSliceOp>(node);
if (opName == "Softmax")
buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSoftmaxOp>(node);
if (opName == "Softplus")
buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSoftplusOp>(node);
if (opName == "Softsign")
buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSoftsignOp>(node);
if (opName == "SpaceToDepth")
buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSpaceToDepthOp>(node);
if (opName == "Split")
buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
buildOperation<mlir::ONNXSplitOp>(node);
if (opName == "SplitToSequence")
buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSplitToSequenceOp>(node);
if (opName == "Sqrt")
buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSqrtOp>(node);
if (opName == "Squeeze")
buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSqueezeOp>(node);
if (opName == "StringNormalizer")
buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXStringNormalizerOp>(node);
if (opName == "Sub")
buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSubOp>(node);
if (opName == "Sum")
buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXSumOp>(node);
if (opName == "Tan")
buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXTanOp>(node);
if (opName == "Tanh")
buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXTanhOp>(node);
if (opName == "TfIdfVectorizer")
buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXTfIdfVectorizerOp>(node);
if (opName == "ThresholdedRelu")
buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXThresholdedReluOp>(node);
if (opName == "Tile")
buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXTileOp>(node);
if (opName == "TopK")
buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
buildOperation<mlir::ONNXTopKOp>(node);
if (opName == "Transpose")
buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXTransposeOp>(node);
if (opName == "Unique")
buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
buildOperation<mlir::ONNXUniqueOp>(node);
if (opName == "Unsqueeze")
buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXUnsqueezeOp>(node);
if (opName == "Upsample")
buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXUpsampleOp>(node);
if (opName == "Where")
buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXWhereOp>(node);
if (opName == "Xor")
buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
buildOperation<mlir::ONNXXorOp>(node);

View File

@ -14,6 +14,17 @@ def MLONNXArrayFeatureExtractorOp:MLONNX_Op<"ArrayFeatureExtractor",
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 2;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
}
def MLONNXBinarizerOp:MLONNX_Op<"Binarizer",
@ -22,9 +33,20 @@ def MLONNXBinarizerOp:MLONNX_Op<"Binarizer",
let description = [{
"Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
DefaultValuedAttr<F32Attr, "0.0">:$threshold);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
}
def MLONNXCastMapOp:MLONNX_Op<"CastMap",
@ -40,6 +62,17 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap",
DefaultValuedAttr<StrAttr, "DENSE">:$map_form,
DefaultValuedAttr<I64Attr, "1">:$max_map);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper",
@ -61,6 +94,17 @@ def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper",
DefaultValuedAttr<I64Attr, "-1">:$default_int64,
DefaultValuedAttr<StrAttr, "_Unused">:$default_string);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
@ -84,6 +128,17 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
OptionalAttr<I64ArrayAttr>:$int64_vocabulary,
OptionalAttr<StrArrayAttr>:$string_vocabulary);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer",
@ -95,9 +150,20 @@ def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer",
" Inputs are copied to the output maintaining the order of the input arguments.<br>"
" All inputs must be integers or floats, while the output will be all floating point values."
}];
let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$X,
let arguments = (ins Variadic<AnyTypeOf<[TensorOf<[I32,I64,F32,F64]>, MemRefOf<[I32,I64,F32,F64]>]>>:$X,
OptionalAttr<I64ArrayAttr>:$inputdimensions);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return -1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXImputerOp:MLONNX_Op<"Imputer",
@ -113,12 +179,23 @@ def MLONNXImputerOp:MLONNX_Op<"Imputer",
" which one depends on whether floats or integers are being processed.<br>"
" The imputed_value attribute length can be 1 element, or it can have one element per input feature.<br>In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$imputed_value_floats,
OptionalAttr<I64ArrayAttr>:$imputed_value_int64s,
DefaultValuedAttr<F32Attr, "0.0">:$replaced_value_float,
DefaultValuedAttr<I64Attr, "0">:$replaced_value_int64);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {20};
}
}];
}
def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder",
@ -154,6 +231,17 @@ def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder",
OptionalAttr<I64ArrayAttr>:$values_int64s,
OptionalAttr<StrArrayAttr>:$values_strings);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
@ -162,7 +250,7 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
let description = [{
"Linear classifier"
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<I64ArrayAttr>:$classlabels_ints,
OptionalAttr<StrArrayAttr>:$classlabels_strings,
F32ArrayAttr:$coefficients,
@ -171,6 +259,17 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
DefaultValuedAttr<StrAttr, "NONE">:$post_transform);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 2;
}
static std::vector<int> getTypeMap() {
return {-1,-1};
}
}];
}
def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor",
@ -184,12 +283,23 @@ def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor",
" The coefficients array is of length n, and the coefficients for each target are contiguous."
" Intercepts are optional but if provided must match the number of targets."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$coefficients,
OptionalAttr<F32ArrayAttr>:$intercepts,
DefaultValuedAttr<StrAttr, "NONE">:$post_transform,
DefaultValuedAttr<I64Attr, "1">:$targets);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXNormalizerOp:MLONNX_Op<"Normalizer",
@ -207,9 +317,20 @@ def MLONNXNormalizerOp:MLONNX_Op<"Normalizer",
" For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row"
" of the batch is normalized independently."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
DefaultValuedAttr<StrAttr, "MAX">:$norm);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder",
@ -230,6 +351,17 @@ def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder",
OptionalAttr<StrArrayAttr>:$cats_strings,
DefaultValuedAttr<I64Attr, "1">:$zeros);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
@ -238,7 +370,7 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
let description = [{
"Support Vector Machine classifier"
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<I64ArrayAttr>:$classlabels_ints,
OptionalAttr<StrArrayAttr>:$classlabels_strings,
OptionalAttr<F32ArrayAttr>:$coefficients,
@ -252,6 +384,17 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
OptionalAttr<I64ArrayAttr>:$vectors_per_class);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 2;
}
static std::vector<int> getTypeMap() {
return {-1,-1};
}
}];
}
def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
@ -260,7 +403,7 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
let description = [{
"Support Vector Machine regression prediction and one-class SVM anomaly detection."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$coefficients,
OptionalAttr<F32ArrayAttr>:$kernel_params,
DefaultValuedAttr<StrAttr, "LINEAR">:$kernel_type,
@ -270,6 +413,17 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
OptionalAttr<F32ArrayAttr>:$rho,
OptionalAttr<F32ArrayAttr>:$support_vectors);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXScalerOp:MLONNX_Op<"Scaler",
@ -278,10 +432,21 @@ def MLONNXScalerOp:MLONNX_Op<"Scaler",
let description = [{
"Rescale input data, for example to standardize features by removing the mean and scaling to unit variance."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$offset,
OptionalAttr<F32ArrayAttr>:$scale);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
@ -298,7 +463,7 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
" One and only one of classlabels_strings or classlabels_int64s"
" will be defined. The class_ids are indices into this list."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
OptionalAttr<F32ArrayAttr>:$base_values,
OptionalAttr<I64ArrayAttr>:$class_ids,
OptionalAttr<I64ArrayAttr>:$class_nodeids,
@ -318,6 +483,17 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
DefaultValuedAttr<StrAttr, "NONE">:$post_transform);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 2;
}
static std::vector<int> getTypeMap() {
return {-1,-1};
}
}];
}
def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
@ -335,7 +511,7 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
" All trees must have their node ids start at 0 and increment by 1.<br>"
" Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF"
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
DefaultValuedAttr<StrAttr, "SUM">:$aggregate_function,
OptionalAttr<F32ArrayAttr>:$base_values,
OptionalAttr<I64Attr>:$n_targets,
@ -354,6 +530,17 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
OptionalAttr<I64ArrayAttr>:$target_treeids,
OptionalAttr<F32ArrayAttr>:$target_weights);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}
def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
@ -369,5 +556,16 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
OptionalAttr<I64ArrayAttr>:$classlabels_int64s,
OptionalAttr<StrArrayAttr>:$classlabels_strings);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {-1};
}
}];
}

View File

@ -38,7 +38,7 @@ def ONNX_Dialect : Dialect {
// * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation.
class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
Op<ONNX_Dialect, mnemonic, traits>;
Op<ONNX_Dialect, mnemonic, traits> ;
//===----------------------------------------------------------------------===//
// ONNX Operations
@ -112,6 +112,17 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
DefaultValuedAttr<I64Attr, "0">:$storage_order,
OptionalAttr<I64ArrayAttr>:$strides);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
}
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
@ -137,6 +148,17 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
DefaultValuedAttr<F32Attr, "0.9">:$momentum);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 5;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
}
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
@ -154,6 +176,17 @@ def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
}
def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad",
@ -168,6 +201,17 @@ def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad",
I64ArrayAttr:$pads,
DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
}
def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
@ -186,6 +230,17 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
let builders = [OpBuilder<"OpBuilder &builder, OperationState &state, "
"Value data, ArrayAttr pads, "
"FloatAttr constant_value, StringAttr mode">];
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {0};
}
}];
}
#endif // ONNX_OPS

File diff suppressed because it is too large Load Diff

View File

@ -144,62 +144,62 @@ func @test_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*
// -----
func @test_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
func @test_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%0) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_and
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: return [[RES]] : memref<10x10xi32>
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: return [[RES]] : memref<10x10xi1>
}
// -----
func @test_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
func @test_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%0) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_or
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: return [[RES]] : memref<10x10xi32>
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: return [[RES]] : memref<10x10xi1>
}
// -----
func @test_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
func @test_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%0) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_xor
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: return [[RES]] : memref<10x10xi32>
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: return [[RES]] : memref<10x10xi1>
}
// -----

View File

@ -158,24 +158,24 @@ func @test_sub_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens
// -----
func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
%1 = "onnx.And"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32>
"std.return"(%1) : (tensor<*xi32>) -> ()
func @test_and_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
%1 = "onnx.And"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%1) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_and_and
/// First And
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
/// Second And
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
@ -183,38 +183,38 @@ func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<10x10xi32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32>
// CHECK: dealloc [[RES]] : memref<10x10xi1>
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1>
// CHECK: return [[RET_RES]] : memref<10x10xi32>
// CHECK: return [[RET_RES]] : memref<10x10xi1>
}
// -----
func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
%1 = "onnx.Or"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32>
"std.return"(%1) : (tensor<*xi32>) -> ()
func @test_or_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
%1 = "onnx.Or"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%1) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_or_or
/// First Or
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
/// Second Or
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
@ -222,38 +222,38 @@ func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<10x10xi32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32>
// CHECK: dealloc [[RES]] : memref<10x10xi1>
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1>
// CHECK: return [[RET_RES]] : memref<10x10xi32>
// CHECK: return [[RET_RES]] : memref<10x10xi1>
}
// -----
func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
%1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32>
"std.return"(%1) : (tensor<*xi32>) -> ()
func @test_xor_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
%1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1>
"std.return"(%1) : (tensor<*xi1>) -> ()
// CHECK-LABEL: test_xor_xor
/// First Xor
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
/// Second Xor
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
@ -261,16 +261,16 @@ func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
// CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32>
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1
// CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<10x10xi32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32>
// CHECK: dealloc [[RES]] : memref<10x10xi1>
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1>
// CHECK: return [[RET_RES]] : memref<10x10xi32>
// CHECK: return [[RET_RES]] : memref<10x10xi1>
}
// -----

View File

@ -298,6 +298,16 @@ custom_definition_misc = dict([ ('Constant',
)])
onnx_types = (
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
'float', 'double', 'complex64', 'complex128'
)
tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64',
'Complex<F32>', 'Complex<F64>'
)
MAX_NUM_TYPES=20
SNIPPETS = collect_snippets()
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
ONNX_ML = bool(args.domain == "ONNX_ML")
@ -376,53 +386,55 @@ def tblgen_operand_type_to_cpp_type(op_type):
def np_type_to_tblgen_attr_type(tstr):
tfrom = np.array([
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
'float', 'double'
])
tto = np.array(
['I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64'])
index = -1
for i in range(len(tfrom)):
if tfrom[i] in tstr:
for i in range(len(onnx_types)):
if onnx_types[i] in tstr:
index = i
break
if index == -1:
print("error", tstr)
return ''
return None
else:
return tto[i]
return tblgen_types[i]
def get_tblgen_type_index(type_str):
return tblgen_types.index(type_str)
#the possible data structures are tensor, map and seq(tensor())
#TOFIX: currently, only tensor structure is supported
def get_data_structure_element(allowed_type_str):
if allowed_type_str.startswith('tensor') :
element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1)
return ('tensor', element)
else :
return (None, None)
def get_allowed_elem_types(schema, input):
allowed_types_str = None
return allowed_types_str
#allowed_types_str = None
# return allowed_types_str
# TODO: enable type constraints.
# if input.typeStr :
# tstr = input.typeStr
# else :
# return allwedTypeStr
# if schema.type_constraints:
# for type_constraint in schema.type_constraints:
# if type_constraint.type_param_str != tstr :
# continue
# allowedTypes = type_constraint.allowed_type_strs
# allowedTypeStr=''
# if (len(allowedTypes) > 0):
# t = convert_type(allowedTypes[0])
# if t == '' :
# return ''
# allowedTypeStr += t
# for allowedType in allowedTypes[1:]:
# t = convert_type(allowedType)
# if t == '' :
# return ''
# if not t in allowedTypeStr :
# allowedTypeStr += ', '+t
#
# return allowedTypeStr
#
# return allowedTypeStr
if input.typeStr :
tstr = input.typeStr
else :
return None
if schema.type_constraints:
for type_constraint in schema.type_constraints:
if type_constraint.type_param_str != tstr :
continue
allowed_type_list=[]
allowedTypes = type_constraint.allowed_type_strs
for allowedType in allowedTypes:
structure, element = get_data_structure_element(allowedType);
if structure == None or element == None:
return None
t = np_type_to_tblgen_attr_type(element)
if t == None :
return None
if not t in allowed_type_list :
allowed_tyoe_list = allowed_type_list.append(t)
return allowed_type_list
return None
def inc_indent(indent=None):
@ -436,7 +448,6 @@ def dec_indent(indent):
def join_args(args):
return ", ".join(args)
def get_operands_or_results(schema, is_input):
value_list = schema.inputs if is_input else schema.outputs
if not value_list:
@ -456,8 +467,9 @@ def get_operands_or_results(schema, is_input):
if elem_types is None:
types = ["AnyMemRef", "AnyTensor"]
else:
elem_types_str = ','.join(elem_types)
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
types = list(map(lambda x: x.format(elem_types), types))
types = list(map(lambda x: x.format(elem_types_str), types))
# If operand is promotable to an attribute, then it must be
# nullable in case it migrates to be an attribute.
@ -545,6 +557,64 @@ def get_attrs(schema):
name_to_type[attr.name] = get_attr_type_optional(attr.type)
return name_to_type
def get_numberof_list(mylist):
expected_num = len(mylist)
for element in mylist :
if OpSchema.FormalParameterOption.Variadic == element.option:
expected_num = -1
return expected_num
def get_output_type_mapping(schema):
mapping=[]
for output in schema.outputs :
#if only one type is allowed, just set that
allowed_elem_types = get_allowed_elem_types(schema, output)
if allowed_elem_types != None and len(allowed_elem_types) == 1 :
mapping.append(str(get_tblgen_type_index(allowed_elem_types[0])))
continue
#map the type string
if output.typeStr :
tstr = output.typeStr
found = False
for i, input in enumerate(schema.inputs):
if input.typeStr and input.typeStr == tstr:
mapping.append(str(i+MAX_NUM_TYPES))
found = True
break
if found:
continue
#unknown output type
mapping.append(str(-1))
return mapping
def get_numberof_inout(s, indent, schema):
expected_num_operands = get_numberof_list(schema.inputs)
indent = inc_indent(indent)
s += indent + "static int getNumberOfOperands() {\n"
indent = inc_indent(indent)
s += indent + "return {};\n".format(expected_num_operands)
indent = dec_indent(indent)
s += indent + "}\n"
expected_num_results = get_numberof_list(schema.outputs)
s += indent + "static int getNumberOfResults() {\n"
indent = inc_indent(indent)
s += indent + "return {};\n".format(expected_num_results)
indent = dec_indent(indent)
s += indent + "}\n"
s += indent + "static std::vector<int> getTypeMap() {\n"
mapping = get_output_type_mapping(schema)
indent = inc_indent(indent)
s += indent + "return {" + ",".join(mapping) + "};\n"
indent = dec_indent(indent)
s += indent + "}\n"
return s
def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
cpp_name_to_idx_literal = "{" + ", ".join([
@ -552,15 +622,15 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
for name_to_idx in const_operands_name_to_idx
]) + "}"
s += indent + "let extraClassDeclaration = [{\n"
#s += indent + "let extraClassDeclaration = [{\n"
indent = inc_indent(indent)
s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n"
indent = inc_indent(indent)
s += indent + "return {};\n".format(cpp_name_to_idx_literal)
indent = dec_indent(indent)
s += indent + "}\n"
indent = dec_indent(indent)
s += indent + "}];\n"
#indent = dec_indent(indent)
#s += indent + "}];\n"
return s
@ -657,10 +727,20 @@ def gen_op_def(schema):
s += '\n' + indent + '];\n'
# generate extracClassDeclaration
s += indent + "let extraClassDeclaration = [{\n"
#indent = inc_indent(indent)
# generate input/output number
s = get_numberof_inout(s, indent, schema)
# generate ProtableConst
if schema.name in OpsWithPromotableConstOperands:
s = get_promotable_const_operands_func(
s, indent, OpsWithPromotableConstOperands[schema.name])
s += indent + '}];\n'
if ( schema.name in custom_definition_misc) :
s += custom_definition_misc[schema.name]
@ -700,11 +780,13 @@ def gen_op_importer(schema, file):
# Special handlers currently require expected num operands/results to be specified.
# TODO: remove special handlers.
args = ["node"]
"""
if expected_num_operands != -1 or expected_num_results != -1 or "buildOperation" not in handler_func:
args.append(
"/* expected_num_operands = */ {}".format(expected_num_operands))
args.append(
'/* expected_num_results = */ {}'.format(expected_num_results))
"""
s += inc_indent(indent) + " {}({});\n".format(
handler_func, ", ".join(args))