Improve shape inference of flatten Op (#314)
* change code Signed-off-by: chentong <chentong@us.ibm.com> * add test Signed-off-by: chentong <chentong@us.ibm.com> * fix type error Signed-off-by: chentong <chentong@us.ibm.com> Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
This commit is contained in:
parent
aa2eed411f
commit
46306a8a26
|
@ -2153,32 +2153,48 @@ LogicalResult ONNXSplitOp::inferShapes() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXFlattenOp::inferShapes() {
|
LogicalResult ONNXFlattenOp::inferShapes() {
|
||||||
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
|
|
||||||
auto inTy = input().getType().dyn_cast<ShapedType>();
|
auto inTy = input().getType().dyn_cast<ShapedType>();
|
||||||
if (!inTy) {
|
if (!inTy) {
|
||||||
return emitOpError("Input is a non-shaped type");
|
return emitOpError("Input is a non-shaped type");
|
||||||
}
|
}
|
||||||
auto outTy = output().getType().dyn_cast<ShapedType>();
|
|
||||||
if (!outTy) {
|
|
||||||
return emitOpError("Output is a non-shaped type");
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(tjingrant): Seems like we can also fairly easily support the case
|
auto axisValue = axis();
|
||||||
// where the batch dimension is dynamic
|
auto inputShape = inTy.getShape();
|
||||||
if (!outTy.hasStaticShape()) {
|
auto inputRank = inputShape.size();
|
||||||
auto inShape = inTy.getShape();
|
if (axisValue < -1 * (int64_t)inputRank || axisValue > (int64_t)inputRank) {
|
||||||
assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0");
|
return emitOpError("ONNXFlattenOP: axis() value is out of range");
|
||||||
uint64_t outDim = 1;
|
|
||||||
for (auto it = inShape.begin() + 1; it < inShape.end(); it++) {
|
|
||||||
outDim *= *it;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
// https://pytorch.org/docs/master/generated/torch.nn.Flatten.html
|
|
||||||
dims.emplace_back(inShape[0]);
|
// Negative axis is counting dimension from back
|
||||||
dims.emplace_back(outDim);
|
if (axisValue < 0)
|
||||||
getResult().setType(RankedTensorType::get(dims, outTy.getElementType()));
|
axisValue = inputRank + axisValue + 1;
|
||||||
|
|
||||||
|
// Determine the size of the first dimension of output
|
||||||
|
int64_t firstDim = 1;
|
||||||
|
for (auto i = 0; i < axisValue; i++) {
|
||||||
|
if (inputShape[i] == -1) {
|
||||||
|
firstDim = -1;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
firstDim *= inputShape[i];
|
||||||
|
}
|
||||||
|
dims.emplace_back(firstDim);
|
||||||
|
|
||||||
|
// Determine the size of the second dimension of output
|
||||||
|
int64_t secondDim = 1;
|
||||||
|
for (auto i = axisValue; i < inputRank; i++) {
|
||||||
|
if (inputShape[i] == -1) {
|
||||||
|
secondDim = -1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
secondDim *= inputShape[i];
|
||||||
|
}
|
||||||
|
dims.emplace_back(secondDim);
|
||||||
|
|
||||||
|
// Set the type of output
|
||||||
|
getResult().setType(RankedTensorType::get(dims, inTy.getElementType()));
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -639,6 +639,40 @@ func @test_flatten_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return [[RES]] : tensor<5x24xf32>
|
// CHECK: return [[RES]] : tensor<5x24xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test when axis is 0
|
||||||
|
func @test_flatten_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
|
||||||
|
%1 = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
// CHECK-LABEL: test_flatten_2
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<1x24xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<1x24xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test when axis is negative
|
||||||
|
func @test_flatten_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
|
||||||
|
%1 = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
// CHECK-LABEL: test_flatten_3
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<24x1xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<24x1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test when input is not static shape
|
||||||
|
func @test_flatten_4(%arg0 : tensor<2x4x5x?xf32>) -> tensor<*xf32> {
|
||||||
|
%1 = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
// CHECK-LABEL: test_flatten_4
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<8x?xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<8x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
/// Test the reshape op inference when concat are present.
|
/// Test the reshape op inference when concat are present.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue