Improve support for krnl.dim (#317)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Make krnl dim more robust. * Format. * Update comments. * Change pass name.
This commit is contained in:
parent
3491b90b1e
commit
4bbe12ff50
|
@ -77,7 +77,8 @@ void initOMPasses() {
|
|||
return mlir::createPackKrnlGlobalConstantsPass();
|
||||
});
|
||||
|
||||
mlir::registerPass("disconnect-dims", "Disconnect dims from allocs.",
|
||||
mlir::registerPass("lower-krnl-shape-to-std",
|
||||
"Lowers krnl shape-related operations.",
|
||||
[]() -> std::unique_ptr<mlir::Pass> {
|
||||
return mlir::createDisconnectKrnlDimFromAllocPass();
|
||||
});
|
||||
|
|
|
@ -27,13 +27,37 @@ namespace {
|
|||
/*!
|
||||
* RewritePattern that replaces:
|
||||
* %0 = alloc(%d) : memref<?x10x<type>, #map>
|
||||
* %1 = krnl.dim(%0, 0) : memref<?x10x<type>>
|
||||
* %2 = krnl.dim(%0, 1) : memref<?x10x<type>>
|
||||
* %1 = krnl.dim(%0, 0) : (memref<?x10x<type>, #map>, index) -> index
|
||||
* %2 = krnl.dim(%0, 1) : (memref<?x10x<type>, #map>, index) -> index
|
||||
* %3 = add %1, %2
|
||||
* with:
|
||||
* %0 = alloc(%d) : memref<?x10x<type>, #map>
|
||||
* %2 = constant 10 : index
|
||||
* %3 = add %d, %2
|
||||
*
|
||||
* When the first argument of the krnl.dim is an input argument
|
||||
* i.e. it is not the output of an alloc operation, we emit either
|
||||
* the constant or the strandard dim operation depending on whether
|
||||
* the dimension is static or dynamic.
|
||||
*
|
||||
* function(%arg0 : memref<?x10x<type>>) {
|
||||
* %0 = krnl.dim(%arg0, 0) : (memref<?x10x<type>>, index) -> index
|
||||
* %1 = krnl.dim(%arg0, 1) : memref<?x10x<type>>
|
||||
* }
|
||||
*
|
||||
*
|
||||
* becomes:
|
||||
*
|
||||
* function(%arg0 : memref<?x10x<type>>) {
|
||||
* %0 = dim %arg0, 0 : (memref<?x10x<type>>, index) -> index
|
||||
* %1 = constant 10 : index
|
||||
* }
|
||||
*
|
||||
* The following case is not supported:
|
||||
*
|
||||
* function(%arg0 : memref<?x10x<type>, #map>) {
|
||||
* %0 = krnl.dim(%arg0, 0) : (memref<?x10x<type>, #map>, index) -> index
|
||||
* }
|
||||
*/
|
||||
|
||||
class DisconnectKrnlDimFromAlloc : public OpRewritePattern<KrnlDimOp> {
|
||||
|
@ -53,25 +77,40 @@ public:
|
|||
// Get the integer value of the index.
|
||||
int64_t index = indexOp.getAttrOfType<IntegerAttr>("value").getInt();
|
||||
|
||||
// Get defining operation for the MemRef argument.
|
||||
AllocOp allocOp = dyn_cast<AllocOp>(krnlDimOp.alloc().getDefiningOp());
|
||||
auto memRefShape =
|
||||
convertToMemRefType(allocOp.getResult().getType()).getShape();
|
||||
// Get the shape of the MemRef argument.
|
||||
auto memRefType = convertToMemRefType(krnlDimOp.alloc().getType());
|
||||
auto memRefShape = memRefType.getShape();
|
||||
auto rank = memRefShape.size();
|
||||
assert(index >= 0 && index < rank && "Index must be in bounds");
|
||||
|
||||
// Get the defining operation of the first argument of krnl.dim.
|
||||
// If this operation is not an alloc, and the value comes from the
|
||||
// list of input arguments, the support is limited to MemRefs without
|
||||
// maps.
|
||||
auto firstArgDefOp = krnlDimOp.alloc().getDefiningOp();
|
||||
|
||||
Value result;
|
||||
if (memRefShape[index] > -1) {
|
||||
// If dimension is static, then we can just emit the constant value.
|
||||
result = rewriter.create<ConstantOp>(loc,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), memRefShape[index]));
|
||||
} else {
|
||||
} else if (firstArgDefOp && isa<AllocOp>(firstArgDefOp)) {
|
||||
// Get defining operation for the MemRef argument.
|
||||
AllocOp allocOp = dyn_cast<AllocOp>(krnlDimOp.alloc().getDefiningOp());
|
||||
|
||||
// If dimension is dynamic we need to return the input alloc Value which
|
||||
// corresponds to it.
|
||||
int64_t dynDimIdx = getAllocArgIndex(allocOp, index);
|
||||
assert(dynDimIdx >= 0 && dynDimIdx < allocOp.getOperands().size() &&
|
||||
"Dynamic index outside range of alloc argument list.");
|
||||
result = allocOp.getOperands()[dynDimIdx];
|
||||
} else if (memRefType.getAffineMaps().empty()) {
|
||||
// Use a standard DimOp since no map is present.
|
||||
result =
|
||||
rewriter.create<DimOp>(loc, krnlDimOp.alloc(), krnlDimOp.index());
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
"dynamic sized MemRef with map must be defined by an AllocOp");
|
||||
}
|
||||
|
||||
rewriter.replaceOp(krnlDimOp, result);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// RUN: onnx-mlir-opt --disconnect-dims %s -split-input-file | FileCheck %s
|
||||
// RUN: onnx-mlir-opt --lower-krnl-shape-to-std %s -split-input-file | FileCheck %s
|
||||
|
||||
/// Lower krnl.dim when input MemRef does not have an affine map.
|
||||
func @test_krnl_dim_lowering(%arg0: memref<?x?xf32>) -> index {
|
||||
%c1 = constant 1 : index
|
||||
%c0 = constant 0 : index
|
||||
|
@ -21,6 +22,7 @@ func @test_krnl_dim_lowering(%arg0: memref<?x?xf32>) -> index {
|
|||
|
||||
// -----
|
||||
|
||||
/// Lower krnl.dim when input MemRef has an affine map.
|
||||
#map = affine_map<(d0, d1) -> (d1, d0)>
|
||||
func @test_krnl_dim_lowering_with_map(%arg0: memref<?x?xf32>) -> index {
|
||||
%c1 = constant 1 : index
|
||||
|
@ -41,3 +43,32 @@ func @test_krnl_dim_lowering_with_map(%arg0: memref<?x?xf32>) -> index {
|
|||
// CHECK: [[SUM:%.+]] = addi [[DIM]], [[CONST10]] : index
|
||||
// CHECK: return [[SUM]] : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Lower krnl.dim to constant when first argument of krnl.dim is an input arg
|
||||
/// and the dimensions is static.
|
||||
func @test_krnl_dim_lowering_with_const_arg(%arg0: memref<10x20xf32>) -> index {
|
||||
%c0 = constant 0 : index
|
||||
%0 = "krnl.dim"(%arg0, %c0) : (memref<10x20xf32>, index) -> index
|
||||
return %0 : index
|
||||
|
||||
// CHECK-LABEL: test_krnl_dim_lowering_with_const_arg
|
||||
// CHECK: [[CONST10:%.+]] = constant 10 : index
|
||||
// CHECK: return [[CONST10]] : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Lower krnl.dim to a standard dim operation when first argument of krnl.dim
|
||||
/// is an input arg and the dimensions is dynamic.
|
||||
func @test_krnl_dim_lowering_with_dynamic_arg(%arg0: memref<10x?xf32>) -> index {
|
||||
%c0 = constant 1 : index
|
||||
%0 = "krnl.dim"(%arg0, %c0) : (memref<10x?xf32>, index) -> index
|
||||
return %0 : index
|
||||
|
||||
// CHECK-LABEL: test_krnl_dim_lowering_with_dynamic_arg
|
||||
// CHECK: [[CONST1:%.+]] = constant 1 : index
|
||||
// CHECK: [[DIM:%.+]] = dim %arg0, [[CONST1]] : memref<10x?xf32>
|
||||
// CHECK: return [[DIM]] : index
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue