diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index de3f950..a1f0480 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -748,10 +748,12 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) { auto dimSize = operandType.getDimSize(i); auto resultDimSize = resultType.getDimSize(dimIndex); - if (dimSize != 1 && dimSize != resultDimSize) { + // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we + // add a manual check for this. + if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) { return op.emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not equal to " - "1 or size of result dimension {2} ({3})", + llvm::formatv("size of operand dimension {0} ({1}) is not compatible " + "with size of result dimension {2} ({3})", i, dimSize, dimIndex, resultDimSize)); } } diff --git a/tests/ops.mlir b/tests/ops.mlir index 212e794..3443f21 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -116,6 +116,30 @@ func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64>) -> // ----- +// CHECK-LABEL: func @dynamic_broadcast_in_dim_unknown_dim +func @dynamic_broadcast_in_dim_unknown_dim(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_ok_dim +func @dynamic_broadcast_in_dim_ok_dim(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + return %0 : tensor<7x8x9xf32> +} + +// ----- + +func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { + // expected-error@+1 {{size of operand dimension 0 (32) is not compatible with size of result dimension 2 (9)}} + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + return %0 : tensor<7x8x9xf32> +} + +// ----- + func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>