Express some basic features of an Operation in TableGen file (#103)
* change operation definition * change importer * default type inference * file format * generate types for input/output * generate the mapping for operation output type * remove debug message for gen_doc.py * update the dialect doc * add support Complex * format * update document Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
df18efcb48
commit
6099efd91b
|
@ -35,13 +35,13 @@ ONNX Binarizer operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
| Result | Description |
|
| Result | Description |
|
||||||
| :----: | ----------- |
|
| :----: | ----------- |
|
||||||
`Y` | memref of any type values or tensor of any type values
|
`Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
### `mlonnx.CastMap` (MLONNXCastMapOp)
|
### `mlonnx.CastMap` (MLONNXCastMapOp)
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ ONNX FeatureVectorizer operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values or memref of 32-bit signless integer or 64-bit signless integer or 32-bit float or 64-bit float values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -194,13 +194,13 @@ ONNX Imputer operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
| Result | Description |
|
| Result | Description |
|
||||||
| :----: | ----------- |
|
| :----: | ----------- |
|
||||||
`Y` | memref of any type values or tensor of any type values
|
`Y` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
### `mlonnx.LabelEncoder` (MLONNXLabelEncoderOp)
|
### `mlonnx.LabelEncoder` (MLONNXLabelEncoderOp)
|
||||||
|
|
||||||
|
@ -271,7 +271,7 @@ ONNX LinearClassifier operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -304,7 +304,7 @@ ONNX LinearRegressor operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -337,7 +337,7 @@ ONNX Normalizer operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -404,7 +404,7 @@ ONNX SVMClassifier operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -436,7 +436,7 @@ ONNX SVMRegressor operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -461,7 +461,7 @@ ONNX Scaler operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -509,7 +509,7 @@ ONNX TreeEnsembleClassifier operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
@ -559,7 +559,7 @@ ONNX TreeEnsembleRegressor operation
|
||||||
|
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`X` | memref of any type values or tensor of any type values
|
`X` | tensor of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values or memref of 32-bit float or 64-bit float or 64-bit signless integer or 32-bit signless integer values
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -197,6 +197,47 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MAX_TYPE 20
|
||||||
|
// itblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32',
|
||||||
|
// 'F64', 'Complex<F32>', 'Complex<F64>' )
|
||||||
|
mlir::Type buildTypeFromIndex(int index) {
|
||||||
|
switch (index) {
|
||||||
|
case 0:
|
||||||
|
return builder_.getI1Type();
|
||||||
|
case 1:
|
||||||
|
return builder_.getIntegerType(8);
|
||||||
|
case 2:
|
||||||
|
return builder_.getIntegerType(16);
|
||||||
|
case 3:
|
||||||
|
return builder_.getIntegerType(32);
|
||||||
|
case 4:
|
||||||
|
return builder_.getIntegerType(64);
|
||||||
|
case 5:
|
||||||
|
return builder_.getBF16Type();
|
||||||
|
case 6:
|
||||||
|
return builder_.getF16Type();
|
||||||
|
case 7:
|
||||||
|
return builder_.getF32Type();
|
||||||
|
case 8:
|
||||||
|
return builder_.getF64Type();
|
||||||
|
case 9: {
|
||||||
|
std::vector<mlir::Type> typeTuple(2);
|
||||||
|
typeTuple.push_back(builder_.getF32Type());
|
||||||
|
typeTuple.push_back(builder_.getF32Type());
|
||||||
|
return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple));
|
||||||
|
}
|
||||||
|
case 10: {
|
||||||
|
std::vector<mlir::Type> typeTuple(2);
|
||||||
|
typeTuple.push_back(builder_.getF64Type());
|
||||||
|
typeTuple.push_back(builder_.getF64Type());
|
||||||
|
return builder_.getTupleType(llvm::ArrayRef<mlir::Type>(typeTuple));
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported type index encountered.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void buildOutputAndOperation(const onnx::NodeProto &node,
|
void buildOutputAndOperation(const onnx::NodeProto &node,
|
||||||
std::vector<mlir::Value> inputs, int expectedNumOperands,
|
std::vector<mlir::Value> inputs, int expectedNumOperands,
|
||||||
|
@ -217,13 +258,34 @@ private:
|
||||||
inputs.emplace_back(none_);
|
inputs.emplace_back(none_);
|
||||||
|
|
||||||
std::vector<mlir::Type> outputTypes;
|
std::vector<mlir::Type> outputTypes;
|
||||||
for (auto item : node.output()) {
|
|
||||||
|
// Use the type map to determine the data type of output.
|
||||||
|
std::vector<int> outputMap = T::getTypeMap();
|
||||||
|
for (auto i = 0; i < node.output().size(); i++) {
|
||||||
// Optional outputs using empty string.
|
// Optional outputs using empty string.
|
||||||
if (item.empty())
|
if (node.output()[i].empty()) {
|
||||||
outputTypes.emplace_back(builder_.getNoneType());
|
outputTypes.emplace_back(builder_.getNoneType());
|
||||||
else
|
} else {
|
||||||
outputTypes.push_back(
|
if (i < outputMap.size() && outputMap[i] >= MAX_TYPE) {
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
// Mapping gives a connection with an input.
|
||||||
|
mlir::Type inputType = inputs[outputMap[i] - MAX_TYPE].getType();
|
||||||
|
if (inputType.isa<mlir::TensorType>()) {
|
||||||
|
auto elementType =
|
||||||
|
inputType.cast<mlir::TensorType>().getElementType();
|
||||||
|
auto outType = mlir::UnrankedTensorType::get(elementType);
|
||||||
|
outputTypes.emplace_back(outType);
|
||||||
|
} else {
|
||||||
|
outputTypes.push_back(inputType);
|
||||||
|
}
|
||||||
|
} else if (i < outputMap.size() && outputMap[i] != -1) {
|
||||||
|
// Mapping gives a direct type.
|
||||||
|
auto elementType = buildTypeFromIndex(outputMap[i]);
|
||||||
|
auto outType = mlir::UnrankedTensorType::get(elementType);
|
||||||
|
outputTypes.emplace_back(outType);
|
||||||
|
} else {
|
||||||
|
outputTypes.emplace_back(builder_.getNoneType());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Trailing optional outputs.
|
// Trailing optional outputs.
|
||||||
if (!variadicOut)
|
if (!variadicOut)
|
||||||
|
@ -241,9 +303,10 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1,
|
void buildOperation(const onnx::NodeProto &node) {
|
||||||
int expectedNumResults = -1) {
|
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
|
int expectedNumOperands = T::getNumberOfOperands();
|
||||||
|
int expectedNumResults = T::getNumberOfResults();
|
||||||
for (const auto &item : node.input())
|
for (const auto &item : node.input())
|
||||||
if (initializedTensors.ContainKey(legalize_name(item))) {
|
if (initializedTensors.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
||||||
|
@ -256,7 +319,9 @@ private:
|
||||||
node, inputs, expectedNumOperands, expectedNumResults);
|
node, inputs, expectedNumOperands, expectedNumResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) {
|
void ImportNodeReshape(onnx::NodeProto node) {
|
||||||
|
int expectedNumOperands = mlir::ONNXReshapeOp::getNumberOfOperands();
|
||||||
|
int expectedNumResults = mlir::ONNXReshapeOp::getNumberOfResults();
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
std::string item;
|
std::string item;
|
||||||
for (int i = 0; i < node.input().size(); ++i) {
|
for (int i = 0; i < node.input().size(); ++i) {
|
||||||
|
@ -270,39 +335,40 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut);
|
buildOutputAndOperation<mlir::ONNXReshapeOp>(
|
||||||
|
node, inputs, expectedNumOperands, expectedNumResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for MaxPool operations.
|
* Special handle for MaxPool operations.
|
||||||
*/
|
*/
|
||||||
void ImportNodeMaxPool(onnx::NodeProto node, int nIn, int nOut) {
|
void ImportNodeMaxPool(onnx::NodeProto node) {
|
||||||
int nOuts = node.output().size();
|
int nOuts = node.output().size();
|
||||||
if (nOuts == 1) {
|
if (nOuts == 1) {
|
||||||
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts);
|
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node);
|
||||||
} else {
|
} else {
|
||||||
buildOperation<mlir::ONNXMaxPoolOp>(node, nIn, nOuts);
|
buildOperation<mlir::ONNXMaxPoolOp>(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for BatchNormalization operations.
|
* Special handle for BatchNormalization operations.
|
||||||
*/
|
*/
|
||||||
void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) {
|
void ImportNodeBatchNormalization(onnx::NodeProto node) {
|
||||||
int nOuts = node.output().size();
|
int nOuts = node.output().size();
|
||||||
if (nOuts == 1) {
|
if (nOuts == 1) {
|
||||||
// Test mode with one output.
|
// Test mode with one output.
|
||||||
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn, nOuts);
|
buildOperation<mlir::ONNXBatchNormalizationTestModeOp>(node);
|
||||||
} else {
|
} else {
|
||||||
// Training mode with four trailing optional outputs. Not handled yet.
|
// Training mode with four trailing optional outputs. Not handled yet.
|
||||||
buildOperation<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
|
buildOperation<mlir::ONNXBatchNormalizationOp>(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for Pad operations.
|
* Special handle for Pad operations.
|
||||||
*/
|
*/
|
||||||
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
|
void ImportNodePad(onnx::NodeProto node) {
|
||||||
|
|
||||||
int nOps = node.input().size();
|
int nOps = node.input().size();
|
||||||
if (nOps == 2) {
|
if (nOps == 2) {
|
||||||
|
@ -330,9 +396,11 @@ private:
|
||||||
}
|
}
|
||||||
inputs.push_back(constantResult);
|
inputs.push_back(constantResult);
|
||||||
|
|
||||||
|
int nIn = mlir::ONNXPadOp::getNumberOfOperands();
|
||||||
|
int nOut = mlir::ONNXPadOp::getNumberOfResults();
|
||||||
buildOutputAndOperation<mlir::ONNXPadOp>(node, inputs, nIn, nOut);
|
buildOutputAndOperation<mlir::ONNXPadOp>(node, inputs, nIn, nOut);
|
||||||
} else {
|
} else {
|
||||||
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
|
buildOperation<mlir::ONNXPadOp>(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,38 +5,38 @@
|
||||||
//********************************************************
|
//********************************************************
|
||||||
|
|
||||||
if (opName == "ArrayFeatureExtractor")
|
if (opName == "ArrayFeatureExtractor")
|
||||||
return buildOperation<mlir::MLONNXArrayFeatureExtractorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXArrayFeatureExtractorOp>(node);
|
||||||
if (opName == "Binarizer")
|
if (opName == "Binarizer")
|
||||||
return buildOperation<mlir::MLONNXBinarizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXBinarizerOp>(node);
|
||||||
if (opName == "CastMap")
|
if (opName == "CastMap")
|
||||||
return buildOperation<mlir::MLONNXCastMapOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXCastMapOp>(node);
|
||||||
if (opName == "CategoryMapper")
|
if (opName == "CategoryMapper")
|
||||||
return buildOperation<mlir::MLONNXCategoryMapperOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXCategoryMapperOp>(node);
|
||||||
if (opName == "DictVectorizer")
|
if (opName == "DictVectorizer")
|
||||||
return buildOperation<mlir::MLONNXDictVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXDictVectorizerOp>(node);
|
||||||
if (opName == "FeatureVectorizer")
|
if (opName == "FeatureVectorizer")
|
||||||
return buildOperation<mlir::MLONNXFeatureVectorizerOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXFeatureVectorizerOp>(node);
|
||||||
if (opName == "Imputer")
|
if (opName == "Imputer")
|
||||||
return buildOperation<mlir::MLONNXImputerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXImputerOp>(node);
|
||||||
if (opName == "LabelEncoder")
|
if (opName == "LabelEncoder")
|
||||||
return buildOperation<mlir::MLONNXLabelEncoderOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXLabelEncoderOp>(node);
|
||||||
if (opName == "LinearClassifier")
|
if (opName == "LinearClassifier")
|
||||||
return buildOperation<mlir::MLONNXLinearClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
buildOperation<mlir::MLONNXLinearClassifierOp>(node);
|
||||||
if (opName == "LinearRegressor")
|
if (opName == "LinearRegressor")
|
||||||
return buildOperation<mlir::MLONNXLinearRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXLinearRegressorOp>(node);
|
||||||
if (opName == "Normalizer")
|
if (opName == "Normalizer")
|
||||||
return buildOperation<mlir::MLONNXNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXNormalizerOp>(node);
|
||||||
if (opName == "OneHotEncoder")
|
if (opName == "OneHotEncoder")
|
||||||
return buildOperation<mlir::MLONNXOneHotEncoderOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXOneHotEncoderOp>(node);
|
||||||
if (opName == "SVMClassifier")
|
if (opName == "SVMClassifier")
|
||||||
return buildOperation<mlir::MLONNXSVMClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
buildOperation<mlir::MLONNXSVMClassifierOp>(node);
|
||||||
if (opName == "SVMRegressor")
|
if (opName == "SVMRegressor")
|
||||||
return buildOperation<mlir::MLONNXSVMRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXSVMRegressorOp>(node);
|
||||||
if (opName == "Scaler")
|
if (opName == "Scaler")
|
||||||
return buildOperation<mlir::MLONNXScalerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXScalerOp>(node);
|
||||||
if (opName == "TreeEnsembleClassifier")
|
if (opName == "TreeEnsembleClassifier")
|
||||||
return buildOperation<mlir::MLONNXTreeEnsembleClassifierOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
buildOperation<mlir::MLONNXTreeEnsembleClassifierOp>(node);
|
||||||
if (opName == "TreeEnsembleRegressor")
|
if (opName == "TreeEnsembleRegressor")
|
||||||
return buildOperation<mlir::MLONNXTreeEnsembleRegressorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXTreeEnsembleRegressorOp>(node);
|
||||||
if (opName == "ZipMap")
|
if (opName == "ZipMap")
|
||||||
return buildOperation<mlir::MLONNXZipMapOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::MLONNXZipMapOp>(node);
|
||||||
|
|
|
@ -5,314 +5,314 @@
|
||||||
//********************************************************
|
//********************************************************
|
||||||
|
|
||||||
if (opName == "Abs")
|
if (opName == "Abs")
|
||||||
buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAbsOp>(node);
|
||||||
if (opName == "Acos")
|
if (opName == "Acos")
|
||||||
buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAcosOp>(node);
|
||||||
if (opName == "Acosh")
|
if (opName == "Acosh")
|
||||||
buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAcoshOp>(node);
|
||||||
if (opName == "Add")
|
if (opName == "Add")
|
||||||
buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAddOp>(node);
|
||||||
if (opName == "And")
|
if (opName == "And")
|
||||||
buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAndOp>(node);
|
||||||
if (opName == "ArgMax")
|
if (opName == "ArgMax")
|
||||||
buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXArgMaxOp>(node);
|
||||||
if (opName == "ArgMin")
|
if (opName == "ArgMin")
|
||||||
buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXArgMinOp>(node);
|
||||||
if (opName == "Asin")
|
if (opName == "Asin")
|
||||||
buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAsinOp>(node);
|
||||||
if (opName == "Asinh")
|
if (opName == "Asinh")
|
||||||
buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAsinhOp>(node);
|
||||||
if (opName == "Atan")
|
if (opName == "Atan")
|
||||||
buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAtanOp>(node);
|
||||||
if (opName == "Atanh")
|
if (opName == "Atanh")
|
||||||
buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAtanhOp>(node);
|
||||||
if (opName == "AveragePool")
|
if (opName == "AveragePool")
|
||||||
buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAveragePoolOp>(node);
|
||||||
if (opName == "BatchNormalization")
|
if (opName == "BatchNormalization")
|
||||||
ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
|
ImportNodeBatchNormalization(node);
|
||||||
if (opName == "BitShift")
|
if (opName == "BitShift")
|
||||||
buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXBitShiftOp>(node);
|
||||||
if (opName == "Cast")
|
if (opName == "Cast")
|
||||||
buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCastOp>(node);
|
||||||
if (opName == "Ceil")
|
if (opName == "Ceil")
|
||||||
buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCeilOp>(node);
|
||||||
if (opName == "Clip")
|
if (opName == "Clip")
|
||||||
buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXClipOp>(node);
|
||||||
if (opName == "Compress")
|
if (opName == "Compress")
|
||||||
buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCompressOp>(node);
|
||||||
if (opName == "Concat")
|
if (opName == "Concat")
|
||||||
buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConcatOp>(node);
|
||||||
if (opName == "ConcatFromSequence")
|
if (opName == "ConcatFromSequence")
|
||||||
buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConcatFromSequenceOp>(node);
|
||||||
if (opName == "Constant")
|
if (opName == "Constant")
|
||||||
buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConstantOp>(node);
|
||||||
if (opName == "ConstantOfShape")
|
if (opName == "ConstantOfShape")
|
||||||
buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConstantOfShapeOp>(node);
|
||||||
if (opName == "Conv")
|
if (opName == "Conv")
|
||||||
buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConvOp>(node);
|
||||||
if (opName == "ConvInteger")
|
if (opName == "ConvInteger")
|
||||||
buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConvIntegerOp>(node);
|
||||||
if (opName == "ConvTranspose")
|
if (opName == "ConvTranspose")
|
||||||
buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConvTransposeOp>(node);
|
||||||
if (opName == "Cos")
|
if (opName == "Cos")
|
||||||
buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCosOp>(node);
|
||||||
if (opName == "Cosh")
|
if (opName == "Cosh")
|
||||||
buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCoshOp>(node);
|
||||||
if (opName == "CumSum")
|
if (opName == "CumSum")
|
||||||
buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCumSumOp>(node);
|
||||||
if (opName == "DepthToSpace")
|
if (opName == "DepthToSpace")
|
||||||
buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXDepthToSpaceOp>(node);
|
||||||
if (opName == "DequantizeLinear")
|
if (opName == "DequantizeLinear")
|
||||||
buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXDequantizeLinearOp>(node);
|
||||||
if (opName == "Det")
|
if (opName == "Det")
|
||||||
buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXDetOp>(node);
|
||||||
if (opName == "Div")
|
if (opName == "Div")
|
||||||
buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXDivOp>(node);
|
||||||
if (opName == "Dropout")
|
if (opName == "Dropout")
|
||||||
buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
buildOperation<mlir::ONNXDropoutOp>(node);
|
||||||
if (opName == "DynamicQuantizeLinear")
|
if (opName == "DynamicQuantizeLinear")
|
||||||
buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
|
buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node);
|
||||||
if (opName == "Elu")
|
if (opName == "Elu")
|
||||||
buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXEluOp>(node);
|
||||||
if (opName == "Equal")
|
if (opName == "Equal")
|
||||||
buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXEqualOp>(node);
|
||||||
if (opName == "Erf")
|
if (opName == "Erf")
|
||||||
buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXErfOp>(node);
|
||||||
if (opName == "Exp")
|
if (opName == "Exp")
|
||||||
buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXExpOp>(node);
|
||||||
if (opName == "Expand")
|
if (opName == "Expand")
|
||||||
buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXExpandOp>(node);
|
||||||
if (opName == "EyeLike")
|
if (opName == "EyeLike")
|
||||||
buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXEyeLikeOp>(node);
|
||||||
if (opName == "Flatten")
|
if (opName == "Flatten")
|
||||||
buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXFlattenOp>(node);
|
||||||
if (opName == "Floor")
|
if (opName == "Floor")
|
||||||
buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXFloorOp>(node);
|
||||||
if (opName == "GRU")
|
if (opName == "GRU")
|
||||||
buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
buildOperation<mlir::ONNXGRUOp>(node);
|
||||||
if (opName == "Gather")
|
if (opName == "Gather")
|
||||||
buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGatherOp>(node);
|
||||||
if (opName == "GatherElements")
|
if (opName == "GatherElements")
|
||||||
buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGatherElementsOp>(node);
|
||||||
if (opName == "GatherND")
|
if (opName == "GatherND")
|
||||||
buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGatherNDOp>(node);
|
||||||
if (opName == "Gemm")
|
if (opName == "Gemm")
|
||||||
buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGemmOp>(node);
|
||||||
if (opName == "GlobalAveragePool")
|
if (opName == "GlobalAveragePool")
|
||||||
buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGlobalAveragePoolOp>(node);
|
||||||
if (opName == "GlobalLpPool")
|
if (opName == "GlobalLpPool")
|
||||||
buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGlobalLpPoolOp>(node);
|
||||||
if (opName == "GlobalMaxPool")
|
if (opName == "GlobalMaxPool")
|
||||||
buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGlobalMaxPoolOp>(node);
|
||||||
if (opName == "Greater")
|
if (opName == "Greater")
|
||||||
buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGreaterOp>(node);
|
||||||
if (opName == "HardSigmoid")
|
if (opName == "HardSigmoid")
|
||||||
buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXHardSigmoidOp>(node);
|
||||||
if (opName == "Hardmax")
|
if (opName == "Hardmax")
|
||||||
buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXHardmaxOp>(node);
|
||||||
if (opName == "Identity")
|
if (opName == "Identity")
|
||||||
buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXIdentityOp>(node);
|
||||||
if (opName == "If")
|
if (opName == "If")
|
||||||
buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
buildOperation<mlir::ONNXIfOp>(node);
|
||||||
if (opName == "InstanceNormalization")
|
if (opName == "InstanceNormalization")
|
||||||
buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXInstanceNormalizationOp>(node);
|
||||||
if (opName == "IsInf")
|
if (opName == "IsInf")
|
||||||
buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXIsInfOp>(node);
|
||||||
if (opName == "IsNaN")
|
if (opName == "IsNaN")
|
||||||
buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXIsNaNOp>(node);
|
||||||
if (opName == "LRN")
|
if (opName == "LRN")
|
||||||
buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLRNOp>(node);
|
||||||
if (opName == "LSTM")
|
if (opName == "LSTM")
|
||||||
buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
|
buildOperation<mlir::ONNXLSTMOp>(node);
|
||||||
if (opName == "LeakyRelu")
|
if (opName == "LeakyRelu")
|
||||||
buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLeakyReluOp>(node);
|
||||||
if (opName == "Less")
|
if (opName == "Less")
|
||||||
buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLessOp>(node);
|
||||||
if (opName == "Log")
|
if (opName == "Log")
|
||||||
buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLogOp>(node);
|
||||||
if (opName == "LogSoftmax")
|
if (opName == "LogSoftmax")
|
||||||
buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLogSoftmaxOp>(node);
|
||||||
if (opName == "Loop")
|
if (opName == "Loop")
|
||||||
buildOperation<mlir::ONNXLoopOp>(node);
|
buildOperation<mlir::ONNXLoopOp>(node);
|
||||||
if (opName == "LpNormalization")
|
if (opName == "LpNormalization")
|
||||||
buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLpNormalizationOp>(node);
|
||||||
if (opName == "LpPool")
|
if (opName == "LpPool")
|
||||||
buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLpPoolOp>(node);
|
||||||
if (opName == "MatMul")
|
if (opName == "MatMul")
|
||||||
buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMatMulOp>(node);
|
||||||
if (opName == "MatMulInteger")
|
if (opName == "MatMulInteger")
|
||||||
buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMatMulIntegerOp>(node);
|
||||||
if (opName == "Max")
|
if (opName == "Max")
|
||||||
buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMaxOp>(node);
|
||||||
if (opName == "MaxPool")
|
if (opName == "MaxPool")
|
||||||
ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
ImportNodeMaxPool(node);
|
||||||
if (opName == "MaxRoiPool")
|
if (opName == "MaxRoiPool")
|
||||||
buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMaxRoiPoolOp>(node);
|
||||||
if (opName == "MaxUnpool")
|
if (opName == "MaxUnpool")
|
||||||
buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMaxUnpoolOp>(node);
|
||||||
if (opName == "Mean")
|
if (opName == "Mean")
|
||||||
buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMeanOp>(node);
|
||||||
if (opName == "MeanVarianceNormalization")
|
if (opName == "MeanVarianceNormalization")
|
||||||
buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node);
|
||||||
if (opName == "Min")
|
if (opName == "Min")
|
||||||
buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMinOp>(node);
|
||||||
if (opName == "Mod")
|
if (opName == "Mod")
|
||||||
buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXModOp>(node);
|
||||||
if (opName == "Mul")
|
if (opName == "Mul")
|
||||||
buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMulOp>(node);
|
||||||
if (opName == "Multinomial")
|
if (opName == "Multinomial")
|
||||||
buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMultinomialOp>(node);
|
||||||
if (opName == "Neg")
|
if (opName == "Neg")
|
||||||
buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXNegOp>(node);
|
||||||
if (opName == "NonMaxSuppression")
|
if (opName == "NonMaxSuppression")
|
||||||
buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXNonMaxSuppressionOp>(node);
|
||||||
if (opName == "NonZero")
|
if (opName == "NonZero")
|
||||||
buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXNonZeroOp>(node);
|
||||||
if (opName == "Not")
|
if (opName == "Not")
|
||||||
buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXNotOp>(node);
|
||||||
if (opName == "OneHot")
|
if (opName == "OneHot")
|
||||||
buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXOneHotOp>(node);
|
||||||
if (opName == "Or")
|
if (opName == "Or")
|
||||||
buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXOrOp>(node);
|
||||||
if (opName == "PRelu")
|
if (opName == "PRelu")
|
||||||
buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXPReluOp>(node);
|
||||||
if (opName == "Pad")
|
if (opName == "Pad")
|
||||||
ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
ImportNodePad(node);
|
||||||
if (opName == "Pow")
|
if (opName == "Pow")
|
||||||
buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXPowOp>(node);
|
||||||
if (opName == "QLinearConv")
|
if (opName == "QLinearConv")
|
||||||
buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXQLinearConvOp>(node);
|
||||||
if (opName == "QLinearMatMul")
|
if (opName == "QLinearMatMul")
|
||||||
buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXQLinearMatMulOp>(node);
|
||||||
if (opName == "QuantizeLinear")
|
if (opName == "QuantizeLinear")
|
||||||
buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXQuantizeLinearOp>(node);
|
||||||
if (opName == "RNN")
|
if (opName == "RNN")
|
||||||
buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
buildOperation<mlir::ONNXRNNOp>(node);
|
||||||
if (opName == "RandomNormal")
|
if (opName == "RandomNormal")
|
||||||
buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRandomNormalOp>(node);
|
||||||
if (opName == "RandomNormalLike")
|
if (opName == "RandomNormalLike")
|
||||||
buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRandomNormalLikeOp>(node);
|
||||||
if (opName == "RandomUniform")
|
if (opName == "RandomUniform")
|
||||||
buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRandomUniformOp>(node);
|
||||||
if (opName == "RandomUniformLike")
|
if (opName == "RandomUniformLike")
|
||||||
buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRandomUniformLikeOp>(node);
|
||||||
if (opName == "Range")
|
if (opName == "Range")
|
||||||
buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRangeOp>(node);
|
||||||
if (opName == "Reciprocal")
|
if (opName == "Reciprocal")
|
||||||
buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReciprocalOp>(node);
|
||||||
if (opName == "ReduceL1")
|
if (opName == "ReduceL1")
|
||||||
buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceL1Op>(node);
|
||||||
if (opName == "ReduceL2")
|
if (opName == "ReduceL2")
|
||||||
buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceL2Op>(node);
|
||||||
if (opName == "ReduceLogSum")
|
if (opName == "ReduceLogSum")
|
||||||
buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceLogSumOp>(node);
|
||||||
if (opName == "ReduceLogSumExp")
|
if (opName == "ReduceLogSumExp")
|
||||||
buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceLogSumExpOp>(node);
|
||||||
if (opName == "ReduceMax")
|
if (opName == "ReduceMax")
|
||||||
buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceMaxOp>(node);
|
||||||
if (opName == "ReduceMean")
|
if (opName == "ReduceMean")
|
||||||
buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceMeanOp>(node);
|
||||||
if (opName == "ReduceMin")
|
if (opName == "ReduceMin")
|
||||||
buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceMinOp>(node);
|
||||||
if (opName == "ReduceProd")
|
if (opName == "ReduceProd")
|
||||||
buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceProdOp>(node);
|
||||||
if (opName == "ReduceSum")
|
if (opName == "ReduceSum")
|
||||||
buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceSumOp>(node);
|
||||||
if (opName == "ReduceSumSquare")
|
if (opName == "ReduceSumSquare")
|
||||||
buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceSumSquareOp>(node);
|
||||||
if (opName == "Relu")
|
if (opName == "Relu")
|
||||||
buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReluOp>(node);
|
||||||
if (opName == "Reshape")
|
if (opName == "Reshape")
|
||||||
ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
ImportNodeReshape(node);
|
||||||
if (opName == "Resize")
|
if (opName == "Resize")
|
||||||
buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXResizeOp>(node);
|
||||||
if (opName == "ReverseSequence")
|
if (opName == "ReverseSequence")
|
||||||
buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReverseSequenceOp>(node);
|
||||||
if (opName == "RoiAlign")
|
if (opName == "RoiAlign")
|
||||||
buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRoiAlignOp>(node);
|
||||||
if (opName == "Round")
|
if (opName == "Round")
|
||||||
buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRoundOp>(node);
|
||||||
if (opName == "Scan")
|
if (opName == "Scan")
|
||||||
buildOperation<mlir::ONNXScanOp>(node);
|
buildOperation<mlir::ONNXScanOp>(node);
|
||||||
if (opName == "Scatter")
|
if (opName == "Scatter")
|
||||||
buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXScatterOp>(node);
|
||||||
if (opName == "ScatterElements")
|
if (opName == "ScatterElements")
|
||||||
buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXScatterElementsOp>(node);
|
||||||
if (opName == "ScatterND")
|
if (opName == "ScatterND")
|
||||||
buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXScatterNDOp>(node);
|
||||||
if (opName == "Selu")
|
if (opName == "Selu")
|
||||||
buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSeluOp>(node);
|
||||||
if (opName == "SequenceAt")
|
if (opName == "SequenceAt")
|
||||||
buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceAtOp>(node);
|
||||||
if (opName == "SequenceConstruct")
|
if (opName == "SequenceConstruct")
|
||||||
buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceConstructOp>(node);
|
||||||
if (opName == "SequenceEmpty")
|
if (opName == "SequenceEmpty")
|
||||||
buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceEmptyOp>(node);
|
||||||
if (opName == "SequenceErase")
|
if (opName == "SequenceErase")
|
||||||
buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceEraseOp>(node);
|
||||||
if (opName == "SequenceInsert")
|
if (opName == "SequenceInsert")
|
||||||
buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceInsertOp>(node);
|
||||||
if (opName == "SequenceLength")
|
if (opName == "SequenceLength")
|
||||||
buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceLengthOp>(node);
|
||||||
if (opName == "Shape")
|
if (opName == "Shape")
|
||||||
buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXShapeOp>(node);
|
||||||
if (opName == "Shrink")
|
if (opName == "Shrink")
|
||||||
buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXShrinkOp>(node);
|
||||||
if (opName == "Sigmoid")
|
if (opName == "Sigmoid")
|
||||||
buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSigmoidOp>(node);
|
||||||
if (opName == "Sign")
|
if (opName == "Sign")
|
||||||
buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSignOp>(node);
|
||||||
if (opName == "Sin")
|
if (opName == "Sin")
|
||||||
buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSinOp>(node);
|
||||||
if (opName == "Sinh")
|
if (opName == "Sinh")
|
||||||
buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSinhOp>(node);
|
||||||
if (opName == "Size")
|
if (opName == "Size")
|
||||||
buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSizeOp>(node);
|
||||||
if (opName == "Slice")
|
if (opName == "Slice")
|
||||||
buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSliceOp>(node);
|
||||||
if (opName == "Softmax")
|
if (opName == "Softmax")
|
||||||
buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSoftmaxOp>(node);
|
||||||
if (opName == "Softplus")
|
if (opName == "Softplus")
|
||||||
buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSoftplusOp>(node);
|
||||||
if (opName == "Softsign")
|
if (opName == "Softsign")
|
||||||
buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSoftsignOp>(node);
|
||||||
if (opName == "SpaceToDepth")
|
if (opName == "SpaceToDepth")
|
||||||
buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSpaceToDepthOp>(node);
|
||||||
if (opName == "Split")
|
if (opName == "Split")
|
||||||
buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
buildOperation<mlir::ONNXSplitOp>(node);
|
||||||
if (opName == "SplitToSequence")
|
if (opName == "SplitToSequence")
|
||||||
buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSplitToSequenceOp>(node);
|
||||||
if (opName == "Sqrt")
|
if (opName == "Sqrt")
|
||||||
buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSqrtOp>(node);
|
||||||
if (opName == "Squeeze")
|
if (opName == "Squeeze")
|
||||||
buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSqueezeOp>(node);
|
||||||
if (opName == "StringNormalizer")
|
if (opName == "StringNormalizer")
|
||||||
buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXStringNormalizerOp>(node);
|
||||||
if (opName == "Sub")
|
if (opName == "Sub")
|
||||||
buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSubOp>(node);
|
||||||
if (opName == "Sum")
|
if (opName == "Sum")
|
||||||
buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSumOp>(node);
|
||||||
if (opName == "Tan")
|
if (opName == "Tan")
|
||||||
buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTanOp>(node);
|
||||||
if (opName == "Tanh")
|
if (opName == "Tanh")
|
||||||
buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTanhOp>(node);
|
||||||
if (opName == "TfIdfVectorizer")
|
if (opName == "TfIdfVectorizer")
|
||||||
buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTfIdfVectorizerOp>(node);
|
||||||
if (opName == "ThresholdedRelu")
|
if (opName == "ThresholdedRelu")
|
||||||
buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXThresholdedReluOp>(node);
|
||||||
if (opName == "Tile")
|
if (opName == "Tile")
|
||||||
buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTileOp>(node);
|
||||||
if (opName == "TopK")
|
if (opName == "TopK")
|
||||||
buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
|
buildOperation<mlir::ONNXTopKOp>(node);
|
||||||
if (opName == "Transpose")
|
if (opName == "Transpose")
|
||||||
buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTransposeOp>(node);
|
||||||
if (opName == "Unique")
|
if (opName == "Unique")
|
||||||
buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
|
buildOperation<mlir::ONNXUniqueOp>(node);
|
||||||
if (opName == "Unsqueeze")
|
if (opName == "Unsqueeze")
|
||||||
buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXUnsqueezeOp>(node);
|
||||||
if (opName == "Upsample")
|
if (opName == "Upsample")
|
||||||
buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXUpsampleOp>(node);
|
||||||
if (opName == "Where")
|
if (opName == "Where")
|
||||||
buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXWhereOp>(node);
|
||||||
if (opName == "Xor")
|
if (opName == "Xor")
|
||||||
buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXXorOp>(node);
|
||||||
|
|
|
@ -14,6 +14,17 @@ def MLONNXArrayFeatureExtractorOp:MLONNX_Op<"ArrayFeatureExtractor",
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {20};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXBinarizerOp:MLONNX_Op<"Binarizer",
|
def MLONNXBinarizerOp:MLONNX_Op<"Binarizer",
|
||||||
|
@ -22,9 +33,20 @@ def MLONNXBinarizerOp:MLONNX_Op<"Binarizer",
|
||||||
let description = [{
|
let description = [{
|
||||||
"Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value."
|
"Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
DefaultValuedAttr<F32Attr, "0.0">:$threshold);
|
DefaultValuedAttr<F32Attr, "0.0">:$threshold);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {20};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXCastMapOp:MLONNX_Op<"CastMap",
|
def MLONNXCastMapOp:MLONNX_Op<"CastMap",
|
||||||
|
@ -40,6 +62,17 @@ def MLONNXCastMapOp:MLONNX_Op<"CastMap",
|
||||||
DefaultValuedAttr<StrAttr, "DENSE">:$map_form,
|
DefaultValuedAttr<StrAttr, "DENSE">:$map_form,
|
||||||
DefaultValuedAttr<I64Attr, "1">:$max_map);
|
DefaultValuedAttr<I64Attr, "1">:$max_map);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper",
|
def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper",
|
||||||
|
@ -61,6 +94,17 @@ def MLONNXCategoryMapperOp:MLONNX_Op<"CategoryMapper",
|
||||||
DefaultValuedAttr<I64Attr, "-1">:$default_int64,
|
DefaultValuedAttr<I64Attr, "-1">:$default_int64,
|
||||||
DefaultValuedAttr<StrAttr, "_Unused">:$default_string);
|
DefaultValuedAttr<StrAttr, "_Unused">:$default_string);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
|
def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
|
||||||
|
@ -84,6 +128,17 @@ def MLONNXDictVectorizerOp:MLONNX_Op<"DictVectorizer",
|
||||||
OptionalAttr<I64ArrayAttr>:$int64_vocabulary,
|
OptionalAttr<I64ArrayAttr>:$int64_vocabulary,
|
||||||
OptionalAttr<StrArrayAttr>:$string_vocabulary);
|
OptionalAttr<StrArrayAttr>:$string_vocabulary);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer",
|
def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer",
|
||||||
|
@ -95,9 +150,20 @@ def MLONNXFeatureVectorizerOp:MLONNX_Op<"FeatureVectorizer",
|
||||||
" Inputs are copied to the output maintaining the order of the input arguments.<br>"
|
" Inputs are copied to the output maintaining the order of the input arguments.<br>"
|
||||||
" All inputs must be integers or floats, while the output will be all floating point values."
|
" All inputs must be integers or floats, while the output will be all floating point values."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$X,
|
let arguments = (ins Variadic<AnyTypeOf<[TensorOf<[I32,I64,F32,F64]>, MemRefOf<[I32,I64,F32,F64]>]>>:$X,
|
||||||
OptionalAttr<I64ArrayAttr>:$inputdimensions);
|
OptionalAttr<I64ArrayAttr>:$inputdimensions);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXImputerOp:MLONNX_Op<"Imputer",
|
def MLONNXImputerOp:MLONNX_Op<"Imputer",
|
||||||
|
@ -113,12 +179,23 @@ def MLONNXImputerOp:MLONNX_Op<"Imputer",
|
||||||
" which one depends on whether floats or integers are being processed.<br>"
|
" which one depends on whether floats or integers are being processed.<br>"
|
||||||
" The imputed_value attribute length can be 1 element, or it can have one element per input feature.<br>In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature."
|
" The imputed_value attribute length can be 1 element, or it can have one element per input feature.<br>In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
OptionalAttr<F32ArrayAttr>:$imputed_value_floats,
|
OptionalAttr<F32ArrayAttr>:$imputed_value_floats,
|
||||||
OptionalAttr<I64ArrayAttr>:$imputed_value_int64s,
|
OptionalAttr<I64ArrayAttr>:$imputed_value_int64s,
|
||||||
DefaultValuedAttr<F32Attr, "0.0">:$replaced_value_float,
|
DefaultValuedAttr<F32Attr, "0.0">:$replaced_value_float,
|
||||||
DefaultValuedAttr<I64Attr, "0">:$replaced_value_int64);
|
DefaultValuedAttr<I64Attr, "0">:$replaced_value_int64);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {20};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder",
|
def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder",
|
||||||
|
@ -154,6 +231,17 @@ def MLONNXLabelEncoderOp:MLONNX_Op<"LabelEncoder",
|
||||||
OptionalAttr<I64ArrayAttr>:$values_int64s,
|
OptionalAttr<I64ArrayAttr>:$values_int64s,
|
||||||
OptionalAttr<StrArrayAttr>:$values_strings);
|
OptionalAttr<StrArrayAttr>:$values_strings);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
|
def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
|
||||||
|
@ -162,7 +250,7 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
|
||||||
let description = [{
|
let description = [{
|
||||||
"Linear classifier"
|
"Linear classifier"
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
OptionalAttr<I64ArrayAttr>:$classlabels_ints,
|
OptionalAttr<I64ArrayAttr>:$classlabels_ints,
|
||||||
OptionalAttr<StrArrayAttr>:$classlabels_strings,
|
OptionalAttr<StrArrayAttr>:$classlabels_strings,
|
||||||
F32ArrayAttr:$coefficients,
|
F32ArrayAttr:$coefficients,
|
||||||
|
@ -171,6 +259,17 @@ def MLONNXLinearClassifierOp:MLONNX_Op<"LinearClassifier",
|
||||||
DefaultValuedAttr<StrAttr, "NONE">:$post_transform);
|
DefaultValuedAttr<StrAttr, "NONE">:$post_transform);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1,-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor",
|
def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor",
|
||||||
|
@ -184,12 +283,23 @@ def MLONNXLinearRegressorOp:MLONNX_Op<"LinearRegressor",
|
||||||
" The coefficients array is of length n, and the coefficients for each target are contiguous."
|
" The coefficients array is of length n, and the coefficients for each target are contiguous."
|
||||||
" Intercepts are optional but if provided must match the number of targets."
|
" Intercepts are optional but if provided must match the number of targets."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
OptionalAttr<F32ArrayAttr>:$coefficients,
|
OptionalAttr<F32ArrayAttr>:$coefficients,
|
||||||
OptionalAttr<F32ArrayAttr>:$intercepts,
|
OptionalAttr<F32ArrayAttr>:$intercepts,
|
||||||
DefaultValuedAttr<StrAttr, "NONE">:$post_transform,
|
DefaultValuedAttr<StrAttr, "NONE">:$post_transform,
|
||||||
DefaultValuedAttr<I64Attr, "1">:$targets);
|
DefaultValuedAttr<I64Attr, "1">:$targets);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXNormalizerOp:MLONNX_Op<"Normalizer",
|
def MLONNXNormalizerOp:MLONNX_Op<"Normalizer",
|
||||||
|
@ -207,9 +317,20 @@ def MLONNXNormalizerOp:MLONNX_Op<"Normalizer",
|
||||||
" For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row"
|
" For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row"
|
||||||
" of the batch is normalized independently."
|
" of the batch is normalized independently."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
DefaultValuedAttr<StrAttr, "MAX">:$norm);
|
DefaultValuedAttr<StrAttr, "MAX">:$norm);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder",
|
def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder",
|
||||||
|
@ -230,6 +351,17 @@ def MLONNXOneHotEncoderOp:MLONNX_Op<"OneHotEncoder",
|
||||||
OptionalAttr<StrArrayAttr>:$cats_strings,
|
OptionalAttr<StrArrayAttr>:$cats_strings,
|
||||||
DefaultValuedAttr<I64Attr, "1">:$zeros);
|
DefaultValuedAttr<I64Attr, "1">:$zeros);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
|
def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
|
||||||
|
@ -238,7 +370,7 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
|
||||||
let description = [{
|
let description = [{
|
||||||
"Support Vector Machine classifier"
|
"Support Vector Machine classifier"
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
OptionalAttr<I64ArrayAttr>:$classlabels_ints,
|
OptionalAttr<I64ArrayAttr>:$classlabels_ints,
|
||||||
OptionalAttr<StrArrayAttr>:$classlabels_strings,
|
OptionalAttr<StrArrayAttr>:$classlabels_strings,
|
||||||
OptionalAttr<F32ArrayAttr>:$coefficients,
|
OptionalAttr<F32ArrayAttr>:$coefficients,
|
||||||
|
@ -252,6 +384,17 @@ def MLONNXSVMClassifierOp:MLONNX_Op<"SVMClassifier",
|
||||||
OptionalAttr<I64ArrayAttr>:$vectors_per_class);
|
OptionalAttr<I64ArrayAttr>:$vectors_per_class);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1,-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
|
def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
|
||||||
|
@ -260,7 +403,7 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
|
||||||
let description = [{
|
let description = [{
|
||||||
"Support Vector Machine regression prediction and one-class SVM anomaly detection."
|
"Support Vector Machine regression prediction and one-class SVM anomaly detection."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
OptionalAttr<F32ArrayAttr>:$coefficients,
|
OptionalAttr<F32ArrayAttr>:$coefficients,
|
||||||
OptionalAttr<F32ArrayAttr>:$kernel_params,
|
OptionalAttr<F32ArrayAttr>:$kernel_params,
|
||||||
DefaultValuedAttr<StrAttr, "LINEAR">:$kernel_type,
|
DefaultValuedAttr<StrAttr, "LINEAR">:$kernel_type,
|
||||||
|
@ -270,6 +413,17 @@ def MLONNXSVMRegressorOp:MLONNX_Op<"SVMRegressor",
|
||||||
OptionalAttr<F32ArrayAttr>:$rho,
|
OptionalAttr<F32ArrayAttr>:$rho,
|
||||||
OptionalAttr<F32ArrayAttr>:$support_vectors);
|
OptionalAttr<F32ArrayAttr>:$support_vectors);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXScalerOp:MLONNX_Op<"Scaler",
|
def MLONNXScalerOp:MLONNX_Op<"Scaler",
|
||||||
|
@ -278,10 +432,21 @@ def MLONNXScalerOp:MLONNX_Op<"Scaler",
|
||||||
let description = [{
|
let description = [{
|
||||||
"Rescale input data, for example to standardize features by removing the mean and scaling to unit variance."
|
"Rescale input data, for example to standardize features by removing the mean and scaling to unit variance."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
OptionalAttr<F32ArrayAttr>:$offset,
|
OptionalAttr<F32ArrayAttr>:$offset,
|
||||||
OptionalAttr<F32ArrayAttr>:$scale);
|
OptionalAttr<F32ArrayAttr>:$scale);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
|
def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
|
||||||
|
@ -298,7 +463,7 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
|
||||||
" One and only one of classlabels_strings or classlabels_int64s"
|
" One and only one of classlabels_strings or classlabels_int64s"
|
||||||
" will be defined. The class_ids are indices into this list."
|
" will be defined. The class_ids are indices into this list."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
OptionalAttr<F32ArrayAttr>:$base_values,
|
OptionalAttr<F32ArrayAttr>:$base_values,
|
||||||
OptionalAttr<I64ArrayAttr>:$class_ids,
|
OptionalAttr<I64ArrayAttr>:$class_ids,
|
||||||
OptionalAttr<I64ArrayAttr>:$class_nodeids,
|
OptionalAttr<I64ArrayAttr>:$class_nodeids,
|
||||||
|
@ -318,6 +483,17 @@ def MLONNXTreeEnsembleClassifierOp:MLONNX_Op<"TreeEnsembleClassifier",
|
||||||
DefaultValuedAttr<StrAttr, "NONE">:$post_transform);
|
DefaultValuedAttr<StrAttr, "NONE">:$post_transform);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1,-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
|
def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
|
||||||
|
@ -335,7 +511,7 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
|
||||||
" All trees must have their node ids start at 0 and increment by 1.<br>"
|
" All trees must have their node ids start at 0 and increment by 1.<br>"
|
||||||
" Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF"
|
" Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF"
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F32,F64,I64,I32]>, MemRefOf<[F32,F64,I64,I32]>]>:$X,
|
||||||
DefaultValuedAttr<StrAttr, "SUM">:$aggregate_function,
|
DefaultValuedAttr<StrAttr, "SUM">:$aggregate_function,
|
||||||
OptionalAttr<F32ArrayAttr>:$base_values,
|
OptionalAttr<F32ArrayAttr>:$base_values,
|
||||||
OptionalAttr<I64Attr>:$n_targets,
|
OptionalAttr<I64Attr>:$n_targets,
|
||||||
|
@ -354,6 +530,17 @@ def MLONNXTreeEnsembleRegressorOp:MLONNX_Op<"TreeEnsembleRegressor",
|
||||||
OptionalAttr<I64ArrayAttr>:$target_treeids,
|
OptionalAttr<I64ArrayAttr>:$target_treeids,
|
||||||
OptionalAttr<F32ArrayAttr>:$target_weights);
|
OptionalAttr<F32ArrayAttr>:$target_weights);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
|
def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
|
||||||
|
@ -369,5 +556,16 @@ def MLONNXZipMapOp:MLONNX_Op<"ZipMap",
|
||||||
OptionalAttr<I64ArrayAttr>:$classlabels_int64s,
|
OptionalAttr<I64ArrayAttr>:$classlabels_int64s,
|
||||||
OptionalAttr<StrArrayAttr>:$classlabels_strings);
|
OptionalAttr<StrArrayAttr>:$classlabels_strings);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$Z);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {-1};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ def ONNX_Dialect : Dialect {
|
||||||
// * The mnemonic for the operation, or the name without the dialect prefix.
|
// * The mnemonic for the operation, or the name without the dialect prefix.
|
||||||
// * A list of traits for the operation.
|
// * A list of traits for the operation.
|
||||||
class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
Op<ONNX_Dialect, mnemonic, traits>;
|
Op<ONNX_Dialect, mnemonic, traits> ;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ONNX Operations
|
// ONNX Operations
|
||||||
|
@ -112,6 +112,17 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||||
DefaultValuedAttr<I64Attr, "0">:$storage_order,
|
DefaultValuedAttr<I64Attr, "0">:$storage_order,
|
||||||
OptionalAttr<I64ArrayAttr>:$strides);
|
OptionalAttr<I64ArrayAttr>:$strides);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {0};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
|
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
|
||||||
|
@ -137,6 +148,17 @@ def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
|
||||||
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
|
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
|
||||||
DefaultValuedAttr<F32Attr, "0.9">:$momentum);
|
DefaultValuedAttr<F32Attr, "0.9">:$momentum);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 5;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {0};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
|
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
|
||||||
|
@ -154,6 +176,17 @@ def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
|
||||||
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
|
DefaultValuedAttr<F32Attr, "0.0">:$constant_value,
|
||||||
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {0};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad",
|
def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad",
|
||||||
|
@ -168,6 +201,17 @@ def ONNXPadConstantPadOp : ONNX_Op<"PadConstantPad",
|
||||||
I64ArrayAttr:$pads,
|
I64ArrayAttr:$pads,
|
||||||
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {0};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
|
def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
|
||||||
|
@ -186,6 +230,17 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad",
|
||||||
let builders = [OpBuilder<"OpBuilder &builder, OperationState &state, "
|
let builders = [OpBuilder<"OpBuilder &builder, OperationState &state, "
|
||||||
"Value data, ArrayAttr pads, "
|
"Value data, ArrayAttr pads, "
|
||||||
"FloatAttr constant_value, StringAttr mode">];
|
"FloatAttr constant_value, StringAttr mode">];
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static int getNumberOfOperands() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static int getNumberOfResults() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
static std::vector<int> getTypeMap() {
|
||||||
|
return {0};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // ONNX_OPS
|
#endif // ONNX_OPS
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -144,62 +144,62 @@ func @test_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @test_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
|
func @test_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
|
||||||
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
|
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
|
||||||
"std.return"(%0) : (tensor<*xi32>) -> ()
|
"std.return"(%0) : (tensor<*xi1>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_and
|
// CHECK-LABEL: test_and
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
|
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1
|
||||||
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: return [[RES]] : memref<10x10xi32>
|
// CHECK: return [[RES]] : memref<10x10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @test_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
|
func @test_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
|
||||||
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
|
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
|
||||||
"std.return"(%0) : (tensor<*xi32>) -> ()
|
"std.return"(%0) : (tensor<*xi1>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_or
|
// CHECK-LABEL: test_or
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
|
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1
|
||||||
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: return [[RES]] : memref<10x10xi32>
|
// CHECK: return [[RES]] : memref<10x10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @test_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
|
func @test_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
|
||||||
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
|
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
|
||||||
"std.return"(%0) : (tensor<*xi32>) -> ()
|
"std.return"(%0) : (tensor<*xi1>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_xor
|
// CHECK-LABEL: test_xor
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
|
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1
|
||||||
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: return [[RES]] : memref<10x10xi32>
|
// CHECK: return [[RES]] : memref<10x10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
|
@ -158,24 +158,24 @@ func @test_sub_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
|
func @test_and_and(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
|
||||||
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
|
%0 = "onnx.And"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
|
||||||
%1 = "onnx.And"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32>
|
%1 = "onnx.And"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1>
|
||||||
"std.return"(%1) : (tensor<*xi32>) -> ()
|
"std.return"(%1) : (tensor<*xi1>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_and_and
|
// CHECK-LABEL: test_and_and
|
||||||
/// First And
|
/// First And
|
||||||
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32>
|
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1>
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
|
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1
|
||||||
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: store [[AND]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
|
|
||||||
/// Second And
|
/// Second And
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
@ -183,38 +183,38 @@ func @test_and_and(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i32
|
// CHECK: [[AND:%.+]] = and [[LOAD1]], [[LOAD2]] : i1
|
||||||
// CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: store [[AND]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
|
|
||||||
/// Dealloc of first result.
|
/// Dealloc of first result.
|
||||||
// CHECK: dealloc [[RES]] : memref<10x10xi32>
|
// CHECK: dealloc [[RES]] : memref<10x10xi1>
|
||||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32>
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1>
|
||||||
|
|
||||||
// CHECK: return [[RET_RES]] : memref<10x10xi32>
|
// CHECK: return [[RET_RES]] : memref<10x10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
|
func @test_or_or(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
|
||||||
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
|
%0 = "onnx.Or"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
|
||||||
%1 = "onnx.Or"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32>
|
%1 = "onnx.Or"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1>
|
||||||
"std.return"(%1) : (tensor<*xi32>) -> ()
|
"std.return"(%1) : (tensor<*xi1>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_or_or
|
// CHECK-LABEL: test_or_or
|
||||||
/// First Or
|
/// First Or
|
||||||
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32>
|
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1>
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
|
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1
|
||||||
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: store [[OR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
|
|
||||||
/// Second Or
|
/// Second Or
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
@ -222,38 +222,38 @@ func @test_or_or(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i32
|
// CHECK: [[OR:%.+]] = or [[LOAD1]], [[LOAD2]] : i1
|
||||||
// CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: store [[OR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
|
|
||||||
/// Dealloc of first result.
|
/// Dealloc of first result.
|
||||||
// CHECK: dealloc [[RES]] : memref<10x10xi32>
|
// CHECK: dealloc [[RES]] : memref<10x10xi1>
|
||||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32>
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1>
|
||||||
|
|
||||||
// CHECK: return [[RET_RES]] : memref<10x10xi32>
|
// CHECK: return [[RET_RES]] : memref<10x10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
|
func @test_xor_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>) -> tensor<*xi1> {
|
||||||
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
|
%0 = "onnx.Xor"(%arg0, %arg1) : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<*xi1>
|
||||||
%1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi32>, tensor<10x10xi32>) -> tensor<*xi32>
|
%1 = "onnx.Xor"(%0, %arg1) : (tensor<*xi1>, tensor<10x10xi1>) -> tensor<*xi1>
|
||||||
"std.return"(%1) : (tensor<*xi32>) -> ()
|
"std.return"(%1) : (tensor<*xi1>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_xor_xor
|
// CHECK-LABEL: test_xor_xor
|
||||||
/// First Xor
|
/// First Xor
|
||||||
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi32>
|
// CHECK: [[RET_RES:%.+]] = alloc() : memref<10x10xi1>
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xi1>
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
|
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1
|
||||||
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
|
|
||||||
/// Second Xor
|
/// Second Xor
|
||||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
@ -261,16 +261,16 @@ func @test_xor_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tens
|
||||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<10x10xi1>
|
||||||
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i32
|
// CHECK: [[XOR:%.+]] = xor [[LOAD1]], [[LOAD2]] : i1
|
||||||
// CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi32>
|
// CHECK: store [[XOR]], [[RET_RES]][%arg2, %arg3] : memref<10x10xi1>
|
||||||
|
|
||||||
/// Dealloc of first result.
|
/// Dealloc of first result.
|
||||||
// CHECK: dealloc [[RES]] : memref<10x10xi32>
|
// CHECK: dealloc [[RES]] : memref<10x10xi1>
|
||||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi32>
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<10x10xi1>
|
||||||
|
|
||||||
// CHECK: return [[RET_RES]] : memref<10x10xi32>
|
// CHECK: return [[RET_RES]] : memref<10x10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
|
@ -298,6 +298,16 @@ custom_definition_misc = dict([ ('Constant',
|
||||||
)])
|
)])
|
||||||
|
|
||||||
|
|
||||||
|
onnx_types = (
|
||||||
|
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
|
||||||
|
'float', 'double', 'complex64', 'complex128'
|
||||||
|
)
|
||||||
|
tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64',
|
||||||
|
'Complex<F32>', 'Complex<F64>'
|
||||||
|
)
|
||||||
|
|
||||||
|
MAX_NUM_TYPES=20
|
||||||
|
|
||||||
SNIPPETS = collect_snippets()
|
SNIPPETS = collect_snippets()
|
||||||
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
|
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
|
||||||
ONNX_ML = bool(args.domain == "ONNX_ML")
|
ONNX_ML = bool(args.domain == "ONNX_ML")
|
||||||
|
@ -376,53 +386,55 @@ def tblgen_operand_type_to_cpp_type(op_type):
|
||||||
|
|
||||||
|
|
||||||
def np_type_to_tblgen_attr_type(tstr):
|
def np_type_to_tblgen_attr_type(tstr):
|
||||||
tfrom = np.array([
|
|
||||||
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
|
|
||||||
'float', 'double'
|
|
||||||
])
|
|
||||||
tto = np.array(
|
|
||||||
['I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64'])
|
|
||||||
index = -1
|
index = -1
|
||||||
for i in range(len(tfrom)):
|
for i in range(len(onnx_types)):
|
||||||
if tfrom[i] in tstr:
|
if onnx_types[i] in tstr:
|
||||||
index = i
|
index = i
|
||||||
break
|
break
|
||||||
if index == -1:
|
if index == -1:
|
||||||
print("error", tstr)
|
return None
|
||||||
return ''
|
|
||||||
else:
|
else:
|
||||||
return tto[i]
|
return tblgen_types[i]
|
||||||
|
|
||||||
|
def get_tblgen_type_index(type_str):
|
||||||
|
return tblgen_types.index(type_str)
|
||||||
|
|
||||||
|
#the possible data structures are tensor, map and seq(tensor())
|
||||||
|
#TOFIX: currently, only tensor structure is supported
|
||||||
|
def get_data_structure_element(allowed_type_str):
|
||||||
|
if allowed_type_str.startswith('tensor') :
|
||||||
|
element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1)
|
||||||
|
return ('tensor', element)
|
||||||
|
else :
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
def get_allowed_elem_types(schema, input):
|
def get_allowed_elem_types(schema, input):
|
||||||
allowed_types_str = None
|
#allowed_types_str = None
|
||||||
return allowed_types_str
|
# return allowed_types_str
|
||||||
# TODO: enable type constraints.
|
# TODO: enable type constraints.
|
||||||
# if input.typeStr :
|
if input.typeStr :
|
||||||
# tstr = input.typeStr
|
tstr = input.typeStr
|
||||||
# else :
|
else :
|
||||||
# return allwedTypeStr
|
return None
|
||||||
# if schema.type_constraints:
|
if schema.type_constraints:
|
||||||
# for type_constraint in schema.type_constraints:
|
for type_constraint in schema.type_constraints:
|
||||||
# if type_constraint.type_param_str != tstr :
|
if type_constraint.type_param_str != tstr :
|
||||||
# continue
|
continue
|
||||||
# allowedTypes = type_constraint.allowed_type_strs
|
allowed_type_list=[]
|
||||||
# allowedTypeStr=''
|
allowedTypes = type_constraint.allowed_type_strs
|
||||||
# if (len(allowedTypes) > 0):
|
for allowedType in allowedTypes:
|
||||||
# t = convert_type(allowedTypes[0])
|
structure, element = get_data_structure_element(allowedType);
|
||||||
# if t == '' :
|
if structure == None or element == None:
|
||||||
# return ''
|
return None
|
||||||
# allowedTypeStr += t
|
t = np_type_to_tblgen_attr_type(element)
|
||||||
# for allowedType in allowedTypes[1:]:
|
if t == None :
|
||||||
# t = convert_type(allowedType)
|
return None
|
||||||
# if t == '' :
|
if not t in allowed_type_list :
|
||||||
# return ''
|
allowed_tyoe_list = allowed_type_list.append(t)
|
||||||
# if not t in allowedTypeStr :
|
|
||||||
# allowedTypeStr += ', '+t
|
return allowed_type_list
|
||||||
#
|
|
||||||
# return allowedTypeStr
|
return None
|
||||||
#
|
|
||||||
# return allowedTypeStr
|
|
||||||
|
|
||||||
|
|
||||||
def inc_indent(indent=None):
|
def inc_indent(indent=None):
|
||||||
|
@ -436,7 +448,6 @@ def dec_indent(indent):
|
||||||
def join_args(args):
|
def join_args(args):
|
||||||
return ", ".join(args)
|
return ", ".join(args)
|
||||||
|
|
||||||
|
|
||||||
def get_operands_or_results(schema, is_input):
|
def get_operands_or_results(schema, is_input):
|
||||||
value_list = schema.inputs if is_input else schema.outputs
|
value_list = schema.inputs if is_input else schema.outputs
|
||||||
if not value_list:
|
if not value_list:
|
||||||
|
@ -456,8 +467,9 @@ def get_operands_or_results(schema, is_input):
|
||||||
if elem_types is None:
|
if elem_types is None:
|
||||||
types = ["AnyMemRef", "AnyTensor"]
|
types = ["AnyMemRef", "AnyTensor"]
|
||||||
else:
|
else:
|
||||||
|
elem_types_str = ','.join(elem_types)
|
||||||
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
|
types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"]
|
||||||
types = list(map(lambda x: x.format(elem_types), types))
|
types = list(map(lambda x: x.format(elem_types_str), types))
|
||||||
|
|
||||||
# If operand is promotable to an attribute, then it must be
|
# If operand is promotable to an attribute, then it must be
|
||||||
# nullable in case it migrates to be an attribute.
|
# nullable in case it migrates to be an attribute.
|
||||||
|
@ -545,6 +557,64 @@ def get_attrs(schema):
|
||||||
name_to_type[attr.name] = get_attr_type_optional(attr.type)
|
name_to_type[attr.name] = get_attr_type_optional(attr.type)
|
||||||
return name_to_type
|
return name_to_type
|
||||||
|
|
||||||
|
def get_numberof_list(mylist):
|
||||||
|
expected_num = len(mylist)
|
||||||
|
for element in mylist :
|
||||||
|
if OpSchema.FormalParameterOption.Variadic == element.option:
|
||||||
|
expected_num = -1
|
||||||
|
return expected_num
|
||||||
|
|
||||||
|
def get_output_type_mapping(schema):
|
||||||
|
mapping=[]
|
||||||
|
for output in schema.outputs :
|
||||||
|
#if only one type is allowed, just set that
|
||||||
|
allowed_elem_types = get_allowed_elem_types(schema, output)
|
||||||
|
if allowed_elem_types != None and len(allowed_elem_types) == 1 :
|
||||||
|
mapping.append(str(get_tblgen_type_index(allowed_elem_types[0])))
|
||||||
|
continue
|
||||||
|
|
||||||
|
#map the type string
|
||||||
|
if output.typeStr :
|
||||||
|
tstr = output.typeStr
|
||||||
|
found = False
|
||||||
|
for i, input in enumerate(schema.inputs):
|
||||||
|
if input.typeStr and input.typeStr == tstr:
|
||||||
|
mapping.append(str(i+MAX_NUM_TYPES))
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
if found:
|
||||||
|
continue
|
||||||
|
|
||||||
|
#unknown output type
|
||||||
|
mapping.append(str(-1))
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def get_numberof_inout(s, indent, schema):
|
||||||
|
expected_num_operands = get_numberof_list(schema.inputs)
|
||||||
|
indent = inc_indent(indent)
|
||||||
|
s += indent + "static int getNumberOfOperands() {\n"
|
||||||
|
indent = inc_indent(indent)
|
||||||
|
s += indent + "return {};\n".format(expected_num_operands)
|
||||||
|
indent = dec_indent(indent)
|
||||||
|
s += indent + "}\n"
|
||||||
|
|
||||||
|
expected_num_results = get_numberof_list(schema.outputs)
|
||||||
|
s += indent + "static int getNumberOfResults() {\n"
|
||||||
|
indent = inc_indent(indent)
|
||||||
|
s += indent + "return {};\n".format(expected_num_results)
|
||||||
|
indent = dec_indent(indent)
|
||||||
|
s += indent + "}\n"
|
||||||
|
|
||||||
|
s += indent + "static std::vector<int> getTypeMap() {\n"
|
||||||
|
mapping = get_output_type_mapping(schema)
|
||||||
|
indent = inc_indent(indent)
|
||||||
|
s += indent + "return {" + ",".join(mapping) + "};\n"
|
||||||
|
indent = dec_indent(indent)
|
||||||
|
s += indent + "}\n"
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
|
def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
|
||||||
cpp_name_to_idx_literal = "{" + ", ".join([
|
cpp_name_to_idx_literal = "{" + ", ".join([
|
||||||
|
@ -552,15 +622,15 @@ def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx):
|
||||||
for name_to_idx in const_operands_name_to_idx
|
for name_to_idx in const_operands_name_to_idx
|
||||||
]) + "}"
|
]) + "}"
|
||||||
|
|
||||||
s += indent + "let extraClassDeclaration = [{\n"
|
#s += indent + "let extraClassDeclaration = [{\n"
|
||||||
indent = inc_indent(indent)
|
indent = inc_indent(indent)
|
||||||
s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n"
|
s += indent + "std::map<std::string, size_t> promotableConstOperands() {\n"
|
||||||
indent = inc_indent(indent)
|
indent = inc_indent(indent)
|
||||||
s += indent + "return {};\n".format(cpp_name_to_idx_literal)
|
s += indent + "return {};\n".format(cpp_name_to_idx_literal)
|
||||||
indent = dec_indent(indent)
|
indent = dec_indent(indent)
|
||||||
s += indent + "}\n"
|
s += indent + "}\n"
|
||||||
indent = dec_indent(indent)
|
#indent = dec_indent(indent)
|
||||||
s += indent + "}];\n"
|
#s += indent + "}];\n"
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
@ -657,10 +727,20 @@ def gen_op_def(schema):
|
||||||
|
|
||||||
s += '\n' + indent + '];\n'
|
s += '\n' + indent + '];\n'
|
||||||
|
|
||||||
|
# generate extracClassDeclaration
|
||||||
|
s += indent + "let extraClassDeclaration = [{\n"
|
||||||
|
#indent = inc_indent(indent)
|
||||||
|
|
||||||
|
# generate input/output number
|
||||||
|
s = get_numberof_inout(s, indent, schema)
|
||||||
|
|
||||||
|
# generate ProtableConst
|
||||||
if schema.name in OpsWithPromotableConstOperands:
|
if schema.name in OpsWithPromotableConstOperands:
|
||||||
s = get_promotable_const_operands_func(
|
s = get_promotable_const_operands_func(
|
||||||
s, indent, OpsWithPromotableConstOperands[schema.name])
|
s, indent, OpsWithPromotableConstOperands[schema.name])
|
||||||
|
|
||||||
|
s += indent + '}];\n'
|
||||||
|
|
||||||
if ( schema.name in custom_definition_misc) :
|
if ( schema.name in custom_definition_misc) :
|
||||||
s += custom_definition_misc[schema.name]
|
s += custom_definition_misc[schema.name]
|
||||||
|
|
||||||
|
@ -700,11 +780,13 @@ def gen_op_importer(schema, file):
|
||||||
# Special handlers currently require expected num operands/results to be specified.
|
# Special handlers currently require expected num operands/results to be specified.
|
||||||
# TODO: remove special handlers.
|
# TODO: remove special handlers.
|
||||||
args = ["node"]
|
args = ["node"]
|
||||||
|
"""
|
||||||
if expected_num_operands != -1 or expected_num_results != -1 or "buildOperation" not in handler_func:
|
if expected_num_operands != -1 or expected_num_results != -1 or "buildOperation" not in handler_func:
|
||||||
args.append(
|
args.append(
|
||||||
"/* expected_num_operands = */ {}".format(expected_num_operands))
|
"/* expected_num_operands = */ {}".format(expected_num_operands))
|
||||||
args.append(
|
args.append(
|
||||||
'/* expected_num_results = */ {}'.format(expected_num_results))
|
'/* expected_num_results = */ {}'.format(expected_num_results))
|
||||||
|
"""
|
||||||
s += inc_indent(indent) + " {}({});\n".format(
|
s += inc_indent(indent) + " {}({});\n".format(
|
||||||
handler_func, ", ".join(args))
|
handler_func, ", ".join(args))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue