Improve shape inference of flatten Op (#314)
* change code Signed-off-by: chentong <chentong@us.ibm.com> * add test Signed-off-by: chentong <chentong@us.ibm.com> * fix type error Signed-off-by: chentong <chentong@us.ibm.com> Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
This commit is contained in:
parent
aa2eed411f
commit
46306a8a26
|
@ -2153,32 +2153,48 @@ LogicalResult ONNXSplitOp::inferShapes() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXFlattenOp::inferShapes() {
|
||||
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
|
||||
auto inTy = input().getType().dyn_cast<ShapedType>();
|
||||
if (!inTy) {
|
||||
return emitOpError("Input is a non-shaped type");
|
||||
}
|
||||
auto outTy = output().getType().dyn_cast<ShapedType>();
|
||||
if (!outTy) {
|
||||
return emitOpError("Output is a non-shaped type");
|
||||
|
||||
auto axisValue = axis();
|
||||
auto inputShape = inTy.getShape();
|
||||
auto inputRank = inputShape.size();
|
||||
if (axisValue < -1 * (int64_t)inputRank || axisValue > (int64_t)inputRank) {
|
||||
return emitOpError("ONNXFlattenOP: axis() value is out of range");
|
||||
}
|
||||
|
||||
// TODO(tjingrant): Seems like we can also fairly easily support the case
|
||||
// where the batch dimension is dynamic
|
||||
if (!outTy.hasStaticShape()) {
|
||||
auto inShape = inTy.getShape();
|
||||
assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0");
|
||||
uint64_t outDim = 1;
|
||||
for (auto it = inShape.begin() + 1; it < inShape.end(); it++) {
|
||||
outDim *= *it;
|
||||
SmallVector<int64_t, 2> dims;
|
||||
|
||||
// Negative axis is counting dimension from back
|
||||
if (axisValue < 0)
|
||||
axisValue = inputRank + axisValue + 1;
|
||||
|
||||
// Determine the size of the first dimension of output
|
||||
int64_t firstDim = 1;
|
||||
for (auto i = 0; i < axisValue; i++) {
|
||||
if (inputShape[i] == -1) {
|
||||
firstDim = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
// https://pytorch.org/docs/master/generated/torch.nn.Flatten.html
|
||||
dims.emplace_back(inShape[0]);
|
||||
dims.emplace_back(outDim);
|
||||
getResult().setType(RankedTensorType::get(dims, outTy.getElementType()));
|
||||
firstDim *= inputShape[i];
|
||||
}
|
||||
dims.emplace_back(firstDim);
|
||||
|
||||
// Determine the size of the second dimension of output
|
||||
int64_t secondDim = 1;
|
||||
for (auto i = axisValue; i < inputRank; i++) {
|
||||
if (inputShape[i] == -1) {
|
||||
secondDim = -1;
|
||||
break;
|
||||
}
|
||||
secondDim *= inputShape[i];
|
||||
}
|
||||
dims.emplace_back(secondDim);
|
||||
|
||||
// Set the type of output
|
||||
getResult().setType(RankedTensorType::get(dims, inTy.getElementType()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -639,6 +639,40 @@ func @test_flatten_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xf32> {
|
|||
// CHECK: return [[RES]] : tensor<5x24xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test when axis is 0
|
||||
func @test_flatten_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
|
||||
%1 = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
// CHECK-LABEL: test_flatten_2
|
||||
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<1x24xf32>
|
||||
// CHECK: return [[RES]] : tensor<1x24xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test when axis is negative
|
||||
func @test_flatten_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
|
||||
%1 = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
// CHECK-LABEL: test_flatten_3
|
||||
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<24x1xf32>
|
||||
// CHECK: return [[RES]] : tensor<24x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test when input is not static shape
|
||||
func @test_flatten_4(%arg0 : tensor<2x4x5x?xf32>) -> tensor<*xf32> {
|
||||
%1 = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
// CHECK-LABEL: test_flatten_4
|
||||
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<8x?xf32>
|
||||
// CHECK: return [[RES]] : tensor<8x?xf32>
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Test the reshape op inference when concat are present.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue