Concat lower (#82)

* implement shape inference for concat

* better checking of axis being concatenated: constant values only

* lowering of Concat with lit and backend tests

* fixes

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Alexandre Eichenberger 2020-04-13 11:40:39 -04:00 committed by GitHub
parent 81d8c5e2ea
commit fa8962753c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 183 additions and 37 deletions

View File

@ -15,6 +15,7 @@ add_library(OMONNXToKrnl
Tensor/Transpose.cpp Tensor/Transpose.cpp
Tensor/Unsqueeze.cpp Tensor/Unsqueeze.cpp
Tensor/Constant.cpp Tensor/Constant.cpp
Tensor/Concat.cpp
ConvertONNXToKrnl.cpp) ConvertONNXToKrnl.cpp)
target_include_directories(OMONNXToKrnl target_include_directories(OMONNXToKrnl
PRIVATE PRIVATE

View File

@ -98,6 +98,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
populateLoweringONNXTransposeOpPattern(patterns, &getContext()); populateLoweringONNXTransposeOpPattern(patterns, &getContext());
populateLoweringONNXIdentityOpPattern(patterns, &getContext()); populateLoweringONNXIdentityOpPattern(patterns, &getContext());
populateLoweringONNXConstantOpPattern(patterns, &getContext()); populateLoweringONNXConstantOpPattern(patterns, &getContext());
populateLoweringONNXConcatOpPattern(patterns, &getContext());
// Neural network // Neural network
populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXConvOpPattern(patterns, &getContext());
populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); populateLoweringONNXNormalizationOpPattern(patterns, &getContext());

View File

@ -115,7 +115,7 @@ struct ONNXConvOpLowering : public ConversionPattern {
gIndex = outerLoops.pushBounds(0, group); gIndex = outerLoops.pushBounds(0, group);
// for m = 0 .. kernelsPerGroup: // for m = 0 .. kernelsPerGroup:
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup); int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
// Outer loop iteration // Outer loop iterations.
outerLoops.createIterateOp(); outerLoops.createIterateOp();
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock()); rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
{ {

View File

@ -15,11 +15,11 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h" #include "llvm/ADT/Sequence.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp"
@ -40,9 +40,8 @@ MemRefType convertToMemRefType(Type type);
/// Insert an allocation and deallocation for the given MemRefType. /// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc, Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter, PatternRewriter &rewriter, bool insertDealloc,
bool insertDealloc, ArrayRef<Value> operands = {});
ArrayRef<Value> operands = {});
// Determine if current function returns the result value of the // Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be // current op being lowered. If it does then dealloc should not be
@ -52,51 +51,46 @@ bool checkInsertDealloc(Operation *currentOp);
// Create a mapping from result type's dimensions to input type's dimensions, // Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input // given that the result type is the result of a reduction op over the input
// type. // type.
std::map<int64_t, int64_t> std::map<int64_t, int64_t> getReductionMapping(
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims); MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
// Add bounds associated with the op operand to the KRNL iteration pack. // Add bounds associated with the op operand to the KRNL iteration pack.
// Dynamic dimenions are supported. // Dynamic dimenions are supported.
void addDimensionToPack(ConversionPatternRewriter &rewriter, void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
Location loc, KrnlIterateOperandPack &pack, KrnlIterateOperandPack &pack, Value operand, int index);
Value operand, int index);
// Function that defines the KRNL dialect loops and their respective // Function that defines the KRNL dialect loops and their respective
// optimized version. // optimized version.
KrnlOptimizeLoopsOp KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter,
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
std::vector<Value> &loops, int64_t numLoops);
std::vector<Value> &optimizedLoops, int64_t numLoops);
// Function that emits the loops and their optimized version. // Function that emits the loops and their optimized version.
// The function returns a reference to the inner optimization block. // The function returns a reference to the inner optimization block.
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
std::vector<Value> &loops, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
std::vector<Value> &optimizedLoops, int64_t numLoops);
int64_t numLoops);
// Function which emits a basic set of loops and optimized loops // Function which emits a basic set of loops and optimized loops
// for a given operation argument. A reference to the loop optimization // for a given operation argument. A reference to the loop optimization
// block is returned in the last argument of the function. // block is returned in the last argument of the function.
void emitKrnlLoopsAndIterationForOperand( void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter,
ConversionPatternRewriter &rewriter, Location loc, Value operand, Location loc, Value operand, std::vector<Value> &originalLoops,
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp);
KrnlIterateOp &iterateOp);
unsigned getMemRefEltSizeInBytes(MemRefType memRefType); unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
std::map<int, std::map<int, Value>> std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter, MemRefType memRefType,
MemRefType memRefType, ArrayRef<Value> operands); ArrayRef<Value> operands);
// Extract induction variables that are used for broadcasting values of a // Extract induction variables that are used for broadcasting values of a
// given operand. // given operand.
std::vector<Value> std::vector<Value> getLoopIVsForBroadcasting(Location loc,
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
ArrayRef<Value> loopIVs, Value operand, std::map<int, Value> broadcastedDims);
std::map<int, Value> broadcastedDims);
// Emit a constant of a specific type. // Emit a constant of a specific type.
// Use this function for small values only to avoid unexpected loss in type // Use this function for small values only to avoid unexpected loss in type
@ -169,9 +163,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc,
struct TensorTypeConverter : public TypeConverter { struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter; using TypeConverter::TypeConverter;
TensorTypeConverter() { TensorTypeConverter() { addConversion(convertType); }
addConversion(convertType);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto type = convertToMemRefType(t)) { if (auto type = convertToMemRefType(t)) {
@ -188,8 +180,8 @@ struct TensorTypeConverter : public TypeConverter {
/// inputs. Once unranked results can be handled gracefully this /// inputs. Once unranked results can be handled gracefully this
/// override needs to be removed in favour of the original MLIR one.] /// override needs to be removed in favour of the original MLIR one.]
bool isSignatureLegal(FunctionType funcType) { bool isSignatureLegal(FunctionType funcType) {
return llvm::all_of(funcType.getInputs(), return llvm::all_of(
[this](Type type) { return isLegal(type); }); funcType.getInputs(), [this](Type type) { return isLegal(type); });
} }
}; };
@ -202,8 +194,8 @@ struct TensorTypeConverter : public TypeConverter {
void populateLoweringONNXElementwiseOpPattern( void populateLoweringONNXElementwiseOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, void populateLoweringONNXGemmOpPattern(
MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXMatMulOpPattern( void populateLoweringONNXMatMulOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);
@ -244,3 +236,6 @@ void populateLoweringONNXIdentityOpPattern(
void populateLoweringONNXConstantOpPattern( void populateLoweringONNXConstantOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXConcatOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);

View File

@ -0,0 +1,82 @@
//===---------------- Concat.cpp - Lowering Concat Op -------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Concat Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir;
struct ONNXConcatOpLowering : public ConversionPattern {
ONNXConcatOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Gather info.
auto loc = op->getLoc();
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
ONNXConcatOp concatOp = llvm::dyn_cast<ONNXConcatOp>(op);
auto axis = concatOp.axis().getSExtValue();
int inputNum = operands.size();
// Alloc and dealloc.
auto resultOperand = concatOp.concat_result();
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto resultShape = memRefType.getShape();
auto rank = resultShape.size();
assert((axis >=0 && axis < rank) && "Concat axis out of bounds");
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {resultOperand});
// Creates loops, one for each input.
int writeOffset = 0;
for (int i = 0; i < inputNum; ++i) {
OpBuilder::InsertionGuard insertGuard(rewriter);
// Operand info.
auto currShape = operands[i].getType().cast<MemRefType>().getShape();
// Create loop.
BuildKrnlLoop inputLoops(rewriter, loc, rank);
inputLoops.createDefineAndOptimizeOp();
for (int r = 0; r < rank; ++r)
inputLoops.pushBounds(0, operands[i], r);
inputLoops.createIterateOp();
rewriter.setInsertionPointToStart(inputLoops.getIterateBlock());
// Indices for the read and write.
SmallVector<Value, 4> readIndices;
SmallVector<Value, 4> writeIndices;
for (int r = 0; r < rank; ++r) {
readIndices.emplace_back(inputLoops.getInductionVar(r));
if (r != axis || writeOffset == 0) {
writeIndices.emplace_back(inputLoops.getInductionVar(r));
} else {
auto indexWithOffset = rewriter.create<AddIOp>(loc,
rewriter.create<ConstantIndexOp>(loc, writeOffset),
inputLoops.getInductionVar(r));
writeIndices.emplace_back(indexWithOffset);
}
}
// Insert copy.
auto loadData = rewriter.create<LoadOp>(loc, operands[i], readIndices);
rewriter.create<StoreOp>(loc, loadData, alloc, writeIndices);
// Increment offset
writeOffset += currShape[axis];
}
rewriter.replaceOp(op, alloc);
return success();
}
};
void populateLoweringONNXConcatOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConcatOpLowering>(ctx);
}

View File

@ -1479,7 +1479,13 @@ bool ONNXConcatOp::inferShapes() {
auto commonShape = commonType.getShape(); auto commonShape = commonType.getShape();
auto commonRank = commonShape.size(); auto commonRank = commonShape.size();
auto axisIndex = axis().getSExtValue(); auto axisIndex = axis().getSExtValue();
if (!(axisIndex >= 0 && axisIndex < commonRank)) { // Negative axis means values are counted from the opposite side.
if (axisIndex < 0) {
axisIndex = commonRank + axisIndex;
auto builder = mlir::Builder(getContext());
axisAttr(builder.getI64IntegerAttr(axisIndex));
}
if (axisIndex >= commonRank) {
emitError("Concat axis value out of bound"); emitError("Concat axis value out of bound");
return false; return false;
} }

View File

@ -82,6 +82,21 @@ test_to_enable = [
"test_cosh_cpu", "test_cosh_cpu",
"test_cosh_example_cpu", "test_cosh_example_cpu",
# Concat
"test_concat_1d_axis_0_cpu",
"test_concat_2d_axis_0_cpu",
"test_concat_2d_axis_1_cpu",
"test_concat_3d_axis_0_cpu",
"test_concat_3d_axis_1_cpu",
"test_concat_3d_axis_2_cpu",
"test_concat_1d_axis_negative_1_cpu",
"test_concat_2d_axis_negative_1_cpu",
"test_concat_2d_axis_negative_2_cpu",
"test_concat_3d_axis_negative_1_cpu",
"test_concat_3d_axis_negative_2_cpu",
"test_concat_3d_axis_negative_3_cpu",
# Tanh: # Tanh:
"test_tanh_cpu", "test_tanh_cpu",
"test_tanh_example_cpu", "test_tanh_example_cpu",

View File

@ -1688,3 +1688,40 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : memref<3x2xf32> // CHECK: return [[RES]] : memref<3x2xf32>
} }
func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, %arg2 : tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32> {
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 2 } : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32>
"std.return"(%1) : (tensor<5x5x9x32xf32>) -> ()
// CHECK-LABEL: test_concat_1
// CHECK: [[RES:%.+]] = alloc() : memref<5x5x9x32xf32>
// CHECK: [[DEF_LOOPS0:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS0:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS0]]#0, [[DEF_LOOPS0]]#1, [[DEF_LOOPS0]]#2, [[DEF_LOOPS0]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS0]]#0, [[OPT_LOOPS0]]#1, [[OPT_LOOPS0]]#2, [[OPT_LOOPS0]]#3) with ([[DEF_LOOPS0]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS0]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS0]]#2 -> %arg5 = 0 to 1, [[DEF_LOOPS0]]#3 -> %arg6 = 0 to 32) {
// CHECK: [[LOAD0:%.+]] = load %arg0[%arg3, %arg4, %arg5, %arg6] : memref<5x5x1x32xf32>
// CHECK: store [[LOAD0]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<5x5x9x32xf32>
// CHECK: [[DEF_LOOPS1:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS1:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1, [[DEF_LOOPS1]]#2, [[DEF_LOOPS1]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1, [[OPT_LOOPS1]]#2, [[OPT_LOOPS1]]#3) with ([[DEF_LOOPS1]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS1]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS1]]#2 -> %arg5 = 0 to 3, [[DEF_LOOPS1]]#3 -> %arg6 = 0 to 32) {
// CHECK: [[OFF1:%.+]] = constant 1 : index
// CHECK: [[ADD1:%.+]] = addi [[OFF1]], %arg5 : index
// CHECK: [[LOAD1:%.+]] = load %arg1[%arg3, %arg4, %arg5, %arg6] : memref<5x5x3x32xf32>
// CHECK: store [[LOAD1]], [[RES]][%arg3, %arg4, [[ADD1]], %arg6] : memref<5x5x9x32xf32>
// CHECK: [[DEF_LOOPS2:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS2:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2, [[DEF_LOOPS2]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2, [[OPT_LOOPS2]]#3) with ([[DEF_LOOPS2]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS2]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS2]]#2 -> %arg5 = 0 to 5, [[DEF_LOOPS2]]#3 -> %arg6 = 0 to 32) {
// CHECK: [[OFF2:%.+]] = constant 4 : index
// CHECK: [[ADD2:%.+]] = addi [[OFF2]], %arg5 : index
// CHECK: [[LOAD2:%.+]] = load %arg2[%arg3, %arg4, %arg5, %arg6] : memref<5x5x5x32xf32>
// CHECK: store [[LOAD2]], [[RES]][%arg3, %arg4, [[ADD2]], %arg6] : memref<5x5x9x32xf32>
// CHECK: return [[RES]] : memref<5x5x9x32xf32>
}

View File

@ -509,3 +509,12 @@ func @test_concat_2(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32> // CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
// CHECK: return [[RES]] : tensor<5x9x32xf32> // CHECK: return [[RES]] : tensor<5x9x32xf32>
} }
func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg2 : tensor<5x5x32xf32>) -> tensor<*xf32> {
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 } : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_concat_3
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
// CHECK: return [[RES]] : tensor<5x9x32xf32>
}