diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index f98c5f0..82c37da 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -42,6 +42,11 @@ static int64_t ArrayAttrIntVal(Optional a, int i) { return (a.getValue().getValue()[i]).cast().getInt(); } +// Returns the ConstantOp which defines an MLIR Value or null. +static mlir::ONNXConstantOp getONNXConstantOp(Value value) { + return dyn_cast_or_null(value.getDefiningOp()); +} + //===----------------------------------------------------------------------===// // Get reduction type //===----------------------------------------------------------------------===// @@ -861,26 +866,48 @@ void ONNXReshapeOp::inferShapes() { if (outputRank < 0) emitError("Shape tensor must have constant shape"); + // Compute total number of elements. + int64_t totalInputSize = 1; + for(auto inputDim : inputTensorTy.getShape()) + totalInputSize *= inputDim; + // Check if second argument of ReshapeOp is a constant. - // Get operation that defines the second argument. If this operation is a - // `ConstantTensor` operation, the shape of this `Reshape` operation - // resides in the `value` attribute of the `ConstantTensor` operation. - auto *secondArgDefiningOp = (*getODSOperands(1).begin()).getDefiningOp(); - auto constantOp = - dyn_cast_or_null(secondArgDefiningOp); + auto constantOp = getONNXConstantOp(shape()); SmallVector dims(outputRank, -1); if (constantOp) { + // Cast attribute to ArrayAttr. ArrayAttr valueAttribute = constantOp.valueAttr().dyn_cast(); - if (!valueAttribute) emitError("ArrayAttr expected"); - if (valueAttribute.getValue().size() != outputRank) + if (ArrayAttrSize(valueAttribute) != outputRank) emitError("Constant value must have same rank as output"); - for (int i=0; i().getInt(); + int64_t numberOfDynamicInputs = 0; + int64_t totalKnownDimsSize = 1; + int64_t dynamicValueIndex = -1; + for (int i=0; i 0 && + totalInputSize > 0) + dims[dynamicValueIndex] = totalInputSize / totalKnownDimsSize; } getResult().setType( @@ -970,6 +997,8 @@ void ONNXReduceSumOp::inferShapes() { getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); } +//===----------------------------------------------------------------------===// + // Conv // For this operation, we define the attributes once in the original Conv @@ -995,6 +1024,7 @@ void ONNXConvNoBiasOp::inferShapes() { auto xShape = xTy.getShape(); auto weightTy = W().getType().cast(); auto weightShape = weightTy.getShape(); + auto builder = mlir::Builder(this->getContext()); // Lowest supported convolution is a one dimensional convolution. if (xShape.size() < 3) @@ -1006,6 +1036,11 @@ void ONNXConvNoBiasOp::inferShapes() { // Group is a required attribute and should have default value of 1. int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); + + // Check if the attribute actually exists. If it does not then add it. + if (!groupAttr()) + groupAttr(builder.getI64IntegerAttr(group)); + // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. if (xShape[1] != -1 && weightShape[1] != -1 && xShape[1] != (weightShape[1] * group)) diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 14e3a8c..ec2f63d 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -282,64 +282,74 @@ func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7 func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> { %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_PadConstantValuePad_1 + // CHECK: [[RES:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<18x13xf32> + // CHECK: return [[RES]] : tensor<18x13xf32> } -// CHECK-LABEL: test_PadConstantValuePad_1 -// CHECK: [[RES:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<18x13xf32> -// CHECK: return [[RES]] : tensor<18x13xf32> /// Test PadConstantPad_1 func @test_PadConstantPad_1(%arg0 : tensor<16x13xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { %0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + // CHECK-LABEL: test_PadConstantPad_1 + // CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32> + // CHECK: return [[RES]] : tensor<18x17xf32> } -// CHECK-LABEL: test_PadConstantPad_1 -// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x13xf32>, tensor<*xf32>) -> tensor<18x17xf32> -// CHECK: return [[RES]] : tensor<18x17xf32> /// Test PadConstantPad_2 func @test_PadConstantPad_2(%arg0 : tensor<16x?xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { %0 = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_PadConstantPad_2 + // CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<18x?xf32> + // CHECK: return [[RES]] : tensor<18x?xf32> } -// CHECK-LABEL: test_PadConstantPad_2 -// CHECK: [[RES:%.+]] = "onnx.PadConstantPad"(%arg0, %arg1) {mode = "constant", pads = [0, 3, 2, 1]} : (tensor<16x?xf32>, tensor<*xf32>) -> tensor<18x?xf32> -// CHECK: return [[RES]] : tensor<18x?xf32> + +//===----------------------------------------------------------------------===// +/// Test for constant op. +//===----------------------------------------------------------------------===// /// Test ConstantOp shape inference for 1-D dense tensor. func @test_constant_dense_1d_value() -> tensor<*xf32> { %0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_constant_dense_1d_value + // CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK: return [[RES]] : tensor<3xf32> } -// CHECK-LABEL: test_constant_dense_1d_value -// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32> -// CHECK: return [[RES]] : tensor<3xf32> /// Test ConstantOp shape inference for 2-D dense tensor. func @test_constant_dense_2d_value() -> tensor<*xf32> { %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_constant_dense_2d_value + // CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<{{\[}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00{{\]}}]> : tensor<3x2xf32>} : () -> tensor<3x2xf32> + // CHECK: return [[RES]] : tensor<3x2xf32> } -// CHECK-LABEL: test_constant_dense_2d_value -// CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<{{\[}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00{{\]}}]> : tensor<3x2xf32>} : () -> tensor<3x2xf32> -// CHECK: return [[RES]] : tensor<3x2xf32> /// Test ConstantOp shape inference for 1-D sparse tensor. func @test_constant_sparse_1d_value() -> tensor<*xf32> { %0 = "onnx.Constant"() {sparse_value = sparse<[[0]], [1.0]> : tensor<3xf32>} : () -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_constant_sparse_1d_value + // CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<0, 1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK: return [[RES]] : tensor<3xf32> } -// CHECK-LABEL: test_constant_sparse_1d_value -// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<0, 1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> -// CHECK: return [[RES]] : tensor<3xf32> /// Test ConstantOp shape inference for 2-D sparse tensor. func @test_constant_sparse_2d_value() -> tensor<*xf32> { %0 = "onnx.Constant"() {sparse_value = sparse<[[0, 1]], [2.0]> : tensor<3x2xf32>} : () -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_constant_sparse_2d_value + // CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<{{\[}}[0, 1{{\]}}], 2.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32> + // CHECK: return [[RES]] : tensor<3x2xf32> } -// CHECK-LABEL: test_constant_sparse_2d_value -// CHECK: [[RES:%.+]] = "onnx.Constant"() {sparse_value = sparse<{{\[}}[0, 1{{\]}}], 2.000000e+00> : tensor<3x2xf32>} : () -> tensor<3x2xf32> -// CHECK: return [[RES]] : tensor<3x2xf32> /// Test the default behavior of Average Pool with no padding (pad are set but shoud be ignored) func @test_default_averagepool(%arg0 : tensor<5x5x32x32xf32>) -> tensor<*xf32> { @@ -411,3 +421,45 @@ func @test_default_averagepool_strides_nonunifpad_ceil(%arg0 : tensor<5x5x30x32x // CHECK: return [[RES]] : tensor<5x5x16x16xf32> } +//===----------------------------------------------------------------------===// +/// Test the reshape op inference when constants are present. +//===----------------------------------------------------------------------===// + +func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> { + %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reshape_dynamic + // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor + // CHECK: return [[RES]] : tensor +} + +func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { + %0 = "onnx.Constant"() {sparse_value = [], value = [5, 5, 16, 2] } : () -> tensor<4xi32> + %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reshape_1 + // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<5x5x16x2xf32> + // CHECK: return [[RES]] : tensor<5x5x16x2xf32> +} + +func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { + %0 = "onnx.Constant"() {sparse_value = [], value = [-1, 16, 2] } : () -> tensor<3xi32> + %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reshape_2 + // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<25x16x2xf32> + // CHECK: return [[RES]] : tensor<25x16x2xf32> +} + +func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { + %0 = "onnx.Constant"() {sparse_value = [], value = [-1, 0, 2] } : () -> tensor<3xi32> + %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reshape_3 + // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32> + // CHECK: return [[RES]] : tensor<80x5x2xf32> +}