[MLIR] Add optimization based on TableGen pattern (#363)
* Define pattern. * Fix file names. * Add canonicalizer optimization based on TableGen pattern. * Remove custom builders. * Enable canonicalization in ONNF and ONNF-OPT.
This commit is contained in:
parent
780e6f0aa0
commit
63596e723f
27
MLIR.cmake
27
MLIR.cmake
|
@ -88,6 +88,11 @@ find_library(MLIR_LLVM_IR
|
|||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(MLIR_LIB_TRANSFORM_UTILS
|
||||
NAMES MLIRTransformUtils
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
find_library(LLVM_LIB_SUPPORT
|
||||
NAMES LLVMSupport
|
||||
PATHS ${LLVM_PROJECT_LIB}
|
||||
|
@ -106,6 +111,7 @@ set(MLIRLIBS
|
|||
${MLIR_LIB_STANDARD_OPS}
|
||||
${MLIR_LIB_OPT_MAIN}
|
||||
${MLIR_LIB_SUPPORT}
|
||||
${MLIR_LIB_TRANSFORM_UTILS}
|
||||
|
||||
${MLIR_LIB_ANALYSIS}
|
||||
${MLIR_LIB_IR}
|
||||
|
@ -116,10 +122,31 @@ set(MLIRLIBS
|
|||
${MLIR_LIB_STANDARD_OPS}
|
||||
${MLIR_LIB_OPT_MAIN}
|
||||
${MLIR_LIB_SUPPORT}
|
||||
${MLIR_LIB_TRANSFORM_UTILS}
|
||||
|
||||
${LLVM_LIB_SUPPORT}
|
||||
Threads::Threads)
|
||||
|
||||
function(whole_archive_link target)
|
||||
if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
|
||||
set(link_flags "-L${LLVM_BUILD}/lib ")
|
||||
FOREACH(LIB ${ARGN})
|
||||
string(CONCAT link_flags ${link_flags} "-Wl,-force_load ${LLVM_BUILD}/lib/lib${LIB}.a ")
|
||||
ENDFOREACH(LIB)
|
||||
elseif(MSVC)
|
||||
FOREACH(LIB ${ARGN})
|
||||
string(CONCAT link_flags ${link_flags} "/WHOLEARCHIVE:${LIB} ")
|
||||
ENDFOREACH(LIB)
|
||||
else()
|
||||
set(link_flags "-L${LLVM_BUILD}/lib -Wl,--whole-archive,")
|
||||
FOREACH(LIB ${ARGN})
|
||||
string(CONCAT link_flags ${link_flags} "-l${LIB},")
|
||||
ENDFOREACH(LIB)
|
||||
string(CONCAT link_flags ${link_flags} "--no-whole-archive")
|
||||
endif()
|
||||
set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags})
|
||||
endfunction(whole_archive_link)
|
||||
|
||||
# Set up TableGen environment.
|
||||
include(${LLVM_BUILD}/lib/cmake/llvm/TableGen.cmake)
|
||||
|
||||
|
|
|
@ -193,11 +193,33 @@ class FrontendGenImpl {
|
|||
// ONNX Dialect.
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
if (OpName == "Add") {
|
||||
auto op =
|
||||
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
|
||||
auto op = builder_.create<mlir::ONNXAddOp>(UnknownLoc(),
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
|
||||
inputs[1]);
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[0]), op.getResult());
|
||||
return;
|
||||
} else if (OpName == "MatMul") {
|
||||
auto op = builder_.create<mlir::ONNXMatMulOp>(UnknownLoc(),
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
|
||||
inputs[1]);
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[0]), op.getResult());
|
||||
return;
|
||||
} else if (OpName == "Gemm") {
|
||||
if (inputs.size() == 3) {
|
||||
auto op = builder_.create<mlir::ONNXFullGemmOp>(UnknownLoc(),
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
|
||||
inputs[1], inputs[2]);
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[0]), op.getResult());
|
||||
} else {
|
||||
auto op = builder_.create<mlir::ONNXGemmOp>(UnknownLoc(),
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs);
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[0]), op.getResult());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Old way of doing things.
|
||||
|
|
|
@ -10,6 +10,7 @@ add_library(
|
|||
dialect/krnl/parser_helper.hpp
|
||||
pass/shape_inference_pass.cpp
|
||||
pass/shape_inference_interface.hpp
|
||||
pass/onnx_combine.cpp
|
||||
pass/passes.hpp)
|
||||
|
||||
# Include root src directory.
|
||||
|
@ -32,7 +33,10 @@ target_link_libraries(compiler
|
|||
${MLIRLIBS} curses)
|
||||
|
||||
add_executable(onnf-opt
|
||||
tool/onnf_opt/onnf_opt.cpp)
|
||||
tool/onnf_opt/onnf_opt.cpp)
|
||||
|
||||
set(LIB_LIST MLIRAffineOps MLIRLoopOps MLIRTransformUtils MLIREDSC MLIRTransforms)
|
||||
whole_archive_link(onnf-opt ${LIB_LIST})
|
||||
|
||||
target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler)
|
||||
target_include_directories(onnf-opt PRIVATE ../..)
|
||||
|
@ -44,6 +48,11 @@ onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
|||
add_public_tablegen_target(gen_shape_inference)
|
||||
add_dependencies(compiler gen_shape_inference)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS pass/onnx_combine.td)
|
||||
onnf_tablegen(onnx_combine.inc -gen-rewriters)
|
||||
add_public_tablegen_target(gen_onnx_combine)
|
||||
add_dependencies(compiler gen_onnx_combine)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
|
|
|
@ -58,13 +58,47 @@ def ONNXAddOp: ONNX_Op<"add",
|
|||
//
|
||||
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in);
|
||||
let results = (outs AnyTensor);
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
// Build an ONNX Add operation using two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
||||
buildONNXAddOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
def ONNXMatMulOp: ONNX_Op<"matmul",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX matrix multiply operation";
|
||||
let description = [{
|
||||
|
||||
The "onnx.mul" multiplies two matrices.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[F32Tensor, F64Tensor]>:$lhs_in,
|
||||
AnyTypeOf<[F32Tensor, F64Tensor]>:$rhs_in);
|
||||
let results = (outs AnyTypeOf<[F32Tensor, F64Tensor]>);
|
||||
}
|
||||
|
||||
def ONNXGemmOp: ONNX_Op<"gemm",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX general matrix multiply operation";
|
||||
let description = [{
|
||||
|
||||
The "onnx.gemm" generic matrix multiplication with bias.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyTensor>:$inputs);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def ONNXFullGemmOp: ONNX_Op<"full_gemm",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX general matrix multiply operation";
|
||||
let description = [{
|
||||
|
||||
The "onnx.gemm" generic matrix multiplication with bias.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
#endif // ONNX_OPS
|
||||
|
|
|
@ -40,18 +40,54 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
|
|||
// ONNX Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state,
|
||||
mlir::Value* lhs, mlir::Value* rhs) {
|
||||
state.addTypes(UnrankedTensorType::get(builder->getF32Type()));
|
||||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
// Add
|
||||
|
||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXAddOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// MatMul
|
||||
|
||||
void ONNXMatMulOp::inferShapes() {
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Verify that matrix sizes are valid.
|
||||
// Take into account the dimensionality of the matrix.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Gemm
|
||||
|
||||
void ONNXGemmOp::inferShapes() {
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
// FullGemm
|
||||
|
||||
void ONNXFullGemmOp::inferShapes() {
|
||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Verify that matrix sizes are valid for multiplication and addition.
|
||||
// Take into account the dimensionality of the matrix.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
//===- ONNXCombine.cpp - ONNX High Level Optimizer ------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a set of simple combiners for optimizing operations in
|
||||
// the ONNX dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include <numeric>
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "src/compiler/onnx_combine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Register optimization patterns as "canonicalization" patterns
|
||||
/// on the ONNXMatMultOp.
|
||||
void ONNXAddOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<MulAddToGemmOptPattern>(context);
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
//=- ONNXCombine.td - Pattern Match Optimizations for ONNX -*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines language-specific pattern match optimizations for ONNX using
|
||||
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ONNX_COMBINE
|
||||
#define ONNX_COMBINE
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "dialect/onnx/onnx.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
/// Note: The DRR definition used for defining patterns is shown below:
|
||||
///
|
||||
/// class Pattern<
|
||||
/// dag sourcePattern, list<dag> resultPatterns,
|
||||
/// list<dag> additionalConstraints = [],
|
||||
/// dag benefitsAdded = (addBenefit 0)
|
||||
/// >;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-Match and Rewrite
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// onnx.add(onnx.matmult(%X, %Y), %Z) = onnx.gemm(%X, %Y, %Z)
|
||||
def MulAddToGemmOptPattern : Pat<(ONNXAddOp(ONNXMatMulOp $m1, $m2), $m3),
|
||||
(ONNXFullGemmOp $m1, $m2, $m3)>;
|
||||
|
||||
#endif // ONNX_COMBINE
|
|
@ -10,14 +10,17 @@
|
|||
#include <llvm/Support/InitLLVM.h>
|
||||
#include <llvm/Support/MemoryBuffer.h>
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
#include <mlir/Dialect/StandardOps/Ops.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
#include <mlir/Support/MlirOptMain.h>
|
||||
#include <mlir/Dialect/StandardOps/Ops.h>
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
#include "src/compiler/helper.hpp"
|
||||
#include "src/compiler/pass/passes.hpp"
|
||||
|
||||
using namespace onnf;
|
||||
|
||||
|
@ -45,6 +48,8 @@ static llvm::cl::opt<bool> verify_passes("verify-each",
|
|||
int main(int argc, char** argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||
|
||||
// Register any pass manager command line options.
|
||||
mlir::registerPassManagerCLOptions();
|
||||
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
|
||||
|
@ -57,9 +62,6 @@ int main(int argc, char** argv) {
|
|||
|
||||
auto output = mlir::openOutputFile(output_filename, &error_message);
|
||||
|
||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
|
||||
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
|
||||
split_input_file, verify_diagnostics, verify_passes));
|
||||
}
|
||||
|
|
|
@ -76,6 +76,7 @@ int main(int ac, char* av[]) {
|
|||
|
||||
mlir::PassManager pm(&context);
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.run(*module);
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Reference in New Issue