Concat lower (#82)
* implement shape inference for concat * better checking of axis being concatenated: constant values only * lowering of Concat with lit and backend tests * fixes Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
81d8c5e2ea
commit
fa8962753c
|
@ -15,6 +15,7 @@ add_library(OMONNXToKrnl
|
|||
Tensor/Transpose.cpp
|
||||
Tensor/Unsqueeze.cpp
|
||||
Tensor/Constant.cpp
|
||||
Tensor/Concat.cpp
|
||||
ConvertONNXToKrnl.cpp)
|
||||
target_include_directories(OMONNXToKrnl
|
||||
PRIVATE
|
||||
|
@ -23,4 +24,4 @@ target_include_directories(OMONNXToKrnl
|
|||
${ONNX_MLIR_SRC_ROOT})
|
||||
target_link_libraries(OMONNXToKrnl
|
||||
${MLIRLibs}
|
||||
OMKrnlOps)
|
||||
OMKrnlOps)
|
||||
|
|
|
@ -98,6 +98,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
||||
// Neural network
|
||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||
|
|
|
@ -115,7 +115,7 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
|||
gIndex = outerLoops.pushBounds(0, group);
|
||||
// for m = 0 .. kernelsPerGroup:
|
||||
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
|
||||
// Outer loop iteration
|
||||
// Outer loop iterations.
|
||||
outerLoops.createIterateOp();
|
||||
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
|
||||
{
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Dialect/Krnl/KrnlHelper.hpp"
|
||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||
|
@ -40,9 +40,8 @@ MemRefType convertToMemRefType(Type type);
|
|||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter,
|
||||
bool insertDealloc,
|
||||
ArrayRef<Value> operands = {});
|
||||
PatternRewriter &rewriter, bool insertDealloc,
|
||||
ArrayRef<Value> operands = {});
|
||||
|
||||
// Determine if current function returns the result value of the
|
||||
// current op being lowered. If it does then dealloc should not be
|
||||
|
@ -52,51 +51,46 @@ bool checkInsertDealloc(Operation *currentOp);
|
|||
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||
// given that the result type is the result of a reduction op over the input
|
||||
// type.
|
||||
std::map<int64_t, int64_t>
|
||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
|
||||
std::map<int64_t, int64_t> getReductionMapping(
|
||||
MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
|
||||
|
||||
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||
// Dynamic dimenions are supported.
|
||||
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
||||
Location loc, KrnlIterateOperandPack &pack,
|
||||
Value operand, int index);
|
||||
void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
|
||||
KrnlIterateOperandPack &pack, Value operand, int index);
|
||||
|
||||
// Function that defines the KRNL dialect loops and their respective
|
||||
// optimized version.
|
||||
KrnlOptimizeLoopsOp
|
||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops, int64_t numLoops);
|
||||
KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter,
|
||||
Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
|
||||
int64_t numLoops);
|
||||
|
||||
// Function that emits the loops and their optimized version.
|
||||
// The function returns a reference to the inner optimization block.
|
||||
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||
std::vector<Value> &loops,
|
||||
std::vector<Value> &optimizedLoops,
|
||||
int64_t numLoops);
|
||||
std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
|
||||
int64_t numLoops);
|
||||
|
||||
// Function which emits a basic set of loops and optimized loops
|
||||
// for a given operation argument. A reference to the loop optimization
|
||||
// block is returned in the last argument of the function.
|
||||
void emitKrnlLoopsAndIterationForOperand(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
||||
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
||||
KrnlIterateOp &iterateOp);
|
||||
void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value operand, std::vector<Value> &originalLoops,
|
||||
KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp);
|
||||
|
||||
unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
|
||||
|
||||
// Get run-time dimension information for unknown dimensions used for
|
||||
// broadcasting.
|
||||
std::map<int, std::map<int, Value>>
|
||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
||||
MemRefType memRefType, ArrayRef<Value> operands);
|
||||
std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
|
||||
ConversionPatternRewriter &rewriter, MemRefType memRefType,
|
||||
ArrayRef<Value> operands);
|
||||
|
||||
// Extract induction variables that are used for broadcasting values of a
|
||||
// given operand.
|
||||
std::vector<Value>
|
||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<Value> loopIVs, Value operand,
|
||||
std::map<int, Value> broadcastedDims);
|
||||
std::vector<Value> getLoopIVsForBroadcasting(Location loc,
|
||||
ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
|
||||
std::map<int, Value> broadcastedDims);
|
||||
|
||||
// Emit a constant of a specific type.
|
||||
// Use this function for small values only to avoid unexpected loss in type
|
||||
|
@ -169,9 +163,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc,
|
|||
struct TensorTypeConverter : public TypeConverter {
|
||||
using TypeConverter::TypeConverter;
|
||||
|
||||
TensorTypeConverter() {
|
||||
addConversion(convertType);
|
||||
}
|
||||
TensorTypeConverter() { addConversion(convertType); }
|
||||
|
||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||
if (auto type = convertToMemRefType(t)) {
|
||||
|
@ -188,8 +180,8 @@ struct TensorTypeConverter : public TypeConverter {
|
|||
/// inputs. Once unranked results can be handled gracefully this
|
||||
/// override needs to be removed in favour of the original MLIR one.]
|
||||
bool isSignatureLegal(FunctionType funcType) {
|
||||
return llvm::all_of(funcType.getInputs(),
|
||||
[this](Type type) { return isLegal(type); });
|
||||
return llvm::all_of(
|
||||
funcType.getInputs(), [this](Type type) { return isLegal(type); });
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -202,8 +194,8 @@ struct TensorTypeConverter : public TypeConverter {
|
|||
void populateLoweringONNXElementwiseOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx);
|
||||
void populateLoweringONNXGemmOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXMatMulOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
@ -244,3 +236,6 @@ void populateLoweringONNXIdentityOpPattern(
|
|||
|
||||
void populateLoweringONNXConstantOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXConcatOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
//===---------------- Concat.cpp - Lowering Concat Op -------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file lowers the ONNX Concat Operator to Krnl dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXConcatOpLowering : public ConversionPattern {
|
||||
ONNXConcatOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// Gather info.
|
||||
auto loc = op->getLoc();
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
ONNXConcatOp concatOp = llvm::dyn_cast<ONNXConcatOp>(op);
|
||||
auto axis = concatOp.axis().getSExtValue();
|
||||
int inputNum = operands.size();
|
||||
// Alloc and dealloc.
|
||||
auto resultOperand = concatOp.concat_result();
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto resultShape = memRefType.getShape();
|
||||
auto rank = resultShape.size();
|
||||
assert((axis >=0 && axis < rank) && "Concat axis out of bounds");
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc, {resultOperand});
|
||||
|
||||
// Creates loops, one for each input.
|
||||
int writeOffset = 0;
|
||||
for (int i = 0; i < inputNum; ++i) {
|
||||
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||
// Operand info.
|
||||
auto currShape = operands[i].getType().cast<MemRefType>().getShape();
|
||||
// Create loop.
|
||||
BuildKrnlLoop inputLoops(rewriter, loc, rank);
|
||||
inputLoops.createDefineAndOptimizeOp();
|
||||
for (int r = 0; r < rank; ++r)
|
||||
inputLoops.pushBounds(0, operands[i], r);
|
||||
inputLoops.createIterateOp();
|
||||
rewriter.setInsertionPointToStart(inputLoops.getIterateBlock());
|
||||
// Indices for the read and write.
|
||||
SmallVector<Value, 4> readIndices;
|
||||
SmallVector<Value, 4> writeIndices;
|
||||
for (int r = 0; r < rank; ++r) {
|
||||
readIndices.emplace_back(inputLoops.getInductionVar(r));
|
||||
if (r != axis || writeOffset == 0) {
|
||||
writeIndices.emplace_back(inputLoops.getInductionVar(r));
|
||||
} else {
|
||||
auto indexWithOffset = rewriter.create<AddIOp>(loc,
|
||||
rewriter.create<ConstantIndexOp>(loc, writeOffset),
|
||||
inputLoops.getInductionVar(r));
|
||||
writeIndices.emplace_back(indexWithOffset);
|
||||
}
|
||||
}
|
||||
// Insert copy.
|
||||
auto loadData = rewriter.create<LoadOp>(loc, operands[i], readIndices);
|
||||
rewriter.create<StoreOp>(loc, loadData, alloc, writeIndices);
|
||||
// Increment offset
|
||||
writeOffset += currShape[axis];
|
||||
}
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLoweringONNXConcatOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXConcatOpLowering>(ctx);
|
||||
}
|
|
@ -1479,7 +1479,13 @@ bool ONNXConcatOp::inferShapes() {
|
|||
auto commonShape = commonType.getShape();
|
||||
auto commonRank = commonShape.size();
|
||||
auto axisIndex = axis().getSExtValue();
|
||||
if (!(axisIndex >= 0 && axisIndex < commonRank)) {
|
||||
// Negative axis means values are counted from the opposite side.
|
||||
if (axisIndex < 0) {
|
||||
axisIndex = commonRank + axisIndex;
|
||||
auto builder = mlir::Builder(getContext());
|
||||
axisAttr(builder.getI64IntegerAttr(axisIndex));
|
||||
}
|
||||
if (axisIndex >= commonRank) {
|
||||
emitError("Concat axis value out of bound");
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -82,6 +82,21 @@ test_to_enable = [
|
|||
"test_cosh_cpu",
|
||||
"test_cosh_example_cpu",
|
||||
|
||||
# Concat
|
||||
"test_concat_1d_axis_0_cpu",
|
||||
"test_concat_2d_axis_0_cpu",
|
||||
"test_concat_2d_axis_1_cpu",
|
||||
"test_concat_3d_axis_0_cpu",
|
||||
"test_concat_3d_axis_1_cpu",
|
||||
"test_concat_3d_axis_2_cpu",
|
||||
|
||||
"test_concat_1d_axis_negative_1_cpu",
|
||||
"test_concat_2d_axis_negative_1_cpu",
|
||||
"test_concat_2d_axis_negative_2_cpu",
|
||||
"test_concat_3d_axis_negative_1_cpu",
|
||||
"test_concat_3d_axis_negative_2_cpu",
|
||||
"test_concat_3d_axis_negative_3_cpu",
|
||||
|
||||
# Tanh:
|
||||
"test_tanh_cpu",
|
||||
"test_tanh_example_cpu",
|
||||
|
|
|
@ -1688,3 +1688,40 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
|||
// CHECK: return [[RES]] : memref<3x2xf32>
|
||||
}
|
||||
|
||||
|
||||
func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, %arg2 : tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32> {
|
||||
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 2 } : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32>
|
||||
"std.return"(%1) : (tensor<5x5x9x32xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_concat_1
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<5x5x9x32xf32>
|
||||
// CHECK: [[DEF_LOOPS0:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS0:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS0]]#0, [[DEF_LOOPS0]]#1, [[DEF_LOOPS0]]#2, [[DEF_LOOPS0]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS0]]#0, [[OPT_LOOPS0]]#1, [[OPT_LOOPS0]]#2, [[OPT_LOOPS0]]#3) with ([[DEF_LOOPS0]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS0]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS0]]#2 -> %arg5 = 0 to 1, [[DEF_LOOPS0]]#3 -> %arg6 = 0 to 32) {
|
||||
// CHECK: [[LOAD0:%.+]] = load %arg0[%arg3, %arg4, %arg5, %arg6] : memref<5x5x1x32xf32>
|
||||
// CHECK: store [[LOAD0]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<5x5x9x32xf32>
|
||||
|
||||
// CHECK: [[DEF_LOOPS1:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS1:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1, [[DEF_LOOPS1]]#2, [[DEF_LOOPS1]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1, [[OPT_LOOPS1]]#2, [[OPT_LOOPS1]]#3) with ([[DEF_LOOPS1]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS1]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS1]]#2 -> %arg5 = 0 to 3, [[DEF_LOOPS1]]#3 -> %arg6 = 0 to 32) {
|
||||
// CHECK: [[OFF1:%.+]] = constant 1 : index
|
||||
// CHECK: [[ADD1:%.+]] = addi [[OFF1]], %arg5 : index
|
||||
// CHECK: [[LOAD1:%.+]] = load %arg1[%arg3, %arg4, %arg5, %arg6] : memref<5x5x3x32xf32>
|
||||
// CHECK: store [[LOAD1]], [[RES]][%arg3, %arg4, [[ADD1]], %arg6] : memref<5x5x9x32xf32>
|
||||
|
||||
// CHECK: [[DEF_LOOPS2:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS2:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2, [[DEF_LOOPS2]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2, [[OPT_LOOPS2]]#3) with ([[DEF_LOOPS2]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS2]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS2]]#2 -> %arg5 = 0 to 5, [[DEF_LOOPS2]]#3 -> %arg6 = 0 to 32) {
|
||||
// CHECK: [[OFF2:%.+]] = constant 4 : index
|
||||
// CHECK: [[ADD2:%.+]] = addi [[OFF2]], %arg5 : index
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg2[%arg3, %arg4, %arg5, %arg6] : memref<5x5x5x32xf32>
|
||||
// CHECK: store [[LOAD2]], [[RES]][%arg3, %arg4, [[ADD2]], %arg6] : memref<5x5x9x32xf32>
|
||||
|
||||
// CHECK: return [[RES]] : memref<5x5x9x32xf32>
|
||||
}
|
||||
|
|
|
@ -509,3 +509,12 @@ func @test_concat_2(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg
|
|||
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
||||
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
||||
}
|
||||
|
||||
func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg2 : tensor<5x5x32xf32>) -> tensor<*xf32> {
|
||||
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 } : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_concat_3
|
||||
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
||||
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue