Improve support for krnl.dim (#317)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Make krnl dim more robust.

* Format.

* Update comments.

* Change pass name.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-09-23 14:36:16 -04:00 committed by GitHub
parent 3491b90b1e
commit 4bbe12ff50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 9 deletions

View File

@ -77,7 +77,8 @@ void initOMPasses() {
return mlir::createPackKrnlGlobalConstantsPass();
});
mlir::registerPass("disconnect-dims", "Disconnect dims from allocs.",
mlir::registerPass("lower-krnl-shape-to-std",
"Lowers krnl shape-related operations.",
[]() -> std::unique_ptr<mlir::Pass> {
return mlir::createDisconnectKrnlDimFromAllocPass();
});

View File

@ -27,13 +27,37 @@ namespace {
/*!
* RewritePattern that replaces:
* %0 = alloc(%d) : memref<?x10x<type>, #map>
* %1 = krnl.dim(%0, 0) : memref<?x10x<type>>
* %2 = krnl.dim(%0, 1) : memref<?x10x<type>>
* %1 = krnl.dim(%0, 0) : (memref<?x10x<type>, #map>, index) -> index
* %2 = krnl.dim(%0, 1) : (memref<?x10x<type>, #map>, index) -> index
* %3 = add %1, %2
* with:
* %0 = alloc(%d) : memref<?x10x<type>, #map>
* %2 = constant 10 : index
* %3 = add %d, %2
*
* When the first argument of the krnl.dim is an input argument
* i.e. it is not the output of an alloc operation, we emit either
* the constant or the strandard dim operation depending on whether
* the dimension is static or dynamic.
*
* function(%arg0 : memref<?x10x<type>>) {
* %0 = krnl.dim(%arg0, 0) : (memref<?x10x<type>>, index) -> index
* %1 = krnl.dim(%arg0, 1) : memref<?x10x<type>>
* }
*
*
* becomes:
*
* function(%arg0 : memref<?x10x<type>>) {
* %0 = dim %arg0, 0 : (memref<?x10x<type>>, index) -> index
* %1 = constant 10 : index
* }
*
* The following case is not supported:
*
* function(%arg0 : memref<?x10x<type>, #map>) {
* %0 = krnl.dim(%arg0, 0) : (memref<?x10x<type>, #map>, index) -> index
* }
*/
class DisconnectKrnlDimFromAlloc : public OpRewritePattern<KrnlDimOp> {
@ -53,25 +77,40 @@ public:
// Get the integer value of the index.
int64_t index = indexOp.getAttrOfType<IntegerAttr>("value").getInt();
// Get defining operation for the MemRef argument.
AllocOp allocOp = dyn_cast<AllocOp>(krnlDimOp.alloc().getDefiningOp());
auto memRefShape =
convertToMemRefType(allocOp.getResult().getType()).getShape();
// Get the shape of the MemRef argument.
auto memRefType = convertToMemRefType(krnlDimOp.alloc().getType());
auto memRefShape = memRefType.getShape();
auto rank = memRefShape.size();
assert(index >= 0 && index < rank && "Index must be in bounds");
// Get the defining operation of the first argument of krnl.dim.
// If this operation is not an alloc, and the value comes from the
// list of input arguments, the support is limited to MemRefs without
// maps.
auto firstArgDefOp = krnlDimOp.alloc().getDefiningOp();
Value result;
if (memRefShape[index] > -1) {
// If dimension is static, then we can just emit the constant value.
result = rewriter.create<ConstantOp>(loc,
rewriter.getIntegerAttr(rewriter.getIndexType(), memRefShape[index]));
} else {
} else if (firstArgDefOp && isa<AllocOp>(firstArgDefOp)) {
// Get defining operation for the MemRef argument.
AllocOp allocOp = dyn_cast<AllocOp>(krnlDimOp.alloc().getDefiningOp());
// If dimension is dynamic we need to return the input alloc Value which
// corresponds to it.
int64_t dynDimIdx = getAllocArgIndex(allocOp, index);
assert(dynDimIdx >= 0 && dynDimIdx < allocOp.getOperands().size() &&
"Dynamic index outside range of alloc argument list.");
result = allocOp.getOperands()[dynDimIdx];
} else if (memRefType.getAffineMaps().empty()) {
// Use a standard DimOp since no map is present.
result =
rewriter.create<DimOp>(loc, krnlDimOp.alloc(), krnlDimOp.index());
} else {
llvm_unreachable(
"dynamic sized MemRef with map must be defined by an AllocOp");
}
rewriter.replaceOp(krnlDimOp, result);

View File

@ -1,5 +1,6 @@
// RUN: onnx-mlir-opt --disconnect-dims %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --lower-krnl-shape-to-std %s -split-input-file | FileCheck %s
/// Lower krnl.dim when input MemRef does not have an affine map.
func @test_krnl_dim_lowering(%arg0: memref<?x?xf32>) -> index {
%c1 = constant 1 : index
%c0 = constant 0 : index
@ -21,6 +22,7 @@ func @test_krnl_dim_lowering(%arg0: memref<?x?xf32>) -> index {
// -----
/// Lower krnl.dim when input MemRef has an affine map.
#map = affine_map<(d0, d1) -> (d1, d0)>
func @test_krnl_dim_lowering_with_map(%arg0: memref<?x?xf32>) -> index {
%c1 = constant 1 : index
@ -41,3 +43,32 @@ func @test_krnl_dim_lowering_with_map(%arg0: memref<?x?xf32>) -> index {
// CHECK: [[SUM:%.+]] = addi [[DIM]], [[CONST10]] : index
// CHECK: return [[SUM]] : index
}
// -----
/// Lower krnl.dim to constant when first argument of krnl.dim is an input arg
/// and the dimensions is static.
func @test_krnl_dim_lowering_with_const_arg(%arg0: memref<10x20xf32>) -> index {
%c0 = constant 0 : index
%0 = "krnl.dim"(%arg0, %c0) : (memref<10x20xf32>, index) -> index
return %0 : index
// CHECK-LABEL: test_krnl_dim_lowering_with_const_arg
// CHECK: [[CONST10:%.+]] = constant 10 : index
// CHECK: return [[CONST10]] : index
}
// -----
/// Lower krnl.dim to a standard dim operation when first argument of krnl.dim
/// is an input arg and the dimensions is dynamic.
func @test_krnl_dim_lowering_with_dynamic_arg(%arg0: memref<10x?xf32>) -> index {
%c0 = constant 1 : index
%0 = "krnl.dim"(%arg0, %c0) : (memref<10x?xf32>, index) -> index
return %0 : index
// CHECK-LABEL: test_krnl_dim_lowering_with_dynamic_arg
// CHECK: [[CONST1:%.+]] = constant 1 : index
// CHECK: [[DIM:%.+]] = dim %arg0, [[CONST1]] : memref<10x?xf32>
// CHECK: return [[DIM]] : index
}