Concat lower (#82)
* implement shape inference for concat * better checking of axis being concatenated: constant values only * lowering of Concat with lit and backend tests * fixes Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
81d8c5e2ea
commit
fa8962753c
|
@ -15,6 +15,7 @@ add_library(OMONNXToKrnl
|
||||||
Tensor/Transpose.cpp
|
Tensor/Transpose.cpp
|
||||||
Tensor/Unsqueeze.cpp
|
Tensor/Unsqueeze.cpp
|
||||||
Tensor/Constant.cpp
|
Tensor/Constant.cpp
|
||||||
|
Tensor/Concat.cpp
|
||||||
ConvertONNXToKrnl.cpp)
|
ConvertONNXToKrnl.cpp)
|
||||||
target_include_directories(OMONNXToKrnl
|
target_include_directories(OMONNXToKrnl
|
||||||
PRIVATE
|
PRIVATE
|
||||||
|
|
|
@ -98,6 +98,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
||||||
// Neural network
|
// Neural network
|
||||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||||
|
|
|
@ -115,7 +115,7 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
||||||
gIndex = outerLoops.pushBounds(0, group);
|
gIndex = outerLoops.pushBounds(0, group);
|
||||||
// for m = 0 .. kernelsPerGroup:
|
// for m = 0 .. kernelsPerGroup:
|
||||||
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
|
int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
|
||||||
// Outer loop iteration
|
// Outer loop iterations.
|
||||||
outerLoops.createIterateOp();
|
outerLoops.createIterateOp();
|
||||||
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
|
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
|
||||||
{
|
{
|
||||||
|
|
|
@ -15,11 +15,11 @@
|
||||||
|
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/Sequence.h"
|
#include "llvm/ADT/Sequence.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
#include "src/Dialect/Krnl/KrnlHelper.hpp"
|
#include "src/Dialect/Krnl/KrnlHelper.hpp"
|
||||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
@ -40,9 +40,8 @@ MemRefType convertToMemRefType(Type type);
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
PatternRewriter &rewriter,
|
PatternRewriter &rewriter, bool insertDealloc,
|
||||||
bool insertDealloc,
|
ArrayRef<Value> operands = {});
|
||||||
ArrayRef<Value> operands = {});
|
|
||||||
|
|
||||||
// Determine if current function returns the result value of the
|
// Determine if current function returns the result value of the
|
||||||
// current op being lowered. If it does then dealloc should not be
|
// current op being lowered. If it does then dealloc should not be
|
||||||
|
@ -52,51 +51,46 @@ bool checkInsertDealloc(Operation *currentOp);
|
||||||
// Create a mapping from result type's dimensions to input type's dimensions,
|
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||||
// given that the result type is the result of a reduction op over the input
|
// given that the result type is the result of a reduction op over the input
|
||||||
// type.
|
// type.
|
||||||
std::map<int64_t, int64_t>
|
std::map<int64_t, int64_t> getReductionMapping(
|
||||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
|
MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
|
||||||
|
|
||||||
// Add bounds associated with the op operand to the KRNL iteration pack.
|
// Add bounds associated with the op operand to the KRNL iteration pack.
|
||||||
// Dynamic dimenions are supported.
|
// Dynamic dimenions are supported.
|
||||||
void addDimensionToPack(ConversionPatternRewriter &rewriter,
|
void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Location loc, KrnlIterateOperandPack &pack,
|
KrnlIterateOperandPack &pack, Value operand, int index);
|
||||||
Value operand, int index);
|
|
||||||
|
|
||||||
// Function that defines the KRNL dialect loops and their respective
|
// Function that defines the KRNL dialect loops and their respective
|
||||||
// optimized version.
|
// optimized version.
|
||||||
KrnlOptimizeLoopsOp
|
KrnlOptimizeLoopsOp emitOptimizedLoops(ConversionPatternRewriter &rewriter,
|
||||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
|
Location loc, std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
|
||||||
std::vector<Value> &loops,
|
int64_t numLoops);
|
||||||
std::vector<Value> &optimizedLoops, int64_t numLoops);
|
|
||||||
|
|
||||||
// Function that emits the loops and their optimized version.
|
// Function that emits the loops and their optimized version.
|
||||||
// The function returns a reference to the inner optimization block.
|
// The function returns a reference to the inner optimization block.
|
||||||
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
std::vector<Value> &loops,
|
std::vector<Value> &loops, std::vector<Value> &optimizedLoops,
|
||||||
std::vector<Value> &optimizedLoops,
|
int64_t numLoops);
|
||||||
int64_t numLoops);
|
|
||||||
|
|
||||||
// Function which emits a basic set of loops and optimized loops
|
// Function which emits a basic set of loops and optimized loops
|
||||||
// for a given operation argument. A reference to the loop optimization
|
// for a given operation argument. A reference to the loop optimization
|
||||||
// block is returned in the last argument of the function.
|
// block is returned in the last argument of the function.
|
||||||
void emitKrnlLoopsAndIterationForOperand(
|
void emitKrnlLoopsAndIterationForOperand(ConversionPatternRewriter &rewriter,
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Value operand,
|
Location loc, Value operand, std::vector<Value> &originalLoops,
|
||||||
std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
|
KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp);
|
||||||
KrnlIterateOp &iterateOp);
|
|
||||||
|
|
||||||
unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
|
unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
|
||||||
|
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
// broadcasting.
|
// broadcasting.
|
||||||
std::map<int, std::map<int, Value>>
|
std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
|
||||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter, MemRefType memRefType,
|
||||||
MemRefType memRefType, ArrayRef<Value> operands);
|
ArrayRef<Value> operands);
|
||||||
|
|
||||||
// Extract induction variables that are used for broadcasting values of a
|
// Extract induction variables that are used for broadcasting values of a
|
||||||
// given operand.
|
// given operand.
|
||||||
std::vector<Value>
|
std::vector<Value> getLoopIVsForBroadcasting(Location loc,
|
||||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
|
||||||
ArrayRef<Value> loopIVs, Value operand,
|
std::map<int, Value> broadcastedDims);
|
||||||
std::map<int, Value> broadcastedDims);
|
|
||||||
|
|
||||||
// Emit a constant of a specific type.
|
// Emit a constant of a specific type.
|
||||||
// Use this function for small values only to avoid unexpected loss in type
|
// Use this function for small values only to avoid unexpected loss in type
|
||||||
|
@ -169,9 +163,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
struct TensorTypeConverter : public TypeConverter {
|
struct TensorTypeConverter : public TypeConverter {
|
||||||
using TypeConverter::TypeConverter;
|
using TypeConverter::TypeConverter;
|
||||||
|
|
||||||
TensorTypeConverter() {
|
TensorTypeConverter() { addConversion(convertType); }
|
||||||
addConversion(convertType);
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||||
if (auto type = convertToMemRefType(t)) {
|
if (auto type = convertToMemRefType(t)) {
|
||||||
|
@ -188,8 +180,8 @@ struct TensorTypeConverter : public TypeConverter {
|
||||||
/// inputs. Once unranked results can be handled gracefully this
|
/// inputs. Once unranked results can be handled gracefully this
|
||||||
/// override needs to be removed in favour of the original MLIR one.]
|
/// override needs to be removed in favour of the original MLIR one.]
|
||||||
bool isSignatureLegal(FunctionType funcType) {
|
bool isSignatureLegal(FunctionType funcType) {
|
||||||
return llvm::all_of(funcType.getInputs(),
|
return llvm::all_of(
|
||||||
[this](Type type) { return isLegal(type); });
|
funcType.getInputs(), [this](Type type) { return isLegal(type); });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -202,8 +194,8 @@ struct TensorTypeConverter : public TypeConverter {
|
||||||
void populateLoweringONNXElementwiseOpPattern(
|
void populateLoweringONNXElementwiseOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
|
void populateLoweringONNXGemmOpPattern(
|
||||||
MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
void populateLoweringONNXMatMulOpPattern(
|
void populateLoweringONNXMatMulOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
@ -244,3 +236,6 @@ void populateLoweringONNXIdentityOpPattern(
|
||||||
|
|
||||||
void populateLoweringONNXConstantOpPattern(
|
void populateLoweringONNXConstantOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXConcatOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
//===---------------- Concat.cpp - Lowering Concat Op -------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Concat Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
struct ONNXConcatOpLowering : public ConversionPattern {
|
||||||
|
ONNXConcatOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXConcatOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
// Gather info.
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
ONNXConcatOp concatOp = llvm::dyn_cast<ONNXConcatOp>(op);
|
||||||
|
auto axis = concatOp.axis().getSExtValue();
|
||||||
|
int inputNum = operands.size();
|
||||||
|
// Alloc and dealloc.
|
||||||
|
auto resultOperand = concatOp.concat_result();
|
||||||
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
auto resultShape = memRefType.getShape();
|
||||||
|
auto rank = resultShape.size();
|
||||||
|
assert((axis >=0 && axis < rank) && "Concat axis out of bounds");
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(
|
||||||
|
memRefType, loc, rewriter, insertDealloc, {resultOperand});
|
||||||
|
|
||||||
|
// Creates loops, one for each input.
|
||||||
|
int writeOffset = 0;
|
||||||
|
for (int i = 0; i < inputNum; ++i) {
|
||||||
|
OpBuilder::InsertionGuard insertGuard(rewriter);
|
||||||
|
// Operand info.
|
||||||
|
auto currShape = operands[i].getType().cast<MemRefType>().getShape();
|
||||||
|
// Create loop.
|
||||||
|
BuildKrnlLoop inputLoops(rewriter, loc, rank);
|
||||||
|
inputLoops.createDefineAndOptimizeOp();
|
||||||
|
for (int r = 0; r < rank; ++r)
|
||||||
|
inputLoops.pushBounds(0, operands[i], r);
|
||||||
|
inputLoops.createIterateOp();
|
||||||
|
rewriter.setInsertionPointToStart(inputLoops.getIterateBlock());
|
||||||
|
// Indices for the read and write.
|
||||||
|
SmallVector<Value, 4> readIndices;
|
||||||
|
SmallVector<Value, 4> writeIndices;
|
||||||
|
for (int r = 0; r < rank; ++r) {
|
||||||
|
readIndices.emplace_back(inputLoops.getInductionVar(r));
|
||||||
|
if (r != axis || writeOffset == 0) {
|
||||||
|
writeIndices.emplace_back(inputLoops.getInductionVar(r));
|
||||||
|
} else {
|
||||||
|
auto indexWithOffset = rewriter.create<AddIOp>(loc,
|
||||||
|
rewriter.create<ConstantIndexOp>(loc, writeOffset),
|
||||||
|
inputLoops.getInductionVar(r));
|
||||||
|
writeIndices.emplace_back(indexWithOffset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Insert copy.
|
||||||
|
auto loadData = rewriter.create<LoadOp>(loc, operands[i], readIndices);
|
||||||
|
rewriter.create<StoreOp>(loc, loadData, alloc, writeIndices);
|
||||||
|
// Increment offset
|
||||||
|
writeOffset += currShape[axis];
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXConcatOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXConcatOpLowering>(ctx);
|
||||||
|
}
|
|
@ -1479,7 +1479,13 @@ bool ONNXConcatOp::inferShapes() {
|
||||||
auto commonShape = commonType.getShape();
|
auto commonShape = commonType.getShape();
|
||||||
auto commonRank = commonShape.size();
|
auto commonRank = commonShape.size();
|
||||||
auto axisIndex = axis().getSExtValue();
|
auto axisIndex = axis().getSExtValue();
|
||||||
if (!(axisIndex >= 0 && axisIndex < commonRank)) {
|
// Negative axis means values are counted from the opposite side.
|
||||||
|
if (axisIndex < 0) {
|
||||||
|
axisIndex = commonRank + axisIndex;
|
||||||
|
auto builder = mlir::Builder(getContext());
|
||||||
|
axisAttr(builder.getI64IntegerAttr(axisIndex));
|
||||||
|
}
|
||||||
|
if (axisIndex >= commonRank) {
|
||||||
emitError("Concat axis value out of bound");
|
emitError("Concat axis value out of bound");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,6 +82,21 @@ test_to_enable = [
|
||||||
"test_cosh_cpu",
|
"test_cosh_cpu",
|
||||||
"test_cosh_example_cpu",
|
"test_cosh_example_cpu",
|
||||||
|
|
||||||
|
# Concat
|
||||||
|
"test_concat_1d_axis_0_cpu",
|
||||||
|
"test_concat_2d_axis_0_cpu",
|
||||||
|
"test_concat_2d_axis_1_cpu",
|
||||||
|
"test_concat_3d_axis_0_cpu",
|
||||||
|
"test_concat_3d_axis_1_cpu",
|
||||||
|
"test_concat_3d_axis_2_cpu",
|
||||||
|
|
||||||
|
"test_concat_1d_axis_negative_1_cpu",
|
||||||
|
"test_concat_2d_axis_negative_1_cpu",
|
||||||
|
"test_concat_2d_axis_negative_2_cpu",
|
||||||
|
"test_concat_3d_axis_negative_1_cpu",
|
||||||
|
"test_concat_3d_axis_negative_2_cpu",
|
||||||
|
"test_concat_3d_axis_negative_3_cpu",
|
||||||
|
|
||||||
# Tanh:
|
# Tanh:
|
||||||
"test_tanh_cpu",
|
"test_tanh_cpu",
|
||||||
"test_tanh_example_cpu",
|
"test_tanh_example_cpu",
|
||||||
|
|
|
@ -1688,3 +1688,40 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return [[RES]] : memref<3x2xf32>
|
// CHECK: return [[RES]] : memref<3x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, %arg2 : tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32> {
|
||||||
|
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 2 } : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32>
|
||||||
|
"std.return"(%1) : (tensor<5x5x9x32xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_concat_1
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<5x5x9x32xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS0:%.+]]:4 = krnl.define_loops 4
|
||||||
|
// CHECK: [[OPT_LOOPS0:%.+]]:4 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS0]]#0, [[DEF_LOOPS0]]#1, [[DEF_LOOPS0]]#2, [[DEF_LOOPS0]]#3
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS0]]#0, [[OPT_LOOPS0]]#1, [[OPT_LOOPS0]]#2, [[OPT_LOOPS0]]#3) with ([[DEF_LOOPS0]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS0]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS0]]#2 -> %arg5 = 0 to 1, [[DEF_LOOPS0]]#3 -> %arg6 = 0 to 32) {
|
||||||
|
// CHECK: [[LOAD0:%.+]] = load %arg0[%arg3, %arg4, %arg5, %arg6] : memref<5x5x1x32xf32>
|
||||||
|
// CHECK: store [[LOAD0]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<5x5x9x32xf32>
|
||||||
|
|
||||||
|
// CHECK: [[DEF_LOOPS1:%.+]]:4 = krnl.define_loops 4
|
||||||
|
// CHECK: [[OPT_LOOPS1:%.+]]:4 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1, [[DEF_LOOPS1]]#2, [[DEF_LOOPS1]]#3
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1, [[OPT_LOOPS1]]#2, [[OPT_LOOPS1]]#3) with ([[DEF_LOOPS1]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS1]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS1]]#2 -> %arg5 = 0 to 3, [[DEF_LOOPS1]]#3 -> %arg6 = 0 to 32) {
|
||||||
|
// CHECK: [[OFF1:%.+]] = constant 1 : index
|
||||||
|
// CHECK: [[ADD1:%.+]] = addi [[OFF1]], %arg5 : index
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg1[%arg3, %arg4, %arg5, %arg6] : memref<5x5x3x32xf32>
|
||||||
|
// CHECK: store [[LOAD1]], [[RES]][%arg3, %arg4, [[ADD1]], %arg6] : memref<5x5x9x32xf32>
|
||||||
|
|
||||||
|
// CHECK: [[DEF_LOOPS2:%.+]]:4 = krnl.define_loops 4
|
||||||
|
// CHECK: [[OPT_LOOPS2:%.+]]:4 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2, [[DEF_LOOPS2]]#3
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1, [[OPT_LOOPS2]]#2, [[OPT_LOOPS2]]#3) with ([[DEF_LOOPS2]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS2]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS2]]#2 -> %arg5 = 0 to 5, [[DEF_LOOPS2]]#3 -> %arg6 = 0 to 32) {
|
||||||
|
// CHECK: [[OFF2:%.+]] = constant 4 : index
|
||||||
|
// CHECK: [[ADD2:%.+]] = addi [[OFF2]], %arg5 : index
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg2[%arg3, %arg4, %arg5, %arg6] : memref<5x5x5x32xf32>
|
||||||
|
// CHECK: store [[LOAD2]], [[RES]][%arg3, %arg4, [[ADD2]], %arg6] : memref<5x5x9x32xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RES]] : memref<5x5x9x32xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -509,3 +509,12 @@ func @test_concat_2(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
||||||
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg2 : tensor<5x5x32xf32>) -> tensor<*xf32> {
|
||||||
|
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = -2 } : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_concat_3
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<5x9x32xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue