Lower SqueezeOp to Krnl dialect (#164)

* Lower Squeeze op to Krnl dialect

* Emit tensor size as a single constant; add a lit test for unknown dimensions

* Code style

* Speical case where the input is only used by this squeeze op

* Remove squeeze-in-place optimization

* Update ConvertONNXToKrnl.cpp

Twek to re-run tests.

* Trigger buildbot re-run.

* Re-run CI

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-07-03 17:26:41 +09:00 committed by GitHub
parent 4d96247327
commit 2c8f5701bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 216 additions and 2 deletions

View File

@ -17,6 +17,7 @@ add_library(OMONNXToKrnl
Tensor/PadConstantValuePad.cpp Tensor/PadConstantValuePad.cpp
Tensor/Pad.cpp Tensor/Pad.cpp
Tensor/Transpose.cpp Tensor/Transpose.cpp
Tensor/Squeeze.cpp
Tensor/Unsqueeze.cpp Tensor/Unsqueeze.cpp
Tensor/Constant.cpp Tensor/Constant.cpp
Tensor/Concat.cpp Tensor/Concat.cpp

View File

@ -99,6 +99,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
populateLoweringONNXIdentityOpPattern(patterns, &getContext()); populateLoweringONNXIdentityOpPattern(patterns, &getContext());
populateLoweringONNXConstantOpPattern(patterns, &getContext()); populateLoweringONNXConstantOpPattern(patterns, &getContext());
populateLoweringONNXConcatOpPattern(patterns, &getContext()); populateLoweringONNXConcatOpPattern(patterns, &getContext());
populateLoweringONNXSqueezeOpPattern(patterns, &getContext());
populateLoweringONNXSplitOpPattern(patterns, &getContext()); populateLoweringONNXSplitOpPattern(patterns, &getContext());
// Neural network // Neural network
populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXConvOpPattern(patterns, &getContext());
@ -118,4 +119,4 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
std::unique_ptr<Pass> mlir::createLowerToKrnlPass() { std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
return std::make_unique<FrontendToKrnlLoweringPass>(); return std::make_unique<FrontendToKrnlLoweringPass>();
} }

View File

@ -251,5 +251,8 @@ void populateLoweringONNXConstantOpPattern(
void populateLoweringONNXConcatOpPattern( void populateLoweringONNXConcatOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXSqueezeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXSplitOpPattern( void populateLoweringONNXSplitOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);

View File

@ -0,0 +1,97 @@
//===--------------- Squeeze.cpp - Lowering Squeeze Op --------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Squeeze Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir;
struct ONNXSqueezeOpLowering : public ConversionPattern {
ONNXSqueezeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXSqueezeOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
ONNXSqueezeOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto memRefShape = memRefType.getShape();
auto elementSizeInBytes = getMemRefEltSizeInBytes(memRefType);
Value data = operandAdaptor.data();
// Assume that `axes` has been validated by shape inference.
// So, here we just get it.
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXSqueezeOp>(op).axesAttr();
SmallVector<int, 4> axes;
for (auto axisAttr : axisAttrs.getValue()) {
int axis = axisAttr.cast<IntegerAttr>().getInt();
axes.emplace_back(axis);
}
// Insert an allocation and deallocation for the result of this operation,
// and compute the output tensor's size in bytes.
Value alloc, tensorSize;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType)) {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
auto tensorSizeInBytes = elementSizeInBytes;
for (int i = 0; i < memRefShape.size(); ++i) {
tensorSizeInBytes *= memRefShape[i];
}
tensorSize = emitConstantOp(
rewriter, loc, rewriter.getIntegerType(64), tensorSizeInBytes);
} else {
// Need to know the input dimension from which the unknown output
// dimension comes from.
SmallVector<Value, 4> allocOperands;
auto tensorSizeConstant = elementSizeInBytes;
int64_t inRank = data.getType().cast<ShapedType>().getRank();
for (decltype(inRank) inIdx = 0, outIdx = 0; inIdx < inRank; ++inIdx) {
Value dimVal = nullptr;
// Squeeze dimension is not in the output, ignore it.
if (std::find(axes.begin(), axes.end(), inIdx) != axes.end())
continue;
// Found effective input dimension.
if (memRefShape[outIdx] < 0) {
Value index = rewriter.create<DimOp>(loc, data, inIdx);
allocOperands.emplace_back(index);
} else {
// Collect constant dimensions for calculating the output tensor size.
tensorSizeConstant *= memRefShape[outIdx];
}
// Move to the next output dimension.
outIdx++;
}
// Allocate memory.
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
auto *parentBlock = alloc.getDefiningOp()->getBlock();
if (insertDealloc) {
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
// Compute the output tensor's size.
tensorSize = emitConstantOp(
rewriter, loc, rewriter.getIntegerType(64), tensorSizeConstant);
for (Value dim : allocOperands) {
Value dimVal =
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
}
}
rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
rewriter.replaceOp(op, alloc);
return success();
}
};
void populateLoweringONNXSqueezeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXSqueezeOpLowering>(ctx);
}

View File

@ -1804,6 +1804,51 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Squeeze
LogicalResult ONNXSqueezeOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");
auto operandTy = data().getType().cast<RankedTensorType>();
int64_t inRank = operandTy.getRank();
ArrayAttr axisAttrs = axesAttr();
if (!axisAttrs)
return emitError("Axes attribute is required");
SmallVector<int64_t, 4> axes;
bool hasNegativeAxis = false;
for (auto axisAttr : axisAttrs.getValue()) {
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
if (axis < -inRank || axis >= inRank)
return emitError("Invalid axis value");
if (axis < 0) {
axis = inRank + axis;
hasNegativeAxis = true;
}
if (std::find(axes.begin(), axes.end(), axis) != axes.end())
return emitError("Duplicated axes");
axes.emplace_back(axis);
}
if (hasNegativeAxis) {
// Update axes attribute so that it contains only positive values.
auto builder = mlir::Builder(getContext());
ArrayRef<int64_t> defaultRefs(axes);
axesAttr(builder.getI64ArrayAttr(defaultRefs));
}
SmallVector<int64_t, 4> dims;
for (int i = 0; i < inRank; ++i) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
dims.emplace_back(operandTy.getShape()[i]);
}
}
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
return success();
}
// Cast // Cast
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -5092,7 +5092,7 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt",
} }
def ONNXSqueezeOp:ONNX_Op<"Squeeze", def ONNXSqueezeOp:ONNX_Op<"Squeeze",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Squeeze operation"; let summary = "ONNX Squeeze operation";
let description = [{ let description = [{
"Remove single-dimensional entries from the shape of a tensor." "Remove single-dimensional entries from the shape of a tensor."

View File

@ -403,6 +403,10 @@ test_to_enable = [
"test_lstm_with_initial_bias_cpu", "test_lstm_with_initial_bias_cpu",
"test_lstm_with_peepholes_cpu", "test_lstm_with_peepholes_cpu",
# Squeeze
"test_squeeze_cpu",
"test_squeeze_negative_axes_cpu",
# Split # Split
"test_split_equal_parts_1d_cpu", "test_split_equal_parts_1d_cpu",
"test_split_equal_parts_2d_cpu", "test_split_equal_parts_2d_cpu",

View File

@ -2133,6 +2133,35 @@ func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x
// ----- // -----
func @test_squeeze(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
%0 = "onnx.Squeeze"(%arg0) { axes = [1, -2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: @test_squeeze
// CHECK: [[RES:%.+]] = alloc() : memref<16x32x64xf32>
// CHECK: [[TENSOR_SIZE:%.+]] = constant 131072 : i64
// CHECK: "krnl.memcpy"([[RES]], %arg0, [[TENSOR_SIZE]]) : (memref<16x32x64xf32>, memref<16x1x32x1x64xf32>, i64) -> ()
// CHECK: return [[RES]] : memref<16x32x64xf32>
}
// -----
func @test_squeeze_unknown_dimensions(%arg0 : tensor<?x1x32x?x64xf32>) -> tensor<*xf32> {
%0 = "onnx.Squeeze"(%arg0) { axes = [1,-2]} : (tensor<?x1x32x?x64xf32>) -> (tensor<*xf32>)
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: @test_squeeze_unknown_dimensions
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x1x32x?x64xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x32x64xf32>
// CHECK: [[TENSOR_SIZE_0:%.+]] = constant 8192 : i64
// CHECK: [[DIM_0_i64:%.+]] = index_cast [[DIM_0]] : index to i64
// CHECK: [[TENSOR_SIZE_1:%.+]] = muli [[TENSOR_SIZE_0]], [[DIM_0_i64]] : i64
// CHECK: "krnl.memcpy"([[RES]], %arg0, [[TENSOR_SIZE_1]]) : (memref<?x32x64xf32>, memref<?x1x32x?x64xf32>, i64) -> ()
// CHECK: return [[RES]] : memref<?x32x64xf32>
}
// -----
func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) { func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0, %1 = "onnx.Split"(%arg0) { axis = 0} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) %0, %1 = "onnx.Split"(%arg0) { axis = 0} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> () "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()

View File

@ -886,6 +886,39 @@ func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]]#0 : tensor<16x2x64xf32> // CHECK: return [[RES]]#0 : tensor<16x2x64xf32>
} }
// -----
func @test_squeeze(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
%0 = "onnx.Squeeze"(%arg0) { axes = [1]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_squeeze
// CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [1]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x32x1x64xf32>
// CHECK: return [[RES]] : tensor<16x32x1x64xf32>
}
// -----
func @test_squeeze_negative_axis(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
%0 = "onnx.Squeeze"(%arg0) { axes = [-2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_squeeze_negative_axis
// CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [3]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x1x32x64xf32>
// CHECK: return [[RES]] : tensor<16x1x32x64xf32>
}
// -----
func @test_squeeze_mix(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> {
%0 = "onnx.Squeeze"(%arg0) { axes = [1, -2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>)
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_squeeze_mix
// CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [1, 3]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x32x64xf32>
// CHECK: return [[RES]] : tensor<16x32x64xf32>
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Test the cast op inference. /// Test the cast op inference.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -253,6 +253,7 @@ OpsWithShapeInference = [
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten', 'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten',
'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger', 'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger',
'Squeeze'
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.