Improve shape inference of flatten Op (#314)
* change code Signed-off-by: chentong <chentong@us.ibm.com> * add test Signed-off-by: chentong <chentong@us.ibm.com> * fix type error Signed-off-by: chentong <chentong@us.ibm.com> Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
This commit is contained in:
		
							parent
							
								
									aa2eed411f
								
							
						
					
					
						commit
						46306a8a26
					
				| 
						 | 
					@ -2153,32 +2153,48 @@ LogicalResult ONNXSplitOp::inferShapes() {
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LogicalResult ONNXFlattenOp::inferShapes() {
 | 
					LogicalResult ONNXFlattenOp::inferShapes() {
 | 
				
			||||||
  assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
 | 
					 | 
				
			||||||
  auto inTy = input().getType().dyn_cast<ShapedType>();
 | 
					  auto inTy = input().getType().dyn_cast<ShapedType>();
 | 
				
			||||||
  if (!inTy) {
 | 
					  if (!inTy) {
 | 
				
			||||||
    return emitOpError("Input is a non-shaped type");
 | 
					    return emitOpError("Input is a non-shaped type");
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  auto outTy = output().getType().dyn_cast<ShapedType>();
 | 
					 | 
				
			||||||
  if (!outTy) {
 | 
					 | 
				
			||||||
    return emitOpError("Output is a non-shaped type");
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // TODO(tjingrant): Seems like we can also fairly easily support the case
 | 
					  auto axisValue = axis();
 | 
				
			||||||
  //                  where the batch dimension is dynamic
 | 
					  auto inputShape = inTy.getShape();
 | 
				
			||||||
  if (!outTy.hasStaticShape()) {
 | 
					  auto inputRank = inputShape.size();
 | 
				
			||||||
    auto inShape = inTy.getShape();
 | 
					  if (axisValue < -1 * (int64_t)inputRank || axisValue > (int64_t)inputRank) {
 | 
				
			||||||
    assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0");
 | 
					    return emitOpError("ONNXFlattenOP: axis() value is out of range");
 | 
				
			||||||
    uint64_t outDim = 1;
 | 
					 | 
				
			||||||
    for (auto it = inShape.begin() + 1; it < inShape.end(); it++) {
 | 
					 | 
				
			||||||
      outDim *= *it;
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  SmallVector<int64_t, 2> dims;
 | 
					  SmallVector<int64_t, 2> dims;
 | 
				
			||||||
    // https://pytorch.org/docs/master/generated/torch.nn.Flatten.html
 | 
					
 | 
				
			||||||
    dims.emplace_back(inShape[0]);
 | 
					  // Negative axis is counting dimension from back
 | 
				
			||||||
    dims.emplace_back(outDim);
 | 
					  if (axisValue < 0)
 | 
				
			||||||
    getResult().setType(RankedTensorType::get(dims, outTy.getElementType()));
 | 
					    axisValue = inputRank + axisValue + 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Determine the size of the first dimension of output
 | 
				
			||||||
 | 
					  int64_t firstDim = 1;
 | 
				
			||||||
 | 
					  for (auto i = 0; i < axisValue; i++) {
 | 
				
			||||||
 | 
					    if (inputShape[i] == -1) {
 | 
				
			||||||
 | 
					      firstDim = -1;
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    firstDim *= inputShape[i];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  dims.emplace_back(firstDim);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Determine the size of the second dimension of output
 | 
				
			||||||
 | 
					  int64_t secondDim = 1;
 | 
				
			||||||
 | 
					  for (auto i = axisValue; i < inputRank; i++) {
 | 
				
			||||||
 | 
					    if (inputShape[i] == -1) {
 | 
				
			||||||
 | 
					      secondDim = -1;
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    secondDim *= inputShape[i];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  dims.emplace_back(secondDim);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Set the type of output
 | 
				
			||||||
 | 
					  getResult().setType(RankedTensorType::get(dims, inTy.getElementType()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return success();
 | 
					  return success();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -639,6 +639,40 @@ func @test_flatten_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<5x24xf32>
 | 
					  // CHECK: return [[RES]] : tensor<5x24xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test when axis is 0
 | 
				
			||||||
 | 
					func @test_flatten_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %1 = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_flatten_2
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<2x3x4xf32>) -> tensor<1x24xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<1x24xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test when axis is negative
 | 
				
			||||||
 | 
					func @test_flatten_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %1 = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_flatten_3
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = -1 : si64} : (tensor<2x3x4xf32>) -> tensor<24x1xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<24x1xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test when input is not static shape
 | 
				
			||||||
 | 
					func @test_flatten_4(%arg0 : tensor<2x4x5x?xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %1 = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_flatten_4
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<2x4x5x?xf32>) -> tensor<8x?xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : tensor<8x?xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
/// Test the reshape op inference when concat are present.
 | 
					/// Test the reshape op inference when concat are present.
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue