diff --git a/src/compiler/dialect/onnx/onnx.td b/src/compiler/dialect/onnx/onnx.td index e8478a6..f50cdbd 100644 --- a/src/compiler/dialect/onnx/onnx.td +++ b/src/compiler/dialect/onnx/onnx.td @@ -58,17 +58,17 @@ class ONNX_Op traits = []> : include "dialect/onnx/onnxop.inc" -def ONNXFullGemmOp: ONNX_Op<"full_gemm", - [NoSideEffect, DeclareOpInterfaceMethods]> { - let summary = "ONNX general matrix multiply operation"; - let description = [{ +def ONNXFullGemmOp: ONNX_Op<"FullGemm", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ONNX general matrix multiply operation"; + let description = [{ - The "onnx.gemm" generic matrix multiplication with bias. + The "onnx.gemm" generic matrix multiplication with bias. - }]; + }]; - let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in); - let results = (outs AnyTensor); + let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in); + let results = (outs AnyTensor); } #endif // ONNX_OPS diff --git a/src/compiler/pass/onnx_combine.td b/src/compiler/pass/onnx_combine.td index ceb3c7e..946991d 100644 --- a/src/compiler/pass/onnx_combine.td +++ b/src/compiler/pass/onnx_combine.td @@ -30,7 +30,7 @@ def HasOneUse : ConstrainthasOneUse()">>; // Pattern-Match and Rewrite //===----------------------------------------------------------------------===// -// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.full_gemm(%X, %Y, %Z) +// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z) def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), (ONNXFullGemmOp $m1, $m2, $m3), [(HasOneUse $res)]>; diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index 78311bf..1548369 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -82,10 +82,10 @@ class ShapeInferencePass : public mlir::FunctionPass { // All operations which do not return a ranked tensor type have dynamic // shaped outputs. All those operation need to implement the inferShape() // method. - if (op->getName().getStringRef() != "onnx.add" && - op->getName().getStringRef() != "onnx.matmul" && - op->getName().getStringRef() != "onnx.gemm" && - op->getName().getStringRef() != "onnx.full_gemm") + if (op->getName().getStringRef() != "onnx.Add" && + op->getName().getStringRef() != "onnx.MatMul" && + op->getName().getStringRef() != "onnx.Gemm" && + op->getName().getStringRef() != "onnx.FullGemm") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); }); diff --git a/test/mlir/lit.cfg.py b/test/mlir/lit.cfg.py index e6b2993..86aab83 100644 --- a/test/mlir/lit.cfg.py +++ b/test/mlir/lit.cfg.py @@ -1,6 +1,14 @@ +import os +import sys +import re +import platform +import subprocess + +import lit.util import lit.formats from lit.llvm import llvm_config +from lit.llvm.subst import FindTool from lit.llvm.subst import ToolSubst # name: The name of this test suite. diff --git a/test/mlir/lit.site.cfg.py.in b/test/mlir/lit.site.cfg.py.in index 975cb3b..b4e876d 100644 --- a/test/mlir/lit.site.cfg.py.in +++ b/test/mlir/lit.site.cfg.py.in @@ -2,7 +2,7 @@ import lit.llvm config.llvm_tools_dir = "@MLIR_TOOLS_DIR@" -config.mlir_obj_root = "@MLIR_BUILD_DIR@" +config.mlir_obj_root = "@LLVM_BUILD@" config.mlir_tools_dir = "@MLIR_TOOLS_DIR@" config.suffixes = ['.mlir'] diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 7b7a79c..1cf1a89 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -2,13 +2,10 @@ //CHECK: module { module { - func @test_sigmoid() { - %0 = "frontend.input t1"() : () -> tensor<10x10xf32> - %1 = "frontend.input t2"() : () -> tensor<10x10xf32> - %2 = "frontend.input t3"() : () -> tensor<10x10xf32> - // CHECK: %{{[0-9]+}} = "onnx.full_gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> - %3 = "onnx.MatMul"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> - %4 = "onnx.Add"(%3, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> - %5 = "frontend.output t4"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32> + func @test_sigmoid(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> { + // CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + "std.return"(%1) : (tensor<10x10xf32>) -> () } }