Fix reshape output shape inference when a single dynamic shape is given (#22)

* Fix reshape when a dynamic shape is given.

* Fix default attributes for ConvNoBias.

* Fix comment.

* Resolve comment.

* Improve checks.

* Handle zero dim case.

* Add helper to fetch constants. Add test for dynamic reshape.

* Add test for zero.

* Use shortcut method for size.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-03-13 17:18:46 -04:00 committed by GitHub
parent 6137fc7c17
commit c46880d5c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 118 additions and 31 deletions

View File

@ -42,6 +42,11 @@ static int64_t ArrayAttrIntVal(Optional<ArrayAttr> a, int i) {
return (a.getValue().getValue()[i]).cast<IntegerAttr>().getInt();
}
// Returns the ConstantOp which defines an MLIR Value or null.
static mlir::ONNXConstantOp getONNXConstantOp(Value value) {
return dyn_cast_or_null<mlir::ONNXConstantOp>(value.getDefiningOp());
}
//===----------------------------------------------------------------------===//
// Get reduction type
//===----------------------------------------------------------------------===//
@ -861,26 +866,48 @@ void ONNXReshapeOp::inferShapes() {
if (outputRank < 0)
emitError("Shape tensor must have constant shape");
// Compute total number of elements.
int64_t totalInputSize = 1;
for(auto inputDim : inputTensorTy.getShape())
totalInputSize *= inputDim;
// Check if second argument of ReshapeOp is a constant.
// Get operation that defines the second argument. If this operation is a
// `ConstantTensor` operation, the shape of this `Reshape` operation
// resides in the `value` attribute of the `ConstantTensor` operation.
auto *secondArgDefiningOp = (*getODSOperands(1).begin()).getDefiningOp();
auto constantOp =
dyn_cast_or_null<mlir::ONNXConstantOp>(secondArgDefiningOp);
auto constantOp = getONNXConstantOp(shape());
SmallVector<int64_t, 2> dims(outputRank, -1);
if (constantOp) {
// Cast attribute to ArrayAttr.
ArrayAttr valueAttribute = constantOp.valueAttr().dyn_cast<ArrayAttr>();
if (!valueAttribute)
emitError("ArrayAttr expected");
if (valueAttribute.getValue().size() != outputRank)
if (ArrayAttrSize(valueAttribute) != outputRank)
emitError("Constant value must have same rank as output");
for (int i=0; i<outputRank; ++i)
dims[i] = valueAttribute.getValue()[i].cast<IntegerAttr>().getInt();
int64_t numberOfDynamicInputs = 0;
int64_t totalKnownDimsSize = 1;
int64_t dynamicValueIndex = -1;
for (int i=0; i<outputRank; ++i) {
// Set output dimension.
dims[i] = ArrayAttrIntVal(valueAttribute, i);
if (dims[i] == 0)
dims[i] = inputTensorTy.getShape()[i];
if (dims[i] < 0) {
numberOfDynamicInputs++;
dynamicValueIndex = i;
} else {
totalKnownDimsSize *= dims[i];
}
}
// If the number of dynamic inputs is 1 then deduce the missing value
// based on the total input size. The total input size must be greater
// than 0 i.e. all constant dimensions.
// TODO: Support dynamic input dimensons.
if (numberOfDynamicInputs == 1 && totalKnownDimsSize > 0 &&
totalInputSize > 0)
dims[dynamicValueIndex] = totalInputSize / totalKnownDimsSize;
}
getResult().setType(
@ -970,6 +997,8 @@ void ONNXReduceSumOp::inferShapes() {
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
}
//===----------------------------------------------------------------------===//
// Conv
// For this operation, we define the attributes once in the original Conv
@ -995,6 +1024,7 @@ void ONNXConvNoBiasOp::inferShapes() {
auto xShape = xTy.getShape();
auto weightTy = W().getType().cast<RankedTensorType>();
auto weightShape = weightTy.getShape();
auto builder = mlir::Builder(this->getContext());
// Lowest supported convolution is a one dimensional convolution.
if (xShape.size() < 3)
@ -1006,6 +1036,11 @@ void ONNXConvNoBiasOp::inferShapes() {
// Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvNoBiasOp::group().getSExtValue();
// Check if the attribute actually exists. If it does not then add it.
if (!groupAttr())
groupAttr(builder.getI64IntegerAttr(group));
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (xShape[1] != -1 && weightShape[1] != -1 &&
xShape[1] != (weightShape[1] * group))

View File

@ -282,64 +282,74 @@ func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7
func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
%0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_PadConstantValuePad_1
// CHECK: [[RES:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<18x13xf32>
// CHECK: return [[RES]] : tensor<18x13xf32>
}
// CHECK-LABEL: test_PadConstantValuePad_1
// CHECK: [[RES:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<18x13xf32>
// CHECK: return [[RES]] : tensor<18x13xf32>
/// Test PadConstantPad_1
func @test_PadConstantPad_1(%arg0 : tensor<16x13xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
%0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_PadConstantPad_1
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32>
// CHECK: return [[RES]] : tensor<18x17xf32>
}
// CHECK-LABEL: test_PadConstantPad_1
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32>
// CHECK: return [[RES]] : tensor<18x17xf32>
/// Test PadConstantPad_2
func @test_PadConstantPad_2(%arg0 : tensor<16x?xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
%0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_PadConstantPad_2
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<18x?xf32>
// CHECK: return [[RES]] : tensor<18x?xf32>
}
// CHECK-LABEL: test_PadConstantPad_2
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<18x?xf32>
// CHECK: return [[RES]] : tensor<18x?xf32>
//===----------------------------------------------------------------------===//
/// Test for constant op.
//===----------------------------------------------------------------------===//
/// Test ConstantOp shape inference for 1-D dense tensor.
func @test_constant_dense_1d_value() -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_constant_dense_1d_value
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK: return [[RES]] : tensor<3xf32>
}
// CHECK-LABEL: test_constant_dense_1d_value
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK: return [[RES]] : tensor<3xf32>
/// Test ConstantOp shape inference for 2-D dense tensor.
func @test_constant_dense_2d_value() -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_constant_dense_2d_value
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<{{\[}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00{{\]}}]> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
// CHECK: return [[RES]] : tensor<3x2xf32>
}
// CHECK-LABEL: test_constant_dense_2d_value
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<{{\[}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00{{\]}}]> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
// CHECK: return [[RES]] : tensor<3x2xf32>
/// Test ConstantOp shape inference for 1-D sparse tensor.
func @test_constant_sparse_1d_value() -> tensor<*xf32> {
%0 = "onnx.Constant"() {sparse_value = sparse<[[0]], [1.0]> : tensor<3xf32>} : () -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_constant_sparse_1d_value
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<0, 1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK: return [[RES]] : tensor<3xf32>
}
// CHECK-LABEL: test_constant_sparse_1d_value
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<0, 1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK: return [[RES]] : tensor<3xf32>
/// Test ConstantOp shape inference for 2-D sparse tensor.
func @test_constant_sparse_2d_value() -> tensor<*xf32> {
%0 = "onnx.Constant"() {sparse_value = sparse<[[0, 1]], [2.0]> : tensor<3x2xf32>} : () -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_constant_sparse_2d_value
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<{{\[}}[0, 1{{\]}}], 2.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
// CHECK: return [[RES]] : tensor<3x2xf32>
}
// CHECK-LABEL: test_constant_sparse_2d_value
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<{{\[}}[0, 1{{\]}}], 2.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
// CHECK: return [[RES]] : tensor<3x2xf32>
/// Test the default behavior of Average Pool with no padding (pad are set but shoud be ignored)
func @test_default_averagepool(%arg0 : tensor<5x5x32x32xf32>) -> tensor<*xf32> {
@ -411,3 +421,45 @@ func @test_default_averagepool_strides_nonunifpad_ceil(%arg0 : tensor<5x5x30x32x
// CHECK: return [[RES]] : tensor<5x5x16x16xf32>
}
//===----------------------------------------------------------------------===//
/// Test the reshape op inference when constants are present.
//===----------------------------------------------------------------------===//
func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape_dynamic
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
// CHECK: return [[RES]] : tensor<?x?x?x?xf32>
}
func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {sparse_value = [], value = [5, 5, 16, 2] } : () -> tensor<4xi32>
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape_1
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<5x5x16x2xf32>
// CHECK: return [[RES]] : tensor<5x5x16x2xf32>
}
func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {sparse_value = [], value = [-1, 16, 2] } : () -> tensor<3xi32>
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape_2
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<25x16x2xf32>
// CHECK: return [[RES]] : tensor<25x16x2xf32>
}
func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {sparse_value = [], value = [-1, 0, 2] } : () -> tensor<3xi32>
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape_3
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32>
// CHECK: return [[RES]] : tensor<80x5x2xf32>
}