diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index 27d6580..7bdc1fb 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(OMONNXToKrnl Tensor/Unsqueeze.cpp Tensor/Constant.cpp Tensor/Concat.cpp + Tensor/Split.cpp ConvertONNXToKrnl.cpp) target_link_libraries(OMONNXToKrnl onnx) diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index c3de275..fac281a 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -99,6 +99,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { populateLoweringONNXIdentityOpPattern(patterns, &getContext()); populateLoweringONNXConstantOpPattern(patterns, &getContext()); populateLoweringONNXConcatOpPattern(patterns, &getContext()); + populateLoweringONNXSplitOpPattern(patterns, &getContext()); // Neural network populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 69b7f17..2d12677 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -250,3 +250,6 @@ void populateLoweringONNXConstantOpPattern( void populateLoweringONNXConcatOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXSplitOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Split.cpp b/src/Conversion/ONNXToKrnl/Tensor/Split.cpp new file mode 100644 index 0000000..8b5e31c --- /dev/null +++ b/src/Conversion/ONNXToKrnl/Tensor/Split.cpp @@ -0,0 +1,106 @@ +//===---------------- Split.cpp - Lowering Split Op -----------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Split Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +using namespace mlir; + +struct ONNXSplitOpLowering : public ConversionPattern { + ONNXSplitOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXSplitOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Gather info. + auto loc = op->getLoc(); + ONNXSplitOp splitOp = llvm::dyn_cast(op); + auto axis = splitOp.axis().getSExtValue(); + auto split = splitOp.split().getValue(); + SmallVector splitOffset; + int64_t offset = 0; + for (int i = 0; i < split.size(); ++i) { + splitOffset.emplace_back(offset); + offset += ArrayAttrIntVal(split, i); + } + auto rank = splitOp.input().getType().cast().getRank(); + auto outputNum = splitOp.getNumResults(); + + // Alloc and dealloc. + SmallVector allocs; + for (int i = 0; i < outputNum; ++i) { + Value alloc; + bool insertDealloc = checkInsertDealloc(op, i); + auto memRefType = convertToMemRefType(splitOp.outputs()[i].getType()); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else { + SmallVector allocOperands; + auto shape = memRefType.getShape(); + for (decltype(rank) r = 0; r < rank; ++r) { + if (shape[r] < 0) { + Value dim; + if (r != axis) + dim = rewriter.create(loc, operands[0], r); + else + dim = emitConstantOp(rewriter, loc, rewriter.getIndexType(), + ArrayAttrIntVal(split, i)); + allocOperands.push_back(dim); + } + } + alloc = rewriter.create(loc, memRefType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + allocs.emplace_back(alloc); + } + + // Creates loops, one for each output. + for (int i = 0; i < outputNum; ++i) { + OpBuilder::InsertionGuard insertGuard(rewriter); + // Create loop. + BuildKrnlLoop outputLoops(rewriter, loc, rank); + outputLoops.createDefineOptimizeAndIterateOp(allocs[i]); + outputLoops.createIterateOp(); + rewriter.setInsertionPointToStart(outputLoops.getIterateBlock()); + // Indices for the read and write. + SmallVector readIndices; + SmallVector writeIndices; + for (int r = 0; r < rank; ++r) { + // Same index for read and write if the dimension is: + // - the first dimension, or + // - not the split axis. + if (i == 0 || r != axis) { + readIndices.emplace_back(outputLoops.getInductionVar(r)); + } else { + auto index = rewriter.getAffineDimExpr(0); + auto indexMap = AffineMap::get(1, 0, index + splitOffset[i]); + auto indexWithOffset = rewriter.create(loc, indexMap, + ArrayRef{/*index=*/outputLoops.getInductionVar(r)}); + readIndices.emplace_back(indexWithOffset); + } + writeIndices.emplace_back(outputLoops.getInductionVar(r)); + } + // Insert copy. + auto loadData = rewriter.create(loc, operands[0], readIndices); + rewriter.create(loc, loadData, allocs[i], writeIndices); + } + rewriter.replaceOp(op, allocs); + return success(); + } +}; + +void populateLoweringONNXSplitOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/test/backend/test.py b/test/backend/test.py index 62f6a7d..5792967 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -352,8 +352,16 @@ test_to_enable = [ "test_lstm_with_initial_bias_cpu", "test_lstm_with_peepholes_cpu", + # Split + "test_split_equal_parts_1d_cpu", + "test_split_equal_parts_2d_cpu", + "test_split_equal_parts_default_axis_cpu", + "test_split_variable_parts_1d_cpu", + "test_split_variable_parts_2d_cpu", + "test_split_variable_parts_default_axis_cpu", ] + # Extract name of all test cases. import inspect all_tests = inspect.getmembers( diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 6832699..3e3ed4a 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -2134,3 +2134,100 @@ func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x // CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}} // CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32> } + +// ----- + +func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0, %1 = "onnx.Split"(%arg0) { axis = 0} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) + "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> () + + // CHECK: [[INDEX_MAP:#.+]] = affine_map<(d0) -> (d0 + 8)> + // CHECK-LABEL: @test_split_equal + + // CHECK: [[RES_1:%.+]] = alloc() : memref<8x32x64xf32> + // CHECK: [[RES_0:%.+]] = alloc() : memref<8x32x64xf32> + // CHECK: [[DEF_LOOP_0:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOP_0:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOP_0]]#0, [[DEF_LOOP_0]]#1, [[DEF_LOOP_0]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOP_0]]#0, [[OPT_LOOP_0]]#1, [[OPT_LOOP_0]]#2) with ([[DEF_LOOP_0]]#0 -> %arg1 = 0 to 8, [[DEF_LOOP_0]]#1 -> %arg2 = 0 to 32, [[DEF_LOOP_0]]#2 -> %arg3 = 0 to 64) { + // CHECK: [[LOAD_0:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<16x32x64xf32> + // CHECK: store [[LOAD_0]], [[RES_0]][%arg1, %arg2, %arg3] : memref<8x32x64xf32> + // CHECK: } + // CHECK: [[DEF_LOOP_1:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOP_1:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOP_1]]#0, [[DEF_LOOP_1]]#1, [[DEF_LOOP_1]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOP_1]]#0, [[OPT_LOOP_1]]#1, [[OPT_LOOP_1]]#2) with ([[DEF_LOOP_1]]#0 -> %arg1 = 0 to 8, [[DEF_LOOP_1]]#1 -> %arg2 = 0 to 32, [[DEF_LOOP_1]]#2 -> %arg3 = 0 to 64) { + // CHECK: %[[INDEX:.+]] = affine.apply [[INDEX_MAP]](%arg1) + // CHECK: [[LOAD_1:%.+]] = load %arg0[%[[INDEX]], %arg2, %arg3] : memref<16x32x64xf32> + // CHECK: store [[LOAD_1]], [[RES_1]][%arg1, %arg2, %arg3] : memref<8x32x64xf32> + // CHECK: } + // CHECK: return [[RES_0]], [[RES_1]] : memref<8x32x64xf32>, memref<8x32x64xf32> +} + +// ----- + +func @test_split_variable(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0, %1 = "onnx.Split"(%arg0) { axis = 1, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) + "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> () + + // CHECK: [[INDEX_MAP:#.+]] = affine_map<(d0) -> (d0 + 2)> + // CHECK-LABEL: @test_split_variable + + // CHECK: [[RES_1:%.+]] = alloc() : memref<16x30x64xf32> + // CHECK: [[RES_0:%.+]] = alloc() : memref<16x2x64xf32> + // CHECK: [[DEF_LOOP_0:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOP_0:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOP_0]]#0, [[DEF_LOOP_0]]#1, [[DEF_LOOP_0]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOP_0]]#0, [[OPT_LOOP_0]]#1, [[OPT_LOOP_0]]#2) with ([[DEF_LOOP_0]]#0 -> %arg1 = 0 to 16, [[DEF_LOOP_0]]#1 -> %arg2 = 0 to 2, [[DEF_LOOP_0]]#2 -> %arg3 = 0 to 64) { + // CHECK: [[LOAD_0:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref<16x32x64xf32> + // CHECK: store [[LOAD_0]], [[RES_0]][%arg1, %arg2, %arg3] : memref<16x2x64xf32> + // CHECK: } + // CHECK: [[DEF_LOOP_1:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOP_1:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOP_1]]#0, [[DEF_LOOP_1]]#1, [[DEF_LOOP_1]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOP_1]]#0, [[OPT_LOOP_1]]#1, [[OPT_LOOP_1]]#2) with ([[DEF_LOOP_1]]#0 -> %arg1 = 0 to 16, [[DEF_LOOP_1]]#1 -> %arg2 = 0 to 30, [[DEF_LOOP_1]]#2 -> %arg3 = 0 to 64) { + // CHECK: %[[INDEX:.+]] = affine.apply [[INDEX_MAP]](%arg2) + // CHECK: [[LOAD_1:%.+]] = load %arg0[%arg1, %[[INDEX]], %arg3] : memref<16x32x64xf32> + // CHECK: store [[LOAD_1]], [[RES_1]][%arg1, %arg2, %arg3] : memref<16x30x64xf32> + // CHECK: } + // CHECK: return [[RES_0]], [[RES_1]] : memref<16x2x64xf32>, memref<16x30x64xf32> +} + +// ----- + +func @test_split_unknown_dimension(%arg0 : tensor) -> (tensor<*xf32>, tensor<*xf32>) { + %0, %1 = "onnx.Split"(%arg0) { axis = 1, split = [2, 30]} : (tensor) -> (tensor<*xf32>, tensor<*xf32>) + "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> () + + // CHECK: [[INDEX_MAP:#.+]] = affine_map<(d0) -> (d0 + 2)> + // CHECK-LABEL: @test_split_unknown_dimension + + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES_0:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES_1:%.+]] = alloc([[DIM_1]]) : memref + // CHECK: [[DEF_LOOP_0:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOP_0:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOP_0]]#0, [[DEF_LOOP_0]]#1, [[DEF_LOOP_0]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: [[DIM_0:%.+]] = dim [[RES_0]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOP_0]]#0, [[OPT_LOOP_0]]#1, [[OPT_LOOP_0]]#2) with ([[DEF_LOOP_0]]#0 -> %arg1 = 0 to [[DIM_0]], [[DEF_LOOP_0]]#1 -> %arg2 = 0 to 2, [[DEF_LOOP_0]]#2 -> %arg3 = 0 to 64) { + // CHECK: [[LOAD_0:%.+]] = load %arg0[%arg1, %arg2, %arg3] : memref + // CHECK: store [[LOAD_0]], [[RES_0]][%arg1, %arg2, %arg3] : memref + // CHECK: } + // CHECK: [[DEF_LOOP_1:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOP_1:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOP_1]]#0, [[DEF_LOOP_1]]#1, [[DEF_LOOP_1]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: [[DIM_1:%.+]] = dim [[RES_1]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOP_1]]#0, [[OPT_LOOP_1]]#1, [[OPT_LOOP_1]]#2) with ([[DEF_LOOP_1]]#0 -> %arg1 = 0 to [[DIM_1]], [[DEF_LOOP_1]]#1 -> %arg2 = 0 to 30, [[DEF_LOOP_1]]#2 -> %arg3 = 0 to 64) { + // CHECK: %[[INDEX:.+]] = affine.apply [[INDEX_MAP]](%arg2) + // CHECK: [[LOAD_1:%.+]] = load %arg0[%arg1, %[[INDEX]], %arg3] : memref + // CHECK: store [[LOAD_1]], [[RES_1]][%arg1, %arg2, %arg3] : memref + // CHECK: } + // CHECK: return [[RES_0]], [[RES_1]] : memref, memref +}