Import 2-argument Gemm as GemmNoBias (#68)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
60ac8f081f
commit
0bfb660d02
|
@ -36,6 +36,7 @@ special_attr_defaults = dict([
|
|||
special_op_handler = dict([
|
||||
("Conv", "ImportNodeConv"),
|
||||
("MaxPool", "ImportNodeMaxPool"),
|
||||
("Gemm", "ImportNodeGemm"),
|
||||
#("Transpose", "ImportNodeTranspose")
|
||||
])
|
||||
|
||||
|
|
|
@ -434,6 +434,18 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Special handle for Gemm operations.
|
||||
*/
|
||||
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
|
||||
int nOps = node.input().size();
|
||||
if (nOps == 2) {
|
||||
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
|
||||
} else {
|
||||
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
|
||||
}
|
||||
}
|
||||
|
||||
void ImportNode(const onnx::NodeProto &node) {
|
||||
std::vector<mlir::Value> inputs;
|
||||
for (const auto &item : node.input()) {
|
||||
|
|
|
@ -98,7 +98,7 @@
|
|||
}else if (OpName == "GatherND") {
|
||||
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
|
||||
}else if (OpName == "Gemm") {
|
||||
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1);
|
||||
ImportNodeGemm(node, 3, 1);
|
||||
}else if (OpName == "GlobalAveragePool") {
|
||||
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
|
||||
}else if (OpName == "GlobalLpPool") {
|
||||
|
|
Loading…
Reference in New Issue