Import 2-argument Gemm as GemmNoBias (#68)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
60ac8f081f
commit
0bfb660d02
|
@ -36,6 +36,7 @@ special_attr_defaults = dict([
|
||||||
special_op_handler = dict([
|
special_op_handler = dict([
|
||||||
("Conv", "ImportNodeConv"),
|
("Conv", "ImportNodeConv"),
|
||||||
("MaxPool", "ImportNodeMaxPool"),
|
("MaxPool", "ImportNodeMaxPool"),
|
||||||
|
("Gemm", "ImportNodeGemm"),
|
||||||
#("Transpose", "ImportNodeTranspose")
|
#("Transpose", "ImportNodeTranspose")
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
@ -434,6 +434,18 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Special handle for Gemm operations.
|
||||||
|
*/
|
||||||
|
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
|
int nOps = node.input().size();
|
||||||
|
if (nOps == 2) {
|
||||||
|
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
|
||||||
|
} else {
|
||||||
|
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ImportNode(const onnx::NodeProto &node) {
|
void ImportNode(const onnx::NodeProto &node) {
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (const auto &item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
|
|
|
@ -98,7 +98,7 @@
|
||||||
}else if (OpName == "GatherND") {
|
}else if (OpName == "GatherND") {
|
||||||
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
|
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
|
||||||
}else if (OpName == "Gemm") {
|
}else if (OpName == "Gemm") {
|
||||||
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1);
|
ImportNodeGemm(node, 3, 1);
|
||||||
}else if (OpName == "GlobalAveragePool") {
|
}else if (OpName == "GlobalAveragePool") {
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
|
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
|
||||||
}else if (OpName == "GlobalLpPool") {
|
}else if (OpName == "GlobalLpPool") {
|
||||||
|
|
Loading…
Reference in New Issue