Lower SplitOp to Krnl dialect (#155)
* Fix importing variadic output * Lower splitop * Support unknown dimension and add lit tests Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
4ab96fbc6c
commit
8c4d527eea
|
@ -20,6 +20,7 @@ add_library(OMONNXToKrnl
|
||||||
Tensor/Unsqueeze.cpp
|
Tensor/Unsqueeze.cpp
|
||||||
Tensor/Constant.cpp
|
Tensor/Constant.cpp
|
||||||
Tensor/Concat.cpp
|
Tensor/Concat.cpp
|
||||||
|
Tensor/Split.cpp
|
||||||
ConvertONNXToKrnl.cpp)
|
ConvertONNXToKrnl.cpp)
|
||||||
target_link_libraries(OMONNXToKrnl
|
target_link_libraries(OMONNXToKrnl
|
||||||
onnx)
|
onnx)
|
||||||
|
|
|
@ -99,6 +99,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
||||||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXSplitOpPattern(patterns, &getContext());
|
||||||
// Neural network
|
// Neural network
|
||||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||||
|
|
|
@ -250,3 +250,6 @@ void populateLoweringONNXConstantOpPattern(
|
||||||
|
|
||||||
void populateLoweringONNXConcatOpPattern(
|
void populateLoweringONNXConcatOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXSplitOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
//===---------------- Split.cpp - Lowering Split Op -----------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Split Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
struct ONNXSplitOpLowering : public ConversionPattern {
|
||||||
|
ONNXSplitOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXSplitOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
// Gather info.
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
ONNXSplitOp splitOp = llvm::dyn_cast<ONNXSplitOp>(op);
|
||||||
|
auto axis = splitOp.axis().getSExtValue();
|
||||||
|
auto split = splitOp.split().getValue();
|
||||||
|
SmallVector<int64_t, 4> splitOffset;
|
||||||
|
int64_t offset = 0;
|
||||||
|
for (int i = 0; i < split.size(); ++i) {
|
||||||
|
splitOffset.emplace_back(offset);
|
||||||
|
offset += ArrayAttrIntVal(split, i);
|
||||||
|
}
|
||||||
|
auto rank = splitOp.input().getType().cast<ShapedType>().getRank();
|
||||||
|
auto outputNum = splitOp.getNumResults();
|
||||||
|
|
||||||
|
// Alloc and dealloc.
|
||||||
|
SmallVector<Value, 4> allocs;
|
||||||
|
for (int i = 0; i < outputNum; ++i) {
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op, i);
|
||||||
|
auto memRefType = convertToMemRefType(splitOp.outputs()[i].getType());
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else {
|
||||||
|
SmallVector<Value, 4> allocOperands;
|
||||||
|
auto shape = memRefType.getShape();
|
||||||
|
for (decltype(rank) r = 0; r < rank; ++r) {
|
||||||
|
if (shape[r] < 0) {
|
||||||
|
Value dim;
|
||||||
|
if (r != axis)
|
||||||
|
dim = rewriter.create<DimOp>(loc, operands[0], r);
|
||||||
|
else
|
||||||
|
dim = emitConstantOp(rewriter, loc, rewriter.getIndexType(),
|
||||||
|
ArrayAttrIntVal(split, i));
|
||||||
|
allocOperands.push_back(dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allocs.emplace_back(alloc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates loops, one for each output.
|
||||||
|
for (int i = 0; i < outputNum; ++i) {
|
||||||
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||||
|
// Create loop.
|
||||||
|
BuildKrnlLoop outputLoops(rewriter, loc, rank);
|
||||||
|
outputLoops.createDefineOptimizeAndIterateOp(allocs[i]);
|
||||||
|
outputLoops.createIterateOp();
|
||||||
|
rewriter.setInsertionPointToStart(outputLoops.getIterateBlock());
|
||||||
|
// Indices for the read and write.
|
||||||
|
SmallVector<Value, 4> readIndices;
|
||||||
|
SmallVector<Value, 4> writeIndices;
|
||||||
|
for (int r = 0; r < rank; ++r) {
|
||||||
|
// Same index for read and write if the dimension is:
|
||||||
|
// - the first dimension, or
|
||||||
|
// - not the split axis.
|
||||||
|
if (i == 0 || r != axis) {
|
||||||
|
readIndices.emplace_back(outputLoops.getInductionVar(r));
|
||||||
|
} else {
|
||||||
|
auto index = rewriter.getAffineDimExpr(0);
|
||||||
|
auto indexMap = AffineMap::get(1, 0, index + splitOffset[i]);
|
||||||
|
auto indexWithOffset = rewriter.create<AffineApplyOp>(loc, indexMap,
|
||||||
|
ArrayRef<Value>{/*index=*/outputLoops.getInductionVar(r)});
|
||||||
|
readIndices.emplace_back(indexWithOffset);
|
||||||
|
}
|
||||||
|
writeIndices.emplace_back(outputLoops.getInductionVar(r));
|
||||||
|
}
|
||||||
|
// Insert copy.
|
||||||
|
auto loadData = rewriter.create<LoadOp>(loc, operands[0], readIndices);
|
||||||
|
rewriter.create<StoreOp>(loc, loadData, allocs[i], writeIndices);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, allocs);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXSplitOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXSplitOpLowering>(ctx);
|
||||||
|
}
|
|
@ -352,8 +352,16 @@ test_to_enable = [
|
||||||
"test_lstm_with_initial_bias_cpu",
|
"test_lstm_with_initial_bias_cpu",
|
||||||
"test_lstm_with_peepholes_cpu",
|
"test_lstm_with_peepholes_cpu",
|
||||||
|
|
||||||
|
# Split
|
||||||
|
"test_split_equal_parts_1d_cpu",
|
||||||
|
"test_split_equal_parts_2d_cpu",
|
||||||
|
"test_split_equal_parts_default_axis_cpu",
|
||||||
|
"test_split_variable_parts_1d_cpu",
|
||||||
|
"test_split_variable_parts_2d_cpu",
|
||||||
|
"test_split_variable_parts_default_axis_cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Extract name of all test cases.
|
# Extract name of all test cases.
|
||||||
import inspect
|
import inspect
|
||||||
all_tests = inspect.getmembers(
|
all_tests = inspect.getmembers(
|
||||||
|
|
|
@ -2134,3 +2134,100 @@ func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x
|
||||||
// CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}}
|
// CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}}
|
||||||
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32>
|
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||||
|
%0, %1 = "onnx.Split"(%arg0) { axis = 0} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK: [[INDEX_MAP:#.+]] = affine_map<(d0) -> (d0 + 8)>
|
||||||
|
// CHECK-LABEL: @test_split_equal
|
||||||
|
|
||||||
|
// CHECK: [[RES_1:%.+]] = alloc() : memref<8x32x64xf32>
|
||||||
|
// CHECK: [[RES_0:%.+]] = alloc() : memref<8x32x64xf32>
|
||||||
|
// CHECK: [[DEF_LOOP_0:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: [[OPT_LOOP_0:%.+]]:3 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOP_0]]#0, [[DEF_LOOP_0]]#1, [[DEF_LOOP_0]]#2
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOP_0]]#0, [[OPT_LOOP_0]]#1, [[OPT_LOOP_0]]#2) with ([[DEF_LOOP_0]]#0 -> %arg1 = 0 to 8, [[DEF_LOOP_0]]#1 -> %arg2 = 0 to 32, [[DEF_LOOP_0]]#2 -> %arg3 = 0 to 64) {
|
||||||
|
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<16x32x64xf32>
|
||||||
|
// CHECK: store [[LOAD_0]], [[RES_0]][%arg1, %arg2, %arg3] : memref<8x32x64xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[DEF_LOOP_1:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: [[OPT_LOOP_1:%.+]]:3 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOP_1]]#0, [[DEF_LOOP_1]]#1, [[DEF_LOOP_1]]#2
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOP_1]]#0, [[OPT_LOOP_1]]#1, [[OPT_LOOP_1]]#2) with ([[DEF_LOOP_1]]#0 -> %arg1 = 0 to 8, [[DEF_LOOP_1]]#1 -> %arg2 = 0 to 32, [[DEF_LOOP_1]]#2 -> %arg3 = 0 to 64) {
|
||||||
|
// CHECK: %[[INDEX:.+]] = affine.apply [[INDEX_MAP]](%arg1)
|
||||||
|
// CHECK: [[LOAD_1:%.+]] = load %arg0[%[[INDEX]], %arg2, %arg3] : memref<16x32x64xf32>
|
||||||
|
// CHECK: store [[LOAD_1]], [[RES_1]][%arg1, %arg2, %arg3] : memref<8x32x64xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return [[RES_0]], [[RES_1]] : memref<8x32x64xf32>, memref<8x32x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_split_variable(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||||
|
%0, %1 = "onnx.Split"(%arg0) { axis = 1, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK: [[INDEX_MAP:#.+]] = affine_map<(d0) -> (d0 + 2)>
|
||||||
|
// CHECK-LABEL: @test_split_variable
|
||||||
|
|
||||||
|
// CHECK: [[RES_1:%.+]] = alloc() : memref<16x30x64xf32>
|
||||||
|
// CHECK: [[RES_0:%.+]] = alloc() : memref<16x2x64xf32>
|
||||||
|
// CHECK: [[DEF_LOOP_0:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: [[OPT_LOOP_0:%.+]]:3 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOP_0]]#0, [[DEF_LOOP_0]]#1, [[DEF_LOOP_0]]#2
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOP_0]]#0, [[OPT_LOOP_0]]#1, [[OPT_LOOP_0]]#2) with ([[DEF_LOOP_0]]#0 -> %arg1 = 0 to 16, [[DEF_LOOP_0]]#1 -> %arg2 = 0 to 2, [[DEF_LOOP_0]]#2 -> %arg3 = 0 to 64) {
|
||||||
|
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<16x32x64xf32>
|
||||||
|
// CHECK: store [[LOAD_0]], [[RES_0]][%arg1, %arg2, %arg3] : memref<16x2x64xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[DEF_LOOP_1:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: [[OPT_LOOP_1:%.+]]:3 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOP_1]]#0, [[DEF_LOOP_1]]#1, [[DEF_LOOP_1]]#2
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOP_1]]#0, [[OPT_LOOP_1]]#1, [[OPT_LOOP_1]]#2) with ([[DEF_LOOP_1]]#0 -> %arg1 = 0 to 16, [[DEF_LOOP_1]]#1 -> %arg2 = 0 to 30, [[DEF_LOOP_1]]#2 -> %arg3 = 0 to 64) {
|
||||||
|
// CHECK: %[[INDEX:.+]] = affine.apply [[INDEX_MAP]](%arg2)
|
||||||
|
// CHECK: [[LOAD_1:%.+]] = load %arg0[%arg1, %[[INDEX]], %arg3] : memref<16x32x64xf32>
|
||||||
|
// CHECK: store [[LOAD_1]], [[RES_1]][%arg1, %arg2, %arg3] : memref<16x30x64xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return [[RES_0]], [[RES_1]] : memref<16x2x64xf32>, memref<16x30x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_split_unknown_dimension(%arg0 : tensor<?x?x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||||
|
%0, %1 = "onnx.Split"(%arg0) { axis = 1, split = [2, 30]} : (tensor<?x?x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK: [[INDEX_MAP:#.+]] = affine_map<(d0) -> (d0 + 2)>
|
||||||
|
// CHECK-LABEL: @test_split_unknown_dimension
|
||||||
|
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x?x64xf32>
|
||||||
|
// CHECK: [[RES_0:%.+]] = alloc([[DIM_0]]) : memref<?x2x64xf32>
|
||||||
|
// CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref<?x?x64xf32>
|
||||||
|
// CHECK: [[RES_1:%.+]] = alloc([[DIM_1]]) : memref<?x30x64xf32>
|
||||||
|
// CHECK: [[DEF_LOOP_0:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: [[OPT_LOOP_0:%.+]]:3 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOP_0]]#0, [[DEF_LOOP_0]]#1, [[DEF_LOOP_0]]#2
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES_0]], 0 : memref<?x2x64xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOP_0]]#0, [[OPT_LOOP_0]]#1, [[OPT_LOOP_0]]#2) with ([[DEF_LOOP_0]]#0 -> %arg1 = 0 to [[DIM_0]], [[DEF_LOOP_0]]#1 -> %arg2 = 0 to 2, [[DEF_LOOP_0]]#2 -> %arg3 = 0 to 64) {
|
||||||
|
// CHECK: [[LOAD_0:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<?x?x64xf32>
|
||||||
|
// CHECK: store [[LOAD_0]], [[RES_0]][%arg1, %arg2, %arg3] : memref<?x2x64xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[DEF_LOOP_1:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: [[OPT_LOOP_1:%.+]]:3 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOP_1]]#0, [[DEF_LOOP_1]]#1, [[DEF_LOOP_1]]#2
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_1:%.+]] = dim [[RES_1]], 0 : memref<?x30x64xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOP_1]]#0, [[OPT_LOOP_1]]#1, [[OPT_LOOP_1]]#2) with ([[DEF_LOOP_1]]#0 -> %arg1 = 0 to [[DIM_1]], [[DEF_LOOP_1]]#1 -> %arg2 = 0 to 30, [[DEF_LOOP_1]]#2 -> %arg3 = 0 to 64) {
|
||||||
|
// CHECK: %[[INDEX:.+]] = affine.apply [[INDEX_MAP]](%arg2)
|
||||||
|
// CHECK: [[LOAD_1:%.+]] = load %arg0[%arg1, %[[INDEX]], %arg3] : memref<?x?x64xf32>
|
||||||
|
// CHECK: store [[LOAD_1]], [[RES_1]][%arg1, %arg2, %arg3] : memref<?x30x64xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return [[RES_0]], [[RES_1]] : memref<?x2x64xf32>, memref<?x30x64xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue