Lower SqueezeOp to Krnl dialect (#164)
* Lower Squeeze op to Krnl dialect * Emit tensor size as a single constant; add a lit test for unknown dimensions * Code style * Speical case where the input is only used by this squeeze op * Remove squeeze-in-place optimization * Update ConvertONNXToKrnl.cpp Twek to re-run tests. * Trigger buildbot re-run. * Re-run CI Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
4d96247327
commit
2c8f5701bd
|
@ -17,6 +17,7 @@ add_library(OMONNXToKrnl
|
||||||
Tensor/PadConstantValuePad.cpp
|
Tensor/PadConstantValuePad.cpp
|
||||||
Tensor/Pad.cpp
|
Tensor/Pad.cpp
|
||||||
Tensor/Transpose.cpp
|
Tensor/Transpose.cpp
|
||||||
|
Tensor/Squeeze.cpp
|
||||||
Tensor/Unsqueeze.cpp
|
Tensor/Unsqueeze.cpp
|
||||||
Tensor/Constant.cpp
|
Tensor/Constant.cpp
|
||||||
Tensor/Concat.cpp
|
Tensor/Concat.cpp
|
||||||
|
|
|
@ -99,6 +99,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
||||||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
populateLoweringONNXConstantOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
populateLoweringONNXConcatOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXSqueezeOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXSplitOpPattern(patterns, &getContext());
|
populateLoweringONNXSplitOpPattern(patterns, &getContext());
|
||||||
// Neural network
|
// Neural network
|
||||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||||
|
|
|
@ -251,5 +251,8 @@ void populateLoweringONNXConstantOpPattern(
|
||||||
void populateLoweringONNXConcatOpPattern(
|
void populateLoweringONNXConcatOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXSqueezeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
void populateLoweringONNXSplitOpPattern(
|
void populateLoweringONNXSplitOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
//===--------------- Squeeze.cpp - Lowering Squeeze Op --------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Squeeze Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
struct ONNXSqueezeOpLowering : public ConversionPattern {
|
||||||
|
ONNXSqueezeOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXSqueezeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
ONNXSqueezeOpOperandAdaptor operandAdaptor(operands);
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
auto elementSizeInBytes = getMemRefEltSizeInBytes(memRefType);
|
||||||
|
Value data = operandAdaptor.data();
|
||||||
|
|
||||||
|
// Assume that `axes` has been validated by shape inference.
|
||||||
|
// So, here we just get it.
|
||||||
|
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXSqueezeOp>(op).axesAttr();
|
||||||
|
SmallVector<int, 4> axes;
|
||||||
|
for (auto axisAttr : axisAttrs.getValue()) {
|
||||||
|
int axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||||
|
axes.emplace_back(axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation,
|
||||||
|
// and compute the output tensor's size in bytes.
|
||||||
|
Value alloc, tensorSize;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
if (hasAllConstantDimensions(memRefType)) {
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
auto tensorSizeInBytes = elementSizeInBytes;
|
||||||
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
|
tensorSizeInBytes *= memRefShape[i];
|
||||||
|
}
|
||||||
|
tensorSize = emitConstantOp(
|
||||||
|
rewriter, loc, rewriter.getIntegerType(64), tensorSizeInBytes);
|
||||||
|
} else {
|
||||||
|
// Need to know the input dimension from which the unknown output
|
||||||
|
// dimension comes from.
|
||||||
|
SmallVector<Value, 4> allocOperands;
|
||||||
|
auto tensorSizeConstant = elementSizeInBytes;
|
||||||
|
int64_t inRank = data.getType().cast<ShapedType>().getRank();
|
||||||
|
for (decltype(inRank) inIdx = 0, outIdx = 0; inIdx < inRank; ++inIdx) {
|
||||||
|
Value dimVal = nullptr;
|
||||||
|
// Squeeze dimension is not in the output, ignore it.
|
||||||
|
if (std::find(axes.begin(), axes.end(), inIdx) != axes.end())
|
||||||
|
continue;
|
||||||
|
// Found effective input dimension.
|
||||||
|
if (memRefShape[outIdx] < 0) {
|
||||||
|
Value index = rewriter.create<DimOp>(loc, data, inIdx);
|
||||||
|
allocOperands.emplace_back(index);
|
||||||
|
} else {
|
||||||
|
// Collect constant dimensions for calculating the output tensor size.
|
||||||
|
tensorSizeConstant *= memRefShape[outIdx];
|
||||||
|
}
|
||||||
|
// Move to the next output dimension.
|
||||||
|
outIdx++;
|
||||||
|
}
|
||||||
|
// Allocate memory.
|
||||||
|
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||||
|
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the output tensor's size.
|
||||||
|
tensorSize = emitConstantOp(
|
||||||
|
rewriter, loc, rewriter.getIntegerType(64), tensorSizeConstant);
|
||||||
|
for (Value dim : allocOperands) {
|
||||||
|
Value dimVal =
|
||||||
|
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
|
||||||
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXSqueezeOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXSqueezeOpLowering>(ctx);
|
||||||
|
}
|
|
@ -1804,6 +1804,51 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Squeeze
|
||||||
|
|
||||||
|
LogicalResult ONNXSqueezeOp::inferShapes() {
|
||||||
|
if (!data().getType().isa<RankedTensorType>())
|
||||||
|
return emitError("Input tensor not ranked");
|
||||||
|
|
||||||
|
auto operandTy = data().getType().cast<RankedTensorType>();
|
||||||
|
int64_t inRank = operandTy.getRank();
|
||||||
|
|
||||||
|
ArrayAttr axisAttrs = axesAttr();
|
||||||
|
if (!axisAttrs)
|
||||||
|
return emitError("Axes attribute is required");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> axes;
|
||||||
|
bool hasNegativeAxis = false;
|
||||||
|
for (auto axisAttr : axisAttrs.getValue()) {
|
||||||
|
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
|
||||||
|
if (axis < -inRank || axis >= inRank)
|
||||||
|
return emitError("Invalid axis value");
|
||||||
|
if (axis < 0) {
|
||||||
|
axis = inRank + axis;
|
||||||
|
hasNegativeAxis = true;
|
||||||
|
}
|
||||||
|
if (std::find(axes.begin(), axes.end(), axis) != axes.end())
|
||||||
|
return emitError("Duplicated axes");
|
||||||
|
axes.emplace_back(axis);
|
||||||
|
}
|
||||||
|
if (hasNegativeAxis) {
|
||||||
|
// Update axes attribute so that it contains only positive values.
|
||||||
|
auto builder = mlir::Builder(getContext());
|
||||||
|
ArrayRef<int64_t> defaultRefs(axes);
|
||||||
|
axesAttr(builder.getI64ArrayAttr(defaultRefs));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> dims;
|
||||||
|
for (int i = 0; i < inRank; ++i) {
|
||||||
|
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
|
||||||
|
dims.emplace_back(operandTy.getShape()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Cast
|
// Cast
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
|
@ -5092,7 +5092,7 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSqueezeOp:ONNX_Op<"Squeeze",
|
def ONNXSqueezeOp:ONNX_Op<"Squeeze",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Squeeze operation";
|
let summary = "ONNX Squeeze operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Remove single-dimensional entries from the shape of a tensor."
|
"Remove single-dimensional entries from the shape of a tensor."
|
||||||
|
|
|
@ -403,6 +403,10 @@ test_to_enable = [
|
||||||
"test_lstm_with_initial_bias_cpu",
|
"test_lstm_with_initial_bias_cpu",
|
||||||
"test_lstm_with_peepholes_cpu",
|
"test_lstm_with_peepholes_cpu",
|
||||||
|
|
||||||
|
# Squeeze
|
||||||
|
"test_squeeze_cpu",
|
||||||
|
"test_squeeze_negative_axes_cpu",
|
||||||
|
|
||||||
# Split
|
# Split
|
||||||
"test_split_equal_parts_1d_cpu",
|
"test_split_equal_parts_1d_cpu",
|
||||||
"test_split_equal_parts_2d_cpu",
|
"test_split_equal_parts_2d_cpu",
|
||||||
|
|
|
@ -2133,6 +2133,35 @@ func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @test_squeeze(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Squeeze"(%arg0) { axes = [1, -2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_squeeze
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<16x32x64xf32>
|
||||||
|
// CHECK: [[TENSOR_SIZE:%.+]] = constant 131072 : i64
|
||||||
|
// CHECK: "krnl.memcpy"([[RES]], %arg0, [[TENSOR_SIZE]]) : (memref<16x32x64xf32>, memref<16x1x32x1x64xf32>, i64) -> ()
|
||||||
|
// CHECK: return [[RES]] : memref<16x32x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_squeeze_unknown_dimensions(%arg0 : tensor<?x1x32x?x64xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Squeeze"(%arg0) { axes = [1,-2]} : (tensor<?x1x32x?x64xf32>) -> (tensor<*xf32>)
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_squeeze_unknown_dimensions
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x1x32x?x64xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x32x64xf32>
|
||||||
|
// CHECK: [[TENSOR_SIZE_0:%.+]] = constant 8192 : i64
|
||||||
|
// CHECK: [[DIM_0_i64:%.+]] = index_cast [[DIM_0]] : index to i64
|
||||||
|
// CHECK: [[TENSOR_SIZE_1:%.+]] = muli [[TENSOR_SIZE_0]], [[DIM_0_i64]] : i64
|
||||||
|
// CHECK: "krnl.memcpy"([[RES]], %arg0, [[TENSOR_SIZE_1]]) : (memref<?x32x64xf32>, memref<?x1x32x?x64xf32>, i64) -> ()
|
||||||
|
// CHECK: return [[RES]] : memref<?x32x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||||
%0, %1 = "onnx.Split"(%arg0) { axis = 0} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
%0, %1 = "onnx.Split"(%arg0) { axis = 0} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
|
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||||
|
|
|
@ -886,6 +886,39 @@ func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
|
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_squeeze(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Squeeze"(%arg0) { axes = [1]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_squeeze
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [1]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x32x1x64xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<16x32x1x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_squeeze_negative_axis(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Squeeze"(%arg0) { axes = [-2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_squeeze_negative_axis
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [3]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x1x32x64xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<16x1x32x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_squeeze_mix(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Squeeze"(%arg0) { axes = [1, -2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_squeeze_mix
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [1, 3]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x32x64xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<16x32x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
/// Test the cast op inference.
|
/// Test the cast op inference.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -253,6 +253,7 @@ OpsWithShapeInference = [
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
||||||
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
|
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
|
||||||
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
|
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
|
||||||
|
'Squeeze'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
|
|
Loading…
Reference in New Issue