Remove ImportNodeReshape (#208)
This commit is contained in:
parent
01a4977c74
commit
0a936edf79
|
@ -313,26 +313,6 @@ private:
|
||||||
node, inputs, expectedNumOperands, expectedNumResults);
|
node, inputs, expectedNumOperands, expectedNumResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportNodeReshape(onnx::NodeProto node) {
|
|
||||||
int expectedNumOperands = mlir::ONNXReshapeOp::getNumberOfOperands();
|
|
||||||
int expectedNumResults = mlir::ONNXReshapeOp::getNumberOfResults();
|
|
||||||
std::vector<mlir::Value> inputs;
|
|
||||||
std::string item;
|
|
||||||
for (int i = 0; i < node.input().size(); ++i) {
|
|
||||||
item = node.input()[i];
|
|
||||||
// For the second argument, check if there exists an initializer.
|
|
||||||
if (initializedTensors.ContainKey(legalize_name(item))) {
|
|
||||||
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
|
||||||
UnknownLoc(), builder_, legalize_name(item)));
|
|
||||||
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
buildOutputAndOperation<mlir::ONNXReshapeOp>(
|
|
||||||
node, inputs, expectedNumOperands, expectedNumResults);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for MaxPool operations.
|
* Special handle for MaxPool operations.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -223,7 +223,7 @@ if (opName == "ReduceSumSquare")
|
||||||
if (opName == "Relu")
|
if (opName == "Relu")
|
||||||
buildOperation<mlir::ONNXReluOp>(node);
|
buildOperation<mlir::ONNXReluOp>(node);
|
||||||
if (opName == "Reshape")
|
if (opName == "Reshape")
|
||||||
ImportNodeReshape(node);
|
buildOperation<mlir::ONNXReshapeOp>(node);
|
||||||
if (opName == "Resize")
|
if (opName == "Resize")
|
||||||
buildOperation<mlir::ONNXResizeOp>(node);
|
buildOperation<mlir::ONNXResizeOp>(node);
|
||||||
if (opName == "ReverseSequence")
|
if (opName == "ReverseSequence")
|
||||||
|
|
|
@ -239,7 +239,6 @@ special_op_handler = dict([
|
||||||
("MaxPool", "ImportNodeMaxPool"),
|
("MaxPool", "ImportNodeMaxPool"),
|
||||||
("BatchNormalization", "ImportNodeBatchNormalization"),
|
("BatchNormalization", "ImportNodeBatchNormalization"),
|
||||||
("Pad", "ImportNodePad"),
|
("Pad", "ImportNodePad"),
|
||||||
("Reshape", "ImportNodeReshape"),
|
|
||||||
#("Transpose", "ImportNodeTranspose")
|
#("Transpose", "ImportNodeTranspose")
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue