diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index babdcb0..985f63d 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -347,7 +347,7 @@ void ONNXGemmOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } -// FullGemm +// GemmNoBias void ONNXGemmNoBiasOp::inferShapes() { // Cannot infer shape if no shape exists. diff --git a/src/pass/onnx_combine.td b/src/pass/onnx_combine.td index bec67cd..25a4656 100644 --- a/src/pass/onnx_combine.td +++ b/src/pass/onnx_combine.td @@ -30,7 +30,7 @@ def HasOneUse : Constraint>; // Pattern-Match and Rewrite //===----------------------------------------------------------------------===// -// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z) +// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z) def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), (ONNXGemmOp $m1, $m2, $m3), [(HasOneUse $res)]>; diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index cbdf04b..5ccb9a4 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -114,7 +114,7 @@ public: op->getName().getStringRef() != "onnx.Identity" && op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.Gemm" && - op->getName().getStringRef() != "onnx.FullGemm" && + op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Transpose") return false;