Improve shape inference of flatten Op (#314)

* change code

Signed-off-by: chentong <chentong@us.ibm.com>

* add test

Signed-off-by: chentong <chentong@us.ibm.com>

* fix type error

Signed-off-by: chentong <chentong@us.ibm.com>

Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
This commit is contained in:
chentong319 2020-09-29 12:59:01 -04:00 committed by GitHub
parent aa2eed411f
commit 46306a8a26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 18 deletions

View File

@ -2153,32 +2153,48 @@ LogicalResult ONNXSplitOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult ONNXFlattenOp::inferShapes() { LogicalResult ONNXFlattenOp::inferShapes() {
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
auto inTy = input().getType().dyn_cast<ShapedType>(); auto inTy = input().getType().dyn_cast<ShapedType>();
if (!inTy) { if (!inTy) {
return emitOpError("Input is a non-shaped type"); return emitOpError("Input is a non-shaped type");
} }
auto outTy = output().getType().dyn_cast<ShapedType>();
if (!outTy) {
return emitOpError("Output is a non-shaped type");
}
// TODO(tjingrant): Seems like we can also fairly easily support the case auto axisValue = axis();
// where the batch dimension is dynamic auto inputShape = inTy.getShape();
if (!outTy.hasStaticShape()) { auto inputRank = inputShape.size();
auto inShape = inTy.getShape(); if (axisValue < -1 * (int64_t)inputRank || axisValue > (int64_t)inputRank) {
assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0"); return emitOpError("ONNXFlattenOP: axis() value is out of range");
uint64_t outDim = 1;
for (auto it = inShape.begin() + 1; it < inShape.end(); it++) {
outDim *= *it;
} }
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
// https://pytorch.org/docs/master/generated/torch.nn.Flatten.html
dims.emplace_back(inShape[0]); // Negative axis is counting dimension from back
dims.emplace_back(outDim); if (axisValue < 0)
getResult().setType(RankedTensorType::get(dims, outTy.getElementType())); axisValue = inputRank + axisValue + 1;
// Determine the size of the first dimension of output
int64_t firstDim = 1;
for (auto i = 0; i < axisValue; i++) {
if (inputShape[i] == -1) {
firstDim = -1;
break;
} }
firstDim *= inputShape[i];
}
dims.emplace_back(firstDim);
// Determine the size of the second dimension of output
int64_t secondDim = 1;
for (auto i = axisValue; i < inputRank; i++) {
if (inputShape[i] == -1) {
secondDim = -1;
break;
}
secondDim *= inputShape[i];
}
dims.emplace_back(secondDim);
// Set the type of output
getResult().setType(RankedTensorType::get(dims, inTy.getElementType()));
return success(); return success();
} }

View File

@ -639,6 +639,40 @@ func @test_flatten_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : tensor<5x24xf32> // CHECK: return [[RES]] : tensor<5x24xf32>
} }
// -----
// Test when axis is 0
func @test_flatten_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
%1 = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_flatten_2
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<1x24xf32>
// CHECK: return [[RES]] : tensor<1x24xf32>
}
// -----
// Test when axis is negative
func @test_flatten_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
%1 = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_flatten_3
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<24x1xf32>
// CHECK: return [[RES]] : tensor<24x1xf32>
}
// -----
// Test when input is not static shape
func @test_flatten_4(%arg0 : tensor<2x4x5x?xf32>) -> tensor<*xf32> {
%1 = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_flatten_4
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<8x?xf32>
// CHECK: return [[RES]] : tensor<8x?xf32>
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Test the reshape op inference when concat are present. /// Test the reshape op inference when concat are present.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//