Import 2-argument Gemm as GemmNoBias (#68)

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-02-08 03:45:37 +09:00 committed by GitHub
parent 60ac8f081f
commit 0bfb660d02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 1 deletions

View File

@ -36,6 +36,7 @@ special_attr_defaults = dict([
special_op_handler = dict([
("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"),
("Gemm", "ImportNodeGemm"),
#("Transpose", "ImportNodeTranspose")
])

View File

@ -434,6 +434,18 @@ private:
}
}
/*!
* Special handle for Gemm operations.
*/
void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size();
if (nOps == 2) {
ImportNodeOneOut<mlir::ONNXGemmNoBiasOp>(node, 2, nOut);
} else {
ImportNodeOneOut<mlir::ONNXGemmOp>(node, nIn, nOut);
}
}
void ImportNode(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs;
for (const auto &item : node.input()) {

View File

@ -98,7 +98,7 @@
}else if (OpName == "GatherND") {
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
}else if (OpName == "Gemm") {
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1);
ImportNodeGemm(node, 3, 1);
}else if (OpName == "GlobalAveragePool") {
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
}else if (OpName == "GlobalLpPool") {