From 384b87fad00a64430dbe3cf01eb9be07deea849b Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 4 May 2021 01:02:00 -0700 Subject: [PATCH] Lower ReluGrad via chlo::BroadcastSelect. This allows to get rid of the constraint that it needs to have a static shape. PiperOrigin-RevId: 371862452 --- .../mhlo/transforms/transform_unranked_hlo.cc | 26 ++++++++++++------- tests/hlo-transform-unranked.mlir | 22 ++++++++++++++++ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 097e029..08605e1 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -81,8 +81,8 @@ struct ElementwiseOpConversion : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Don't apply conversion unless all operands are unranked. - if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) { + // Only apply conversion if at least one operand is unranked. + if (!llvm::any_of(op.getOperation()->getOperands(), [&](Value operand) { return operand.getType().isa(); })) { return failure(); @@ -227,15 +227,21 @@ struct ConvertUnrankedDynamicBroadcastNaryOp typename ChloOpTy::Adaptor transformed(operands); ValueRange transformed_operands = transformed.getOperands(); auto num_operands = transformed_operands.size(); - llvm::SmallVector operand_types; - operand_types.reserve(num_operands); + llvm::SmallVector operand_element_types; + operand_element_types.reserve(num_operands); + bool has_unranked_tensor_type = false; for (int i = 0; i < num_operands; ++i) { - auto type = - transformed_operands[i].getType().dyn_cast(); - // Only support unranked operands. - if (!type) return failure(); - operand_types.push_back(type); + if (auto type = + transformed_operands[i].getType().dyn_cast()) { + if (type.isa()) { + has_unranked_tensor_type = true; + } + operand_element_types.push_back(type.getElementType()); + } else { + return failure(); + } } + if (!has_unranked_tensor_type) return failure(); auto result_type = op.getResult().getType().template dyn_cast(); llvm::SmallVector shapes; @@ -278,7 +284,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp if_at_most_one_non_scalar_builder.create( loc, RankedTensorType::get({RankedTensorType::kDynamicSize}, - operand_types[i].getElementType()), + operand_element_types[i]), transformed_operands[i], size_tensor); reshaped_operands.push_back(reshaped); } diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index 148ab28..52d6695 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -64,6 +64,28 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { // ----- +// Transformed if there is a mix of unranked/static shapes. +// CHECK-LABEL: @select_mixed +// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ON_TRUE:.*]]: tensor<*xf32>, %[[ON_FALSE:.*]]: tensor<2xf32>) +func @select_mixed(%pred: tensor<*xi1>, %on_true: tensor<*xf32>, %on_false: tensor<2xf32>) -> tensor<*xf32> { + // CHECK: %[[SHAPE_PRED:.*]] = shape.shape_of %[[PRED]] + // CHECK: %[[SHAPE_ON_TRUE:.*]] = shape.shape_of %[[ON_TRUE]] + // CHECK: %[[SHAPE_ON_FALSE:.*]] = shape.shape_of %[[ON_FALSE]] + // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_PRED]], %[[SHAPE_ON_TRUE]], %[[SHAPE_ON_FALSE]] + // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] + // CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> + // CHECK: %[[FLAT_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[FLAT_SHAPE]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_ON_TRUE:.*]] = "mhlo.dynamic_reshape"(%[[ON_TRUE]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_ON_FALSE:.*]] = "mhlo.dynamic_reshape"(%[[ON_FALSE]], %[[FLAT_SHAPE]]) : (tensor<2xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_RESULT:.*]] = "mhlo.select"(%[[FLAT_PRED]], %[[FLAT_ON_TRUE]], %[[FLAT_ON_FALSE]]) : (tensor, tensor, tensor) -> tensor + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: return %[[RESULT]] : tensor<*xf32> + %b = "mhlo.select"(%pred, %on_true, %on_false) : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + return %b : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: @add_unranked // CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32> func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {