Introduce krnl.shape operation and its lowering (#267)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Add krnl shape op and lowering pass.

* Add lowering function.

* Clean-up code.

* Remove duplicate entry.

* Add test.

* Update LowerKrnlShape.cpp

* Update KrnlToLLVM.cpp
This commit is contained in:
Gheorghe-Teodor Bercea 2020-08-19 12:57:40 -04:00 committed by GitHub
parent 13e8070708
commit f278f08120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 174 additions and 3 deletions

View File

@ -143,6 +143,7 @@ find_mlir_lib(MLIRCopyOpInterface)
find_mlir_lib(MLIRDialect) find_mlir_lib(MLIRDialect)
find_mlir_lib(MLIREDSC) find_mlir_lib(MLIREDSC)
find_mlir_lib(MLIRExecutionEngine) find_mlir_lib(MLIRExecutionEngine)
find_mlir_lib(MLIRInferTypeOpInterface)
find_mlir_lib(MLIRIR) find_mlir_lib(MLIRIR)
find_mlir_lib(MLIRLLVMIR) find_mlir_lib(MLIRLLVMIR)
find_mlir_lib(MLIRLoopAnalysis) find_mlir_lib(MLIRLoopAnalysis)
@ -167,6 +168,10 @@ find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransforms) find_mlir_lib(MLIRTransforms)
find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRSupport) find_mlir_lib(MLIRSupport)
find_mlir_lib(MLIRShape)
find_mlir_lib(MLIRShapeToStandard)
find_mlir_lib(MLIRShapeToSCF)
find_mlir_lib(MLIRSideEffectInterfaces)
find_mlir_lib(MLIROpenMP) find_mlir_lib(MLIROpenMP)
find_mlir_lib(MLIROptLib) find_mlir_lib(MLIROptLib)
find_mlir_lib(MLIRTableGen) find_mlir_lib(MLIRTableGen)
@ -259,6 +264,10 @@ set(MLIRLibs
${MLIRLinalgEDSC} ${MLIRLinalgEDSC}
${MLIRViewLikeInterface} ${MLIRViewLikeInterface}
${MLIRPresburger} ${MLIRPresburger}
${MLIRShape}
${MLIRShapeToStandard}
${MLIRShapeToSCF}
${MLIRInferTypeOpInterface}
# strict order verified # strict order verified
${LLVMBitWriter} ${LLVMBitWriter}
${LLVMObject} ${LLVMObject}

View File

@ -24,7 +24,8 @@ set(OMLibs
OMPackKrnlGlobalConstants OMPackKrnlGlobalConstants
OMEnableMemoryPool OMEnableMemoryPool
OMBundleMemoryPools OMBundleMemoryPools
OMDisconnectKrnlDimFromAlloc) OMDisconnectKrnlDimFromAlloc
OMLowerKrnlShape)
set(OMLibs ${OMLibs} PARENT_SCOPE) set(OMLibs ${OMLibs} PARENT_SCOPE)
add_subdirectory(Tool) add_subdirectory(Tool)

View File

@ -10,6 +10,7 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@ -857,6 +858,7 @@ void mlir::populateAffineAndKrnlToLLVMConversion(
LLVMTypeConverter &typeConverter) { LLVMTypeConverter &typeConverter) {
populateAffineToStdConversionPatterns(patterns, ctx); populateAffineToStdConversionPatterns(patterns, ctx);
populateLoopToStdConversionPatterns(patterns, ctx); populateLoopToStdConversionPatterns(patterns, ctx);
populateShapeToStandardConversionPatterns(patterns, ctx);
populateStdToLLVMConversionPatterns(typeConverter, patterns); populateStdToLLVMConversionPatterns(typeConverter, patterns);
patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>( patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(

View File

@ -54,7 +54,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
// We define the specific operations, or dialects, that are legal targets for // We define the specific operations, or dialects, that are legal targets for
// this lowering. // this lowering.
target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>(); target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect,
shape::ShapeDialect>();
// TODO: enable this once more ops are supported. // TODO: enable this once more ops are supported.
// We also define the ONNX dialect as Illegal so that the conversion will fail // We also define the ONNX dialect as Illegal so that the conversion will fail

View File

@ -10,6 +10,7 @@
#pragma once #pragma once
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"

View File

@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Dialect/Shape/IR/ShapeBase.td"
def Krnl_Dialect : Dialect { def Krnl_Dialect : Dialect {
let name = "krnl"; let name = "krnl";
@ -355,3 +356,21 @@ def KrnlDimOp : Op<Krnl_Dialect, "dim"> {
let parser = ?; let parser = ?;
let printer = ?; let printer = ?;
} }
def KrnlShapeOp : Op<Krnl_Dialect, "shape"> {
let summary = "Krnl operation to retreieve the shape of a MemRef.";
let description = [{
Extracts the shape of a MemRef:
```
"krnl.shape"(%memref)
```
The return result is of `shape.type`.
}];
let arguments = (ins AnyTypeOf<[AnyMemRef]>:$alloc);
let results = (outs Shape_ShapeType:$shape);
let parser = ?;
let printer = ?;
}

View File

@ -81,5 +81,11 @@ void initOMPasses() {
[]() -> std::unique_ptr<mlir::Pass> { []() -> std::unique_ptr<mlir::Pass> {
return mlir::createDisconnectKrnlDimFromAllocPass(); return mlir::createDisconnectKrnlDimFromAllocPass();
}); });
mlir::registerPass("lower-krnl-shape",
"Lower krnl.shape operation to use Shape dialect operations.",
[]() -> std::unique_ptr<mlir::Pass> {
return mlir::createLowerKrnlShapePass();
});
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@ -381,6 +381,7 @@ void registerDialects() {
mlir::registerDialect<mlir::LLVM::LLVMDialect>(); mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::scf::SCFDialect>(); mlir::registerDialect<mlir::scf::SCFDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>(); mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::shape::ShapeDialect>();
mlir::registerDialect<mlir::ONNXOpsDialect>(); mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>(); mlir::registerDialect<mlir::KrnlOpsDialect>();
} }

View File

@ -26,11 +26,13 @@
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
#include "mlir/InitAllDialects.h" #include "mlir/InitAllDialects.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser.h" #include "mlir/Parser.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"

View File

@ -43,6 +43,9 @@ std::unique_ptr<Pass> createConvertKrnlToAffinePass();
/// Pass for lowering krnl.dim operations to standard dialect. /// Pass for lowering krnl.dim operations to standard dialect.
std::unique_ptr<Pass> createDisconnectKrnlDimFromAllocPass(); std::unique_ptr<Pass> createDisconnectKrnlDimFromAllocPass();
/// Pass for lowering krnl.shape operation.
std::unique_ptr<Pass> createLowerKrnlShapePass();
/// Pass for eliding the values of global Krnl operations. /// Pass for eliding the values of global Krnl operations.
std::unique_ptr<Pass> createElideConstGlobalValuePass(); std::unique_ptr<Pass> createElideConstGlobalValuePass();

View File

@ -66,6 +66,7 @@ int main(int argc, char **argv) {
mlir::registerDialect<mlir::scf::SCFDialect>(); mlir::registerDialect<mlir::scf::SCFDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>(); mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::vector::VectorDialect>(); mlir::registerDialect<mlir::vector::VectorDialect>();
mlir::registerDialect<mlir::shape::ShapeDialect>();
registerTransformsPasses(); registerTransformsPasses();
registerAffinePasses(); registerAffinePasses();

View File

@ -60,4 +60,17 @@ add_dependencies(OMDisconnectKrnlDimFromAlloc
OMKrnlOps OMKrnlOps
OMONNXOps) OMONNXOps)
add_library(OMLowerKrnlShape
LowerKrnlShape.cpp)
target_include_directories(OMLowerKrnlShape
PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
target_link_libraries(OMLowerKrnlShape
onnx)
add_dependencies(OMLowerKrnlShape
OMKrnlOps
OMONNXOps)
add_subdirectory(ONNX) add_subdirectory(ONNX)

View File

@ -0,0 +1,86 @@
//===-------- LowerKrnlShape.cpp ------------------------------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This pass enables the lowering of the krnl.shape operation to use Shape
// dialect operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp"
using namespace mlir;
namespace {
/*!
* RewritePattern that replaces:
* %0 = alloc(%d) : memref<?x10x<type>, #map>
* %1 = krnl.shape(%0) : memref<?x10x<type>> -> !shape.shape
* with:
* %0 = alloc(%d) : memref<?x10x<type>, #map>
* %c0 = constant 0 : index
* %1 = krnl.dim(%0, %c0) : memref<?x10x<type>, #map>, index
* %c1 = constant 1 : index
* %2 = krnl.dim(%0, %c1) : memref<?x10x<type>, #map>, index
* %shape = shape.from_extents %1, %2
*/
class LowerKrnlShape : public OpRewritePattern<KrnlShapeOp> {
public:
using OpRewritePattern<KrnlShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
KrnlShapeOp krnlShapeOp, PatternRewriter &rewriter) const override {
auto loc = krnlShapeOp.getLoc();
auto rank =
convertToMemRefType(krnlShapeOp.alloc().getType()).getShape().size();
SmallVector<mlir::Value, 4> fromExtentsOpOperands;
for (int idx = 0; idx < rank; idx++) {
auto index = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), idx));
auto operand = rewriter.create<KrnlDimOp>(
loc, rewriter.getIndexType(), krnlShapeOp.alloc(), index);
fromExtentsOpOperands.emplace_back(operand);
}
auto fromExtentsOp = rewriter.create<mlir::shape::FromExtentsOp>(
loc, rewriter.getType<mlir::shape::ShapeType>(), fromExtentsOpOperands);
rewriter.replaceOp(krnlShapeOp, fromExtentsOp.getResult());
return success();
}
};
/*!
* Function pass that emits the shape of a MemRef.
*/
class LowerKrnlShapePass
: public PassWrapper<LowerKrnlShapePass, FunctionPass> {
public:
void runOnFunction() override {
auto function = getFunction();
ConversionTarget target(getContext());
OwningRewritePatternList patterns;
patterns.insert<LowerKrnlShape>(&getContext());
applyPatternsAndFoldGreedily(function, patterns);
}
};
} // namespace
// TODO: integrate with other passes if needed.
std::unique_ptr<Pass> mlir::createLowerKrnlShapePass() {
return std::make_unique<LowerKrnlShapePass>();
}

View File

@ -0,0 +1,26 @@
// RUN: onnx-mlir-opt --lower-krnl-shape %s -split-input-file | FileCheck %s
func @test_krnl_shape_lowering(%arg0: memref<?x?xf32>) -> index {
%c1 = constant 1 : index
%c0 = constant 0 : index
%0 = dim %arg0, %c0 : memref<?x?xf32>
%1 = alloc(%0) : memref<?x10xf32>
%shape = "krnl.shape"(%1) : (memref<?x10xf32>) -> !shape.shape
%d1 = "shape.get_extent"(%shape, %c1) : (!shape.shape, index) -> !shape.size
%ind = "shape.size_to_index"(%d1) : (!shape.size) -> index
%e = addi %ind, %ind : index
return %e : index
// CHECK-LABEL: test_krnl_shape_lowering
// CHECK: [[CONST0:%.+]] = constant 0 : index
// CHECK: [[CONST1:%.+]] = constant 1 : index
// CHECK: [[DIM:%.+]] = dim %arg0, [[CONST0]] : memref<?x?xf32>
// CHECK: [[ALLOC:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
// CHECK: [[DIM0:%.+]] = "krnl.dim"([[ALLOC]], [[CONST0]]) : (memref<?x10xf32>, index) -> index
// CHECK: [[DIM1:%.+]] = "krnl.dim"([[ALLOC]], [[CONST1]]) : (memref<?x10xf32>, index) -> index
// CHECK: [[SHAPE:%.+]] = shape.from_extents [[DIM0]], [[DIM1]]
// CHECK: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], [[CONST1]] : !shape.shape, index -> !shape.size
// CHECK: [[EXTENT_AS_INDEX:%.+]] = shape.size_to_index [[EXTENT]] : !shape.size
// CHECK: [[RES:%.+]] = addi [[EXTENT_AS_INDEX]], [[EXTENT_AS_INDEX]] : index
// CHECK: return [[RES]] : index
}