Lower ReluGrad via chlo::BroadcastSelect.
This allows to get rid of the constraint that it needs to have a static shape. PiperOrigin-RevId: 371862452
This commit is contained in:
parent
4bf1904c86
commit
384b87fad0
|
@ -81,8 +81,8 @@ struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(OpTy op,
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// Don't apply conversion unless all operands are unranked.
|
// Only apply conversion if at least one operand is unranked.
|
||||||
if (!llvm::all_of(op.getOperation()->getOperands(), [&](Value operand) {
|
if (!llvm::any_of(op.getOperation()->getOperands(), [&](Value operand) {
|
||||||
return operand.getType().isa<UnrankedTensorType>();
|
return operand.getType().isa<UnrankedTensorType>();
|
||||||
})) {
|
})) {
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -227,15 +227,21 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
|
||||||
typename ChloOpTy::Adaptor transformed(operands);
|
typename ChloOpTy::Adaptor transformed(operands);
|
||||||
ValueRange transformed_operands = transformed.getOperands();
|
ValueRange transformed_operands = transformed.getOperands();
|
||||||
auto num_operands = transformed_operands.size();
|
auto num_operands = transformed_operands.size();
|
||||||
llvm::SmallVector<UnrankedTensorType, 3> operand_types;
|
llvm::SmallVector<Type, 3> operand_element_types;
|
||||||
operand_types.reserve(num_operands);
|
operand_element_types.reserve(num_operands);
|
||||||
|
bool has_unranked_tensor_type = false;
|
||||||
for (int i = 0; i < num_operands; ++i) {
|
for (int i = 0; i < num_operands; ++i) {
|
||||||
auto type =
|
if (auto type =
|
||||||
transformed_operands[i].getType().dyn_cast<UnrankedTensorType>();
|
transformed_operands[i].getType().dyn_cast<TensorType>()) {
|
||||||
// Only support unranked operands.
|
if (type.isa<UnrankedTensorType>()) {
|
||||||
if (!type) return failure();
|
has_unranked_tensor_type = true;
|
||||||
operand_types.push_back(type);
|
|
||||||
}
|
}
|
||||||
|
operand_element_types.push_back(type.getElementType());
|
||||||
|
} else {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!has_unranked_tensor_type) return failure();
|
||||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||||
|
|
||||||
llvm::SmallVector<Value> shapes;
|
llvm::SmallVector<Value> shapes;
|
||||||
|
@ -278,7 +284,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
|
||||||
if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>(
|
if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>(
|
||||||
loc,
|
loc,
|
||||||
RankedTensorType::get({RankedTensorType::kDynamicSize},
|
RankedTensorType::get({RankedTensorType::kDynamicSize},
|
||||||
operand_types[i].getElementType()),
|
operand_element_types[i]),
|
||||||
transformed_operands[i], size_tensor);
|
transformed_operands[i], size_tensor);
|
||||||
reshaped_operands.push_back(reshaped);
|
reshaped_operands.push_back(reshaped);
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,6 +64,28 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Transformed if there is a mix of unranked/static shapes.
|
||||||
|
// CHECK-LABEL: @select_mixed
|
||||||
|
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ON_TRUE:.*]]: tensor<*xf32>, %[[ON_FALSE:.*]]: tensor<2xf32>)
|
||||||
|
func @select_mixed(%pred: tensor<*xi1>, %on_true: tensor<*xf32>, %on_false: tensor<2xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: %[[SHAPE_PRED:.*]] = shape.shape_of %[[PRED]]
|
||||||
|
// CHECK: %[[SHAPE_ON_TRUE:.*]] = shape.shape_of %[[ON_TRUE]]
|
||||||
|
// CHECK: %[[SHAPE_ON_FALSE:.*]] = shape.shape_of %[[ON_FALSE]]
|
||||||
|
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_PRED]], %[[SHAPE_ON_TRUE]], %[[SHAPE_ON_FALSE]]
|
||||||
|
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||||
|
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
||||||
|
// CHECK: %[[FLAT_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[FLAT_SHAPE]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
|
||||||
|
// CHECK: %[[FLAT_ON_TRUE:.*]] = "mhlo.dynamic_reshape"(%[[ON_TRUE]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[FLAT_ON_FALSE:.*]] = "mhlo.dynamic_reshape"(%[[ON_FALSE]], %[[FLAT_SHAPE]]) : (tensor<2xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[FLAT_RESULT:.*]] = "mhlo.select"(%[[FLAT_PRED]], %[[FLAT_ON_TRUE]], %[[FLAT_ON_FALSE]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
|
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
||||||
|
%b = "mhlo.select"(%pred, %on_true, %on_false) : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
|
||||||
|
return %b : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @add_unranked
|
// CHECK-LABEL: @add_unranked
|
||||||
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
||||||
func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
|
Loading…
Reference in New Issue