[MLIR][HLO] Add scalar cases for binary rank specialization
For rank specialization clusters that have only two operands, we can materialize two extra cases in which either of them is a scalar. This avoids redundant index computations in these cases. PiperOrigin-RevId: 375037390
This commit is contained in:
parent
a7884196f5
commit
3daf65578a
|
@ -256,31 +256,66 @@ SmallVector<Value, 8> MaterializeFinalReshape(
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
|
Value MaterializeScalarRankSpecializationCase(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
Value non_scalar_operand) {
|
const SmallVector<Value, 8> &shapes, int64_t non_scalar_idx,
|
||||||
|
function_ref<void(OpBuilder &, Location)> else_builder_fn) {
|
||||||
|
// Materialize predicate: All operands except one are scalars.
|
||||||
|
Value one = b.create<ConstantIndexOp>(loc, 1);
|
||||||
|
Value all_others_are_scalar;
|
||||||
|
for (auto it : llvm::enumerate(shapes)) {
|
||||||
|
if (it.index() == non_scalar_idx) continue;
|
||||||
|
auto literal =
|
||||||
|
b.create<CmpIOp>(loc, CmpIPredicate::eq,
|
||||||
|
b.create<shape::NumElementsOp>(loc, it.value()), one);
|
||||||
|
all_others_are_scalar =
|
||||||
|
all_others_are_scalar
|
||||||
|
? b.create<AndOp>(loc, all_others_are_scalar, literal).getResult()
|
||||||
|
: literal.result();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto if_op = b.create<scf::IfOp>(
|
||||||
|
loc, op->getResultTypes(), all_others_are_scalar,
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
// Flatten the non-scalar operand.
|
// Flatten the non-scalar operand.
|
||||||
Value flat_shape = b.create<tensor::FromElementsOp>(
|
Value flat_shape = b.create<tensor::FromElementsOp>(
|
||||||
loc, b.create<shape::NumElementsOp>(
|
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(),
|
||||||
loc, b.getIndexType(),
|
shapes[non_scalar_idx])
|
||||||
b.create<shape::ShapeOfOp>(loc, non_scalar_operand))
|
|
||||||
.result());
|
.result());
|
||||||
|
Value non_scalar_operand = op.operands()[non_scalar_idx];
|
||||||
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
|
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
|
||||||
loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
|
loc,
|
||||||
|
DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
|
||||||
non_scalar_operand, flat_shape);
|
non_scalar_operand, flat_shape);
|
||||||
|
|
||||||
|
// Derive ranked operands.
|
||||||
|
auto ranked_operands =
|
||||||
|
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
|
if (v == non_scalar_operand) return flat_non_scalar_operand;
|
||||||
|
return b
|
||||||
|
.create<tensor::CastOp>(
|
||||||
|
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v)
|
||||||
|
.dest();
|
||||||
|
}));
|
||||||
|
|
||||||
// Materialize ranked variants for the element-wise operations.
|
// Materialize ranked variants for the element-wise operations.
|
||||||
BlockAndValueMapping bvm;
|
BlockAndValueMapping bvm;
|
||||||
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
|
for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands))
|
||||||
Value operand = std::get<1>(it);
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||||
bvm.map(std::get<0>(it),
|
Value unshaped_result =
|
||||||
operand == non_scalar_operand ? flat_non_scalar_operand : operand);
|
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1)
|
||||||
}
|
.front();
|
||||||
SmallVector<Value, 8> unshaped_results =
|
|
||||||
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1);
|
|
||||||
|
|
||||||
// Restore the results' expected shape.
|
// Return as unranked tensor for compatibility with the other cases.
|
||||||
return MaterializeFinalReshape(b, loc, op, unshaped_results);
|
b.create<scf::YieldOp>(
|
||||||
|
loc, b.create<tensor::CastOp>(
|
||||||
|
loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
|
||||||
|
unshaped_result)
|
||||||
|
.dest());
|
||||||
|
},
|
||||||
|
else_builder_fn);
|
||||||
|
|
||||||
|
return if_op.results().front();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value MaterializeEqualShapesRankSpecializationCase(
|
Value MaterializeEqualShapesRankSpecializationCase(
|
||||||
|
@ -313,7 +348,7 @@ Value MaterializeEqualShapesRankSpecializationCase(
|
||||||
Value flat_shape = b.create<tensor::FromElementsOp>(
|
Value flat_shape = b.create<tensor::FromElementsOp>(
|
||||||
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape)
|
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape)
|
||||||
.result());
|
.result());
|
||||||
SmallVector<Value, 8> flat_operands =
|
auto flat_operands =
|
||||||
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
return b
|
return b
|
||||||
.create<mhlo::DynamicReshapeOp>(
|
.create<mhlo::DynamicReshapeOp>(
|
||||||
|
@ -355,7 +390,7 @@ Value MaterializeTargetRankSpecializationCase(
|
||||||
loc, extent_tensor_ty,
|
loc, extent_tensor_ty,
|
||||||
mlir::DenseIntElementsAttr::get(extent_tensor_ty,
|
mlir::DenseIntElementsAttr::get(extent_tensor_ty,
|
||||||
SmallVector<int64_t, 6>(target_rank, 1)));
|
SmallVector<int64_t, 6>(target_rank, 1)));
|
||||||
SmallVector<Value, 2> ranked_operands;
|
SmallVector<Value, 8> ranked_operands;
|
||||||
for (auto it : llvm::zip(op.operands(), shapes)) {
|
for (auto it : llvm::zip(op.operands(), shapes)) {
|
||||||
Value operand, shape;
|
Value operand, shape;
|
||||||
std::tie(operand, shape) = it;
|
std::tie(operand, shape) = it;
|
||||||
|
@ -464,6 +499,72 @@ Value MaterializeGenericRankSpecializationCases(
|
||||||
b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank);
|
b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value MaterializeDefaultRankSpecializationCases(
|
||||||
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
|
const SmallVector<Value, 8> &shapes) {
|
||||||
|
return MaterializeEqualShapesRankSpecializationCase(
|
||||||
|
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
|
||||||
|
b.create<scf::YieldOp>(
|
||||||
|
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Value MaterializeRankSpecializationForExactlyTwoOperands(
|
||||||
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
|
||||||
|
assert(op->getNumOperands() == 2 && op.getNumResults() == 1 &&
|
||||||
|
"The rank specialization strategy for clusters with exactly two "
|
||||||
|
"operands supports only one result.");
|
||||||
|
|
||||||
|
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
|
return b.create<shape::ShapeOfOp>(loc, v).result();
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Materialize all the different cases.
|
||||||
|
Value unshaped_result = MaterializeScalarRankSpecializationCase(
|
||||||
|
b, loc, op, shapes, /*non_scalar_idx=*/1,
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
|
b.create<scf::YieldOp>(
|
||||||
|
loc, MaterializeScalarRankSpecializationCase(
|
||||||
|
b, loc, op, shapes, /*non_scalar_idx=*/0,
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
|
b.create<scf::YieldOp>(
|
||||||
|
loc, MaterializeDefaultRankSpecializationCases(
|
||||||
|
b, loc, op, shapes));
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Materialize final reshape once and for all rank specialization cases.
|
||||||
|
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
|
||||||
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
|
Value non_scalar_operand) {
|
||||||
|
// Flatten the non-scalar operand.
|
||||||
|
Value flat_shape = b.create<tensor::FromElementsOp>(
|
||||||
|
loc, b.create<shape::NumElementsOp>(
|
||||||
|
loc, b.getIndexType(),
|
||||||
|
b.create<shape::ShapeOfOp>(loc, non_scalar_operand))
|
||||||
|
.result());
|
||||||
|
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
|
||||||
|
non_scalar_operand, flat_shape);
|
||||||
|
|
||||||
|
// Materialize ranked variants for the element-wise operations.
|
||||||
|
BlockAndValueMapping bvm;
|
||||||
|
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
|
||||||
|
Value operand = std::get<1>(it);
|
||||||
|
bvm.map(std::get<0>(it),
|
||||||
|
operand == non_scalar_operand ? flat_non_scalar_operand : operand);
|
||||||
|
}
|
||||||
|
SmallVector<Value, 8> unshaped_results =
|
||||||
|
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1);
|
||||||
|
|
||||||
|
// Restore the results' expected shape.
|
||||||
|
return MaterializeFinalReshape(b, loc, op, unshaped_results);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Materialize rank generic rank specialization.
|
||||||
Value MaterializeDefaultRankSpecialization(
|
Value MaterializeDefaultRankSpecialization(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
|
||||||
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
|
@ -471,11 +572,8 @@ Value MaterializeDefaultRankSpecialization(
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Materialize all the different cases.
|
// Materialize all the different cases.
|
||||||
Value unshaped_result = MaterializeEqualShapesRankSpecializationCase(
|
Value unshaped_result =
|
||||||
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
|
MaterializeDefaultRankSpecializationCases(b, loc, op, shapes);
|
||||||
b.create<scf::YieldOp>(
|
|
||||||
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Materialize final reshape once and for all rank specialization cases.
|
// Materialize final reshape once and for all rank specialization cases.
|
||||||
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
||||||
|
@ -487,6 +585,11 @@ struct LowerRankSpecializationClusterPattern
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Restoring the result shape currently relies on all operands being used
|
||||||
|
// for a single result. The result shape is then the broadcasted shape of
|
||||||
|
// all operands.
|
||||||
|
if (op.getNumResults() != 1) return failure();
|
||||||
|
|
||||||
// TODO(frgossen): If there is a single operand, we can flatten it
|
// TODO(frgossen): If there is a single operand, we can flatten it
|
||||||
// completely and apply a non-broadcasting operation.
|
// completely and apply a non-broadcasting operation.
|
||||||
|
|
||||||
|
@ -501,10 +604,13 @@ struct LowerRankSpecializationClusterPattern
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restoring the result shape currently relies on all operands being used
|
// If there are only two operands, we can consider extra cases in which
|
||||||
// for a single result. The result shape is then the broadcasted shape of
|
// either operand is scalar.
|
||||||
// all operands.
|
if (op->getNumOperands() == 2) {
|
||||||
if (op.getNumResults() != 1) return failure();
|
rewriter.replaceOp(op, MaterializeRankSpecializationForExactlyTwoOperands(
|
||||||
|
rewriter, loc, op));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// For all other cases, reshape the operands to match in rank, apply the
|
// For all other cases, reshape the operands to match in rank, apply the
|
||||||
// operation, and restore the expected shape.
|
// operation, and restore the expected shape.
|
||||||
|
|
|
@ -446,3 +446,198 @@ func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor<?x?x?x?x?x?x?x?xf32>)
|
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
// CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
// CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @mul
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
|
||||||
|
func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
|
||||||
|
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
|
||||||
|
// CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]]
|
||||||
|
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
%0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @mul
|
||||||
|
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1
|
||||||
|
// CHECK-SCF-DAG: %[[C2:.*]] = constant 2
|
||||||
|
// CHECK-SCF-DAG: %[[C3:.*]] = constant 3
|
||||||
|
// CHECK-SCF-DAG: %[[C4:.*]] = constant 4
|
||||||
|
// CHECK-SCF-DAG: %[[C5:.*]] = constant 5
|
||||||
|
// CHECK-SCF-DAG: %[[C6:.*]] = constant 6
|
||||||
|
// CHECK-SCF-DAG: %[[C7:.*]] = constant 7
|
||||||
|
// CHECK-SCF-DAG: %[[C8:.*]] = constant 8
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1]
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
|
||||||
|
// Lhs scalar case:
|
||||||
|
// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = cmpi eq, %[[LHS_N]], %[[C1]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_LHS_SCALAR:.*]] = scf.if %[[LHS_SCALAR]]
|
||||||
|
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
|
||||||
|
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[FLAT_NON_SCALAR]] : (tensor<f32>, tensor<?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Rhs scalar case:
|
||||||
|
// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = cmpi eq, %[[RHS_N]], %[[C1]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_RHS_SCALAR:.*]] = scf.if %[[RHS_SCALAR]]
|
||||||
|
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
|
||||||
|
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[SCALAR]] : (tensor<?xf32>, tensor<f32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Equal shapes case:
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]]
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Find maximum reduced rank.
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1
|
||||||
|
// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
|
||||||
|
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
|
||||||
|
// Generic case 1:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Generic case 2:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?xf32>, tensor<?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Generic case 3:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Generic case 4:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Generic case 5:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Generic case 6:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Generic case 7:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]]
|
||||||
|
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: else
|
||||||
|
// Generic case 8:
|
||||||
|
// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]]
|
||||||
|
// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||||
|
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||||
|
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_EQ_SHAPES]]
|
||||||
|
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_RHS_SCALAR]]
|
||||||
|
// Reshape the result.
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]])
|
||||||
|
// CHECK-SCF: return %[[RES]]
|
||||||
|
|
Loading…
Reference in New Issue