Lower ReluGrad via chlo::BroadcastSelect.
This allows to get rid of the constraint that it needs to have a static shape. PiperOrigin-RevId: 371862452
This commit is contained in:
parent
4bf1904c86
commit
384b87fad0
|
@ -81,8 +81,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
|||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Don't apply conversion unless all operands are unranked.
|
||||
if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
|
||||
// Only apply conversion if at least one operand is unranked.
|
||||
if (!llvm::any_of(op.getOperation()->getOperands(), [&](Value operand) {
|
||||
return operand.getType().isa<UnrankedTensorType>();
|
||||
})) {
|
||||
return failure();
|
||||
|
@ -227,15 +227,21 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
|
|||
typename ChloOpTy::Adaptor transformed(operands);
|
||||
ValueRange transformed_operands = transformed.getOperands();
|
||||
auto num_operands = transformed_operands.size();
|
||||
llvm::SmallVector<UnrankedTensorType, 3> operand_types;
|
||||
operand_types.reserve(num_operands);
|
||||
llvm::SmallVector<Type, 3> operand_element_types;
|
||||
operand_element_types.reserve(num_operands);
|
||||
bool has_unranked_tensor_type = false;
|
||||
for (int i = 0; i < num_operands; ++i) {
|
||||
auto type =
|
||||
transformed_operands[i].getType().dyn_cast<UnrankedTensorType>();
|
||||
// Only support unranked operands.
|
||||
if (!type) return failure();
|
||||
operand_types.push_back(type);
|
||||
if (auto type =
|
||||
transformed_operands[i].getType().dyn_cast<TensorType>()) {
|
||||
if (type.isa<UnrankedTensorType>()) {
|
||||
has_unranked_tensor_type = true;
|
||||
}
|
||||
operand_element_types.push_back(type.getElementType());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
if (!has_unranked_tensor_type) return failure();
|
||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||
|
||||
llvm::SmallVector<Value> shapes;
|
||||
|
@ -278,7 +284,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
|
|||
if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc,
|
||||
RankedTensorType::get({RankedTensorType::kDynamicSize},
|
||||
operand_types[i].getElementType()),
|
||||
operand_element_types[i]),
|
||||
transformed_operands[i], size_tensor);
|
||||
reshaped_operands.push_back(reshaped);
|
||||
}
|
||||
|
|
|
@ -64,6 +64,28 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// Transformed if there is a mix of unranked/static shapes.
|
||||
// CHECK-LABEL: @select_mixed
|
||||
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ON_TRUE:.*]]: tensor<*xf32>, %[[ON_FALSE:.*]]: tensor<2xf32>)
|
||||
func @select_mixed(%pred: tensor<*xi1>, %on_true: tensor<*xf32>, %on_false: tensor<2xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[SHAPE_PRED:.*]] = shape.shape_of %[[PRED]]
|
||||
// CHECK: %[[SHAPE_ON_TRUE:.*]] = shape.shape_of %[[ON_TRUE]]
|
||||
// CHECK: %[[SHAPE_ON_FALSE:.*]] = shape.shape_of %[[ON_FALSE]]
|
||||
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_PRED]], %[[SHAPE_ON_TRUE]], %[[SHAPE_ON_FALSE]]
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||
// CHECK: %[[FLAT_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[FLAT_SHAPE]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||
// CHECK: %[[FLAT_ON_TRUE:.*]] = "mhlo.dynamic_reshape"(%[[ON_TRUE]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_ON_FALSE:.*]] = "mhlo.dynamic_reshape"(%[[ON_FALSE]], %[[FLAT_SHAPE]]) : (tensor<2xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_RESULT:.*]] = "mhlo.select"(%[[FLAT_PRED]], %[[FLAT_ON_TRUE]], %[[FLAT_ON_FALSE]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
||||
%b = "mhlo.select"(%pred, %on_true, %on_false) : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
|
||||
return %b : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @add_unranked
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
||||
func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
||||
|
|
Loading…
Reference in New Issue