Improve shape inference of flatten Op (#314)

* change code

Signed-off-by: chentong <chentong@us.ibm.com>

* add test

Signed-off-by: chentong <chentong@us.ibm.com>

* fix type error

Signed-off-by: chentong <chentong@us.ibm.com>

Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
This commit is contained in:
chentong319 2020-09-29 12:59:01 -04:00 committed by GitHub
parent aa2eed411f
commit 46306a8a26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 18 deletions

View File

@ -2153,32 +2153,48 @@ LogicalResult ONNXSplitOp::inferShapes() {
//===----------------------------------------------------------------------===//
LogicalResult ONNXFlattenOp::inferShapes() {
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
auto inTy = input().getType().dyn_cast<ShapedType>();
if (!inTy) {
return emitOpError("Input is a non-shaped type");
}
auto outTy = output().getType().dyn_cast<ShapedType>();
if (!outTy) {
return emitOpError("Output is a non-shaped type");
}
// TODO(tjingrant): Seems like we can also fairly easily support the case
// where the batch dimension is dynamic
if (!outTy.hasStaticShape()) {
auto inShape = inTy.getShape();
assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0");
uint64_t outDim = 1;
for (auto it = inShape.begin() + 1; it < inShape.end(); it++) {
outDim *= *it;
auto axisValue = axis();
auto inputShape = inTy.getShape();
auto inputRank = inputShape.size();
if (axisValue < -1 * (int64_t)inputRank || axisValue > (int64_t)inputRank) {
return emitOpError("ONNXFlattenOP: axis() value is out of range");
}
SmallVector<int64_t, 2> dims;
// https://pytorch.org/docs/master/generated/torch.nn.Flatten.html
dims.emplace_back(inShape[0]);
dims.emplace_back(outDim);
getResult().setType(RankedTensorType::get(dims, outTy.getElementType()));
// Negative axis is counting dimension from back
if (axisValue < 0)
axisValue = inputRank + axisValue + 1;
// Determine the size of the first dimension of output
int64_t firstDim = 1;
for (auto i = 0; i < axisValue; i++) {
if (inputShape[i] == -1) {
firstDim = -1;
break;
}
firstDim *= inputShape[i];
}
dims.emplace_back(firstDim);
// Determine the size of the second dimension of output
int64_t secondDim = 1;
for (auto i = axisValue; i < inputRank; i++) {
if (inputShape[i] == -1) {
secondDim = -1;
break;
}
secondDim *= inputShape[i];
}
dims.emplace_back(secondDim);
// Set the type of output
getResult().setType(RankedTensorType::get(dims, inTy.getElementType()));
return success();
}

View File

@ -639,6 +639,40 @@ func @test_flatten_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : tensor<5x24xf32>
}
// -----
// Test when axis is 0
func @test_flatten_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
%1 = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_flatten_2
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<1x24xf32>
// CHECK: return [[RES]] : tensor<1x24xf32>
}
// -----
// Test when axis is negative
func @test_flatten_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
%1 = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_flatten_3
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<24x1xf32>
// CHECK: return [[RES]] : tensor<24x1xf32>
}
// -----
// Test when input is not static shape
func @test_flatten_4(%arg0 : tensor<2x4x5x?xf32>) -> tensor<*xf32> {
%1 = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_flatten_4
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<8x?xf32>
// CHECK: return [[RES]] : tensor<8x?xf32>
}
//===----------------------------------------------------------------------===//
/// Test the reshape op inference when concat are present.
//===----------------------------------------------------------------------===//