Lower ReluGrad via chlo::BroadcastSelect.

This allows to get rid of the constraint that it needs to have a static shape.

PiperOrigin-RevId: 371862452
This commit is contained in:
Adrian Kuegel 2021-05-04 01:02:00 -07:00 committed by TensorFlow MLIR Team
parent 4bf1904c86
commit 384b87fad0
2 changed files with 38 additions and 10 deletions

View File

@ -81,8 +81,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Don't apply conversion unless all operands are unranked.
if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
// Only apply conversion if at least one operand is unranked.
if (!llvm::any_of(op.getOperation()->getOperands(), [&](Value operand) {
return operand.getType().isa<UnrankedTensorType>();
})) {
return failure();
@ -227,15 +227,21 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
typename ChloOpTy::Adaptor transformed(operands);
ValueRange transformed_operands = transformed.getOperands();
auto num_operands = transformed_operands.size();
llvm::SmallVector<UnrankedTensorType, 3> operand_types;
operand_types.reserve(num_operands);
llvm::SmallVector<Type, 3> operand_element_types;
operand_element_types.reserve(num_operands);
bool has_unranked_tensor_type = false;
for (int i = 0; i < num_operands; ++i) {
auto type =
transformed_operands[i].getType().dyn_cast<UnrankedTensorType>();
// Only support unranked operands.
if (!type) return failure();
operand_types.push_back(type);
if (auto type =
transformed_operands[i].getType().dyn_cast<TensorType>()) {
if (type.isa<UnrankedTensorType>()) {
has_unranked_tensor_type = true;
}
operand_element_types.push_back(type.getElementType());
} else {
return failure();
}
}
if (!has_unranked_tensor_type) return failure();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
llvm::SmallVector<Value> shapes;
@ -278,7 +284,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>(
loc,
RankedTensorType::get({RankedTensorType::kDynamicSize},
operand_types[i].getElementType()),
operand_element_types[i]),
transformed_operands[i], size_tensor);
reshaped_operands.push_back(reshaped);
}

View File

@ -64,6 +64,28 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
// -----
// Transformed if there is a mix of unranked/static shapes.
// CHECK-LABEL: @select_mixed
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ON_TRUE:.*]]: tensor<*xf32>, %[[ON_FALSE:.*]]: tensor<2xf32>)
func @select_mixed(%pred: tensor<*xi1>, %on_true: tensor<*xf32>, %on_false: tensor<2xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE_PRED:.*]] = shape.shape_of %[[PRED]]
// CHECK: %[[SHAPE_ON_TRUE:.*]] = shape.shape_of %[[ON_TRUE]]
// CHECK: %[[SHAPE_ON_FALSE:.*]] = shape.shape_of %[[ON_FALSE]]
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_PRED]], %[[SHAPE_ON_TRUE]], %[[SHAPE_ON_FALSE]]
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[FLAT_SHAPE]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK: %[[FLAT_ON_TRUE:.*]] = "mhlo.dynamic_reshape"(%[[ON_TRUE]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_ON_FALSE:.*]] = "mhlo.dynamic_reshape"(%[[ON_FALSE]], %[[FLAT_SHAPE]]) : (tensor<2xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = "mhlo.select"(%[[FLAT_PRED]], %[[FLAT_ON_TRUE]], %[[FLAT_ON_FALSE]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESULT]] : tensor<*xf32>
%b = "mhlo.select"(%pred, %on_true, %on_false) : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
return %b : tensor<*xf32>
}
// -----
// CHECK-LABEL: @add_unranked
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32>
func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {