From 4bbe12ff50f94500feb6e18e2822263a7fdf5b26 Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Wed, 23 Sep 2020 14:36:16 -0400 Subject: [PATCH] Improve support for krnl.dim (#317) * Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Make krnl dim more robust. * Format. * Update comments. * Change pass name. --- src/InitOMPasses.hpp | 3 +- src/Transform/DisconnectKrnlDimFromAlloc.cpp | 53 ++++++++++++++++--- .../krnl/krnl_disconnect_dim_from_alloc.mlir | 33 +++++++++++- 3 files changed, 80 insertions(+), 9 deletions(-) diff --git a/src/InitOMPasses.hpp b/src/InitOMPasses.hpp index 92500d6..204c36e 100644 --- a/src/InitOMPasses.hpp +++ b/src/InitOMPasses.hpp @@ -77,7 +77,8 @@ void initOMPasses() { return mlir::createPackKrnlGlobalConstantsPass(); }); - mlir::registerPass("disconnect-dims", "Disconnect dims from allocs.", + mlir::registerPass("lower-krnl-shape-to-std", + "Lowers krnl shape-related operations.", []() -> std::unique_ptr { return mlir::createDisconnectKrnlDimFromAllocPass(); }); diff --git a/src/Transform/DisconnectKrnlDimFromAlloc.cpp b/src/Transform/DisconnectKrnlDimFromAlloc.cpp index 58693ba..7c0df87 100644 --- a/src/Transform/DisconnectKrnlDimFromAlloc.cpp +++ b/src/Transform/DisconnectKrnlDimFromAlloc.cpp @@ -27,13 +27,37 @@ namespace { /*! * RewritePattern that replaces: * %0 = alloc(%d) : memref, #map> - * %1 = krnl.dim(%0, 0) : memref> - * %2 = krnl.dim(%0, 1) : memref> + * %1 = krnl.dim(%0, 0) : (memref, #map>, index) -> index + * %2 = krnl.dim(%0, 1) : (memref, #map>, index) -> index * %3 = add %1, %2 * with: * %0 = alloc(%d) : memref, #map> * %2 = constant 10 : index * %3 = add %d, %2 + * + * When the first argument of the krnl.dim is an input argument + * i.e. it is not the output of an alloc operation, we emit either + * the constant or the strandard dim operation depending on whether + * the dimension is static or dynamic. + * + * function(%arg0 : memref>) { + * %0 = krnl.dim(%arg0, 0) : (memref>, index) -> index + * %1 = krnl.dim(%arg0, 1) : memref> + * } + * + * + * becomes: + * + * function(%arg0 : memref>) { + * %0 = dim %arg0, 0 : (memref>, index) -> index + * %1 = constant 10 : index + * } + * + * The following case is not supported: + * + * function(%arg0 : memref, #map>) { + * %0 = krnl.dim(%arg0, 0) : (memref, #map>, index) -> index + * } */ class DisconnectKrnlDimFromAlloc : public OpRewritePattern { @@ -53,25 +77,40 @@ public: // Get the integer value of the index. int64_t index = indexOp.getAttrOfType("value").getInt(); - // Get defining operation for the MemRef argument. - AllocOp allocOp = dyn_cast(krnlDimOp.alloc().getDefiningOp()); - auto memRefShape = - convertToMemRefType(allocOp.getResult().getType()).getShape(); + // Get the shape of the MemRef argument. + auto memRefType = convertToMemRefType(krnlDimOp.alloc().getType()); + auto memRefShape = memRefType.getShape(); auto rank = memRefShape.size(); assert(index >= 0 && index < rank && "Index must be in bounds"); + // Get the defining operation of the first argument of krnl.dim. + // If this operation is not an alloc, and the value comes from the + // list of input arguments, the support is limited to MemRefs without + // maps. + auto firstArgDefOp = krnlDimOp.alloc().getDefiningOp(); + Value result; if (memRefShape[index] > -1) { // If dimension is static, then we can just emit the constant value. result = rewriter.create(loc, rewriter.getIntegerAttr(rewriter.getIndexType(), memRefShape[index])); - } else { + } else if (firstArgDefOp && isa(firstArgDefOp)) { + // Get defining operation for the MemRef argument. + AllocOp allocOp = dyn_cast(krnlDimOp.alloc().getDefiningOp()); + // If dimension is dynamic we need to return the input alloc Value which // corresponds to it. int64_t dynDimIdx = getAllocArgIndex(allocOp, index); assert(dynDimIdx >= 0 && dynDimIdx < allocOp.getOperands().size() && "Dynamic index outside range of alloc argument list."); result = allocOp.getOperands()[dynDimIdx]; + } else if (memRefType.getAffineMaps().empty()) { + // Use a standard DimOp since no map is present. + result = + rewriter.create(loc, krnlDimOp.alloc(), krnlDimOp.index()); + } else { + llvm_unreachable( + "dynamic sized MemRef with map must be defined by an AllocOp"); } rewriter.replaceOp(krnlDimOp, result); diff --git a/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir b/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir index 1d95db3..052c944 100644 --- a/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir +++ b/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir @@ -1,5 +1,6 @@ -// RUN: onnx-mlir-opt --disconnect-dims %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --lower-krnl-shape-to-std %s -split-input-file | FileCheck %s +/// Lower krnl.dim when input MemRef does not have an affine map. func @test_krnl_dim_lowering(%arg0: memref) -> index { %c1 = constant 1 : index %c0 = constant 0 : index @@ -21,6 +22,7 @@ func @test_krnl_dim_lowering(%arg0: memref) -> index { // ----- +/// Lower krnl.dim when input MemRef has an affine map. #map = affine_map<(d0, d1) -> (d1, d0)> func @test_krnl_dim_lowering_with_map(%arg0: memref) -> index { %c1 = constant 1 : index @@ -41,3 +43,32 @@ func @test_krnl_dim_lowering_with_map(%arg0: memref) -> index { // CHECK: [[SUM:%.+]] = addi [[DIM]], [[CONST10]] : index // CHECK: return [[SUM]] : index } + +// ----- + +/// Lower krnl.dim to constant when first argument of krnl.dim is an input arg +/// and the dimensions is static. +func @test_krnl_dim_lowering_with_const_arg(%arg0: memref<10x20xf32>) -> index { + %c0 = constant 0 : index + %0 = "krnl.dim"(%arg0, %c0) : (memref<10x20xf32>, index) -> index + return %0 : index + + // CHECK-LABEL: test_krnl_dim_lowering_with_const_arg + // CHECK: [[CONST10:%.+]] = constant 10 : index + // CHECK: return [[CONST10]] : index +} + +// ----- + +/// Lower krnl.dim to a standard dim operation when first argument of krnl.dim +/// is an input arg and the dimensions is dynamic. +func @test_krnl_dim_lowering_with_dynamic_arg(%arg0: memref<10x?xf32>) -> index { + %c0 = constant 1 : index + %0 = "krnl.dim"(%arg0, %c0) : (memref<10x?xf32>, index) -> index + return %0 : index + + // CHECK-LABEL: test_krnl_dim_lowering_with_dynamic_arg + // CHECK: [[CONST1:%.+]] = constant 1 : index + // CHECK: [[DIM:%.+]] = dim %arg0, [[CONST1]] : memref<10x?xf32> + // CHECK: return [[DIM]] : index +}