From 2c8f5701bd14e6e2be0c913c09269e0151fbd490 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 3 Jul 2020 17:26:41 +0900 Subject: [PATCH] Lower SqueezeOp to Krnl dialect (#164) * Lower Squeeze op to Krnl dialect * Emit tensor size as a single constant; add a lit test for unknown dimensions * Code style * Speical case where the input is only used by this squeeze op * Remove squeeze-in-place optimization * Update ConvertONNXToKrnl.cpp Twek to re-run tests. * Trigger buildbot re-run. * Re-run CI Co-authored-by: Tian Jin --- src/Conversion/ONNXToKrnl/CMakeLists.txt | 1 + .../ONNXToKrnl/ConvertONNXToKrnl.cpp | 3 +- .../ONNXToKrnl/ONNXToKrnlCommon.hpp | 3 + src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp | 97 +++++++++++++++++++ src/Dialect/ONNX/ONNXOps.cpp | 45 +++++++++ src/Dialect/ONNX/ONNXOps.td.inc | 2 +- test/backend/test.py | 4 + test/mlir/onnx/onnx_lowering.mlir | 29 ++++++ test/mlir/onnx/onnx_shape_inference.mlir | 33 +++++++ utils/gen_onnx_mlir.py | 1 + 10 files changed, 216 insertions(+), 2 deletions(-) create mode 100644 src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index 7bdc1fb..9d44980 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(OMONNXToKrnl Tensor/PadConstantValuePad.cpp Tensor/Pad.cpp Tensor/Transpose.cpp + Tensor/Squeeze.cpp Tensor/Unsqueeze.cpp Tensor/Constant.cpp Tensor/Concat.cpp diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index b52a3dc..2571458 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -99,6 +99,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { populateLoweringONNXIdentityOpPattern(patterns, &getContext()); populateLoweringONNXConstantOpPattern(patterns, &getContext()); populateLoweringONNXConcatOpPattern(patterns, &getContext()); + populateLoweringONNXSqueezeOpPattern(patterns, &getContext()); populateLoweringONNXSplitOpPattern(patterns, &getContext()); // Neural network populateLoweringONNXConvOpPattern(patterns, &getContext()); @@ -118,4 +119,4 @@ void FrontendToKrnlLoweringPass::runOnOperation() { std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index f03f8ba..5d6597b 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -251,5 +251,8 @@ void populateLoweringONNXConstantOpPattern( void populateLoweringONNXConcatOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); +void populateLoweringONNXSqueezeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + void populateLoweringONNXSplitOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp b/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp new file mode 100644 index 0000000..1e69f10 --- /dev/null +++ b/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp @@ -0,0 +1,97 @@ +//===--------------- Squeeze.cpp - Lowering Squeeze Op --------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Squeeze Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +using namespace mlir; + +struct ONNXSqueezeOpLowering : public ConversionPattern { + ONNXSqueezeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXSqueezeOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + ONNXSqueezeOpOperandAdaptor operandAdaptor(operands); + auto loc = op->getLoc(); + auto memRefType = convertToMemRefType(*op->result_type_begin()); + auto memRefShape = memRefType.getShape(); + auto elementSizeInBytes = getMemRefEltSizeInBytes(memRefType); + Value data = operandAdaptor.data(); + + // Assume that `axes` has been validated by shape inference. + // So, here we just get it. + ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); + SmallVector axes; + for (auto axisAttr : axisAttrs.getValue()) { + int axis = axisAttr.cast().getInt(); + axes.emplace_back(axis); + } + + // Insert an allocation and deallocation for the result of this operation, + // and compute the output tensor's size in bytes. + Value alloc, tensorSize; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) { + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + auto tensorSizeInBytes = elementSizeInBytes; + for (int i = 0; i < memRefShape.size(); ++i) { + tensorSizeInBytes *= memRefShape[i]; + } + tensorSize = emitConstantOp( + rewriter, loc, rewriter.getIntegerType(64), tensorSizeInBytes); + } else { + // Need to know the input dimension from which the unknown output + // dimension comes from. + SmallVector allocOperands; + auto tensorSizeConstant = elementSizeInBytes; + int64_t inRank = data.getType().cast().getRank(); + for (decltype(inRank) inIdx = 0, outIdx = 0; inIdx < inRank; ++inIdx) { + Value dimVal = nullptr; + // Squeeze dimension is not in the output, ignore it. + if (std::find(axes.begin(), axes.end(), inIdx) != axes.end()) + continue; + // Found effective input dimension. + if (memRefShape[outIdx] < 0) { + Value index = rewriter.create(loc, data, inIdx); + allocOperands.emplace_back(index); + } else { + // Collect constant dimensions for calculating the output tensor size. + tensorSizeConstant *= memRefShape[outIdx]; + } + // Move to the next output dimension. + outIdx++; + } + // Allocate memory. + alloc = rewriter.create(loc, memRefType, allocOperands); + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + if (insertDealloc) { + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + + // Compute the output tensor's size. + tensorSize = emitConstantOp( + rewriter, loc, rewriter.getIntegerType(64), tensorSizeConstant); + for (Value dim : allocOperands) { + Value dimVal = + rewriter.create(loc, dim, rewriter.getIntegerType(64)); + tensorSize = rewriter.create(loc, tensorSize, dimVal); + } + } + rewriter.create(loc, alloc, data, tensorSize); + rewriter.replaceOp(op, alloc); + return success(); + } +}; + +void populateLoweringONNXSqueezeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index fb33cd6..f9ee375 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -1804,6 +1804,51 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() { } //===----------------------------------------------------------------------===// + +// Squeeze + +LogicalResult ONNXSqueezeOp::inferShapes() { + if (!data().getType().isa()) + return emitError("Input tensor not ranked"); + + auto operandTy = data().getType().cast(); + int64_t inRank = operandTy.getRank(); + + ArrayAttr axisAttrs = axesAttr(); + if (!axisAttrs) + return emitError("Axes attribute is required"); + + SmallVector axes; + bool hasNegativeAxis = false; + for (auto axisAttr : axisAttrs.getValue()) { + int64_t axis = axisAttr.cast().getInt(); + if (axis < -inRank || axis >= inRank) + return emitError("Invalid axis value"); + if (axis < 0) { + axis = inRank + axis; + hasNegativeAxis = true; + } + if (std::find(axes.begin(), axes.end(), axis) != axes.end()) + return emitError("Duplicated axes"); + axes.emplace_back(axis); + } + if (hasNegativeAxis) { + // Update axes attribute so that it contains only positive values. + auto builder = mlir::Builder(getContext()); + ArrayRef defaultRefs(axes); + axesAttr(builder.getI64ArrayAttr(defaultRefs)); + } + + SmallVector dims; + for (int i = 0; i < inRank; ++i) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + dims.emplace_back(operandTy.getShape()[i]); + } + } + getResult().setType(RankedTensorType::get(dims, operandTy.getElementType())); + return success(); +} + // Cast //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 1cbd6c4..89d969a 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -5092,7 +5092,7 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt", } def ONNXSqueezeOp:ONNX_Op<"Squeeze", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Squeeze operation"; let description = [{ "Remove single-dimensional entries from the shape of a tensor." diff --git a/test/backend/test.py b/test/backend/test.py index d46a27d..1886d39 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -403,6 +403,10 @@ test_to_enable = [ "test_lstm_with_initial_bias_cpu", "test_lstm_with_peepholes_cpu", + # Squeeze + "test_squeeze_cpu", + "test_squeeze_negative_axes_cpu", + # Split "test_split_equal_parts_1d_cpu", "test_split_equal_parts_2d_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index f32aa5b..ed46b60 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -2133,6 +2133,35 @@ func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x // ----- +func @test_squeeze(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> { + %0 = "onnx.Squeeze"(%arg0) { axes = [1, -2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>) + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: @test_squeeze + // CHECK: [[RES:%.+]] = alloc() : memref<16x32x64xf32> + // CHECK: [[TENSOR_SIZE:%.+]] = constant 131072 : i64 + // CHECK: "krnl.memcpy"([[RES]], %arg0, [[TENSOR_SIZE]]) : (memref<16x32x64xf32>, memref<16x1x32x1x64xf32>, i64) -> () + // CHECK: return [[RES]] : memref<16x32x64xf32> +} + +// ----- + +func @test_squeeze_unknown_dimensions(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Squeeze"(%arg0) { axes = [1,-2]} : (tensor) -> (tensor<*xf32>) + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: @test_squeeze_unknown_dimensions + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[TENSOR_SIZE_0:%.+]] = constant 8192 : i64 + // CHECK: [[DIM_0_i64:%.+]] = index_cast [[DIM_0]] : index to i64 + // CHECK: [[TENSOR_SIZE_1:%.+]] = muli [[TENSOR_SIZE_0]], [[DIM_0_i64]] : i64 + // CHECK: "krnl.memcpy"([[RES]], %arg0, [[TENSOR_SIZE_1]]) : (memref, memref, i64) -> () + // CHECK: return [[RES]] : memref +} + +// ----- + func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %0, %1 = "onnx.Split"(%arg0) { axis = 0} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> () diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index f38fbd5..a5008bd 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -886,6 +886,39 @@ func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { // CHECK: return [[RES]]#0 : tensor<16x2x64xf32> } +// ----- + +func @test_squeeze(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> { + %0 = "onnx.Squeeze"(%arg0) { axes = [1]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>) + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_squeeze + // CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [1]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x32x1x64xf32> + // CHECK: return [[RES]] : tensor<16x32x1x64xf32> +} + +// ----- + +func @test_squeeze_negative_axis(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> { + %0 = "onnx.Squeeze"(%arg0) { axes = [-2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>) + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_squeeze_negative_axis + // CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [3]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x1x32x64xf32> + // CHECK: return [[RES]] : tensor<16x1x32x64xf32> +} + +// ----- + +func @test_squeeze_mix(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<*xf32> { + %0 = "onnx.Squeeze"(%arg0) { axes = [1, -2]} : (tensor<16x1x32x1x64xf32>) -> (tensor<*xf32>) + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_squeeze_mix + // CHECK: [[RES:%.+]] = "onnx.Squeeze"(%arg0) {axes = [1, 3]} : (tensor<16x1x32x1x64xf32>) -> tensor<16x32x64xf32> + // CHECK: return [[RES]] : tensor<16x32x64xf32> +} + //===----------------------------------------------------------------------===// /// Test the cast op inference. //===----------------------------------------------------------------------===// diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index c89f10c..e81ba66 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -253,6 +253,7 @@ OpsWithShapeInference = [ 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', 'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten', 'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger', + 'Squeeze' ] # Operations supporting canonicalization.