From b579bd5d9eea23049afe1c03866be9f074e414a4 Mon Sep 17 00:00:00 2001 From: Richard Uhler Date: Wed, 17 Feb 2021 16:16:32 -0800 Subject: [PATCH] Support dynamic-shaped operand in verification of BroadcastInDim. Verification of HLO_BroadcastInDimOp was previously failing or crashing if the operand had a dynamic shape or was unranked. Update the verification code to allow the operand to be unranked or have dynamic shape. PiperOrigin-RevId: 358056793 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 22 +++++++++++++++------- tests/ops.mlir | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index a6ca692..c6ad9e0 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -740,6 +740,12 @@ static LogicalResult Verify(BroadcastOp op) { static LogicalResult Verify(BroadcastInDimOp op) { auto operandType = op.operand().getType().dyn_cast(); + if (!operandType) { + // The following verification checks all depend on knowing the rank of + // the operand. Bail out now if we don't know the rank of the operand. + return success(); + } + auto operandRank = operandType.getRank(); if (!op.broadcast_dimensions()) { if (operandRank == 0) { @@ -783,13 +789,15 @@ static LogicalResult Verify(BroadcastInDimOp op) { dimIndex, resultRank)); } - auto dimSize = operandType.getDimSize(i); - auto resultDimSize = resultType.getDimSize(dimIndex); - if (dimSize != 1 && dimSize != resultDimSize) { - return op.emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not equal to " - "1 or size of result dimension {2} ({3})", - i, dimSize, dimIndex, resultDimSize)); + if (!operandType.isDynamicDim(i)) { + auto dimSize = operandType.getDimSize(i); + auto resultDimSize = resultType.getDimSize(dimIndex); + if (dimSize != 1 && dimSize != resultDimSize) { + return op.emitOpError( + llvm::formatv("size of operand dimension {0} ({1}) is not equal to " + "1 or size of result dimension {2} ({3})", + i, dimSize, dimIndex, resultDimSize)); + } } } diff --git a/tests/ops.mlir b/tests/ops.mlir index 5651f04..93c4a76 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -180,6 +180,30 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x // ----- +// Regression test for b/180052624, where this was improperly marked as an +// invalid mhlo.broadcast_in_dim op. +// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand +func @broadcast_in_dim_dynamic_shaped_operand(%arg0 : tensor) -> tensor<2xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<0> : tensor<1xi64> + } : (tensor) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +// Regression test for b/180052624, where this crashed verification given the +// unranked operand. +// CHECK-LABEL: func @broadcast_in_dim_unranked_operand +func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<0> : tensor<1xi64> + } : (tensor<*xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}} %0 = "mhlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {