[MLIR] Add optimization based on TableGen pattern (#363)
* Define pattern. * Fix file names. * Add canonicalizer optimization based on TableGen pattern. * Remove custom builders. * Enable canonicalization in ONNF and ONNF-OPT.
This commit is contained in:
parent
780e6f0aa0
commit
63596e723f
27
MLIR.cmake
27
MLIR.cmake
|
@ -88,6 +88,11 @@ find_library(MLIR_LLVM_IR
|
||||||
PATHS ${LLVM_PROJECT_LIB}
|
PATHS ${LLVM_PROJECT_LIB}
|
||||||
NO_DEFAULT_PATH)
|
NO_DEFAULT_PATH)
|
||||||
|
|
||||||
|
find_library(MLIR_LIB_TRANSFORM_UTILS
|
||||||
|
NAMES MLIRTransformUtils
|
||||||
|
PATHS ${LLVM_PROJECT_LIB}
|
||||||
|
NO_DEFAULT_PATH)
|
||||||
|
|
||||||
find_library(LLVM_LIB_SUPPORT
|
find_library(LLVM_LIB_SUPPORT
|
||||||
NAMES LLVMSupport
|
NAMES LLVMSupport
|
||||||
PATHS ${LLVM_PROJECT_LIB}
|
PATHS ${LLVM_PROJECT_LIB}
|
||||||
|
@ -106,6 +111,7 @@ set(MLIRLIBS
|
||||||
${MLIR_LIB_STANDARD_OPS}
|
${MLIR_LIB_STANDARD_OPS}
|
||||||
${MLIR_LIB_OPT_MAIN}
|
${MLIR_LIB_OPT_MAIN}
|
||||||
${MLIR_LIB_SUPPORT}
|
${MLIR_LIB_SUPPORT}
|
||||||
|
${MLIR_LIB_TRANSFORM_UTILS}
|
||||||
|
|
||||||
${MLIR_LIB_ANALYSIS}
|
${MLIR_LIB_ANALYSIS}
|
||||||
${MLIR_LIB_IR}
|
${MLIR_LIB_IR}
|
||||||
|
@ -116,10 +122,31 @@ set(MLIRLIBS
|
||||||
${MLIR_LIB_STANDARD_OPS}
|
${MLIR_LIB_STANDARD_OPS}
|
||||||
${MLIR_LIB_OPT_MAIN}
|
${MLIR_LIB_OPT_MAIN}
|
||||||
${MLIR_LIB_SUPPORT}
|
${MLIR_LIB_SUPPORT}
|
||||||
|
${MLIR_LIB_TRANSFORM_UTILS}
|
||||||
|
|
||||||
${LLVM_LIB_SUPPORT}
|
${LLVM_LIB_SUPPORT}
|
||||||
Threads::Threads)
|
Threads::Threads)
|
||||||
|
|
||||||
|
function(whole_archive_link target)
|
||||||
|
if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
|
||||||
|
set(link_flags "-L${LLVM_BUILD}/lib ")
|
||||||
|
FOREACH(LIB ${ARGN})
|
||||||
|
string(CONCAT link_flags ${link_flags} "-Wl,-force_load ${LLVM_BUILD}/lib/lib${LIB}.a ")
|
||||||
|
ENDFOREACH(LIB)
|
||||||
|
elseif(MSVC)
|
||||||
|
FOREACH(LIB ${ARGN})
|
||||||
|
string(CONCAT link_flags ${link_flags} "/WHOLEARCHIVE:${LIB} ")
|
||||||
|
ENDFOREACH(LIB)
|
||||||
|
else()
|
||||||
|
set(link_flags "-L${LLVM_BUILD}/lib -Wl,--whole-archive,")
|
||||||
|
FOREACH(LIB ${ARGN})
|
||||||
|
string(CONCAT link_flags ${link_flags} "-l${LIB},")
|
||||||
|
ENDFOREACH(LIB)
|
||||||
|
string(CONCAT link_flags ${link_flags} "--no-whole-archive")
|
||||||
|
endif()
|
||||||
|
set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags})
|
||||||
|
endfunction(whole_archive_link)
|
||||||
|
|
||||||
# Set up TableGen environment.
|
# Set up TableGen environment.
|
||||||
include(${LLVM_BUILD}/lib/cmake/llvm/TableGen.cmake)
|
include(${LLVM_BUILD}/lib/cmake/llvm/TableGen.cmake)
|
||||||
|
|
||||||
|
|
|
@ -193,11 +193,33 @@ class FrontendGenImpl {
|
||||||
// ONNX Dialect.
|
// ONNX Dialect.
|
||||||
llvm::StringRef OpName = node.op_type();
|
llvm::StringRef OpName = node.op_type();
|
||||||
if (OpName == "Add") {
|
if (OpName == "Add") {
|
||||||
auto op =
|
auto op = builder_.create<mlir::ONNXAddOp>(UnknownLoc(),
|
||||||
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
|
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
|
||||||
|
inputs[1]);
|
||||||
frontend_symbols_.AddMapping(
|
frontend_symbols_.AddMapping(
|
||||||
legalize_name(node.output()[0]), op.getResult());
|
legalize_name(node.output()[0]), op.getResult());
|
||||||
return;
|
return;
|
||||||
|
} else if (OpName == "MatMul") {
|
||||||
|
auto op = builder_.create<mlir::ONNXMatMulOp>(UnknownLoc(),
|
||||||
|
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
|
||||||
|
inputs[1]);
|
||||||
|
frontend_symbols_.AddMapping(
|
||||||
|
legalize_name(node.output()[0]), op.getResult());
|
||||||
|
return;
|
||||||
|
} else if (OpName == "Gemm") {
|
||||||
|
if (inputs.size() == 3) {
|
||||||
|
auto op = builder_.create<mlir::ONNXFullGemmOp>(UnknownLoc(),
|
||||||
|
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
|
||||||
|
inputs[1], inputs[2]);
|
||||||
|
frontend_symbols_.AddMapping(
|
||||||
|
legalize_name(node.output()[0]), op.getResult());
|
||||||
|
} else {
|
||||||
|
auto op = builder_.create<mlir::ONNXGemmOp>(UnknownLoc(),
|
||||||
|
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs);
|
||||||
|
frontend_symbols_.AddMapping(
|
||||||
|
legalize_name(node.output()[0]), op.getResult());
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Old way of doing things.
|
// Old way of doing things.
|
||||||
|
|
|
@ -10,6 +10,7 @@ add_library(
|
||||||
dialect/krnl/parser_helper.hpp
|
dialect/krnl/parser_helper.hpp
|
||||||
pass/shape_inference_pass.cpp
|
pass/shape_inference_pass.cpp
|
||||||
pass/shape_inference_interface.hpp
|
pass/shape_inference_interface.hpp
|
||||||
|
pass/onnx_combine.cpp
|
||||||
pass/passes.hpp)
|
pass/passes.hpp)
|
||||||
|
|
||||||
# Include root src directory.
|
# Include root src directory.
|
||||||
|
@ -32,7 +33,10 @@ target_link_libraries(compiler
|
||||||
${MLIRLIBS} curses)
|
${MLIRLIBS} curses)
|
||||||
|
|
||||||
add_executable(onnf-opt
|
add_executable(onnf-opt
|
||||||
tool/onnf_opt/onnf_opt.cpp)
|
tool/onnf_opt/onnf_opt.cpp)
|
||||||
|
|
||||||
|
set(LIB_LIST MLIRAffineOps MLIRLoopOps MLIRTransformUtils MLIREDSC MLIRTransforms)
|
||||||
|
whole_archive_link(onnf-opt ${LIB_LIST})
|
||||||
|
|
||||||
target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler)
|
target_link_libraries(onnf-opt ${Boost_LIBRARIES} ${MLIRLIBS} curses compiler)
|
||||||
target_include_directories(onnf-opt PRIVATE ../..)
|
target_include_directories(onnf-opt PRIVATE ../..)
|
||||||
|
@ -44,6 +48,11 @@ onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
||||||
add_public_tablegen_target(gen_shape_inference)
|
add_public_tablegen_target(gen_shape_inference)
|
||||||
add_dependencies(compiler gen_shape_inference)
|
add_dependencies(compiler gen_shape_inference)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS pass/onnx_combine.td)
|
||||||
|
onnf_tablegen(onnx_combine.inc -gen-rewriters)
|
||||||
|
add_public_tablegen_target(gen_onnx_combine)
|
||||||
|
add_dependencies(compiler gen_onnx_combine)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
|
|
|
@ -58,13 +58,47 @@ def ONNXAddOp: ONNX_Op<"add",
|
||||||
//
|
//
|
||||||
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in);
|
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in);
|
||||||
let results = (outs AnyTensor);
|
let results = (outs AnyTensor);
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Build an ONNX Add operation using two input operands.
|
def ONNXMatMulOp: ONNX_Op<"matmul",
|
||||||
let builders = [
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
let summary = "ONNX matrix multiply operation";
|
||||||
buildONNXAddOp(b, result, lhs, rhs);
|
let description = [{
|
||||||
}]
|
|
||||||
>];
|
The "onnx.mul" multiplies two matrices.
|
||||||
|
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyTypeOf<[F32Tensor, F64Tensor]>:$lhs_in,
|
||||||
|
AnyTypeOf<[F32Tensor, F64Tensor]>:$rhs_in);
|
||||||
|
let results = (outs AnyTypeOf<[F32Tensor, F64Tensor]>);
|
||||||
|
}
|
||||||
|
|
||||||
|
def ONNXGemmOp: ONNX_Op<"gemm",
|
||||||
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
|
let summary = "ONNX general matrix multiply operation";
|
||||||
|
let description = [{
|
||||||
|
|
||||||
|
The "onnx.gemm" generic matrix multiplication with bias.
|
||||||
|
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyTensor>:$inputs);
|
||||||
|
let results = (outs AnyTensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
def ONNXFullGemmOp: ONNX_Op<"full_gemm",
|
||||||
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
|
let summary = "ONNX general matrix multiply operation";
|
||||||
|
let description = [{
|
||||||
|
|
||||||
|
The "onnx.gemm" generic matrix multiplication with bias.
|
||||||
|
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
|
||||||
|
let results = (outs AnyTensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // ONNX_OPS
|
#endif // ONNX_OPS
|
||||||
|
|
|
@ -40,18 +40,54 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
|
||||||
// ONNX Operations
|
// ONNX Operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state,
|
// Add
|
||||||
mlir::Value* lhs, mlir::Value* rhs) {
|
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF32Type()));
|
|
||||||
state.addOperands({lhs, rhs});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
|
||||||
/// shape inference interface.
|
|
||||||
void ONNXAddOp::inferShapes() {
|
void ONNXAddOp::inferShapes() {
|
||||||
getResult()->setType(getOperand(0)->getType());
|
getResult()->setType(getOperand(0)->getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// MatMul
|
||||||
|
|
||||||
|
void ONNXMatMulOp::inferShapes() {
|
||||||
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||||
|
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
||||||
|
dims.emplace_back(rhsTy.getShape()[1]);
|
||||||
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO:
|
||||||
|
// Verify that matrix sizes are valid.
|
||||||
|
// Take into account the dimensionality of the matrix.
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Gemm
|
||||||
|
|
||||||
|
void ONNXGemmOp::inferShapes() {
|
||||||
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||||
|
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
||||||
|
dims.emplace_back(rhsTy.getShape()[1]);
|
||||||
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// FullGemm
|
||||||
|
|
||||||
|
void ONNXFullGemmOp::inferShapes() {
|
||||||
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||||
|
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
||||||
|
dims.emplace_back(rhsTy.getShape()[1]);
|
||||||
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO:
|
||||||
|
// Verify that matrix sizes are valid for multiplication and addition.
|
||||||
|
// Take into account the dimensionality of the matrix.
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
//===- ONNXCombine.cpp - ONNX High Level Optimizer ------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements a set of simple combiners for optimizing operations in
|
||||||
|
// the ONNX dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
|
#include "src/compiler/onnx_combine.inc"
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/// Register optimization patterns as "canonicalization" patterns
|
||||||
|
/// on the ONNXMatMultOp.
|
||||||
|
void ONNXAddOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList& results, MLIRContext* context) {
|
||||||
|
results.insert<MulAddToGemmOptPattern>(context);
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
//=- ONNXCombine.td - Pattern Match Optimizations for ONNX -*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Defines language-specific pattern match optimizations for ONNX using
|
||||||
|
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef ONNX_COMBINE
|
||||||
|
#define ONNX_COMBINE
|
||||||
|
|
||||||
|
#ifndef OP_BASE
|
||||||
|
include "dialect/onnx/onnx.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
/// Note: The DRR definition used for defining patterns is shown below:
|
||||||
|
///
|
||||||
|
/// class Pattern<
|
||||||
|
/// dag sourcePattern, list<dag> resultPatterns,
|
||||||
|
/// list<dag> additionalConstraints = [],
|
||||||
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
|
/// >;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pattern-Match and Rewrite
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// onnx.add(onnx.matmult(%X, %Y), %Z) = onnx.gemm(%X, %Y, %Z)
|
||||||
|
def MulAddToGemmOptPattern : Pat<(ONNXAddOp(ONNXMatMulOp $m1, $m2), $m3),
|
||||||
|
(ONNXFullGemmOp $m1, $m2, $m3)>;
|
||||||
|
|
||||||
|
#endif // ONNX_COMBINE
|
|
@ -10,14 +10,17 @@
|
||||||
#include <llvm/Support/InitLLVM.h>
|
#include <llvm/Support/InitLLVM.h>
|
||||||
#include <llvm/Support/MemoryBuffer.h>
|
#include <llvm/Support/MemoryBuffer.h>
|
||||||
#include <llvm/Support/ToolOutputFile.h>
|
#include <llvm/Support/ToolOutputFile.h>
|
||||||
|
#include <mlir/Dialect/StandardOps/Ops.h>
|
||||||
#include <mlir/Pass/Pass.h>
|
#include <mlir/Pass/Pass.h>
|
||||||
#include <mlir/Pass/PassManager.h>
|
#include <mlir/Pass/PassManager.h>
|
||||||
#include <mlir/Support/FileUtilities.h>
|
#include <mlir/Support/FileUtilities.h>
|
||||||
#include <mlir/Support/MlirOptMain.h>
|
#include <mlir/Support/MlirOptMain.h>
|
||||||
#include <mlir/Dialect/StandardOps/Ops.h>
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
|
||||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
#include "src/compiler/helper.hpp"
|
#include "src/compiler/helper.hpp"
|
||||||
|
#include "src/compiler/pass/passes.hpp"
|
||||||
|
|
||||||
using namespace onnf;
|
using namespace onnf;
|
||||||
|
|
||||||
|
@ -45,6 +48,8 @@ static llvm::cl::opt<bool> verify_passes("verify-each",
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
|
||||||
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
|
|
||||||
// Register any pass manager command line options.
|
// Register any pass manager command line options.
|
||||||
mlir::registerPassManagerCLOptions();
|
mlir::registerPassManagerCLOptions();
|
||||||
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
|
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
|
||||||
|
@ -57,9 +62,6 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
auto output = mlir::openOutputFile(output_filename, &error_message);
|
auto output = mlir::openOutputFile(output_filename, &error_message);
|
||||||
|
|
||||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
|
||||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
|
||||||
|
|
||||||
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
|
return failed(mlir::MlirOptMain(output->os(), std::move(file), passPipeline,
|
||||||
split_input_file, verify_diagnostics, verify_passes));
|
split_input_file, verify_diagnostics, verify_passes));
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,6 +76,7 @@ int main(int ac, char* av[]) {
|
||||||
|
|
||||||
mlir::PassManager pm(&context);
|
mlir::PassManager pm(&context);
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.run(*module);
|
pm.run(*module);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
Loading…
Reference in New Issue