[MLIR][HLO] Add scalar cases for binary rank specialization
For rank specialization clusters that have only two operands, we can materialize two extra cases in which either of them is a scalar. This avoids redundant index computations in these cases. PiperOrigin-RevId: 375037390
This commit is contained in:
parent
a7884196f5
commit
3daf65578a
|
@ -256,31 +256,66 @@ SmallVector<Value, 8> MaterializeFinalReshape(
|
|||
}));
|
||||
}
|
||||
|
||||
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
|
||||
Value MaterializeScalarRankSpecializationCase(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||
Value non_scalar_operand) {
|
||||
// Flatten the non-scalar operand.
|
||||
Value flat_shape = b.create<tensor::FromElementsOp>(
|
||||
loc, b.create<shape::NumElementsOp>(
|
||||
loc, b.getIndexType(),
|
||||
b.create<shape::ShapeOfOp>(loc, non_scalar_operand))
|
||||
.result());
|
||||
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
|
||||
loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
|
||||
non_scalar_operand, flat_shape);
|
||||
|
||||
// Materialize ranked variants for the element-wise operations.
|
||||
BlockAndValueMapping bvm;
|
||||
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
|
||||
Value operand = std::get<1>(it);
|
||||
bvm.map(std::get<0>(it),
|
||||
operand == non_scalar_operand ? flat_non_scalar_operand : operand);
|
||||
const SmallVector<Value, 8> &shapes, int64_t non_scalar_idx,
|
||||
function_ref<void(OpBuilder &, Location)> else_builder_fn) {
|
||||
// Materialize predicate: All operands except one are scalars.
|
||||
Value one = b.create<ConstantIndexOp>(loc, 1);
|
||||
Value all_others_are_scalar;
|
||||
for (auto it : llvm::enumerate(shapes)) {
|
||||
if (it.index() == non_scalar_idx) continue;
|
||||
auto literal =
|
||||
b.create<CmpIOp>(loc, CmpIPredicate::eq,
|
||||
b.create<shape::NumElementsOp>(loc, it.value()), one);
|
||||
all_others_are_scalar =
|
||||
all_others_are_scalar
|
||||
? b.create<AndOp>(loc, all_others_are_scalar, literal).getResult()
|
||||
: literal.result();
|
||||
}
|
||||
SmallVector<Value, 8> unshaped_results =
|
||||
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1);
|
||||
|
||||
// Restore the results' expected shape.
|
||||
return MaterializeFinalReshape(b, loc, op, unshaped_results);
|
||||
auto if_op = b.create<scf::IfOp>(
|
||||
loc, op->getResultTypes(), all_others_are_scalar,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
// Flatten the non-scalar operand.
|
||||
Value flat_shape = b.create<tensor::FromElementsOp>(
|
||||
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(),
|
||||
shapes[non_scalar_idx])
|
||||
.result());
|
||||
Value non_scalar_operand = op.operands()[non_scalar_idx];
|
||||
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
|
||||
loc,
|
||||
DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
|
||||
non_scalar_operand, flat_shape);
|
||||
|
||||
// Derive ranked operands.
|
||||
auto ranked_operands =
|
||||
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||
if (v == non_scalar_operand) return flat_non_scalar_operand;
|
||||
return b
|
||||
.create<tensor::CastOp>(
|
||||
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v)
|
||||
.dest();
|
||||
}));
|
||||
|
||||
// Materialize ranked variants for the element-wise operations.
|
||||
BlockAndValueMapping bvm;
|
||||
for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands))
|
||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||
Value unshaped_result =
|
||||
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1)
|
||||
.front();
|
||||
|
||||
// Return as unranked tensor for compatibility with the other cases.
|
||||
b.create<scf::YieldOp>(
|
||||
loc, b.create<tensor::CastOp>(
|
||||
loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
|
||||
unshaped_result)
|
||||
.dest());
|
||||
},
|
||||
else_builder_fn);
|
||||
|
||||
return if_op.results().front();
|
||||
}
|
||||
|
||||
Value MaterializeEqualShapesRankSpecializationCase(
|
||||
|
@ -313,7 +348,7 @@ Value MaterializeEqualShapesRankSpecializationCase(
|
|||
Value flat_shape = b.create<tensor::FromElementsOp>(
|
||||
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape)
|
||||
.result());
|
||||
SmallVector<Value, 8> flat_operands =
|
||||
auto flat_operands =
|
||||
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||
return b
|
||||
.create<mhlo::DynamicReshapeOp>(
|
||||
|
@ -355,7 +390,7 @@ Value MaterializeTargetRankSpecializationCase(
|
|||
loc, extent_tensor_ty,
|
||||
mlir::DenseIntElementsAttr::get(extent_tensor_ty,
|
||||
SmallVector<int64_t, 6>(target_rank, 1)));
|
||||
SmallVector<Value, 2> ranked_operands;
|
||||
SmallVector<Value, 8> ranked_operands;
|
||||
for (auto it : llvm::zip(op.operands(), shapes)) {
|
||||
Value operand, shape;
|
||||
std::tie(operand, shape) = it;
|
||||
|
@ -464,6 +499,72 @@ Value MaterializeGenericRankSpecializationCases(
|
|||
b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank);
|
||||
}
|
||||
|
||||
Value MaterializeDefaultRankSpecializationCases(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||
const SmallVector<Value, 8> &shapes) {
|
||||
return MaterializeEqualShapesRankSpecializationCase(
|
||||
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(
|
||||
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes));
|
||||
});
|
||||
}
|
||||
|
||||
Value MaterializeRankSpecializationForExactlyTwoOperands(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
|
||||
assert(op->getNumOperands() == 2 && op.getNumResults() == 1 &&
|
||||
"The rank specialization strategy for clusters with exactly two "
|
||||
"operands supports only one result.");
|
||||
|
||||
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||
return b.create<shape::ShapeOfOp>(loc, v).result();
|
||||
}));
|
||||
|
||||
// Materialize all the different cases.
|
||||
Value unshaped_result = MaterializeScalarRankSpecializationCase(
|
||||
b, loc, op, shapes, /*non_scalar_idx=*/1,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(
|
||||
loc, MaterializeScalarRankSpecializationCase(
|
||||
b, loc, op, shapes, /*non_scalar_idx=*/0,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(
|
||||
loc, MaterializeDefaultRankSpecializationCases(
|
||||
b, loc, op, shapes));
|
||||
}));
|
||||
});
|
||||
|
||||
// Materialize final reshape once and for all rank specialization cases.
|
||||
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
||||
}
|
||||
|
||||
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||
Value non_scalar_operand) {
|
||||
// Flatten the non-scalar operand.
|
||||
Value flat_shape = b.create<tensor::FromElementsOp>(
|
||||
loc, b.create<shape::NumElementsOp>(
|
||||
loc, b.getIndexType(),
|
||||
b.create<shape::ShapeOfOp>(loc, non_scalar_operand))
|
||||
.result());
|
||||
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
|
||||
loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
|
||||
non_scalar_operand, flat_shape);
|
||||
|
||||
// Materialize ranked variants for the element-wise operations.
|
||||
BlockAndValueMapping bvm;
|
||||
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
|
||||
Value operand = std::get<1>(it);
|
||||
bvm.map(std::get<0>(it),
|
||||
operand == non_scalar_operand ? flat_non_scalar_operand : operand);
|
||||
}
|
||||
SmallVector<Value, 8> unshaped_results =
|
||||
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1);
|
||||
|
||||
// Restore the results' expected shape.
|
||||
return MaterializeFinalReshape(b, loc, op, unshaped_results);
|
||||
}
|
||||
|
||||
// Materialize rank generic rank specialization.
|
||||
Value MaterializeDefaultRankSpecialization(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
|
||||
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||
|
@ -471,11 +572,8 @@ Value MaterializeDefaultRankSpecialization(
|
|||
}));
|
||||
|
||||
// Materialize all the different cases.
|
||||
Value unshaped_result = MaterializeEqualShapesRankSpecializationCase(
|
||||
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(
|
||||
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes));
|
||||
});
|
||||
Value unshaped_result =
|
||||
MaterializeDefaultRankSpecializationCases(b, loc, op, shapes);
|
||||
|
||||
// Materialize final reshape once and for all rank specialization cases.
|
||||
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
||||
|
@ -487,6 +585,11 @@ struct LowerRankSpecializationClusterPattern
|
|||
|
||||
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Restoring the result shape currently relies on all operands being used
|
||||
// for a single result. The result shape is then the broadcasted shape of
|
||||
// all operands.
|
||||
if (op.getNumResults() != 1) return failure();
|
||||
|
||||
// TODO(frgossen): If there is a single operand, we can flatten it
|
||||
// completely and apply a non-broadcasting operation.
|
||||
|
||||
|
@ -501,10 +604,13 @@ struct LowerRankSpecializationClusterPattern
|
|||
return success();
|
||||
}
|
||||
|
||||
// Restoring the result shape currently relies on all operands being used
|
||||
// for a single result. The result shape is then the broadcasted shape of
|
||||
// all operands.
|
||||
if (op.getNumResults() != 1) return failure();
|
||||
// If there are only two operands, we can consider extra cases in which
|
||||
// either operand is scalar.
|
||||
if (op->getNumOperands() == 2) {
|
||||
rewriter.replaceOp(op, MaterializeRankSpecializationForExactlyTwoOperands(
|
||||
rewriter, loc, op));
|
||||
return success();
|
||||
}
|
||||
|
||||
// For all other cases, reshape the operands to match in rank, apply the
|
||||
// operation, and restore the expected shape.
|
||||
|
|
|
@ -346,12 +346,12 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>)
|
|||
}
|
||||
|
||||
// CHECK-SCF-LABEL: @mixed
|
||||
// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor<?xf32>
|
||||
// CHECK-SCF-DAG: %[[TMP1:.*]] = "mhlo.sqrt"(%{{.*}}) : (tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor<?xf32>, tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor<?xf32>, tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: chlo.tan %[[TMP4]] : tensor<?xf32>
|
||||
// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor<?xf32>
|
||||
// CHECK-SCF-DAG: %[[TMP1:.*]] = "mhlo.sqrt"(%{{.*}}) : (tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor<?xf32>, tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor<?xf32>, tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]]) : (tensor<?xf32>)
|
||||
// CHECK-SCF: chlo.tan %[[TMP4]] : tensor<?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -446,3 +446,198 @@ func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor<?x?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
|
||||
func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
|
||||
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
|
||||
// CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-SCF-LABEL: @mul
|
||||
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
|
||||
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1
|
||||
// CHECK-SCF-DAG: %[[C2:.*]] = constant 2
|
||||
// CHECK-SCF-DAG: %[[C3:.*]] = constant 3
|
||||
// CHECK-SCF-DAG: %[[C4:.*]] = constant 4
|
||||
// CHECK-SCF-DAG: %[[C5:.*]] = constant 5
|
||||
// CHECK-SCF-DAG: %[[C6:.*]] = constant 6
|
||||
// CHECK-SCF-DAG: %[[C7:.*]] = constant 7
|
||||
// CHECK-SCF-DAG: %[[C8:.*]] = constant 8
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
|
||||
// Lhs scalar case:
|
||||
// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = cmpi eq, %[[LHS_N]], %[[C1]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_LHS_SCALAR:.*]] = scf.if %[[LHS_SCALAR]]
|
||||
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
|
||||
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[FLAT_NON_SCALAR]] : (tensor<f32>, tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Rhs scalar case:
|
||||
// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = cmpi eq, %[[RHS_N]], %[[C1]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_RHS_SCALAR:.*]] = scf.if %[[RHS_SCALAR]]
|
||||
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
|
||||
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG1]]
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[SCALAR]] : (tensor<?xf32>, tensor<f32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Equal shapes case:
|
||||
// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]]
|
||||
// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
|
||||
// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Find maximum reduced rank.
|
||||
// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0
|
||||
// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1
|
||||
// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
|
||||
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
|
||||
// Generic case 1:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 2:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 3:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 4:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 5:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 6:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 7:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]]
|
||||
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: else
|
||||
// Generic case 8:
|
||||
// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]]
|
||||
// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
|
||||
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
|
||||
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
|
||||
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
|
||||
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
|
||||
// CHECK-SCF: scf.yield %[[INNER_RES_]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_EQ_SHAPES]]
|
||||
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_RHS_SCALAR]]
|
||||
// Reshape the result.
|
||||
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]])
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
|
Loading…
Reference in New Issue