Merge pull request #33 from clang-ykt/fix-gemm

Fix Gemm translation to ONNX dialect.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-15 17:56:43 -05:00 committed by GitHub
commit 2ea0724e4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 11 additions and 11 deletions

View File

@ -90,17 +90,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
// or outputs. This decision affects only ONNX operations with optional
// arguments not ONNX operations with variadic operands.
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation";
let summary = "ONNX general matrix multiply operation without bias.";
let description = [{
The "onnx.gemm" generic matrix multiplication with bias.
The "onnx.Gemm" generic matrix multiplication without bias.
}];
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
let results = (outs AnyTensor);
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
}
def ONNXConv1Op:ONNX_Op<"Conv1",

View File

@ -347,9 +347,9 @@ void ONNXGemmOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
// FullGemm
// GemmNoBias
void ONNXFullGemmOp::inferShapes() {
void ONNXGemmNoBiasOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())

View File

@ -30,9 +30,9 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// Pattern-Match and Rewrite
//===----------------------------------------------------------------------===//
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z)
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXFullGemmOp $m1, $m2, $m3),
(ONNXGemmOp $m1, $m2, $m3),
[(HasOneUse $res)]>;
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)

View File

@ -114,7 +114,7 @@ public:
op->getName().getStringRef() != "onnx.Identity" &&
op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm" &&
op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose")
return false;

View File

@ -2,7 +2,7 @@
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-LABEL: test_matmul_add_simplification
// CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
// CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
"std.return"(%1) : (tensor<10x10xf32>) -> ()