Fix rebase errors. (#378)
This commit is contained in:
parent
6c7ff180f9
commit
bee32e2041
|
@ -58,7 +58,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
|
||||||
include "dialect/onnx/onnxop.inc"
|
include "dialect/onnx/onnxop.inc"
|
||||||
|
|
||||||
def ONNXFullGemmOp: ONNX_Op<"full_gemm",
|
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX general matrix multiply operation";
|
let summary = "ONNX general matrix multiply operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -30,7 +30,7 @@ def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>;
|
||||||
// Pattern-Match and Rewrite
|
// Pattern-Match and Rewrite
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.full_gemm(%X, %Y, %Z)
|
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z)
|
||||||
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
||||||
(ONNXFullGemmOp $m1, $m2, $m3),
|
(ONNXFullGemmOp $m1, $m2, $m3),
|
||||||
[(HasOneUse $res)]>;
|
[(HasOneUse $res)]>;
|
||||||
|
|
|
@ -82,10 +82,10 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
// All operations which do not return a ranked tensor type have dynamic
|
// All operations which do not return a ranked tensor type have dynamic
|
||||||
// shaped outputs. All those operation need to implement the inferShape()
|
// shaped outputs. All those operation need to implement the inferShape()
|
||||||
// method.
|
// method.
|
||||||
if (op->getName().getStringRef() != "onnx.add" &&
|
if (op->getName().getStringRef() != "onnx.Add" &&
|
||||||
op->getName().getStringRef() != "onnx.matmul" &&
|
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||||
op->getName().getStringRef() != "onnx.gemm" &&
|
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||||
op->getName().getStringRef() != "onnx.full_gemm")
|
op->getName().getStringRef() != "onnx.FullGemm")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(),
|
return llvm::any_of(op->getResultTypes(),
|
||||||
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
||||||
|
|
|
@ -1,6 +1,14 @@
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import lit.util
|
||||||
import lit.formats
|
import lit.formats
|
||||||
from lit.llvm import llvm_config
|
from lit.llvm import llvm_config
|
||||||
|
from lit.llvm.subst import FindTool
|
||||||
from lit.llvm.subst import ToolSubst
|
from lit.llvm.subst import ToolSubst
|
||||||
|
|
||||||
# name: The name of this test suite.
|
# name: The name of this test suite.
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import lit.llvm
|
import lit.llvm
|
||||||
|
|
||||||
config.llvm_tools_dir = "@MLIR_TOOLS_DIR@"
|
config.llvm_tools_dir = "@MLIR_TOOLS_DIR@"
|
||||||
config.mlir_obj_root = "@MLIR_BUILD_DIR@"
|
config.mlir_obj_root = "@LLVM_BUILD@"
|
||||||
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
|
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
|
||||||
config.suffixes = ['.mlir']
|
config.suffixes = ['.mlir']
|
||||||
|
|
||||||
|
|
|
@ -2,13 +2,10 @@
|
||||||
|
|
||||||
//CHECK: module {
|
//CHECK: module {
|
||||||
module {
|
module {
|
||||||
func @test_sigmoid() {
|
func @test_sigmoid(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
||||||
%0 = "frontend.input t1"() : () -> tensor<10x10xf32>
|
// CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
%1 = "frontend.input t2"() : () -> tensor<10x10xf32>
|
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
%2 = "frontend.input t3"() : () -> tensor<10x10xf32>
|
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
// CHECK: %{{[0-9]+}} = "onnx.full_gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
||||||
%3 = "onnx.MatMul"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
||||||
%4 = "onnx.Add"(%3, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
||||||
%5 = "frontend.output t4"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue