Merge pull request #33 from clang-ykt/fix-gemm
Fix Gemm translation to ONNX dialect.
This commit is contained in:
commit
2ea0724e4d
|
@ -90,17 +90,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
|||
// or outputs. This decision affects only ONNX operations with optional
|
||||
// arguments not ONNX operations with variadic operands.
|
||||
|
||||
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
|
||||
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX general matrix multiply operation";
|
||||
let summary = "ONNX general matrix multiply operation without bias.";
|
||||
let description = [{
|
||||
|
||||
The "onnx.gemm" generic matrix multiplication with bias.
|
||||
The "onnx.Gemm" generic matrix multiplication without bias.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
|
||||
let results = (outs AnyTensor);
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
}
|
||||
|
||||
def ONNXConv1Op:ONNX_Op<"Conv1",
|
||||
|
|
|
@ -347,9 +347,9 @@ void ONNXGemmOp::inferShapes() {
|
|||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
// FullGemm
|
||||
// GemmNoBias
|
||||
|
||||
void ONNXFullGemmOp::inferShapes() {
|
||||
void ONNXGemmNoBiasOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
|
|
|
@ -30,9 +30,9 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
|||
// Pattern-Match and Rewrite
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z)
|
||||
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
|
||||
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
||||
(ONNXFullGemmOp $m1, $m2, $m3),
|
||||
(ONNXGemmOp $m1, $m2, $m3),
|
||||
[(HasOneUse $res)]>;
|
||||
|
||||
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
||||
|
|
|
@ -114,7 +114,7 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Identity" &&
|
||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||
op->getName().getStringRef() != "onnx.FullGemm" &&
|
||||
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||
op->getName().getStringRef() != "onnx.Transpose")
|
||||
return false;
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
||||
// CHECK-LABEL: test_matmul_add_simplification
|
||||
// CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
// CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue