Fix rebase errors. (#378)

This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-11-20 10:59:48 -05:00 committed by Tian Jin
parent 6c7ff180f9
commit bee32e2041
6 changed files with 27 additions and 22 deletions

View File

@ -58,7 +58,7 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
include "dialect/onnx/onnxop.inc"
def ONNXFullGemmOp: ONNX_Op<"full_gemm",
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation";
let description = [{

View File

@ -30,7 +30,7 @@ def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>;
// Pattern-Match and Rewrite
//===----------------------------------------------------------------------===//
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.full_gemm(%X, %Y, %Z)
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z)
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
(ONNXFullGemmOp $m1, $m2, $m3),
[(HasOneUse $res)]>;

View File

@ -82,10 +82,10 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
// All operations which do not return a ranked tensor type have dynamic
// shaped outputs. All those operation need to implement the inferShape()
// method.
if (op->getName().getStringRef() != "onnx.add" &&
op->getName().getStringRef() != "onnx.matmul" &&
op->getName().getStringRef() != "onnx.gemm" &&
op->getName().getStringRef() != "onnx.full_gemm")
if (op->getName().getStringRef() != "onnx.Add" &&
op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm")
return false;
return llvm::any_of(op->getResultTypes(),
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });

View File

@ -1,6 +1,14 @@
import os
import sys
import re
import platform
import subprocess
import lit.util
import lit.formats
from lit.llvm import llvm_config
from lit.llvm.subst import FindTool
from lit.llvm.subst import ToolSubst
# name: The name of this test suite.

View File

@ -2,7 +2,7 @@
import lit.llvm
config.llvm_tools_dir = "@MLIR_TOOLS_DIR@"
config.mlir_obj_root = "@MLIR_BUILD_DIR@"
config.mlir_obj_root = "@LLVM_BUILD@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.suffixes = ['.mlir']

View File

@ -2,13 +2,10 @@
//CHECK: module {
module {
func @test_sigmoid() {
%0 = "frontend.input t1"() : () -> tensor<10x10xf32>
%1 = "frontend.input t2"() : () -> tensor<10x10xf32>
%2 = "frontend.input t3"() : () -> tensor<10x10xf32>
// CHECK: %{{[0-9]+}} = "onnx.full_gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%3 = "onnx.MatMul"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%4 = "onnx.Add"(%3, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%5 = "frontend.output t4"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32>
func @test_sigmoid(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
"std.return"(%1) : (tensor<10x10xf32>) -> ()
}
}