From fa8962753c841d9b0518af92054fb86a4b3f87b1 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Mon, 13 Apr 2020 11:40:39 -0400 Subject: [PATCH] Concat lower (#82) * implement shape inference for concat * better checking of axis being concatenated: constant values only * lowering of Concat with lit and backend tests * fixes Co-authored-by: Gheorghe-Teodor Bercea --- src/Conversion/ONNXToKrnl/CMakeLists.txt | 3 +- .../ONNXToKrnl/ConvertONNXToKrnl.cpp | 1 + src/Conversion/ONNXToKrnl/NN/Conv.cpp | 2 +- .../ONNXToKrnl/ONNXToKrnlCommon.hpp | 63 +++++++------- src/Conversion/ONNXToKrnl/Tensor/Concat.cpp | 82 +++++++++++++++++++ src/Dialect/ONNX/ONNXOps.cpp | 8 +- test/backend/test.py | 15 ++++ test/mlir/onnx/onnx_lowering.mlir | 37 +++++++++ test/mlir/onnx/onnx_shape_inference.mlir | 9 ++ 9 files changed, 183 insertions(+), 37 deletions(-) create mode 100644 src/Conversion/ONNXToKrnl/Tensor/Concat.cpp diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index 9f027cd..f75cf28 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -15,6 +15,7 @@ add_library(OMONNXToKrnl Tensor/Transpose.cpp Tensor/Unsqueeze.cpp Tensor/Constant.cpp + Tensor/Concat.cpp ConvertONNXToKrnl.cpp) target_include_directories(OMONNXToKrnl PRIVATE @@ -23,4 +24,4 @@ target_include_directories(OMONNXToKrnl ${ONNX_MLIR_SRC_ROOT}) target_link_libraries(OMONNXToKrnl ${MLIRLibs} - OMKrnlOps) \ No newline at end of file + OMKrnlOps) diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 662fc80..be67331 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -98,6 +98,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { populateLoweringONNXTransposeOpPattern(patterns, &getContext()); populateLoweringONNXIdentityOpPattern(patterns, &getContext()); populateLoweringONNXConstantOpPattern(patterns, &getContext()); + populateLoweringONNXConcatOpPattern(patterns, &getContext()); // Neural network populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); diff --git a/src/Conversion/ONNXToKrnl/NN/Conv.cpp b/src/Conversion/ONNXToKrnl/NN/Conv.cpp index 070527d..aa1e076 100644 --- a/src/Conversion/ONNXToKrnl/NN/Conv.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Conv.cpp @@ -115,7 +115,7 @@ struct ONNXConvOpLowering : public ConversionPattern { gIndex = outerLoops.pushBounds(0, group); // for m = 0 .. kernelsPerGroup: int mIndex = outerLoops.pushBounds(0, kernelsPerGroup); - // Outer loop iteration + // Outer loop iterations. outerLoops.createIterateOp(); rewriter.setInsertionPointToStart(outerLoops.getIterateBlock()); { diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 2c0c5f7..044755c 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -15,11 +15,11 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Sequence.h" -#include "mlir/IR/PatternMatch.h" #include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" @@ -40,9 +40,8 @@ MemRefType convertToMemRefType(Type type); /// Insert an allocation and deallocation for the given MemRefType. Value insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter, - bool insertDealloc, - ArrayRef operands = {}); + PatternRewriter &rewriter, bool insertDealloc, + ArrayRef operands = {}); // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be @@ -52,51 +51,46 @@ bool checkInsertDealloc(Operation *currentOp); // Create a mapping from result type's dimensions to input type's dimensions, // given that the result type is the result of a reduction op over the input // type. -std::map -getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims); +std::map getReductionMapping( + MemRefType inputTy, ArrayRef axes, bool keepdims); // Add bounds associated with the op operand to the KRNL iteration pack. // Dynamic dimenions are supported. -void addDimensionToPack(ConversionPatternRewriter &rewriter, - Location loc, KrnlIterateOperandPack &pack, - Value operand, int index); +void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc, + KrnlIterateOperandPack &pack, Value operand, int index); // Function that defines the KRNL dialect loops and their respective // optimized version. -KrnlOptimizeLoopsOp -emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, - std::vector &loops, - std::vector &optimizedLoops, int64_t numLoops); +KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter, + Location loc, std::vector &loops, std::vector &optimizedLoops, + int64_t numLoops); // Function that emits the loops and their optimized version. // The function returns a reference to the inner optimization block. Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, - std::vector &loops, - std::vector &optimizedLoops, - int64_t numLoops); + std::vector &loops, std::vector &optimizedLoops, + int64_t numLoops); // Function which emits a basic set of loops and optimized loops // for a given operation argument. A reference to the loop optimization // block is returned in the last argument of the function. -void emitKrnlLoopsAndIterationForOperand( - ConversionPatternRewriter &rewriter, Location loc, Value operand, - std::vector &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, - KrnlIterateOp &iterateOp); +void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter, + Location loc, Value operand, std::vector &originalLoops, + KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp); unsigned getMemRefEltSizeInBytes(MemRefType memRefType); // Get run-time dimension information for unknown dimensions used for // broadcasting. -std::map> -getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, - MemRefType memRefType, ArrayRef operands); +std::map> getBroadcastedDimInfo(Location loc, + ConversionPatternRewriter &rewriter, MemRefType memRefType, + ArrayRef operands); // Extract induction variables that are used for broadcasting values of a // given operand. -std::vector -getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, - ArrayRef loopIVs, Value operand, - std::map broadcastedDims); +std::vector getLoopIVsForBroadcasting(Location loc, + ConversionPatternRewriter &rewriter, ArrayRef loopIVs, Value operand, + std::map broadcastedDims); // Emit a constant of a specific type. // Use this function for small values only to avoid unexpected loss in type @@ -169,9 +163,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, struct TensorTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; - TensorTypeConverter() { - addConversion(convertType); - } + TensorTypeConverter() { addConversion(convertType); } static LogicalResult convertType(Type t, SmallVectorImpl &results) { if (auto type = convertToMemRefType(t)) { @@ -188,8 +180,8 @@ struct TensorTypeConverter : public TypeConverter { /// inputs. Once unranked results can be handled gracefully this /// override needs to be removed in favour of the original MLIR one.] bool isSignatureLegal(FunctionType funcType) { - return llvm::all_of(funcType.getInputs(), - [this](Type type) { return isLegal(type); }); + return llvm::all_of( + funcType.getInputs(), [this](Type type) { return isLegal(type); }); } }; @@ -202,8 +194,8 @@ struct TensorTypeConverter : public TypeConverter { void populateLoweringONNXElementwiseOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); -void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, - MLIRContext *ctx); +void populateLoweringONNXGemmOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); void populateLoweringONNXMatMulOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); @@ -244,3 +236,6 @@ void populateLoweringONNXIdentityOpPattern( void populateLoweringONNXConstantOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXConcatOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp new file mode 100644 index 0000000..4511fc0 --- /dev/null +++ b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp @@ -0,0 +1,82 @@ +//===---------------- Concat.cpp - Lowering Concat Op -------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Concat Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +using namespace mlir; + +struct ONNXConcatOpLowering : public ConversionPattern { + ONNXConcatOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Gather info. + auto loc = op->getLoc(); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + ONNXConcatOp concatOp = llvm::dyn_cast(op); + auto axis = concatOp.axis().getSExtValue(); + int inputNum = operands.size(); + // Alloc and dealloc. + auto resultOperand = concatOp.concat_result(); + auto memRefType = convertToMemRefType(*op->result_type_begin()); + auto resultShape = memRefType.getShape(); + auto rank = resultShape.size(); + assert((axis >=0 && axis < rank) && "Concat axis out of bounds"); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, {resultOperand}); + + // Creates loops, one for each input. + int writeOffset = 0; + for (int i = 0; i < inputNum; ++i) { + OpBuilder::InsertionGuard insertGuard(rewriter); + // Operand info. + auto currShape = operands[i].getType().cast().getShape(); + // Create loop. + BuildKrnlLoop inputLoops(rewriter, loc, rank); + inputLoops.createDefineAndOptimizeOp(); + for (int r = 0; r < rank; ++r) + inputLoops.pushBounds(0, operands[i], r); + inputLoops.createIterateOp(); + rewriter.setInsertionPointToStart(inputLoops.getIterateBlock()); + // Indices for the read and write. + SmallVector readIndices; + SmallVector writeIndices; + for (int r = 0; r < rank; ++r) { + readIndices.emplace_back(inputLoops.getInductionVar(r)); + if (r != axis || writeOffset == 0) { + writeIndices.emplace_back(inputLoops.getInductionVar(r)); + } else { + auto indexWithOffset = rewriter.create(loc, + rewriter.create(loc, writeOffset), + inputLoops.getInductionVar(r)); + writeIndices.emplace_back(indexWithOffset); + } + } + // Insert copy. + auto loadData = rewriter.create(loc, operands[i], readIndices); + rewriter.create(loc, loadData, alloc, writeIndices); + // Increment offset + writeOffset += currShape[axis]; + } + rewriter.replaceOp(op, alloc); + return success(); + } +}; + +void populateLoweringONNXConcatOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index cb8c8d3..5da613f 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -1479,7 +1479,13 @@ bool ONNXConcatOp::inferShapes() { auto commonShape = commonType.getShape(); auto commonRank = commonShape.size(); auto axisIndex = axis().getSExtValue(); - if (!(axisIndex >= 0 && axisIndex < commonRank)) { + // Negative axis means values are counted from the opposite side. + if (axisIndex < 0) { + axisIndex = commonRank + axisIndex; + auto builder = mlir::Builder(getContext()); + axisAttr(builder.getI64IntegerAttr(axisIndex)); + } + if (axisIndex >= commonRank) { emitError("Concat axis value out of bound"); return false; } diff --git a/test/backend/test.py b/test/backend/test.py index f47ae77..02d1959 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -82,6 +82,21 @@ test_to_enable = [ "test_cosh_cpu", "test_cosh_example_cpu", + # Concat + "test_concat_1d_axis_0_cpu", + "test_concat_2d_axis_0_cpu", + "test_concat_2d_axis_1_cpu", + "test_concat_3d_axis_0_cpu", + "test_concat_3d_axis_1_cpu", + "test_concat_3d_axis_2_cpu", + + "test_concat_1d_axis_negative_1_cpu", + "test_concat_2d_axis_negative_1_cpu", + "test_concat_2d_axis_negative_2_cpu", + "test_concat_3d_axis_negative_1_cpu", + "test_concat_3d_axis_negative_2_cpu", + "test_concat_3d_axis_negative_3_cpu", + # Tanh: "test_tanh_cpu", "test_tanh_example_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index db42e01..36a3e16 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1688,3 +1688,40 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> { // CHECK: return [[RES]] : memref<3x2xf32> } + +func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, %arg2 : tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32> { + %1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 2 } : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32> + "std.return"(%1) : (tensor<5x5x9x32xf32>) -> () + + // CHECK-LABEL: test_concat_1 + // CHECK: [[RES:%.+]] = alloc() : memref<5x5x9x32xf32> + // CHECK: [[DEF_LOOPS0:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS0:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS0]]#0, [[DEF_LOOPS0]]#1, [[DEF_LOOPS0]]#2, [[DEF_LOOPS0]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS0]]#0, [[OPT_LOOPS0]]#1, [[OPT_LOOPS0]]#2, [[OPT_LOOPS0]]#3) with ([[DEF_LOOPS0]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS0]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS0]]#2 -> %arg5 = 0 to 1, [[DEF_LOOPS0]]#3 -> %arg6 = 0 to 32) { + // CHECK: [[LOAD0:%.+]] = load %arg0[%arg3, %arg4, %arg5, %arg6] : memref<5x5x1x32xf32> + // CHECK: store [[LOAD0]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<5x5x9x32xf32> + + // CHECK: [[DEF_LOOPS1:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS1:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1, [[DEF_LOOPS1]]#2, [[DEF_LOOPS1]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1, [[OPT_LOOPS1]]#2, [[OPT_LOOPS1]]#3) with ([[DEF_LOOPS1]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS1]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS1]]#2 -> %arg5 = 0 to 3, [[DEF_LOOPS1]]#3 -> %arg6 = 0 to 32) { + // CHECK: [[OFF1:%.+]] = constant 1 : index + // CHECK: [[ADD1:%.+]] = addi [[OFF1]], %arg5 : index + // CHECK: [[LOAD1:%.+]] = load %arg1[%arg3, %arg4, %arg5, %arg6] : memref<5x5x3x32xf32> + // CHECK: store [[LOAD1]], [[RES]][%arg3, %arg4, [[ADD1]], %arg6] : memref<5x5x9x32xf32> + + // CHECK: [[DEF_LOOPS2:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS2:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2, [[DEF_LOOPS2]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2, [[OPT_LOOPS2]]#3) with ([[DEF_LOOPS2]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS2]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS2]]#2 -> %arg5 = 0 to 5, [[DEF_LOOPS2]]#3 -> %arg6 = 0 to 32) { + // CHECK: [[OFF2:%.+]] = constant 4 : index + // CHECK: [[ADD2:%.+]] = addi [[OFF2]], %arg5 : index + // CHECK: [[LOAD2:%.+]] = load %arg2[%arg3, %arg4, %arg5, %arg6] : memref<5x5x5x32xf32> + // CHECK: store [[LOAD2]], [[RES]][%arg3, %arg4, [[ADD2]], %arg6] : memref<5x5x9x32xf32> + + // CHECK: return [[RES]] : memref<5x5x9x32xf32> +} diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index d44895c..5485105 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -509,3 +509,12 @@ func @test_concat_2(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg // CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32> // CHECK: return [[RES]] : tensor<5x9x32xf32> } + +func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg2 : tensor<5x5x32xf32>) -> tensor<*xf32> { + %1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 } : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_concat_3 + // CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32> + // CHECK: return [[RES]] : tensor<5x9x32xf32> +}