Lower SqueezeOp to Krnl dialect (#164)
* Lower Squeeze op to Krnl dialect * Emit tensor size as a single constant; add a lit test for unknown dimensions * Code style * Speical case where the input is only used by this squeeze op * Remove squeeze-in-place optimization * Update ConvertONNXToKrnl.cpp Twek to re-run tests. * Trigger buildbot re-run. * Re-run CI Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
4d96247327
commit
2c8f5701bd
|
@ -17,6 +17,7 @@ add_library(OMONNXToKrnl
|
|||
Tensor/PadConstantValuePad.cpp
|
||||
Tensor/Pad.cpp
|
||||
Tensor/Transpose.cpp
|
||||
Tensor/Squeeze.cpp
|
||||
Tensor/Unsqueeze.cpp
|
||||
Tensor/Constant.cpp
|
||||
Tensor/Concat.cpp
|
||||
|
|
|
@ -99,6 +99,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
|||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXSqueezeOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXSplitOpPattern(patterns, &getContext());
|
||||
// Neural network
|
||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||
|
|
|
@ -251,5 +251,8 @@ void populateLoweringONNXConstantOpPattern(
|
|||
void populateLoweringONNXConcatOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXSqueezeOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXSplitOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
//===--------------- Squeeze.cpp - Lowering Squeeze Op --------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file lowers the ONNX Squeeze Operator to Krnl dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct ONNXSqueezeOpLowering : public ConversionPattern {
|
||||
ONNXSqueezeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXSqueezeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ONNXSqueezeOpOperandAdaptor operandAdaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto memRefShape = memRefType.getShape();
|
||||
auto elementSizeInBytes = getMemRefEltSizeInBytes(memRefType);
|
||||
Value data = operandAdaptor.data();
|
||||
|
||||
// Assume that `axes` has been validated by shape inference.
|
||||
// So, here we just get it.
|
||||
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXSqueezeOp>(op).axesAttr();
|
||||
SmallVector<int, 4> axes;
|
||||
for (auto axisAttr : axisAttrs.getValue()) {
|
||||
int axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||
axes.emplace_back(axis);
|
||||
}
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation,
|
||||
// and compute the output tensor's size in bytes.
|
||||
Value alloc, tensorSize;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
if (hasAllConstantDimensions(memRefType)) {
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
auto tensorSizeInBytes = elementSizeInBytes;
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
tensorSizeInBytes *= memRefShape[i];
|
||||
}
|
||||
tensorSize = emitConstantOp(
|
||||
rewriter, loc, rewriter.getIntegerType(64), tensorSizeInBytes);
|
||||
} else {
|
||||
// Need to know the input dimension from which the unknown output
|
||||
// dimension comes from.
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
auto tensorSizeConstant = elementSizeInBytes;
|
||||
int64_t inRank = data.getType().cast<ShapedType>().getRank();
|
||||
for (decltype(inRank) inIdx = 0, outIdx = 0; inIdx < inRank; ++inIdx) {
|
||||
Value dimVal = nullptr;
|
||||
// Squeeze dimension is not in the output, ignore it.
|
||||
if (std::find(axes.begin(), axes.end(), inIdx) != axes.end())
|
||||
continue;
|
||||
// Found effective input dimension.
|
||||
if (memRefShape[outIdx] < 0) {
|
||||
Value index = rewriter.create<DimOp>(loc, data, inIdx);
|
||||
allocOperands.emplace_back(index);
|
||||
} else {
|
||||
// Collect constant dimensions for calculating the output tensor size.
|
||||
tensorSizeConstant *= memRefShape[outIdx];
|
||||
}
|
||||
// Move to the next output dimension.
|
||||
outIdx++;
|
||||
}
|
||||
// Allocate memory.
|
||||
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||
if (insertDealloc) {
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
|
||||
// Compute the output tensor's size.
|
||||
tensorSize = emitConstantOp(
|
||||
rewriter, loc, rewriter.getIntegerType(64), tensorSizeConstant);
|
||||
for (Value dim : allocOperands) {
|
||||
Value dimVal =
|
||||
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||
}
|
||||
}
|
||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLoweringONNXSqueezeOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXSqueezeOpLowering>(ctx);
|
||||
}
|
|
@ -1804,6 +1804,51 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Squeeze
|
||||
|
||||
LogicalResult ONNXSqueezeOp::inferShapes() {
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return emitError("Input tensor not ranked");
|
||||
|
||||
auto operandTy = data().getType().cast<RankedTensorType>();
|
||||
int64_t inRank = operandTy.getRank();
|
||||
|
||||
ArrayAttr axisAttrs = axesAttr();
|
||||
if (!axisAttrs)
|
||||
return emitError("Axes attribute is required");
|
||||
|
||||
SmallVector<int64_t, 4> axes;
|
||||
bool hasNegativeAxis = false;
|
||||
for (auto axisAttr : axisAttrs.getValue()) {
|
||||
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||
if (axis < -inRank || axis >= inRank)
|
||||
return emitError("Invalid axis value");
|
||||
if (axis < 0) {
|
||||
axis = inRank + axis;
|
||||
hasNegativeAxis = true;
|
||||
}
|
||||
if (std::find(axes.begin(), axes.end(), axis) != axes.end())
|
||||
return emitError("Duplicated axes");
|
||||
axes.emplace_back(axis);
|
||||
}
|
||||
if (hasNegativeAxis) {
|
||||
// Update axes attribute so that it contains only positive values.
|
||||
auto builder = mlir::Builder(getContext());
|
||||
ArrayRef<int64_t> defaultRefs(axes);
|
||||
axesAttr(builder.getI64ArrayAttr(defaultRefs));
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> dims;
|
||||
for (int i = 0; i < inRank; ++i) {
|
||||
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
|
||||
dims.emplace_back(operandTy.getShape()[i]);
|
||||
}
|
||||
}
|
||||
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Cast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
|
|
@ -5092,7 +5092,7 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt",
|
|||
}
|
||||
|
||||
def ONNXSqueezeOp:ONNX_Op<"Squeeze",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Squeeze operation";
|
||||
let description = [{
|
||||
"Remove single-dimensional entries from the shape of a tensor."
|
||||
|
|
|
@ -403,6 +403,10 @@ test_to_enable = [
|
|||
"test_lstm_with_initial_bias_cpu",
|
||||
"test_lstm_with_peepholes_cpu",
|
||||
|
||||
# Squeeze
|
||||
"test_squeeze_cpu",
|
||||
"test_squeeze_negative_axes_cpu",
|
||||
|
||||
# Split
|
||||
"test_split_equal_parts_1d_cpu",
|
||||
"test_split_equal_parts_2d_cpu",
|
||||
|
|
|
@ -2133,6 +2133,35 @@ func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x
|
|||
|
||||
// -----
|
||||
|
||||
func @test_squeeze(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Squeeze"(%arg0) { axes = [1, -2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: @test_squeeze
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<16x32x64xf32>
|
||||
// CHECK: [[TENSOR_SIZE:%.+]] = constant 131072 : i64
|
||||
// CHECK: "krnl.memcpy"([[RES]], %arg0, [[TENSOR_SIZE]]) : (memref<16x32x64xf32>, memref<16x1x32x1x64xf32>, i64) -> ()
|
||||
// CHECK: return [[RES]] : memref<16x32x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_squeeze_unknown_dimensions(%arg0 : tensor<?x1x32x?x64xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Squeeze"(%arg0) { axes = [1,-2]} : (tensor<?x1x32x?x64xf32>) -> (tensor<*xf32>)
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: @test_squeeze_unknown_dimensions
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x1x32x?x64xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x32x64xf32>
|
||||
// CHECK: [[TENSOR_SIZE_0:%.+]] = constant 8192 : i64
|
||||
// CHECK: [[DIM_0_i64:%.+]] = index_cast [[DIM_0]] : index to i64
|
||||
// CHECK: [[TENSOR_SIZE_1:%.+]] = muli [[TENSOR_SIZE_0]], [[DIM_0_i64]] : i64
|
||||
// CHECK: "krnl.memcpy"([[RES]], %arg0, [[TENSOR_SIZE_1]]) : (memref<?x32x64xf32>, memref<?x1x32x?x64xf32>, i64) -> ()
|
||||
// CHECK: return [[RES]] : memref<?x32x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%0, %1 = "onnx.Split"(%arg0) { axis = 0} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||
|
|
|
@ -886,6 +886,39 @@ func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
|||
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_squeeze(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Squeeze"(%arg0) { axes = [1]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_squeeze
|
||||
// CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [1]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x32x1x64xf32>
|
||||
// CHECK: return [[RES]] : tensor<16x32x1x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_squeeze_negative_axis(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Squeeze"(%arg0) { axes = [-2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_squeeze_negative_axis
|
||||
// CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [3]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x1x32x64xf32>
|
||||
// CHECK: return [[RES]] : tensor<16x1x32x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_squeeze_mix(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Squeeze"(%arg0) { axes = [1, -2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_squeeze_mix
|
||||
// CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [1, 3]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x32x64xf32>
|
||||
// CHECK: return [[RES]] : tensor<16x32x64xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Test the cast op inference.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -253,6 +253,7 @@ OpsWithShapeInference = [
|
|||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
||||
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
|
||||
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
|
||||
'Squeeze'
|
||||
]
|
||||
|
||||
# Operations supporting canonicalization.
|
||||
|
|
Loading…
Reference in New Issue