[MLIR] Add optimization based on TableGen pattern (#363)

* Define pattern.

* Fix file names.

* Add canonicalizer optimization based on TableGen pattern.

* Remove custom builders.

* Enable canonicalization in ONNF and ONNF-OPT.
This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-11-12 13:37:46 -05:00 committed by Tian Jin
parent 780e6f0aa0
commit 63596e723f
9 changed files with 216 additions and 20 deletions

View File

@ -88,6 +88,11 @@ find_library(MLIR_LLVM_IR
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(MLIR_LIB_TRANSFORM_UTILS
NAMES MLIRTransformUtils
PATHS ${LLVM_PROJECT_LIB}
NO_DEFAULT_PATH)
find_library(LLVM_LIB_SUPPORT
NAMES LLVMSupport
PATHS ${LLVM_PROJECT_LIB}
@ -106,6 +111,7 @@ set(MLIRLIBS
${MLIR_LIB_STANDARD_OPS}
${MLIR_LIB_OPT_MAIN}
${MLIR_LIB_SUPPORT}
${MLIR_LIB_TRANSFORM_UTILS}
${MLIR_LIB_ANALYSIS}
${MLIR_LIB_IR}
@ -116,10 +122,31 @@ set(MLIRLIBS
${MLIR_LIB_STANDARD_OPS}
${MLIR_LIB_OPT_MAIN}
${MLIR_LIB_SUPPORT}
${MLIR_LIB_TRANSFORM_UTILS}
${LLVM_LIB_SUPPORT}
Threads::Threads)
function(whole_archive_link target)
if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
set(link_flags "-L${LLVM_BUILD}/lib ")
FOREACH(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "-Wl,-force_load ${LLVM_BUILD}/lib/lib${LIB}.a ")
ENDFOREACH(LIB)
elseif(MSVC)
FOREACH(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "/WHOLEARCHIVE:${LIB} ")
ENDFOREACH(LIB)
else()
set(link_flags "-L${LLVM_BUILD}/lib -Wl,--whole-archive,")
FOREACH(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "-l${LIB},")
ENDFOREACH(LIB)
string(CONCAT link_flags ${link_flags} "--no-whole-archive")
endif()
set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags})
endfunction(whole_archive_link)
# Set up TableGen environment.
include(${LLVM_BUILD}/lib/cmake/llvm/TableGen.cmake)

View File

@ -193,11 +193,33 @@ class FrontendGenImpl {
// ONNX Dialect.
llvm::StringRef OpName = node.op_type();
if (OpName == "Add") {
auto op =
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
auto op = builder_.create<mlir::ONNXAddOp>(UnknownLoc(),
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
inputs[1]);
frontend_symbols_.AddMapping(
legalize_name(node.output()[0]), op.getResult());
return;
} else if (OpName == "MatMul") {
auto op = builder_.create<mlir::ONNXMatMulOp>(UnknownLoc(),
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
inputs[1]);
frontend_symbols_.AddMapping(
legalize_name(node.output()[0]), op.getResult());
return;
} else if (OpName == "Gemm") {
if (inputs.size() == 3) {
auto op = builder_.create<mlir::ONNXFullGemmOp>(UnknownLoc(),
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
inputs[1], inputs[2]);
frontend_symbols_.AddMapping(
legalize_name(node.output()[0]), op.getResult());
} else {
auto op = builder_.create<mlir::ONNXGemmOp>(UnknownLoc(),
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs);
frontend_symbols_.AddMapping(
legalize_name(node.output()[0]), op.getResult());
}
return;
}
// Old way of doing things.

View File

@ -10,6 +10,7 @@ add_library(
dialect/krnl/parser_helper.hpp
pass/shape_inference_pass.cpp
pass/shape_inference_interface.hpp
pass/onnx_combine.cpp
pass/passes.hpp)
# Include root src directory.
@ -32,7 +33,10 @@ target_link_libraries(compiler
${MLIRLIBS} curses)
add_executable(onnf-opt
tool/onnf_opt/onnf_opt.cpp)
tool/onnf_opt/onnf_opt.cpp)
set(LIB_LIST MLIRAffineOps MLIRLoopOps MLIRTransformUtils MLIREDSC MLIRTransforms)
whole_archive_link(onnf-opt ${LIB_LIST})
target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler)
target_include_directories(onnf-opt PRIVATE ../..)
@ -44,6 +48,11 @@ onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(gen_shape_inference)
add_dependencies(compiler gen_shape_inference)
set(LLVM_TARGET_DEFINITIONS pass/onnx_combine.td)
onnf_tablegen(onnx_combine.inc -gen-rewriters)
add_public_tablegen_target(gen_onnx_combine)
add_dependencies(compiler gen_onnx_combine)
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")

View File

@ -58,13 +58,47 @@ def ONNXAddOp: ONNX_Op<"add",
//
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in);
let results = (outs AnyTensor);
let hasCanonicalizer = 1;
}
// Build an ONNX Add operation using two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildONNXAddOp(b, result, lhs, rhs);
}]
>];
def ONNXMatMulOp: ONNX_Op<"matmul",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX matrix multiply operation";
let description = [{
The "onnx.mul" multiplies two matrices.
}];
let arguments = (ins AnyTypeOf<[F32Tensor, F64Tensor]>:$lhs_in,
AnyTypeOf<[F32Tensor, F64Tensor]>:$rhs_in);
let results = (outs AnyTypeOf<[F32Tensor, F64Tensor]>);
}
def ONNXGemmOp: ONNX_Op<"gemm",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation";
let description = [{
The "onnx.gemm" generic matrix multiplication with bias.
}];
let arguments = (ins Variadic<AnyTensor>:$inputs);
let results = (outs AnyTensor);
}
def ONNXFullGemmOp: ONNX_Op<"full_gemm",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX general matrix multiply operation";
let description = [{
The "onnx.gemm" generic matrix multiplication with bias.
}];
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
let results = (outs AnyTensor);
}
#endif // ONNX_OPS

View File

@ -40,18 +40,54 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
// ONNX Operations
//===----------------------------------------------------------------------===//
static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state,
mlir::Value* lhs, mlir::Value* rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF32Type()));
state.addOperands({lhs, rhs});
}
// Add
/// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface.
void ONNXAddOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// MatMul
void ONNXMatMulOp::inferShapes() {
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]);
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
// TODO:
// Verify that matrix sizes are valid.
// Take into account the dimensionality of the matrix.
//===----------------------------------------------------------------------===//
// Gemm
void ONNXGemmOp::inferShapes() {
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]);
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
// FullGemm
void ONNXFullGemmOp::inferShapes() {
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]);
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
// TODO:
// Verify that matrix sizes are valid for multiplication and addition.
// Take into account the dimensionality of the matrix.
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,30 @@
//===- ONNXCombine.cpp - ONNX High Level Optimizer ------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a set of simple combiners for optimizing operations in
// the ONNX dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include <numeric>
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
using namespace mlir;
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/compiler/onnx_combine.inc"
} // end anonymous namespace
/// Register optimization patterns as "canonicalization" patterns
/// on the ONNXMatMultOp.
void ONNXAddOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<MulAddToGemmOptPattern>(context);
}

View File

@ -0,0 +1,35 @@
//=- ONNXCombine.td - Pattern Match Optimizations for ONNX -*- tablegen -*-===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Defines language-specific pattern match optimizations for ONNX using
// Declarative Rewrite Rules (DRR) specified using TableGen records.
//
//===----------------------------------------------------------------------===//
#ifndef ONNX_COMBINE
#define ONNX_COMBINE
#ifndef OP_BASE
include "dialect/onnx/onnx.td"
#endif // OP_BASE
/// Note: The DRR definition used for defining patterns is shown below:
///
/// class Pattern<
/// dag sourcePattern, list<dag> resultPatterns,
/// list<dag> additionalConstraints = [],
/// dag benefitsAdded = (addBenefit 0)
/// >;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite
//===----------------------------------------------------------------------===//
// onnx.add(onnx.matmult(%X, %Y), %Z) = onnx.gemm(%X, %Y, %Z)
def MulAddToGemmOptPattern : Pat<(ONNXAddOp(ONNXMatMulOp $m1, $m2), $m3),
(ONNXFullGemmOp $m1, $m2, $m3)>;
#endif // ONNX_COMBINE

View File

@ -10,14 +10,17 @@
#include <llvm/Support/InitLLVM.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/Dialect/StandardOps/Ops.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h>
#include <mlir/Dialect/StandardOps/Ops.h>
#include "llvm/Support/SourceMgr.h"
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "src/compiler/helper.hpp"
#include "src/compiler/pass/passes.hpp"
using namespace onnf;
@ -45,6 +48,8 @@ static llvm::cl::opt<bool> verify_passes("verify-each",
int main(int argc, char** argv) {
llvm::InitLLVM y(argc, argv);
mlir::registerDialect<mlir::ONNXOpsDialect>();
// Register any pass manager command line options.
mlir::registerPassManagerCLOptions();
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
@ -57,9 +62,6 @@ int main(int argc, char** argv) {
auto output = mlir::openOutputFile(output_filename, &error_message);
mlir::registerDialect<mlir::KrnlOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
split_input_file, verify_diagnostics, verify_passes));
}

View File

@ -76,6 +76,7 @@ int main(int ac, char* av[]) {
mlir::PassManager pm(&context);
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.run(*module);
return 0;