Fix reshape output shape inference when a single dynamic shape is given (#22)
* Fix reshape when a dynamic shape is given. * Fix default attributes for ConvNoBias. * Fix comment. * Resolve comment. * Improve checks. * Handle zero dim case. * Add helper to fetch constants. Add test for dynamic reshape. * Add test for zero. * Use shortcut method for size.
This commit is contained in:
parent
6137fc7c17
commit
c46880d5c6
|
@ -42,6 +42,11 @@ static int64_t ArrayAttrIntVal(Optional<ArrayAttr> a, int i) {
|
||||||
return (a.getValue().getValue()[i]).cast<IntegerAttr>().getInt();
|
return (a.getValue().getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the ConstantOp which defines an MLIR Value or null.
|
||||||
|
static mlir::ONNXConstantOp getONNXConstantOp(Value value) {
|
||||||
|
return dyn_cast_or_null<mlir::ONNXConstantOp>(value.getDefiningOp());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Get reduction type
|
// Get reduction type
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -861,26 +866,48 @@ void ONNXReshapeOp::inferShapes() {
|
||||||
if (outputRank < 0)
|
if (outputRank < 0)
|
||||||
emitError("Shape tensor must have constant shape");
|
emitError("Shape tensor must have constant shape");
|
||||||
|
|
||||||
|
// Compute total number of elements.
|
||||||
|
int64_t totalInputSize = 1;
|
||||||
|
for(auto inputDim : inputTensorTy.getShape())
|
||||||
|
totalInputSize *= inputDim;
|
||||||
|
|
||||||
// Check if second argument of ReshapeOp is a constant.
|
// Check if second argument of ReshapeOp is a constant.
|
||||||
// Get operation that defines the second argument. If this operation is a
|
auto constantOp = getONNXConstantOp(shape());
|
||||||
// `ConstantTensor` operation, the shape of this `Reshape` operation
|
|
||||||
// resides in the `value` attribute of the `ConstantTensor` operation.
|
|
||||||
auto *secondArgDefiningOp = (*getODSOperands(1).begin()).getDefiningOp();
|
|
||||||
auto constantOp =
|
|
||||||
dyn_cast_or_null<mlir::ONNXConstantOp>(secondArgDefiningOp);
|
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims(outputRank, -1);
|
SmallVector<int64_t, 2> dims(outputRank, -1);
|
||||||
if (constantOp) {
|
if (constantOp) {
|
||||||
|
// Cast attribute to ArrayAttr.
|
||||||
ArrayAttr valueAttribute = constantOp.valueAttr().dyn_cast<ArrayAttr>();
|
ArrayAttr valueAttribute = constantOp.valueAttr().dyn_cast<ArrayAttr>();
|
||||||
|
|
||||||
if (!valueAttribute)
|
if (!valueAttribute)
|
||||||
emitError("ArrayAttr expected");
|
emitError("ArrayAttr expected");
|
||||||
|
|
||||||
if (valueAttribute.getValue().size() != outputRank)
|
if (ArrayAttrSize(valueAttribute) != outputRank)
|
||||||
emitError("Constant value must have same rank as output");
|
emitError("Constant value must have same rank as output");
|
||||||
|
|
||||||
for (int i=0; i<outputRank; ++i)
|
int64_t numberOfDynamicInputs = 0;
|
||||||
dims[i] = valueAttribute.getValue()[i].cast<IntegerAttr>().getInt();
|
int64_t totalKnownDimsSize = 1;
|
||||||
|
int64_t dynamicValueIndex = -1;
|
||||||
|
for (int i=0; i<outputRank; ++i) {
|
||||||
|
// Set output dimension.
|
||||||
|
dims[i] = ArrayAttrIntVal(valueAttribute, i);
|
||||||
|
if (dims[i] == 0)
|
||||||
|
dims[i] = inputTensorTy.getShape()[i];
|
||||||
|
|
||||||
|
if (dims[i] < 0) {
|
||||||
|
numberOfDynamicInputs++;
|
||||||
|
dynamicValueIndex = i;
|
||||||
|
} else {
|
||||||
|
totalKnownDimsSize *= dims[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the number of dynamic inputs is 1 then deduce the missing value
|
||||||
|
// based on the total input size. The total input size must be greater
|
||||||
|
// than 0 i.e. all constant dimensions.
|
||||||
|
// TODO: Support dynamic input dimensons.
|
||||||
|
if (numberOfDynamicInputs == 1 && totalKnownDimsSize > 0 &&
|
||||||
|
totalInputSize > 0)
|
||||||
|
dims[dynamicValueIndex] = totalInputSize / totalKnownDimsSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
getResult().setType(
|
getResult().setType(
|
||||||
|
@ -970,6 +997,8 @@ void ONNXReduceSumOp::inferShapes() {
|
||||||
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Conv
|
// Conv
|
||||||
|
|
||||||
// For this operation, we define the attributes once in the original Conv
|
// For this operation, we define the attributes once in the original Conv
|
||||||
|
@ -995,6 +1024,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
auto xShape = xTy.getShape();
|
auto xShape = xTy.getShape();
|
||||||
auto weightTy = W().getType().cast<RankedTensorType>();
|
auto weightTy = W().getType().cast<RankedTensorType>();
|
||||||
auto weightShape = weightTy.getShape();
|
auto weightShape = weightTy.getShape();
|
||||||
|
auto builder = mlir::Builder(this->getContext());
|
||||||
|
|
||||||
// Lowest supported convolution is a one dimensional convolution.
|
// Lowest supported convolution is a one dimensional convolution.
|
||||||
if (xShape.size() < 3)
|
if (xShape.size() < 3)
|
||||||
|
@ -1006,6 +1036,11 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
|
|
||||||
// Group is a required attribute and should have default value of 1.
|
// Group is a required attribute and should have default value of 1.
|
||||||
int64_t group = ONNXConvNoBiasOp::group().getSExtValue();
|
int64_t group = ONNXConvNoBiasOp::group().getSExtValue();
|
||||||
|
|
||||||
|
// Check if the attribute actually exists. If it does not then add it.
|
||||||
|
if (!groupAttr())
|
||||||
|
groupAttr(builder.getI64IntegerAttr(group));
|
||||||
|
|
||||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||||
if (xShape[1] != -1 && weightShape[1] != -1 &&
|
if (xShape[1] != -1 && weightShape[1] != -1 &&
|
||||||
xShape[1] != (weightShape[1] * group))
|
xShape[1] != (weightShape[1] * group))
|
||||||
|
|
|
@ -282,64 +282,74 @@ func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7
|
||||||
func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
|
func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<*xf32>
|
%0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
|
||||||
// CHECK-LABEL: test_PadConstantValuePad_1
|
// CHECK-LABEL: test_PadConstantValuePad_1
|
||||||
// CHECK: [[RES:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<18x13xf32>
|
// CHECK: [[RES:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<18x13xf32>
|
||||||
// CHECK: return [[RES]] : tensor<18x13xf32>
|
// CHECK: return [[RES]] : tensor<18x13xf32>
|
||||||
|
}
|
||||||
|
|
||||||
/// Test PadConstantPad_1
|
/// Test PadConstantPad_1
|
||||||
func @test_PadConstantPad_1(%arg0 : tensor<16x13xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
func @test_PadConstantPad_1(%arg0 : tensor<16x13xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<*xf32>
|
%0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
|
||||||
// CHECK-LABEL: test_PadConstantPad_1
|
// CHECK-LABEL: test_PadConstantPad_1
|
||||||
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32>
|
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32>
|
||||||
// CHECK: return [[RES]] : tensor<18x17xf32>
|
// CHECK: return [[RES]] : tensor<18x17xf32>
|
||||||
|
}
|
||||||
|
|
||||||
/// Test PadConstantPad_2
|
/// Test PadConstantPad_2
|
||||||
func @test_PadConstantPad_2(%arg0 : tensor<16x?xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
func @test_PadConstantPad_2(%arg0 : tensor<16x?xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<*xf32>
|
%0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
|
||||||
// CHECK-LABEL: test_PadConstantPad_2
|
// CHECK-LABEL: test_PadConstantPad_2
|
||||||
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<18x?xf32>
|
// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<18x?xf32>
|
||||||
// CHECK: return [[RES]] : tensor<18x?xf32>
|
// CHECK: return [[RES]] : tensor<18x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test for constant op.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Test ConstantOp shape inference for 1-D dense tensor.
|
/// Test ConstantOp shape inference for 1-D dense tensor.
|
||||||
func @test_constant_dense_1d_value() -> tensor<*xf32> {
|
func @test_constant_dense_1d_value() -> tensor<*xf32> {
|
||||||
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<*xf32>
|
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
|
||||||
// CHECK-LABEL: test_constant_dense_1d_value
|
// CHECK-LABEL: test_constant_dense_1d_value
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
// CHECK: return [[RES]] : tensor<3xf32>
|
// CHECK: return [[RES]] : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
/// Test ConstantOp shape inference for 2-D dense tensor.
|
/// Test ConstantOp shape inference for 2-D dense tensor.
|
||||||
func @test_constant_dense_2d_value() -> tensor<*xf32> {
|
func @test_constant_dense_2d_value() -> tensor<*xf32> {
|
||||||
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
%0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
|
||||||
// CHECK-LABEL: test_constant_dense_2d_value
|
// CHECK-LABEL: test_constant_dense_2d_value
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<{{\[}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00{{\]}}]> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<{{\[}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00{{\]}}]> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
|
||||||
// CHECK: return [[RES]] : tensor<3x2xf32>
|
// CHECK: return [[RES]] : tensor<3x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
/// Test ConstantOp shape inference for 1-D sparse tensor.
|
/// Test ConstantOp shape inference for 1-D sparse tensor.
|
||||||
func @test_constant_sparse_1d_value() -> tensor<*xf32> {
|
func @test_constant_sparse_1d_value() -> tensor<*xf32> {
|
||||||
%0 = "onnx.Constant"() {sparse_value = sparse<[[0]], [1.0]> : tensor<3xf32>} : () -> tensor<*xf32>
|
%0 = "onnx.Constant"() {sparse_value = sparse<[[0]], [1.0]> : tensor<3xf32>} : () -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
|
||||||
// CHECK-LABEL: test_constant_sparse_1d_value
|
// CHECK-LABEL: test_constant_sparse_1d_value
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<0, 1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<0, 1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
// CHECK: return [[RES]] : tensor<3xf32>
|
// CHECK: return [[RES]] : tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
/// Test ConstantOp shape inference for 2-D sparse tensor.
|
/// Test ConstantOp shape inference for 2-D sparse tensor.
|
||||||
func @test_constant_sparse_2d_value() -> tensor<*xf32> {
|
func @test_constant_sparse_2d_value() -> tensor<*xf32> {
|
||||||
%0 = "onnx.Constant"() {sparse_value = sparse<[[0, 1]], [2.0]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
%0 = "onnx.Constant"() {sparse_value = sparse<[[0, 1]], [2.0]> : tensor<3x2xf32>} : () -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
|
||||||
// CHECK-LABEL: test_constant_sparse_2d_value
|
// CHECK-LABEL: test_constant_sparse_2d_value
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<{{\[}}[0, 1{{\]}}], 2.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<{{\[}}[0, 1{{\]}}], 2.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32>
|
||||||
// CHECK: return [[RES]] : tensor<3x2xf32>
|
// CHECK: return [[RES]] : tensor<3x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
/// Test the default behavior of Average Pool with no padding (pad are set but shoud be ignored)
|
/// Test the default behavior of Average Pool with no padding (pad are set but shoud be ignored)
|
||||||
func @test_default_averagepool(%arg0 : tensor<5x5x32x32xf32>) -> tensor<*xf32> {
|
func @test_default_averagepool(%arg0 : tensor<5x5x32x32xf32>) -> tensor<*xf32> {
|
||||||
|
@ -411,3 +421,45 @@ func @test_default_averagepool_strides_nonunifpad_ceil(%arg0 : tensor<5x5x30x32x
|
||||||
// CHECK: return [[RES]] : tensor<5x5x16x16xf32>
|
// CHECK: return [[RES]] : tensor<5x5x16x16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Test the reshape op inference when constants are present.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_reshape_dynamic
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<?x?x?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {sparse_value = [], value = [5, 5, 16, 2] } : () -> tensor<4xi32>
|
||||||
|
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_reshape_1
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<5x5x16x2xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<5x5x16x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {sparse_value = [], value = [-1, 16, 2] } : () -> tensor<3xi32>
|
||||||
|
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_reshape_2
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<25x16x2xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<25x16x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Constant"() {sparse_value = [], value = [-1, 0, 2] } : () -> tensor<3xi32>
|
||||||
|
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_reshape_3
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<80x5x2xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue