diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index a6127ef..0954fde 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -2153,32 +2153,48 @@ LogicalResult ONNXSplitOp::inferShapes() { //===----------------------------------------------------------------------===// LogicalResult ONNXFlattenOp::inferShapes() { - assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now"); auto inTy = input().getType().dyn_cast(); if (!inTy) { return emitOpError("Input is a non-shaped type"); } - auto outTy = output().getType().dyn_cast(); - if (!outTy) { - return emitOpError("Output is a non-shaped type"); + + auto axisValue = axis(); + auto inputShape = inTy.getShape(); + auto inputRank = inputShape.size(); + if (axisValue < -1 * (int64_t)inputRank || axisValue > (int64_t)inputRank) { + return emitOpError("ONNXFlattenOP: axis() value is out of range"); } - // TODO(tjingrant): Seems like we can also fairly easily support the case - // where the batch dimension is dynamic - if (!outTy.hasStaticShape()) { - auto inShape = inTy.getShape(); - assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0"); - uint64_t outDim = 1; - for (auto it = inShape.begin() + 1; it < inShape.end(); it++) { - outDim *= *it; + SmallVector dims; + + // Negative axis is counting dimension from back + if (axisValue < 0) + axisValue = inputRank + axisValue + 1; + + // Determine the size of the first dimension of output + int64_t firstDim = 1; + for (auto i = 0; i < axisValue; i++) { + if (inputShape[i] == -1) { + firstDim = -1; + break; } - - SmallVector dims; - // https://pytorch.org/docs/master/generated/torch.nn.Flatten.html - dims.emplace_back(inShape[0]); - dims.emplace_back(outDim); - getResult().setType(RankedTensorType::get(dims, outTy.getElementType())); + firstDim *= inputShape[i]; } + dims.emplace_back(firstDim); + + // Determine the size of the second dimension of output + int64_t secondDim = 1; + for (auto i = axisValue; i < inputRank; i++) { + if (inputShape[i] == -1) { + secondDim = -1; + break; + } + secondDim *= inputShape[i]; + } + dims.emplace_back(secondDim); + + // Set the type of output + getResult().setType(RankedTensorType::get(dims, inTy.getElementType())); return success(); } diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index b7e2c7d..b6c6b84 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -639,6 +639,40 @@ func @test_flatten_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xf32> { // CHECK: return [[RES]] : tensor<5x24xf32> } +// ----- + +// Test when axis is 0 +func @test_flatten_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> { + %1 = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + // CHECK-LABEL: test_flatten_2 + // CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<1x24xf32> + // CHECK: return [[RES]] : tensor<1x24xf32> +} + +// ----- + +// Test when axis is negative +func @test_flatten_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> { + %1 = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + // CHECK-LABEL: test_flatten_3 + // CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<24x1xf32> + // CHECK: return [[RES]] : tensor<24x1xf32> +} + +// ----- + +// Test when input is not static shape +func @test_flatten_4(%arg0 : tensor<2x4x5x?xf32>) -> tensor<*xf32> { + %1 = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + // CHECK-LABEL: test_flatten_4 + // CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<8x?xf32> + // CHECK: return [[RES]] : tensor<8x?xf32> +} + + //===----------------------------------------------------------------------===// /// Test the reshape op inference when concat are present. //===----------------------------------------------------------------------===//