From 63596e723f935586c643f31ad930b4ad74bdb015 Mon Sep 17 00:00:00 2001 From: GHEORGHE-TEOD BERCEA Date: Tue, 12 Nov 2019 13:37:46 -0500 Subject: [PATCH] [MLIR] Add optimization based on TableGen pattern (#363) * Define pattern. * Fix file names. * Add canonicalizer optimization based on TableGen pattern. * Remove custom builders. * Enable canonicalization in ONNF and ONNF-OPT. --- MLIR.cmake | 27 +++++++++++ src/builder/frontend_dialect_transformer.cpp | 26 +++++++++- src/compiler/CMakeLists.txt | 11 ++++- src/compiler/dialect/onnx/onnx.td | 46 +++++++++++++++--- src/compiler/dialect/onnx/onnx_ops.cpp | 50 +++++++++++++++++--- src/compiler/pass/onnx_combine.cpp | 30 ++++++++++++ src/compiler/pass/onnx_combine.td | 35 ++++++++++++++ src/compiler/tool/onnf_opt/onnf_opt.cpp | 10 ++-- src/main.cpp | 1 + 9 files changed, 216 insertions(+), 20 deletions(-) create mode 100644 src/compiler/pass/onnx_combine.cpp create mode 100644 src/compiler/pass/onnx_combine.td diff --git a/MLIR.cmake b/MLIR.cmake index 1ad10de..7aed344 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -88,6 +88,11 @@ find_library(MLIR_LLVM_IR PATHS ${LLVM_PROJECT_LIB} NO_DEFAULT_PATH) +find_library(MLIR_LIB_TRANSFORM_UTILS + NAMES MLIRTransformUtils + PATHS ${LLVM_PROJECT_LIB} + NO_DEFAULT_PATH) + find_library(LLVM_LIB_SUPPORT NAMES LLVMSupport PATHS ${LLVM_PROJECT_LIB} @@ -106,6 +111,7 @@ set(MLIRLIBS ${MLIR_LIB_STANDARD_OPS} ${MLIR_LIB_OPT_MAIN} ${MLIR_LIB_SUPPORT} + ${MLIR_LIB_TRANSFORM_UTILS} ${MLIR_LIB_ANALYSIS} ${MLIR_LIB_IR} @@ -116,10 +122,31 @@ set(MLIRLIBS ${MLIR_LIB_STANDARD_OPS} ${MLIR_LIB_OPT_MAIN} ${MLIR_LIB_SUPPORT} + ${MLIR_LIB_TRANSFORM_UTILS} ${LLVM_LIB_SUPPORT} Threads::Threads) +function(whole_archive_link target) + if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin") + set(link_flags "-L${LLVM_BUILD}/lib ") + FOREACH(LIB ${ARGN}) + string(CONCAT link_flags ${link_flags} "-Wl,-force_load ${LLVM_BUILD}/lib/lib${LIB}.a ") + ENDFOREACH(LIB) + elseif(MSVC) + FOREACH(LIB ${ARGN}) + string(CONCAT link_flags ${link_flags} "/WHOLEARCHIVE:${LIB} ") + ENDFOREACH(LIB) + else() + set(link_flags "-L${LLVM_BUILD}/lib -Wl,--whole-archive,") + FOREACH(LIB ${ARGN}) + string(CONCAT link_flags ${link_flags} "-l${LIB},") + ENDFOREACH(LIB) + string(CONCAT link_flags ${link_flags} "--no-whole-archive") + endif() + set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags}) +endfunction(whole_archive_link) + # Set up TableGen environment. include(${LLVM_BUILD}/lib/cmake/llvm/TableGen.cmake) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 7413768..2b5810f 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -193,11 +193,33 @@ class FrontendGenImpl { // ONNX Dialect. llvm::StringRef OpName = node.op_type(); if (OpName == "Add") { - auto op = - builder_.create(UnknownLoc(), inputs[0], inputs[1]); + auto op = builder_.create(UnknownLoc(), + mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0], + inputs[1]); frontend_symbols_.AddMapping( legalize_name(node.output()[0]), op.getResult()); return; + } else if (OpName == "MatMul") { + auto op = builder_.create(UnknownLoc(), + mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0], + inputs[1]); + frontend_symbols_.AddMapping( + legalize_name(node.output()[0]), op.getResult()); + return; + } else if (OpName == "Gemm") { + if (inputs.size() == 3) { + auto op = builder_.create(UnknownLoc(), + mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0], + inputs[1], inputs[2]); + frontend_symbols_.AddMapping( + legalize_name(node.output()[0]), op.getResult()); + } else { + auto op = builder_.create(UnknownLoc(), + mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs); + frontend_symbols_.AddMapping( + legalize_name(node.output()[0]), op.getResult()); + } + return; } // Old way of doing things. diff --git a/src/compiler/CMakeLists.txt b/src/compiler/CMakeLists.txt index 417b60f..340e247 100644 --- a/src/compiler/CMakeLists.txt +++ b/src/compiler/CMakeLists.txt @@ -10,6 +10,7 @@ add_library( dialect/krnl/parser_helper.hpp pass/shape_inference_pass.cpp pass/shape_inference_interface.hpp + pass/onnx_combine.cpp pass/passes.hpp) # Include root src directory. @@ -32,7 +33,10 @@ target_link_libraries(compiler ${MLIRLIBS} curses) add_executable(onnf-opt - tool/onnf_opt/onnf_opt.cpp) + tool/onnf_opt/onnf_opt.cpp) + +set(LIB_LIST MLIRAffineOps MLIRLoopOps MLIRTransformUtils MLIREDSC MLIRTransforms) +whole_archive_link(onnf-opt ${LIB_LIST}) target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler) target_include_directories(onnf-opt PRIVATE ../..) @@ -44,6 +48,11 @@ onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(gen_shape_inference) add_dependencies(compiler gen_shape_inference) +set(LLVM_TARGET_DEFINITIONS pass/onnx_combine.td) +onnf_tablegen(onnx_combine.inc -gen-rewriters) +add_public_tablegen_target(gen_onnx_combine) +add_dependencies(compiler gen_onnx_combine) + set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") diff --git a/src/compiler/dialect/onnx/onnx.td b/src/compiler/dialect/onnx/onnx.td index c4ee81b..c4150b7 100644 --- a/src/compiler/dialect/onnx/onnx.td +++ b/src/compiler/dialect/onnx/onnx.td @@ -58,13 +58,47 @@ def ONNXAddOp: ONNX_Op<"add", // let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in); let results = (outs AnyTensor); + let hasCanonicalizer = 1; +} - // Build an ONNX Add operation using two input operands. - let builders = [ - OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{ - buildONNXAddOp(b, result, lhs, rhs); - }] - >]; +def ONNXMatMulOp: ONNX_Op<"matmul", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ONNX matrix multiply operation"; + let description = [{ + + The "onnx.mul" multiplies two matrices. + + }]; + + let arguments = (ins AnyTypeOf<[F32Tensor, F64Tensor]>:$lhs_in, + AnyTypeOf<[F32Tensor, F64Tensor]>:$rhs_in); + let results = (outs AnyTypeOf<[F32Tensor, F64Tensor]>); +} + +def ONNXGemmOp: ONNX_Op<"gemm", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ONNX general matrix multiply operation"; + let description = [{ + + The "onnx.gemm" generic matrix multiplication with bias. + + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs AnyTensor); +} + +def ONNXFullGemmOp: ONNX_Op<"full_gemm", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ONNX general matrix multiply operation"; + let description = [{ + + The "onnx.gemm" generic matrix multiplication with bias. + + }]; + + let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in); + let results = (outs AnyTensor); } #endif // ONNX_OPS diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index 627f804..bca52f1 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -40,18 +40,54 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx) // ONNX Operations //===----------------------------------------------------------------------===// -static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state, - mlir::Value* lhs, mlir::Value* rhs) { - state.addTypes(UnrankedTensorType::get(builder->getF32Type())); - state.addOperands({lhs, rhs}); -} +// Add -/// Infer the output shape of the ONNXAddOp. This method is required by the -/// shape inference interface. void ONNXAddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } +//===----------------------------------------------------------------------===// + +// MatMul + +void ONNXMatMulOp::inferShapes() { + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + SmallVector dims(lhsTy.getShape()[0]); + dims.emplace_back(rhsTy.getShape()[1]); + getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); +} + +// TODO: +// Verify that matrix sizes are valid. +// Take into account the dimensionality of the matrix. + +//===----------------------------------------------------------------------===// + +// Gemm + +void ONNXGemmOp::inferShapes() { + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + SmallVector dims(lhsTy.getShape()[0]); + dims.emplace_back(rhsTy.getShape()[1]); + getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); +} + +// FullGemm + +void ONNXFullGemmOp::inferShapes() { + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + SmallVector dims(lhsTy.getShape()[0]); + dims.emplace_back(rhsTy.getShape()[1]); + getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); +} + +// TODO: +// Verify that matrix sizes are valid for multiplication and addition. +// Take into account the dimensionality of the matrix. + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/compiler/pass/onnx_combine.cpp b/src/compiler/pass/onnx_combine.cpp new file mode 100644 index 0000000..4709f8d --- /dev/null +++ b/src/compiler/pass/onnx_combine.cpp @@ -0,0 +1,30 @@ +//===- ONNXCombine.cpp - ONNX High Level Optimizer ------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements a set of simple combiners for optimizing operations in +// the ONNX dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +#include +#include "src/compiler/dialect/onnx/onnx_ops.hpp" + +using namespace mlir; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "src/compiler/onnx_combine.inc" +} // end anonymous namespace + +/// Register optimization patterns as "canonicalization" patterns +/// on the ONNXMatMultOp. +void ONNXAddOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} diff --git a/src/compiler/pass/onnx_combine.td b/src/compiler/pass/onnx_combine.td new file mode 100644 index 0000000..bfd7905 --- /dev/null +++ b/src/compiler/pass/onnx_combine.td @@ -0,0 +1,35 @@ +//=- ONNXCombine.td - Pattern Match Optimizations for ONNX -*- tablegen -*-===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// Defines language-specific pattern match optimizations for ONNX using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef ONNX_COMBINE +#define ONNX_COMBINE + +#ifndef OP_BASE +include "dialect/onnx/onnx.td" +#endif // OP_BASE + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// onnx.add(onnx.matmult(%X, %Y), %Z) = onnx.gemm(%X, %Y, %Z) +def MulAddToGemmOptPattern : Pat<(ONNXAddOp(ONNXMatMulOp $m1, $m2), $m3), + (ONNXFullGemmOp $m1, $m2, $m3)>; + +#endif // ONNX_COMBINE diff --git a/src/compiler/tool/onnf_opt/onnf_opt.cpp b/src/compiler/tool/onnf_opt/onnf_opt.cpp index 081faa9..8d25f45 100644 --- a/src/compiler/tool/onnf_opt/onnf_opt.cpp +++ b/src/compiler/tool/onnf_opt/onnf_opt.cpp @@ -10,14 +10,17 @@ #include #include #include +#include #include #include #include #include -#include +#include "llvm/Support/SourceMgr.h" #include "src/compiler/dialect/krnl/krnl_ops.hpp" +#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/helper.hpp" +#include "src/compiler/pass/passes.hpp" using namespace onnf; @@ -45,6 +48,8 @@ static llvm::cl::opt verify_passes("verify-each", int main(int argc, char** argv) { llvm::InitLLVM y(argc, argv); + mlir::registerDialect(); + // Register any pass manager command line options. mlir::registerPassManagerCLOptions(); mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); @@ -57,9 +62,6 @@ int main(int argc, char** argv) { auto output = mlir::openOutputFile(output_filename, &error_message); - mlir::registerDialect(); - mlir::registerDialect(); - return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline, split_input_file, verify_diagnostics, verify_passes)); } diff --git a/src/main.cpp b/src/main.cpp index a7c05f8..0a8beb8 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -76,6 +76,7 @@ int main(int ac, char* av[]) { mlir::PassManager pm(&context); pm.addPass(mlir::createShapeInferencePass()); + pm.addPass(mlir::createCanonicalizerPass()); pm.run(*module); return 0;