Support different input/output type for TransformUnrankedHlo.
Also generate the tf.Equal kernel, now that it works. PiperOrigin-RevId: 344402014
This commit is contained in:
parent
1b98bf5fab
commit
6a71a84302
|
@ -164,7 +164,10 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
// the more generic case of both inputs being unranked.
|
// the more generic case of both inputs being unranked.
|
||||||
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
|
||||||
|
|
||||||
|
auto scalar_element_type = lhs_is_scalar ? lhs_ranked_type.getElementType()
|
||||||
|
: rhs_ranked_type.getElementType();
|
||||||
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
|
||||||
|
auto result_element_type = result_type.getElementType();
|
||||||
|
|
||||||
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
|
||||||
Value shape =
|
Value shape =
|
||||||
|
@ -173,15 +176,15 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
||||||
Value size_tensor =
|
Value size_tensor =
|
||||||
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
rewriter.create<TensorFromElementsOp>(loc, num_elements);
|
||||||
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
loc, RankedTensorType::get({-1}, result_type.getElementType()),
|
loc, RankedTensorType::get({-1}, scalar_element_type),
|
||||||
lhs_is_scalar ? rhs : lhs, size_tensor);
|
lhs_is_scalar ? rhs : lhs, size_tensor);
|
||||||
|
|
||||||
// Create a new ranked Chlo op that will be further lowered by other
|
// Create a new ranked Chlo op that will be further lowered by other
|
||||||
// patterns into Mhlo.
|
// patterns into Mhlo.
|
||||||
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
|
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
|
||||||
rhs_is_scalar ? rhs : reshaped};
|
rhs_is_scalar ? rhs : reshaped};
|
||||||
Value computed =
|
Value computed = rewriter.create<ChloOpTy>(
|
||||||
rewriter.create<ChloOpTy>(loc, SmallVector<Type, 1>{reshaped.getType()},
|
loc, TypeRange{RankedTensorType::get({-1}, result_element_type)},
|
||||||
new_operands, op.getAttrs());
|
new_operands, op.getAttrs());
|
||||||
|
|
||||||
// Reshape the result back into an unranked tensor.
|
// Reshape the result back into an unranked tensor.
|
||||||
|
@ -287,8 +290,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Returns the dyanamic result of checking the given value is a scalar
|
// Returns the dynamic result of checking the given value is a scalar tensor.
|
||||||
// tensor.
|
|
||||||
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
@ -300,30 +302,38 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
rewriter.create<ConstantIndexOp>(loc, 0));
|
rewriter.create<ConstantIndexOp>(loc, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the if statement and code for a broadcasting op with a result of a
|
Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
|
||||||
// given rank.
|
|
||||||
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
|
|
||||||
Value lhs, Value rhs,
|
|
||||||
Value actual_rank,
|
|
||||||
int targeted_rank) const {
|
int targeted_rank) const {
|
||||||
auto loc = op.getLoc();
|
return builder.create<CmpIOp>(
|
||||||
|
|
||||||
// Create the if block to place the current specialized logic in.
|
|
||||||
Value greater_rank_is_n = builder.create<CmpIOp>(
|
|
||||||
loc, CmpIPredicate::eq, actual_rank,
|
loc, CmpIPredicate::eq, actual_rank,
|
||||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||||
auto if_op =
|
}
|
||||||
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
|
|
||||||
OpBuilder if_builder = if_op.getThenBodyBuilder(builder.getListener());
|
scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
|
||||||
|
OpBuilder &builder, ChloOpTy op, Value actual_rank,
|
||||||
|
int targeted_rank) const {
|
||||||
|
// Create the if block to place the current specialized logic in.
|
||||||
|
Value greater_rank_is_n =
|
||||||
|
GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
|
||||||
|
return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
|
||||||
|
greater_rank_is_n, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the if statement and code for a broadcasting op with a result of a
|
||||||
|
// given rank.
|
||||||
|
void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
|
||||||
|
Value lhs, Value rhs,
|
||||||
|
int targeted_rank) const {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
// Handle shape broadcasting and inferrence.
|
// Handle shape broadcasting and inferrence.
|
||||||
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||||
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
{RankedTensorType::kDynamicSize}, if_builder.getIndexType());
|
||||||
auto known_rank_extent_tensor_type =
|
auto known_rank_extent_tensor_type =
|
||||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
RankedTensorType::get({targeted_rank}, if_builder.getIndexType());
|
||||||
auto reshaped_type = RankedTensorType::get(
|
auto reshaped_type = RankedTensorType::get(
|
||||||
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||||
RankedTensorType::kDynamicSize),
|
RankedTensorType::kDynamicSize),
|
||||||
|
@ -351,23 +361,26 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
loc, reshaped_type, lhs, extended_lhs_casted);
|
loc, reshaped_type, lhs, extended_lhs_casted);
|
||||||
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||||
loc, reshaped_type, rhs, extended_rhs_casted);
|
loc, reshaped_type, rhs, extended_rhs_casted);
|
||||||
|
auto result_element_type = op.getResult()
|
||||||
|
.getType()
|
||||||
|
.template dyn_cast<TensorType>()
|
||||||
|
.getElementType();
|
||||||
|
auto result_type = RankedTensorType::get(
|
||||||
|
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||||
|
RankedTensorType::kDynamicSize),
|
||||||
|
result_element_type);
|
||||||
Value result = if_builder.create<ChloOpTy>(
|
Value result = if_builder.create<ChloOpTy>(
|
||||||
loc, ArrayRef<Type>{reshaped_type},
|
loc, ArrayRef<Type>{result_type},
|
||||||
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
||||||
Value reshaped_result = if_builder.create<TensorCastOp>(
|
Value reshaped_result = if_builder.create<TensorCastOp>(
|
||||||
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
|
loc, UnrankedTensorType::get(result_element_type), result);
|
||||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||||
|
|
||||||
// Return the if_op, so the result can be used and the else block can be
|
|
||||||
// used for the next rank specialized step.
|
|
||||||
return if_op;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterates over the desired ranks to be specialized and generates the code
|
// Iterates over the desired ranks to be specialized and generates the code
|
||||||
// snippet for each case.
|
// snippet for each case.
|
||||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
|
||||||
Value rhs) const {
|
Value rhs) const {
|
||||||
constexpr int max_rank_specialization = 7;
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
// Find the larger rank of the 2 operands.
|
// Find the larger rank of the 2 operands.
|
||||||
|
@ -388,26 +401,34 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
|
|
||||||
// Generate a list of nested if/else statements to handle rank
|
// Generate a list of nested if/else statements to handle rank
|
||||||
// specializations from 1-6.
|
// specializations from 1-6.
|
||||||
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
|
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
||||||
rhs, greater_rank, 1);
|
rewriter, op, greater_rank, 1);
|
||||||
|
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
||||||
|
createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, 1);
|
||||||
|
|
||||||
// Put each subsequent rank specialization inside the else statement of the
|
// Put each subsequent rank specialization inside the else statement of the
|
||||||
// previous one.
|
// previous one.
|
||||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
for (int i = 2; i < max_rank_specialization; i++) {
|
constexpr int kMaxRankSpecialization = 6;
|
||||||
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
|
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
||||||
rhs, greater_rank, i);
|
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||||
|
else_builder, op, greater_rank, i);
|
||||||
|
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
||||||
|
createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, i);
|
||||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||||
}
|
}
|
||||||
|
// Fire an assertion if none of the rank specializations applied (one of
|
||||||
// Fire an assertion if none of the rank specializations applied (one of the
|
// the ranks was greater than 6).
|
||||||
// ranks was greater than 6).
|
|
||||||
else_builder.create<AssertOp>(
|
else_builder.create<AssertOp>(
|
||||||
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
|
loc,
|
||||||
"Input for dynamic binary op lowering was of a rank greater than 6");
|
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
|
||||||
else_builder.create<scf::YieldOp>(loc, lhs);
|
kMaxRankSpecialization),
|
||||||
|
"Input for dynamic binary op lowering was of a rank greater than "
|
||||||
|
"6");
|
||||||
|
// Add the rank 6 specialization to the innermost else block.
|
||||||
|
createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs,
|
||||||
|
kMaxRankSpecialization);
|
||||||
|
|
||||||
// Return the result of the outermost if statement.
|
// Return the result of the outermost if statement.
|
||||||
return if_op.getResult(0);
|
return if_op.getResult(0);
|
||||||
|
|
|
@ -276,8 +276,8 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: } else {
|
// CHECK-NEXT: } else {
|
||||||
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
|
// CHECK-NEXT: %[[C6:.*]] = constant 6 : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
|
||||||
|
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]]
|
||||||
// Handle rank 6 specialization
|
// Handle rank 6 specialization
|
||||||
// CHECK-NEXT: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
|
||||||
|
@ -288,12 +288,6 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %false = constant false
|
|
||||||
// CHECK-NEXT: assert %false
|
|
||||||
// CHECK-NEXT: scf.yield %[[LHS]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
Loading…
Reference in New Issue