Improve support for krnl.dim (#317)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Make krnl dim more robust. * Format. * Update comments. * Change pass name.
This commit is contained in:
parent
3491b90b1e
commit
4bbe12ff50
|
@ -77,7 +77,8 @@ void initOMPasses() {
|
||||||
return mlir::createPackKrnlGlobalConstantsPass();
|
return mlir::createPackKrnlGlobalConstantsPass();
|
||||||
});
|
});
|
||||||
|
|
||||||
mlir::registerPass("disconnect-dims", "Disconnect dims from allocs.",
|
mlir::registerPass("lower-krnl-shape-to-std",
|
||||||
|
"Lowers krnl shape-related operations.",
|
||||||
[]() -> std::unique_ptr<mlir::Pass> {
|
[]() -> std::unique_ptr<mlir::Pass> {
|
||||||
return mlir::createDisconnectKrnlDimFromAllocPass();
|
return mlir::createDisconnectKrnlDimFromAllocPass();
|
||||||
});
|
});
|
||||||
|
|
|
@ -27,13 +27,37 @@ namespace {
|
||||||
/*!
|
/*!
|
||||||
* RewritePattern that replaces:
|
* RewritePattern that replaces:
|
||||||
* %0 = alloc(%d) : memref<?x10x<type>, #map>
|
* %0 = alloc(%d) : memref<?x10x<type>, #map>
|
||||||
* %1 = krnl.dim(%0, 0) : memref<?x10x<type>>
|
* %1 = krnl.dim(%0, 0) : (memref<?x10x<type>, #map>, index) -> index
|
||||||
* %2 = krnl.dim(%0, 1) : memref<?x10x<type>>
|
* %2 = krnl.dim(%0, 1) : (memref<?x10x<type>, #map>, index) -> index
|
||||||
* %3 = add %1, %2
|
* %3 = add %1, %2
|
||||||
* with:
|
* with:
|
||||||
* %0 = alloc(%d) : memref<?x10x<type>, #map>
|
* %0 = alloc(%d) : memref<?x10x<type>, #map>
|
||||||
* %2 = constant 10 : index
|
* %2 = constant 10 : index
|
||||||
* %3 = add %d, %2
|
* %3 = add %d, %2
|
||||||
|
*
|
||||||
|
* When the first argument of the krnl.dim is an input argument
|
||||||
|
* i.e. it is not the output of an alloc operation, we emit either
|
||||||
|
* the constant or the strandard dim operation depending on whether
|
||||||
|
* the dimension is static or dynamic.
|
||||||
|
*
|
||||||
|
* function(%arg0 : memref<?x10x<type>>) {
|
||||||
|
* %0 = krnl.dim(%arg0, 0) : (memref<?x10x<type>>, index) -> index
|
||||||
|
* %1 = krnl.dim(%arg0, 1) : memref<?x10x<type>>
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* becomes:
|
||||||
|
*
|
||||||
|
* function(%arg0 : memref<?x10x<type>>) {
|
||||||
|
* %0 = dim %arg0, 0 : (memref<?x10x<type>>, index) -> index
|
||||||
|
* %1 = constant 10 : index
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* The following case is not supported:
|
||||||
|
*
|
||||||
|
* function(%arg0 : memref<?x10x<type>, #map>) {
|
||||||
|
* %0 = krnl.dim(%arg0, 0) : (memref<?x10x<type>, #map>, index) -> index
|
||||||
|
* }
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class DisconnectKrnlDimFromAlloc : public OpRewritePattern<KrnlDimOp> {
|
class DisconnectKrnlDimFromAlloc : public OpRewritePattern<KrnlDimOp> {
|
||||||
|
@ -53,25 +77,40 @@ public:
|
||||||
// Get the integer value of the index.
|
// Get the integer value of the index.
|
||||||
int64_t index = indexOp.getAttrOfType<IntegerAttr>("value").getInt();
|
int64_t index = indexOp.getAttrOfType<IntegerAttr>("value").getInt();
|
||||||
|
|
||||||
// Get defining operation for the MemRef argument.
|
// Get the shape of the MemRef argument.
|
||||||
AllocOp allocOp = dyn_cast<AllocOp>(krnlDimOp.alloc().getDefiningOp());
|
auto memRefType = convertToMemRefType(krnlDimOp.alloc().getType());
|
||||||
auto memRefShape =
|
auto memRefShape = memRefType.getShape();
|
||||||
convertToMemRefType(allocOp.getResult().getType()).getShape();
|
|
||||||
auto rank = memRefShape.size();
|
auto rank = memRefShape.size();
|
||||||
assert(index >= 0 && index < rank && "Index must be in bounds");
|
assert(index >= 0 && index < rank && "Index must be in bounds");
|
||||||
|
|
||||||
|
// Get the defining operation of the first argument of krnl.dim.
|
||||||
|
// If this operation is not an alloc, and the value comes from the
|
||||||
|
// list of input arguments, the support is limited to MemRefs without
|
||||||
|
// maps.
|
||||||
|
auto firstArgDefOp = krnlDimOp.alloc().getDefiningOp();
|
||||||
|
|
||||||
Value result;
|
Value result;
|
||||||
if (memRefShape[index] > -1) {
|
if (memRefShape[index] > -1) {
|
||||||
// If dimension is static, then we can just emit the constant value.
|
// If dimension is static, then we can just emit the constant value.
|
||||||
result = rewriter.create<ConstantOp>(loc,
|
result = rewriter.create<ConstantOp>(loc,
|
||||||
rewriter.getIntegerAttr(rewriter.getIndexType(), memRefShape[index]));
|
rewriter.getIntegerAttr(rewriter.getIndexType(), memRefShape[index]));
|
||||||
} else {
|
} else if (firstArgDefOp && isa<AllocOp>(firstArgDefOp)) {
|
||||||
|
// Get defining operation for the MemRef argument.
|
||||||
|
AllocOp allocOp = dyn_cast<AllocOp>(krnlDimOp.alloc().getDefiningOp());
|
||||||
|
|
||||||
// If dimension is dynamic we need to return the input alloc Value which
|
// If dimension is dynamic we need to return the input alloc Value which
|
||||||
// corresponds to it.
|
// corresponds to it.
|
||||||
int64_t dynDimIdx = getAllocArgIndex(allocOp, index);
|
int64_t dynDimIdx = getAllocArgIndex(allocOp, index);
|
||||||
assert(dynDimIdx >= 0 && dynDimIdx < allocOp.getOperands().size() &&
|
assert(dynDimIdx >= 0 && dynDimIdx < allocOp.getOperands().size() &&
|
||||||
"Dynamic index outside range of alloc argument list.");
|
"Dynamic index outside range of alloc argument list.");
|
||||||
result = allocOp.getOperands()[dynDimIdx];
|
result = allocOp.getOperands()[dynDimIdx];
|
||||||
|
} else if (memRefType.getAffineMaps().empty()) {
|
||||||
|
// Use a standard DimOp since no map is present.
|
||||||
|
result =
|
||||||
|
rewriter.create<DimOp>(loc, krnlDimOp.alloc(), krnlDimOp.index());
|
||||||
|
} else {
|
||||||
|
llvm_unreachable(
|
||||||
|
"dynamic sized MemRef with map must be defined by an AllocOp");
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(krnlDimOp, result);
|
rewriter.replaceOp(krnlDimOp, result);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
// RUN: onnx-mlir-opt --disconnect-dims %s -split-input-file | FileCheck %s
|
// RUN: onnx-mlir-opt --lower-krnl-shape-to-std %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
/// Lower krnl.dim when input MemRef does not have an affine map.
|
||||||
func @test_krnl_dim_lowering(%arg0: memref<?x?xf32>) -> index {
|
func @test_krnl_dim_lowering(%arg0: memref<?x?xf32>) -> index {
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
|
@ -21,6 +22,7 @@ func @test_krnl_dim_lowering(%arg0: memref<?x?xf32>) -> index {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
/// Lower krnl.dim when input MemRef has an affine map.
|
||||||
#map = affine_map<(d0, d1) -> (d1, d0)>
|
#map = affine_map<(d0, d1) -> (d1, d0)>
|
||||||
func @test_krnl_dim_lowering_with_map(%arg0: memref<?x?xf32>) -> index {
|
func @test_krnl_dim_lowering_with_map(%arg0: memref<?x?xf32>) -> index {
|
||||||
%c1 = constant 1 : index
|
%c1 = constant 1 : index
|
||||||
|
@ -41,3 +43,32 @@ func @test_krnl_dim_lowering_with_map(%arg0: memref<?x?xf32>) -> index {
|
||||||
// CHECK: [[SUM:%.+]] = addi [[DIM]], [[CONST10]] : index
|
// CHECK: [[SUM:%.+]] = addi [[DIM]], [[CONST10]] : index
|
||||||
// CHECK: return [[SUM]] : index
|
// CHECK: return [[SUM]] : index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
/// Lower krnl.dim to constant when first argument of krnl.dim is an input arg
|
||||||
|
/// and the dimensions is static.
|
||||||
|
func @test_krnl_dim_lowering_with_const_arg(%arg0: memref<10x20xf32>) -> index {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%0 = "krnl.dim"(%arg0, %c0) : (memref<10x20xf32>, index) -> index
|
||||||
|
return %0 : index
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_krnl_dim_lowering_with_const_arg
|
||||||
|
// CHECK: [[CONST10:%.+]] = constant 10 : index
|
||||||
|
// CHECK: return [[CONST10]] : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
/// Lower krnl.dim to a standard dim operation when first argument of krnl.dim
|
||||||
|
/// is an input arg and the dimensions is dynamic.
|
||||||
|
func @test_krnl_dim_lowering_with_dynamic_arg(%arg0: memref<10x?xf32>) -> index {
|
||||||
|
%c0 = constant 1 : index
|
||||||
|
%0 = "krnl.dim"(%arg0, %c0) : (memref<10x?xf32>, index) -> index
|
||||||
|
return %0 : index
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_krnl_dim_lowering_with_dynamic_arg
|
||||||
|
// CHECK: [[CONST1:%.+]] = constant 1 : index
|
||||||
|
// CHECK: [[DIM:%.+]] = dim %arg0, [[CONST1]] : memref<10x?xf32>
|
||||||
|
// CHECK: return [[DIM]] : index
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue