Add krnl.dim operation and lowering pass (#261)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Add krnl.dim op.

* Add test with alloc with map.

* Code clean-up.

* Code clean-up.

* Add comment for function.

* Update DisconnectKrnlDimFromAlloc.cpp
This commit is contained in:
Gheorghe-Teodor Bercea 2020-08-14 12:59:33 -04:00 committed by GitHub
parent 1b42d0b4eb
commit d3dcee7366
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 220 additions and 1 deletions

View File

@ -23,7 +23,8 @@ set(OMLibs
OMElideKrnlGlobalConstants OMElideKrnlGlobalConstants
OMPackKrnlGlobalConstants OMPackKrnlGlobalConstants
OMEnableMemoryPool OMEnableMemoryPool
OMBundleMemoryPools) OMBundleMemoryPools
OMDisconnectKrnlDimFromAlloc)
set(OMLibs ${OMLibs} PARENT_SCOPE) set(OMLibs ${OMLibs} PARENT_SCOPE)
add_subdirectory(Tool) add_subdirectory(Tool)

View File

@ -233,6 +233,8 @@ void ConvertKrnlToAffinePass::runOnFunction() {
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addIllegalOp<KrnlTerminatorOp>(); target.addIllegalOp<KrnlTerminatorOp>();
// krnl.dim operations must be lowered prior to this pass.
target.addIllegalOp<KrnlDimOp>();
target.addLegalOp<AffineYieldOp>(); target.addLegalOp<AffineYieldOp>();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<KrnlTerminatorLowering>(&getContext()); patterns.insert<KrnlTerminatorLowering>(&getContext());

View File

@ -525,3 +525,20 @@ Value getDynamicMemRefSizeInBytes(
return result; return result;
} }
int64_t getAllocArgIndex(AllocOp allocOp, int64_t index) {
auto memRefShape =
convertToMemRefType(allocOp.getResult().getType()).getShape();
auto rank = memRefShape.size();
int dynDimIdx = 0;
for (int idx = 0; idx < rank; ++idx) {
if (memRefShape[idx] < 0) {
if (idx == index)
return dynDimIdx;
dynDimIdx++;
}
}
return -1;
}

View File

@ -252,3 +252,14 @@ int64_t getMemRefSizeInBytes(Value val);
Value getDynamicMemRefSizeInBytes( Value getDynamicMemRefSizeInBytes(
MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp); MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp);
/// This function returns the index in the list of alloc arguments of the
/// dynamic dimension corresponding to `index` in the MemRef shape.
/// As an example:
///
/// alloc(%d0, %d1, %d2) : memref<10x?x?x20x?x30xf32>
///
/// In the above alloc the list of alloc arguments is being represented by
/// %d0, %d1 and %d2. Their indices 0, 1, 2 correspond to `index` values
/// 1, 2 and 4 in the MemRef shape respectively
int64_t getAllocArgIndex(AllocOp allocOp, int64_t index);

View File

@ -334,3 +334,24 @@ def KrnlUnrollOp : Op<Krnl_Dialect, "unroll"> {
$loop attr-dict `:` type($loop) $loop attr-dict `:` type($loop)
}]; }];
} }
def KrnlDimOp : Op<Krnl_Dialect, "dim"> {
let summary = "Krnl dimensions operation.";
let description = [{
Emits the dimension of a MemRef independent of the MemRef alloc:
"krnl.dim"(%memref, %index)
The index identifies the dimension within the shape which is going to be emitted.
Initially the krnl.dim operation depends on the alloc of the MemRef.
Unlike the std.dim operation which maintains a dependency on the alloc of the MemRef, the dimension emitted by krnl.dim will not depend on the alloc operation of the MemRef once the krnl.dim operation is lowered.
Any changes to the original MemRef size after the krnl.dim has been lowered will not be picked up by the emitted dimension. This allows the original MemRef to be safely modified via code transformations or affine map normalization without the risk of changing the value already emitted via krnl.dim.
}];
let arguments = (ins AnyTypeOf<[AnyMemRef]>:$alloc, Index:$index);
let results = (outs Index:$dimension);
let parser = ?;
let printer = ?;
}

View File

@ -76,5 +76,10 @@ void initOMPasses() {
[]() -> std::unique_ptr<mlir::Pass> { []() -> std::unique_ptr<mlir::Pass> {
return mlir::createPackKrnlGlobalConstantsPass(); return mlir::createPackKrnlGlobalConstantsPass();
}); });
mlir::registerPass("disconnect-dims", "Disconnect dims from allocs.",
[]() -> std::unique_ptr<mlir::Pass> {
return mlir::createDisconnectKrnlDimFromAllocPass();
});
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@ -40,6 +40,9 @@ std::unique_ptr<Pass> createLowerToKrnlPass();
/// Pass for lowering frontend dialects to Krnl IR dialect. /// Pass for lowering frontend dialects to Krnl IR dialect.
std::unique_ptr<Pass> createConvertKrnlToAffinePass(); std::unique_ptr<Pass> createConvertKrnlToAffinePass();
/// Pass for lowering krnl.dim operations to standard dialect.
std::unique_ptr<Pass> createDisconnectKrnlDimFromAllocPass();
/// Pass for eliding the values of global Krnl operations. /// Pass for eliding the values of global Krnl operations.
std::unique_ptr<Pass> createElideConstGlobalValuePass(); std::unique_ptr<Pass> createElideConstGlobalValuePass();

View File

@ -47,4 +47,17 @@ add_dependencies(OMBundleMemoryPools
OMKrnlOps OMKrnlOps
OMONNXOps) OMONNXOps)
add_library(OMDisconnectKrnlDimFromAlloc
DisconnectKrnlDimFromAlloc.cpp)
target_include_directories(OMDisconnectKrnlDimFromAlloc
PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
target_link_libraries(OMDisconnectKrnlDimFromAlloc
onnx)
add_dependencies(OMDisconnectKrnlDimFromAlloc
OMKrnlOps
OMONNXOps)
add_subdirectory(ONNX) add_subdirectory(ONNX)

View File

@ -0,0 +1,103 @@
//===-------- DisconnectKrnlDimFromAlloc.cpp ------------------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This pass enables the lowering of the krnl.dim operation to a series of
// instruction which do not depend on the alloc of the MemRef whose dim is
// being taken. The krnl.dim operation works in the presence of MemRefs
// which contain affine maps by ignoring the map if present.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp"
using namespace mlir;
namespace {
/*!
* RewritePattern that replaces:
* %0 = alloc(%d) : memref<?x10x<type>, #map>
* %1 = krnl.dim(%0, 0) : memref<?x10x<type>>
* %2 = krnl.dim(%0, 1) : memref<?x10x<type>>
* %3 = add %1, %2
* with:
* %0 = alloc(%d) : memref<?x10x<type>, #map>
* %2 = constant 10 : index
* %3 = add %d, %2
*/
class DisconnectKrnlDimFromAlloc : public OpRewritePattern<KrnlDimOp> {
public:
using OpRewritePattern<KrnlDimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
KrnlDimOp krnlDimOp, PatternRewriter &rewriter) const override {
auto loc = krnlDimOp.getLoc();
// If index is not constant, return failure.
ConstantOp indexOp =
dyn_cast<ConstantOp>(krnlDimOp.index().getDefiningOp());
if (!indexOp)
return failure();
// Get the integer value of the index.
int64_t index = indexOp.getAttrOfType<IntegerAttr>("value").getInt();
// Get defining operation for the MemRef argument.
AllocOp allocOp = dyn_cast<AllocOp>(krnlDimOp.alloc().getDefiningOp());
auto memRefShape =
convertToMemRefType(allocOp.getResult().getType()).getShape();
auto rank = memRefShape.size();
assert(index >= 0 && index < rank && "Index must be in bounds");
Value result;
if (memRefShape[index] > -1) {
// If dimension is static, then we can just emit the constant value.
result = rewriter.create<ConstantOp>(loc,
rewriter.getIntegerAttr(rewriter.getIndexType(), memRefShape[index]));
} else {
// If dimension is dynamic we need to return the input alloc Value which
// corresponds to it.
int64_t dynDimIdx = getAllocArgIndex(allocOp, index);
assert(dynDimIdx >= 0 && dynDimIdx < allocOp.getOperands().size() &&
"Dynamic index outside range of alloc argument list.");
result = allocOp.getOperands()[dynDimIdx];
}
rewriter.replaceOp(krnlDimOp, result);
return success();
}
};
/*!
* Function pass that disconnects krnl.dim emission from its MemRef alloc.
*/
class DisconnectKrnlDimFromAllocPass
: public PassWrapper<DisconnectKrnlDimFromAllocPass, FunctionPass> {
public:
void runOnFunction() override {
auto function = getFunction();
ConversionTarget target(getContext());
OwningRewritePatternList patterns;
patterns.insert<DisconnectKrnlDimFromAlloc>(&getContext());
applyPatternsAndFoldGreedily(function, patterns);
}
};
} // namespace
std::unique_ptr<Pass> mlir::createDisconnectKrnlDimFromAllocPass() {
return std::make_unique<DisconnectKrnlDimFromAllocPass>();
}

View File

@ -0,0 +1,43 @@
// RUN: onnx-mlir-opt --disconnect-dims %s -split-input-file | FileCheck %s
func @test_krnl_dim_lowering(%arg0: memref<?x?xf32>) -> index {
%c1 = constant 1 : index
%c0 = constant 0 : index
%0 = dim %arg0, %c0 : memref<?x?xf32>
%1 = alloc(%0) : memref<?x10xf32>
%d0 = "krnl.dim"(%1, %c0) : (memref<?x10xf32>, index) -> index
%d1 = "krnl.dim"(%1, %c1) : (memref<?x10xf32>, index) -> index
%e = addi %d0, %d1 : index
return %e : index
// CHECK-LABEL: test_krnl_dim_lowering
// CHECK: [[CONST0:%.+]] = constant 0 : index
// CHECK: [[CONST10:%.+]] = constant 10 : index
// CHECK: [[DIM:%.+]] = dim %arg0, [[CONST0]] : memref<?x?xf32>
// CHECK: [[ALLOC:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
// CHECK: [[SUM:%.+]] = addi [[DIM]], [[CONST10]] : index
// CHECK: return [[SUM]] : index
}
// -----
#map = affine_map<(d0, d1) -> (d1, d0)>
func @test_krnl_dim_lowering_with_map(%arg0: memref<?x?xf32>) -> index {
%c1 = constant 1 : index
%c0 = constant 0 : index
%0 = dim %arg0, %c0 : memref<?x?xf32>
%1 = alloc(%0) : memref<?x10xf32, #map>
%d0 = "krnl.dim"(%1, %c0) : (memref<?x10xf32, #map>, index) -> index
%d1 = "krnl.dim"(%1, %c1) : (memref<?x10xf32, #map>, index) -> index
%e = addi %d0, %d1 : index
return %e : index
// CHECK: [[MAP:#.+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: test_krnl_dim_lowering_with_map
// CHECK: [[CONST0:%.+]] = constant 0 : index
// CHECK: [[CONST10:%.+]] = constant 10 : index
// CHECK: [[DIM:%.+]] = dim %arg0, [[CONST0]] : memref<?x?xf32>
// CHECK: [[ALLOC:%.+]] = alloc([[DIM]]) : memref<?x10xf32, [[MAP]]>
// CHECK: [[SUM:%.+]] = addi [[DIM]], [[CONST10]] : index
// CHECK: return [[SUM]] : index
}