From f278f081200f77cfc9ba0d9a4e52254486754db5 Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Wed, 19 Aug 2020 12:57:40 -0400 Subject: [PATCH] Introduce krnl.shape operation and its lowering (#267) * Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Add krnl shape op and lowering pass. * Add lowering function. * Clean-up code. * Remove duplicate entry. * Add test. * Update LowerKrnlShape.cpp * Update KrnlToLLVM.cpp --- MLIR.cmake | 9 ++ src/CMakeLists.txt | 3 +- src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp | 2 + .../ONNXToKrnl/ConvertONNXToKrnl.cpp | 3 +- src/Dialect/Krnl/KrnlOps.hpp | 1 + src/Dialect/Krnl/KrnlOps.td | 19 ++++ src/InitOMPasses.hpp | 8 +- src/MainUtils.cpp | 1 + src/MainUtils.hpp | 2 + src/Pass/Passes.hpp | 3 + src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp | 1 + src/Transform/CMakeLists.txt | 13 +++ src/Transform/LowerKrnlShape.cpp | 86 +++++++++++++++++++ test/mlir/krnl/krnl_shape_lowering.mlir | 26 ++++++ 14 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 src/Transform/LowerKrnlShape.cpp create mode 100644 test/mlir/krnl/krnl_shape_lowering.mlir diff --git a/MLIR.cmake b/MLIR.cmake index ab8a1a5..c8a7761 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -143,6 +143,7 @@ find_mlir_lib(MLIRCopyOpInterface) find_mlir_lib(MLIRDialect) find_mlir_lib(MLIREDSC) find_mlir_lib(MLIRExecutionEngine) +find_mlir_lib(MLIRInferTypeOpInterface) find_mlir_lib(MLIRIR) find_mlir_lib(MLIRLLVMIR) find_mlir_lib(MLIRLoopAnalysis) @@ -167,6 +168,10 @@ find_mlir_lib(MLIRTargetLLVMIR) find_mlir_lib(MLIRTransforms) find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRSupport) +find_mlir_lib(MLIRShape) +find_mlir_lib(MLIRShapeToStandard) +find_mlir_lib(MLIRShapeToSCF) +find_mlir_lib(MLIRSideEffectInterfaces) find_mlir_lib(MLIROpenMP) find_mlir_lib(MLIROptLib) find_mlir_lib(MLIRTableGen) @@ -259,6 +264,10 @@ set(MLIRLibs ${MLIRLinalgEDSC} ${MLIRViewLikeInterface} ${MLIRPresburger} + ${MLIRShape} + ${MLIRShapeToStandard} + ${MLIRShapeToSCF} + ${MLIRInferTypeOpInterface} # strict order verified ${LLVMBitWriter} ${LLVMObject} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d834ac1..68de557 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,7 +24,8 @@ set(OMLibs OMPackKrnlGlobalConstants OMEnableMemoryPool OMBundleMemoryPools - OMDisconnectKrnlDimFromAlloc) + OMDisconnectKrnlDimFromAlloc + OMLowerKrnlShape) set(OMLibs ${OMLibs} PARENT_SCOPE) add_subdirectory(Tool) diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp index 55564cb..7fec511 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -857,6 +858,7 @@ void mlir::populateAffineAndKrnlToLLVMConversion( LLVMTypeConverter &typeConverter) { populateAffineToStdConversionPatterns(patterns, ctx); populateLoopToStdConversionPatterns(patterns, ctx); + populateShapeToStandardConversionPatterns(patterns, ctx); populateStdToLLVMConversionPatterns(typeConverter, patterns); patterns.insert( diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 0639363..883e888 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -54,7 +54,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() { // We define the specific operations, or dialects, that are legal targets for // this lowering. - target.addLegalDialect(); + target.addLegalDialect(); // TODO: enable this once more ops are supported. // We also define the ONNX dialect as Illegal so that the conversion will fail diff --git a/src/Dialect/Krnl/KrnlOps.hpp b/src/Dialect/Krnl/KrnlOps.hpp index 1d4cd81..9460f00 100644 --- a/src/Dialect/Krnl/KrnlOps.hpp +++ b/src/Dialect/Krnl/KrnlOps.hpp @@ -10,6 +10,7 @@ #pragma once +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" diff --git a/src/Dialect/Krnl/KrnlOps.td b/src/Dialect/Krnl/KrnlOps.td index 0f44171..1199ec1 100644 --- a/src/Dialect/Krnl/KrnlOps.td +++ b/src/Dialect/Krnl/KrnlOps.td @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// include "mlir/IR/OpBase.td" +include "mlir/Dialect/Shape/IR/ShapeBase.td" def Krnl_Dialect : Dialect { let name = "krnl"; @@ -355,3 +356,21 @@ def KrnlDimOp : Op { let parser = ?; let printer = ?; } + +def KrnlShapeOp : Op { + let summary = "Krnl operation to retreieve the shape of a MemRef."; + let description = [{ + Extracts the shape of a MemRef: + ``` + "krnl.shape"(%memref) + ``` + The return result is of `shape.type`. + }]; + + let arguments = (ins AnyTypeOf<[AnyMemRef]>:$alloc); + let results = (outs Shape_ShapeType:$shape); + + let parser = ?; + let printer = ?; +} + diff --git a/src/InitOMPasses.hpp b/src/InitOMPasses.hpp index 4521ad8..92500d6 100644 --- a/src/InitOMPasses.hpp +++ b/src/InitOMPasses.hpp @@ -81,5 +81,11 @@ void initOMPasses() { []() -> std::unique_ptr { return mlir::createDisconnectKrnlDimFromAllocPass(); }); + + mlir::registerPass("lower-krnl-shape", + "Lower krnl.shape operation to use Shape dialect operations.", + []() -> std::unique_ptr { + return mlir::createLowerKrnlShapePass(); + }); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index 01793c0..4dfe1de 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -381,6 +381,7 @@ void registerDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); + mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); } diff --git a/src/MainUtils.hpp b/src/MainUtils.hpp index b680bc7..1155d08 100644 --- a/src/MainUtils.hpp +++ b/src/MainUtils.hpp @@ -26,11 +26,13 @@ #include "src/Pass/Passes.hpp" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/InitAllDialects.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 72c9857..024fdaa 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -43,6 +43,9 @@ std::unique_ptr createConvertKrnlToAffinePass(); /// Pass for lowering krnl.dim operations to standard dialect. std::unique_ptr createDisconnectKrnlDimFromAllocPass(); +/// Pass for lowering krnl.shape operation. +std::unique_ptr createLowerKrnlShapePass(); + /// Pass for eliding the values of global Krnl operations. std::unique_ptr createElideConstGlobalValuePass(); diff --git a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp index 65560e0..2f78bfd 100644 --- a/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp +++ b/src/Tool/ONNXMLIROpt/ONNXMLIROpt.cpp @@ -66,6 +66,7 @@ int main(int argc, char **argv) { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); + mlir::registerDialect(); registerTransformsPasses(); registerAffinePasses(); diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index dda900d..ef8c9e9 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -60,4 +60,17 @@ add_dependencies(OMDisconnectKrnlDimFromAlloc OMKrnlOps OMONNXOps) +add_library(OMLowerKrnlShape + LowerKrnlShape.cpp) +target_include_directories(OMLowerKrnlShape + PRIVATE + ${ONNX_MLIR_SRC_ROOT} + ${ONNX_MLIR_BIN_ROOT} + ${ONNX_MLIR_SRC_ROOT}) +target_link_libraries(OMLowerKrnlShape + onnx) +add_dependencies(OMLowerKrnlShape + OMKrnlOps + OMONNXOps) + add_subdirectory(ONNX) diff --git a/src/Transform/LowerKrnlShape.cpp b/src/Transform/LowerKrnlShape.cpp new file mode 100644 index 0000000..1452ad2 --- /dev/null +++ b/src/Transform/LowerKrnlShape.cpp @@ -0,0 +1,86 @@ +//===-------- LowerKrnlShape.cpp ------------------------------------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// This pass enables the lowering of the krnl.shape operation to use Shape +// dialect operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" + +using namespace mlir; + +namespace { + +/*! + * RewritePattern that replaces: + * %0 = alloc(%d) : memref, #map> + * %1 = krnl.shape(%0) : memref> -> !shape.shape + * with: + * %0 = alloc(%d) : memref, #map> + * %c0 = constant 0 : index + * %1 = krnl.dim(%0, %c0) : memref, #map>, index + * %c1 = constant 1 : index + * %2 = krnl.dim(%0, %c1) : memref, #map>, index + * %shape = shape.from_extents %1, %2 + */ + +class LowerKrnlShape : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + KrnlShapeOp krnlShapeOp, PatternRewriter &rewriter) const override { + auto loc = krnlShapeOp.getLoc(); + auto rank = + convertToMemRefType(krnlShapeOp.alloc().getType()).getShape().size(); + + SmallVector fromExtentsOpOperands; + for (int idx = 0; idx < rank; idx++) { + auto index = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), idx)); + auto operand = rewriter.create( + loc, rewriter.getIndexType(), krnlShapeOp.alloc(), index); + fromExtentsOpOperands.emplace_back(operand); + } + + auto fromExtentsOp = rewriter.create( + loc, rewriter.getType(), fromExtentsOpOperands); + rewriter.replaceOp(krnlShapeOp, fromExtentsOp.getResult()); + + return success(); + } +}; + +/*! + * Function pass that emits the shape of a MemRef. + */ +class LowerKrnlShapePass + : public PassWrapper { +public: + void runOnFunction() override { + auto function = getFunction(); + + ConversionTarget target(getContext()); + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + applyPatternsAndFoldGreedily(function, patterns); + } +}; +} // namespace + +// TODO: integrate with other passes if needed. +std::unique_ptr mlir::createLowerKrnlShapePass() { + return std::make_unique(); +} diff --git a/test/mlir/krnl/krnl_shape_lowering.mlir b/test/mlir/krnl/krnl_shape_lowering.mlir new file mode 100644 index 0000000..20831a6 --- /dev/null +++ b/test/mlir/krnl/krnl_shape_lowering.mlir @@ -0,0 +1,26 @@ +// RUN: onnx-mlir-opt --lower-krnl-shape %s -split-input-file | FileCheck %s + +func @test_krnl_shape_lowering(%arg0: memref) -> index { + %c1 = constant 1 : index + %c0 = constant 0 : index + %0 = dim %arg0, %c0 : memref + %1 = alloc(%0) : memref + %shape = "krnl.shape"(%1) : (memref) -> !shape.shape + %d1 = "shape.get_extent"(%shape, %c1) : (!shape.shape, index) -> !shape.size + %ind = "shape.size_to_index"(%d1) : (!shape.size) -> index + %e = addi %ind, %ind : index + return %e : index + + // CHECK-LABEL: test_krnl_shape_lowering + // CHECK: [[CONST0:%.+]] = constant 0 : index + // CHECK: [[CONST1:%.+]] = constant 1 : index + // CHECK: [[DIM:%.+]] = dim %arg0, [[CONST0]] : memref + // CHECK: [[ALLOC:%.+]] = alloc([[DIM]]) : memref + // CHECK: [[DIM0:%.+]] = "krnl.dim"([[ALLOC]], [[CONST0]]) : (memref, index) -> index + // CHECK: [[DIM1:%.+]] = "krnl.dim"([[ALLOC]], [[CONST1]]) : (memref, index) -> index + // CHECK: [[SHAPE:%.+]] = shape.from_extents [[DIM0]], [[DIM1]] + // CHECK: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], [[CONST1]] : !shape.shape, index -> !shape.size + // CHECK: [[EXTENT_AS_INDEX:%.+]] = shape.size_to_index [[EXTENT]] : !shape.size + // CHECK: [[RES:%.+]] = addi [[EXTENT_AS_INDEX]], [[EXTENT_AS_INDEX]] : index + // CHECK: return [[RES]] : index +}