diff --git a/MLIR.cmake b/MLIR.cmake index ba3fc12..bacb967 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -263,7 +263,9 @@ set(ONNXMLIRWholeArchiveLibs OMShapeInference OMShapeInferenceOpInterface OMAttributePromotion - OMPromotableConstOperandsOpInterface) + OMPromotableConstOperandsOpInterface + OMElideConstants + OMElideKrnlGlobalConstants) # Function to construct linkage option for the static libraries that must be # linked with --whole-archive (or equivalent). diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 37206c2..7f67912 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -34,6 +34,8 @@ target_link_libraries(onnx-mlir OMShapeInferenceOpInterface OMAttributePromotion OMPromotableConstOperandsOpInterface + OMElideConstants + OMElideKrnlGlobalConstants OMKrnlToAffine OMKrnlToLLVM OMONNXToKrnl diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index 8389047..e5105d1 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -41,8 +41,8 @@ struct ONNXConstantOpLowering : public ConversionPattern { auto constantGlobal = rewriter.create(loc, memRefType, rewriter.getI64ArrayAttr(shape), - constantOp.value().getValue(), - rewriter.getStringAttr("constant_" + std::to_string(constantID))); + rewriter.getStringAttr("constant_" + std::to_string(constantID)), + constantOp.value().getValue()); // Increment constant ID: constantID++; diff --git a/src/Dialect/Krnl/KrnlOps.td b/src/Dialect/Krnl/KrnlOps.td index 0e85886..a2772f9 100644 --- a/src/Dialect/Krnl/KrnlOps.td +++ b/src/Dialect/Krnl/KrnlOps.td @@ -199,7 +199,7 @@ def KrnlGlobalOp : Op { Operation for holding global data values. }]; - let arguments = (ins AnyAttr:$shape, AnyAttr:$value, StrAttr:$name); + let arguments = (ins AnyAttr:$shape, StrAttr:$name, OptionalAttr:$value); let results = (outs AnyTypeOf<[AnyMemRef]>:$output); let parser = ?; diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index 6ef8d62..8071908 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -9,6 +9,9 @@ //===----------------------------------------------------------------------===// #include "src/MainUtils.hpp" +#include +#include +#include using namespace std; using namespace onnx_mlir; @@ -35,9 +38,10 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context, } } -void EmitLLVMBitCode(const mlir::OwningModuleRef &module) { +void EmitLLVMBitCode( + const mlir::OwningModuleRef &module, string outputFilename) { error_code error; - llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error, + llvm::raw_fd_ostream moduleBitcodeStream(outputFilename, error, llvm::sys::fs::F_None); llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); @@ -97,3 +101,67 @@ void processInputFile(string inputFilename, EmissionTargetType emissionTarget, LoadMLIR(inputFilename, context, module); } } + +void outputCode( + mlir::OwningModuleRef &module, string filename, string extension) { + // Start a separate process to redirect the model output. I/O redirection + // changes will not be visible to the parent process. + if (fork() == 0) { + const char * tempFilename = (filename + extension).c_str(); + freopen(tempFilename, "w", stderr); + module->dump(); + fclose(stderr); + exit(0); + } +} + +void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget, + mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + // For EmitONNXIR and EmitMLIR the constant value are embedded in the code + // thus making the code hard to read. These values can be elided by emitting + // two versions of the same source code: + // (1) a version with all the constant values included meant for being passed + // back to onnx-mlir for further processing and stored in: + // + // .onnx.mlir + // + // (2) a version without constants meant for being inspected by users and + // stored in: + // + // .mlir + // + // In the case of the LLVM Dialect IR the constant values are grouped + // outside the function code at the beginning of the file in which case the + // elision of these constants is not strictly required. Elision is also not + // necessary when emitting the .bc file. + if (emissionTarget == EmitLLVMBC) { + // Write LLVM bitcode to disk. + string outputFilename = outputBaseName + ".bc"; + EmitLLVMBitCode(module, outputFilename); + printf("LLVM bitcode written to %s\n", outputFilename.c_str()); + } else { + // Emit the version with all constants included. + outputCode(module, outputBaseName, ".onnx.mlir"); + printf("Full MLIR code written to: \n\t%s\n\n", + (outputBaseName + ".onnx.mlir").c_str()); + + // Apply specific passes to clean up the code where necessary. + mlir::PassManager cleanSourcePM(&context); + if (emissionTarget == EmitONNXIR || emissionTarget == EmitONNXBasic) + cleanSourcePM.addPass(mlir::createElideConstantValuePass()); + if (emissionTarget == EmitMLIR) + cleanSourcePM.addPass(mlir::createElideConstGlobalValuePass()); + + if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR || + emissionTarget == EmitMLIR) { + if (mlir::failed(cleanSourcePM.run(*module))) + llvm::errs() << "Could not apply simplification passes.\n"; + outputCode(module, outputBaseName, ".mlir"); + printf("Constant-free MLIR Code written to: \n\t%s\n\n", + (outputBaseName + ".mlir").c_str()); + + printf("Use:\n\t%s\nto continue lowering the code to other dialects.\n", + (outputBaseName + ".onnx.mlir").c_str()); + } + } +} \ No newline at end of file diff --git a/src/MainUtils.hpp b/src/MainUtils.hpp index 143f6fc..bf96cea 100644 --- a/src/MainUtils.hpp +++ b/src/MainUtils.hpp @@ -48,7 +48,8 @@ enum EmissionTargetType { void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context, mlir::OwningModuleRef &module); -void EmitLLVMBitCode(const mlir::OwningModuleRef &module); +void EmitLLVMBitCode( + const mlir::OwningModuleRef &module, std::string outputFilename); void registerDialects(); @@ -61,5 +62,13 @@ void addKrnlToAffinePasses(mlir::PassManager &pm); void addKrnlToLLVMPasses(mlir::PassManager &pm); void processInputFile(std::string inputFilename, - EmissionTargetType emissionTarget, mlir::MLIRContext &context, - mlir::OwningModuleRef &module); \ No newline at end of file + EmissionTargetType emissionTarget, mlir::MLIRContext &context, + mlir::OwningModuleRef &module); + +void outputCode( + mlir::OwningModuleRef &module, std::string filename, + std::string extension); + +void emitOutputFiles(std::string outputBaseName, + EmissionTargetType emissionTarget, mlir::MLIRContext &context, + mlir::OwningModuleRef &module); \ No newline at end of file diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index ced10ec..4048932 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -23,12 +23,18 @@ std::unique_ptr createShapeInferencePass(); /// Pass for promoting constant operands to attributes. std::unique_ptr createAttributePromotionPass(); +/// Pass for eliding the values of constant operations. +std::unique_ptr createElideConstantValuePass(); + /// Add pass for lowering to Krnl IR. std::unique_ptr createLowerToKrnlPass(); /// Pass for lowering frontend dialects to Krnl IR dialect. std::unique_ptr createLowerKrnlPass(); +/// Pass for eliding the values of global Krnl operations. +std::unique_ptr createElideConstGlobalValuePass(); + /// Pass for lowering Krnl dialect to LLVM dialect. std::unique_ptr createKrnlLowerToLLVMPass(); diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index 8231004..c3d064b 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -24,4 +24,16 @@ add_dependencies(OMKrnlToLLVM OMKrnlOps OMONNXOps) +add_library(OMElideKrnlGlobalConstants + ElideKrnlGlobalConstants.cpp) +target_include_directories(OMElideKrnlGlobalConstants + PRIVATE + ${ONNX_MLIR_SRC_ROOT} + ${ONNX_MLIR_BIN_ROOT} + ${ONNX_MLIR_SRC_ROOT}) +target_link_libraries(OMElideKrnlGlobalConstants + ${MLIRLibs} + OMKrnlOps + OMONNXOps) + add_subdirectory(ONNX) diff --git a/src/Transform/ElideKrnlGlobalConstants.cpp b/src/Transform/ElideKrnlGlobalConstants.cpp new file mode 100644 index 0000000..3f61e35 --- /dev/null +++ b/src/Transform/ElideKrnlGlobalConstants.cpp @@ -0,0 +1,76 @@ +//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// In practice, the constant values of Global Krnl operations may be large +// enough to hinder the readability of the MLIR intermediate representation. +// +// This file creates a pass which elides the explicit values of constant +// global operations. This pass has purely cosmetic purposes and should only be +// run to obtain a compact representation of the program when emitting Krnl +// dialect code. This pass should never be invoked on code meant to be run. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" + +using namespace mlir; + +namespace { + +/*! + * RewritePattern that replaces existing constant Krnl global values + * with a similar operation which preserves all attributes except the value + * attribute. + */ + +class KrnlConstGlobalValueElision : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(KrnlGlobalOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + if (op.value().hasValue()) { + auto newGlobalOp = rewriter.create( + loc, op.getResult().getType(), op.shape(), op.name(), nullptr); + rewriter.replaceOp(op, newGlobalOp.getResult()); + } + + return success(); + } +}; + +/*! + * Function pass that performs constant value elision of Krnl globals. + */ +class ElideConstGlobalValuePass + : public mlir::FunctionPass { +public: + void runOnFunction() override { + auto function = getFunction(); + + ConversionTarget target(getContext()); + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + applyPatternsGreedily(function, patterns); + } +}; +} // namespace + +std::unique_ptr mlir::createElideConstGlobalValuePass() { + return std::make_unique(); +} + +static PassRegistration pass("elide-krnl-constants", + "Elide the constant values of the Global Krnl operations."); \ No newline at end of file diff --git a/src/Transform/LowerToLLVM.cpp b/src/Transform/LowerToLLVM.cpp index 468b89d..1f88bd2 100644 --- a/src/Transform/LowerToLLVM.cpp +++ b/src/Transform/LowerToLLVM.cpp @@ -143,9 +143,12 @@ public: OpBuilder::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); + assert(krnlGlobalOp.value().hasValue() && + "Krnl Global must always have a value"); + global = rewriter.create(loc, llvmGlobalType, /*isConstant=*/true, - LLVM::Linkage::Internal, name, krnlGlobalOp.value()); + LLVM::Linkage::Internal, name, krnlGlobalOp.value().getValue()); } // Some frequently used types. diff --git a/src/Transform/ONNX/AttributePromotion.cpp b/src/Transform/ONNX/AttributePromotion.cpp index ccc7f27..7abf2f8 100644 --- a/src/Transform/ONNX/AttributePromotion.cpp +++ b/src/Transform/ONNX/AttributePromotion.cpp @@ -1,5 +1,4 @@ -//===----- attribute_promotion.cpp - Attribute Promotion -//-------------------===// +//===----- attribute_promotion.cpp - Attribute Promotion ------------------===// // // Copyright 2020 The IBM Research Authors. // diff --git a/src/Transform/ONNX/CMakeLists.txt b/src/Transform/ONNX/CMakeLists.txt index 78ab2f7..d34127e 100644 --- a/src/Transform/ONNX/CMakeLists.txt +++ b/src/Transform/ONNX/CMakeLists.txt @@ -8,6 +8,12 @@ target_include_directories(OMAttributePromotion add_dependencies(OMAttributePromotion OMPromotableConstOperandsOpInterface) +add_library(OMElideConstants + ElideConstants.cpp) +target_include_directories(OMElideConstants + PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} + ${ONNF_MLIR_SRC_ROOT}) + set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td) onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters) add_public_tablegen_target(OMONNXRewriteIncGen) diff --git a/src/Transform/ONNX/ElideConstants.cpp b/src/Transform/ONNX/ElideConstants.cpp new file mode 100644 index 0000000..b8fc493 --- /dev/null +++ b/src/Transform/ONNX/ElideConstants.cpp @@ -0,0 +1,83 @@ +//===----- ElideConstants.cpp - Elide Constant Values ---------------------===// +// +// Copyright 2020 The IBM Research Authors. +// +// ============================================================================= +// +// In practice, the constant values of Constant operations may be large enough +// to hinder the readability of the MLIR intermediate representation. +// +// This file creates a pass which elides the explicit values of Constant +// operations. This pass has purely cosmetic purposes and should only be run to +// obtain a compact representation of the program when emitting ONNX and KRNL +// Dialect code. This pass should never be invoked on code meant to be run. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Pass/Passes.hpp" + +using namespace mlir; + +namespace { + +/*! + * RewritePattern that replaces existing Constant operations + * with Constant operations with the same shape information but + * no values. + */ + +class ConstantValueElision : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ONNXConstantOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto constOp = llvm::dyn_cast(&op); + + if (constOp->sparse_value().hasValue()) + emitError(loc, "Only support dense values at this time"); + + if (constOp->value().hasValue()) { + auto newConstOp = rewriter.create( + loc, constOp->getResult().getType(), nullptr, nullptr); + rewriter.replaceOp(op, newConstOp.getResult()); + } + return success(); + } +}; + +/*! + * Function pass that performs constant value elision. + */ +class ElideConstantValuePass + : public mlir::FunctionPass { +public: + void runOnFunction() override { + auto function = getFunction(); + + ConversionTarget target(getContext()); + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + applyPatternsGreedily(function, patterns); + } +}; +} // end anonymous namespace + +/*! + * Create a Constant Value Elision pass. + */ +std::unique_ptr mlir::createElideConstantValuePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "elide-constants", "Elide values of constant operations."); diff --git a/src/main.cpp b/src/main.cpp index 8bbdef7..7cc57a1 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -43,9 +43,14 @@ int main(int argc, char *argv[]) { mlir::OwningModuleRef module; processInputFile(inputFilename, emissionTarget, context, module); + // Input file base name. + string outputBaseName = + inputFilename.substr(0, inputFilename.find_last_of(".")); + mlir::PassManager pm(&context); - if (emissionTarget >= EmitONNXIR) + if (emissionTarget >= EmitONNXIR) { addONNXToMLIRPasses(pm); + } if (emissionTarget >= EmitMLIR) { addONNXToKrnlPasses(pm); @@ -58,12 +63,7 @@ int main(int argc, char *argv[]) { if (mlir::failed(pm.run(*module))) return 4; - if (emissionTarget == EmitLLVMBC) { - // Write LLVM bitcode to disk. - EmitLLVMBitCode(module); - printf("LLVM bitcode written to ./model.bc"); - } else - module->dump(); + emitOutputFiles(outputBaseName, emissionTarget, context, module); return 0; } diff --git a/test/backend/test.py b/test/backend/test.py index c0cb3b3..36c37df 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -42,13 +42,13 @@ class DummyBackend(onnx.backend.base.Backend): execute_commands([ONNX_MLIR, "temp_model.onnx"]) # Call llc to generate object file from bitcode. execute_commands( - [LLC, "-filetype=obj", "-relocation-model=pic", "model.bc"]) + [LLC, "-filetype=obj", "-relocation-model=pic", "temp_model.bc"]) # Generate shared library from object file, linking with c runtime. execute_commands([ - CXX, "-shared", "-fPIC", "model.o", "-o", "model.so", + CXX, "-shared", "-fPIC", "temp_model.o", "-o", "temp_model.so", "-L" + RUNTIME_DIR, "-lcruntime" ]) - return ExecutionSession("./model.so", "_dyn_entry_point_main_graph") + return ExecutionSession("./temp_model.so", "_dyn_entry_point_main_graph") @classmethod def supports_device(cls, device): diff --git a/test/mlir/onnx/onnx_elide_constants.mlir b/test/mlir/onnx/onnx_elide_constants.mlir new file mode 100644 index 0000000..8455bbb --- /dev/null +++ b/test/mlir/onnx/onnx_elide_constants.mlir @@ -0,0 +1,10 @@ +// RUN: onnx-mlir-opt --elide-constants %s -split-input-file | FileCheck %s + +// CHECK-LABEL: func @test_elide_constant(%arg0: tensor<1xf32>) -> tensor<1x10xf32> +func @test_elide_constant(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { + %0 = "onnx.Constant"() {value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> tensor<1x10xf32> + "std.return"(%0) : (tensor<1x10xf32>) -> () + + // CHECK: %0 = "onnx.Constant"() : () -> tensor<1x10xf32> + // CHECK: return %0 : tensor<1x10xf32> +} \ No newline at end of file diff --git a/test/mlir/onnx/onnx_krnl_global_elision.mlir b/test/mlir/onnx/onnx_krnl_global_elision.mlir new file mode 100644 index 0000000..9bce410 --- /dev/null +++ b/test/mlir/onnx/onnx_krnl_global_elision.mlir @@ -0,0 +1,10 @@ +// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s + +// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> +func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> { + %0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> memref<1x10xf32> + return %0 : memref<1x10xf32> + + // CHECK: %0 = "krnl.global"() {name = "constant_0", shape = [1, 10]} : () -> memref<1x10xf32> + // CHECK: return %0 : memref<1x10xf32> +} \ No newline at end of file