[MLIR][HLO] Find shape equivalences and use them for better rank specialization
Find shape equivalence classes among the operands and use them for better rank specialization. If all operands are known to be of the same shape, we can flatten them to rank one. If there are two shape equivalence classes, we can generalize the scalar rank specialization cases. PiperOrigin-RevId: 378844575
This commit is contained in:
parent
5cca8a14e3
commit
bd5752f0bf
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
|
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "llvm/ADT/EquivalenceClasses.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallSet.h"
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
@ -38,7 +39,8 @@ limitations under the License.
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
/// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
|
/// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of
|
||||||
|
/// `mlir::Value`s.
|
||||||
static bool operator<(const Value &lhs, const Value &rhs) {
|
static bool operator<(const Value &lhs, const Value &rhs) {
|
||||||
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
|
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
|
||||||
}
|
}
|
||||||
|
@ -308,18 +310,6 @@ Type DeriveUnrankedTensorTypes(Type ty) {
|
||||||
return ty;
|
return ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<Value> FindUniqueNonScalar(ValueRange values) {
|
|
||||||
Value unique_non_scalar;
|
|
||||||
for (Value v : values) {
|
|
||||||
if (!IsScalarTensorType(v.getType())) {
|
|
||||||
if (unique_non_scalar) return llvm::None;
|
|
||||||
unique_non_scalar = v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!unique_non_scalar) return llvm::None;
|
|
||||||
return unique_non_scalar;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 8> MaterializeRankedOperations(
|
SmallVector<Value, 8> MaterializeRankedOperations(
|
||||||
OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
|
OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
|
||||||
chlo::RankSpecializationClusterOp op) {
|
chlo::RankSpecializationClusterOp op) {
|
||||||
|
@ -375,20 +365,35 @@ SmallVector<Value, 8> MaterializeFinalReshape(
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value MaterializeFlatShape(OpBuilder &b, Location loc, ValueRange same_shapes) {
|
||||||
|
assert(!same_shapes.empty() && "Expected at least one shape.");
|
||||||
|
Value shape = same_shapes.size() == 1
|
||||||
|
? same_shapes.front()
|
||||||
|
: b.create<shape::AnyOp>(loc, same_shapes.front().getType(),
|
||||||
|
same_shapes);
|
||||||
|
return b.create<tensor::FromElementsOp>(
|
||||||
|
loc,
|
||||||
|
b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape).result());
|
||||||
|
}
|
||||||
|
|
||||||
Value MaterializeScalarRankSpecializationCase(
|
Value MaterializeScalarRankSpecializationCase(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
const SmallVector<Value, 8> &shapes, int64_t non_scalar_idx,
|
const SmallVector<Value, 8> &shapes, ValueRange non_scalars_of_same_shape,
|
||||||
function_ref<void(OpBuilder &, Location)> else_builder_fn) {
|
function_ref<void(OpBuilder &, Location)> else_builder_fn) {
|
||||||
// Materialize predicate: All operands except one are scalars.
|
// Materialize predicate: All operands are scalars, except the expected
|
||||||
|
// non-scalars.
|
||||||
Value one = b.create<ConstantIndexOp>(loc, 1);
|
Value one = b.create<ConstantIndexOp>(loc, 1);
|
||||||
Value all_others_are_scalar;
|
Value all_others_are_scalar;
|
||||||
for (auto it : llvm::enumerate(shapes)) {
|
for (auto it : llvm::zip(op.operands(), shapes)) {
|
||||||
if (it.index() == non_scalar_idx) continue;
|
Value operand, shape;
|
||||||
// For statically known scalars, there is no need to test.
|
std::tie(operand, shape) = it;
|
||||||
if (IsScalarTensorType(op.getOperand(it.index()).getType())) continue;
|
if (llvm::is_contained(non_scalars_of_same_shape, operand) ||
|
||||||
|
IsScalarTensorType(operand.getType())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto literal =
|
auto literal =
|
||||||
b.create<CmpIOp>(loc, CmpIPredicate::eq,
|
b.create<CmpIOp>(loc, CmpIPredicate::eq,
|
||||||
b.create<shape::NumElementsOp>(loc, it.value()), one);
|
b.create<shape::NumElementsOp>(loc, shape), one);
|
||||||
all_others_are_scalar =
|
all_others_are_scalar =
|
||||||
all_others_are_scalar
|
all_others_are_scalar
|
||||||
? b.create<mlir::AndOp>(loc, all_others_are_scalar, literal)
|
? b.create<mlir::AndOp>(loc, all_others_are_scalar, literal)
|
||||||
|
@ -399,24 +404,31 @@ Value MaterializeScalarRankSpecializationCase(
|
||||||
auto if_op = b.create<scf::IfOp>(
|
auto if_op = b.create<scf::IfOp>(
|
||||||
loc, op->getResultTypes(), all_others_are_scalar,
|
loc, op->getResultTypes(), all_others_are_scalar,
|
||||||
[&](OpBuilder &b, Location loc) {
|
[&](OpBuilder &b, Location loc) {
|
||||||
// Flatten the non-scalar operand.
|
// Compute flat non-scalar shape.
|
||||||
Value flat_shape = b.create<tensor::FromElementsOp>(
|
SmallVector<Value, 4> non_scalar_shapes;
|
||||||
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(),
|
for (auto it : llvm::zip(op.operands(), shapes)) {
|
||||||
shapes[non_scalar_idx])
|
Value operand, shape;
|
||||||
.result());
|
std::tie(operand, shape) = it;
|
||||||
Value non_scalar_operand = op.operands()[non_scalar_idx];
|
if (llvm::is_contained(non_scalars_of_same_shape, operand))
|
||||||
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
|
non_scalar_shapes.push_back(shape);
|
||||||
loc,
|
}
|
||||||
DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
|
Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
|
||||||
non_scalar_operand, flat_shape);
|
|
||||||
|
|
||||||
// Derive ranked operands.
|
// Derive ranked operands.
|
||||||
auto ranked_operands =
|
auto ranked_operands =
|
||||||
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
if (v == non_scalar_operand) return flat_non_scalar_operand;
|
if (IsScalarTensorType(v.getType())) return v;
|
||||||
|
if (!llvm::is_contained(non_scalars_of_same_shape, v)) {
|
||||||
return b
|
return b
|
||||||
.create<mhlo::ReshapeOp>(
|
.create<mhlo::ReshapeOp>(
|
||||||
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v)
|
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0),
|
||||||
|
v)
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
|
||||||
|
flat_shape)
|
||||||
.getResult();
|
.getResult();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
@ -464,14 +476,7 @@ Value MaterializeEqualShapesRankSpecializationCase(
|
||||||
loc, op->getResultTypes(), all_shapes_eq_or_scalar,
|
loc, op->getResultTypes(), all_shapes_eq_or_scalar,
|
||||||
[&](OpBuilder &b, Location loc) {
|
[&](OpBuilder &b, Location loc) {
|
||||||
// Flatten non-scalar operands.
|
// Flatten non-scalar operands.
|
||||||
Value shape = non_scalar_shapes.front();
|
Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
|
||||||
for (Value s : llvm::drop_begin(non_scalar_shapes)) {
|
|
||||||
shape = b.create<shape::AnyOp>(loc, shape.getType(),
|
|
||||||
ValueRange{shape, s});
|
|
||||||
}
|
|
||||||
Value flat_shape = b.create<tensor::FromElementsOp>(
|
|
||||||
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape)
|
|
||||||
.result());
|
|
||||||
auto flat_operands =
|
auto flat_operands =
|
||||||
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
if (IsScalarTensorType(v.getType())) return v;
|
if (IsScalarTensorType(v.getType())) return v;
|
||||||
|
@ -631,18 +636,15 @@ Value MaterializeDefaultRankSpecializationCases(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
|
SmallVector<Value, 8>
|
||||||
|
MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
Value non_scalar_operand) {
|
ValueRange non_scalars_of_same_shape) {
|
||||||
// Flatten the non-scalar operand.
|
// Compute flat operand shape.
|
||||||
Value flat_shape = b.create<tensor::FromElementsOp>(
|
auto non_scalar_shapes = llvm::to_vector<4>(llvm::map_range(
|
||||||
loc, b.create<shape::NumElementsOp>(
|
non_scalars_of_same_shape,
|
||||||
loc, b.getIndexType(),
|
[&](Value v) { return b.create<shape::ShapeOfOp>(loc, v).result(); }));
|
||||||
b.create<shape::ShapeOfOp>(loc, non_scalar_operand))
|
Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
|
||||||
.result());
|
|
||||||
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
|
|
||||||
loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
|
|
||||||
non_scalar_operand, flat_shape);
|
|
||||||
|
|
||||||
// Materialize ranked variants for the element-wise operations.
|
// Materialize ranked variants for the element-wise operations.
|
||||||
BlockAndValueMapping bvm;
|
BlockAndValueMapping bvm;
|
||||||
|
@ -650,8 +652,14 @@ SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
|
||||||
Value operand;
|
Value operand;
|
||||||
Value bb_arg;
|
Value bb_arg;
|
||||||
std::tie(bb_arg, operand) = it;
|
std::tie(bb_arg, operand) = it;
|
||||||
bvm.map(bb_arg,
|
if (!IsScalarTensorType(operand.getType())) {
|
||||||
operand == non_scalar_operand ? flat_non_scalar_operand : operand);
|
assert(llvm::is_contained(non_scalars_of_same_shape, operand) &&
|
||||||
|
"Expected all non-scalars in the same shape equivalence class.");
|
||||||
|
operand = b.create<mhlo::DynamicReshapeOp>(
|
||||||
|
loc, DeriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand,
|
||||||
|
flat_shape);
|
||||||
|
}
|
||||||
|
bvm.map(bb_arg, operand);
|
||||||
}
|
}
|
||||||
SmallVector<Value, 8> unshaped_results =
|
SmallVector<Value, 8> unshaped_results =
|
||||||
MaterializeRankedOperations(b, loc, bvm, op);
|
MaterializeRankedOperations(b, loc, bvm, op);
|
||||||
|
@ -660,24 +668,24 @@ SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
|
||||||
return MaterializeFinalReshape(b, loc, op, unshaped_results);
|
return MaterializeFinalReshape(b, loc, op, unshaped_results);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value MaterializeRankSpecializationForTwoNonScalarOperands(
|
Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
ValueRange non_scalar_operands, int64_t max_target_rank) {
|
SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs,
|
||||||
assert(non_scalar_operands.size() == 2);
|
int64_t max_target_rank) {
|
||||||
|
assert(non_scalar_eqs.size() == 2 &&
|
||||||
|
"Expect two non-scalar equivalence classes.");
|
||||||
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
return b.create<shape::ShapeOfOp>(loc, v).result();
|
return b.create<shape::ShapeOfOp>(loc, v).result();
|
||||||
}));
|
}));
|
||||||
auto non_scalar_lhs = llvm::find(op.operands(), non_scalar_operands[0]);
|
ValueRange lhs_non_scalar_eqs = non_scalar_eqs[0];
|
||||||
auto non_scalar_rhs = llvm::find(op.operands(), non_scalar_operands[1]);
|
ValueRange rhs_non_scalar_eqs = non_scalar_eqs[1];
|
||||||
|
|
||||||
// Materialize all the different cases.
|
// Materialize all the different cases.
|
||||||
Value unshaped_result = MaterializeScalarRankSpecializationCase(
|
Value unshaped_result = MaterializeScalarRankSpecializationCase(
|
||||||
b, loc, op, shapes, non_scalar_rhs.getIndex(),
|
b, loc, op, shapes, rhs_non_scalar_eqs, [&](OpBuilder &b, Location loc) {
|
||||||
[&](OpBuilder &b, Location loc) {
|
|
||||||
b.create<scf::YieldOp>(
|
b.create<scf::YieldOp>(
|
||||||
loc, MaterializeScalarRankSpecializationCase(
|
loc, MaterializeScalarRankSpecializationCase(
|
||||||
b, loc, op, shapes, non_scalar_lhs.getIndex(),
|
b, loc, op, shapes, lhs_non_scalar_eqs,
|
||||||
[&](OpBuilder &b, Location loc) {
|
[&](OpBuilder &b, Location loc) {
|
||||||
b.create<scf::YieldOp>(
|
b.create<scf::YieldOp>(
|
||||||
loc, MaterializeDefaultRankSpecializationCases(
|
loc, MaterializeDefaultRankSpecializationCases(
|
||||||
|
@ -705,6 +713,54 @@ Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc,
|
||||||
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is a very limited form of shape inference. It is correct but incomplete.
|
||||||
|
SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
|
||||||
|
chlo::RankSpecializationClusterOp op) {
|
||||||
|
llvm::EquivalenceClasses<Value> eqs;
|
||||||
|
|
||||||
|
// Bridge the equivalences between operands and block arguments.
|
||||||
|
for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments()))
|
||||||
|
eqs.unionSets(std::get<0>(it), std::get<1>(it));
|
||||||
|
|
||||||
|
// Find equalities through `SameOperandsAndResultShape` trait.
|
||||||
|
auto union_sets = [&](ValueRange vs) {
|
||||||
|
if (vs.empty()) return;
|
||||||
|
Value repr = vs.front();
|
||||||
|
for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
|
||||||
|
};
|
||||||
|
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
||||||
|
if (nested_op.hasTrait<OpTrait::SameOperandsAndResultShape>()) {
|
||||||
|
union_sets(nested_op.getOperands());
|
||||||
|
union_sets(nested_op.getResults());
|
||||||
|
if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())
|
||||||
|
eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0));
|
||||||
|
}
|
||||||
|
// TODO(frgossen): Replace this with a check for the appropriate trait when
|
||||||
|
// that is available.
|
||||||
|
if (auto select_op = llvm::dyn_cast<mhlo::SelectOp>(nested_op)) {
|
||||||
|
union_sets(
|
||||||
|
{select_op.on_true(), select_op.on_false(), select_op.getResult()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to a list-like equivalence class representation.
|
||||||
|
SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs;
|
||||||
|
for (Value v : op.operands()) {
|
||||||
|
if (IsScalarTensorType(v.getType())) continue;
|
||||||
|
bool inserted = false;
|
||||||
|
for (auto &eq_class : non_scalar_eqs) {
|
||||||
|
if (eqs.isEquivalent(eq_class.front(), v)) {
|
||||||
|
eq_class.push_back(v);
|
||||||
|
inserted = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!inserted) non_scalar_eqs.push_back(SmallVector<Value, 4>({v}));
|
||||||
|
}
|
||||||
|
|
||||||
|
return non_scalar_eqs;
|
||||||
|
}
|
||||||
|
|
||||||
struct LowerRankSpecializationClusterPattern
|
struct LowerRankSpecializationClusterPattern
|
||||||
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||||
LowerRankSpecializationClusterPattern(MLIRContext *ctx,
|
LowerRankSpecializationClusterPattern(MLIRContext *ctx,
|
||||||
|
@ -719,31 +775,27 @@ struct LowerRankSpecializationClusterPattern
|
||||||
// all operands.
|
// all operands.
|
||||||
if (op.getNumResults() != 1) return failure();
|
if (op.getNumResults() != 1) return failure();
|
||||||
|
|
||||||
// If there is only a single non-scalar operand, we can flatten that operand
|
// If there is only a single non-scalar shape equivalence class, we can
|
||||||
// completely.
|
// flatten that operands completely.
|
||||||
|
SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs =
|
||||||
|
FindNonScalarShapeEquivalences(op);
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto non_scalar_operands =
|
if (non_scalar_eqs.size() == 1) {
|
||||||
llvm::to_vector<2>(llvm::make_filter_range(op.operands(), [](Value v) {
|
rewriter.replaceOp(
|
||||||
return !IsScalarTensorType(v.getType());
|
op,
|
||||||
}));
|
MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
|
||||||
if (non_scalar_operands.size() == 1) {
|
rewriter, loc, op, non_scalar_eqs.front()));
|
||||||
rewriter.replaceOp(op,
|
|
||||||
MaterializeRankSpecializationForSingleNonScalarOperand(
|
|
||||||
rewriter, loc, op, non_scalar_operands.front()));
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are exactly two unranked operands and all others are known to be
|
// If there are exactly two non-scalar shape equivalence classes, we can
|
||||||
// scalars, we can consider two extra cases: If either of the unranked
|
// consider two extra cases: If either of the operand classes turns out to
|
||||||
// operands turns out to be a scalar at runtime, we can, again, apply the
|
// be all-scalars at runtime, we can, again, flatten all operands.
|
||||||
// trick for a single non-scalar operand.
|
if (non_scalar_eqs.size() == 2) {
|
||||||
if (non_scalar_operands.size() == 2 &&
|
|
||||||
llvm::all_of(non_scalar_operands, [](Value v) {
|
|
||||||
return v.getType().isa<UnrankedTensorType>();
|
|
||||||
})) {
|
|
||||||
rewriter.replaceOp(
|
rewriter.replaceOp(
|
||||||
op, MaterializeRankSpecializationForTwoNonScalarOperands(
|
op,
|
||||||
rewriter, loc, op, non_scalar_operands, max_target_rank));
|
MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
|
||||||
|
rewriter, loc, op, non_scalar_eqs, max_target_rank));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,9 +51,8 @@ func @compare_const_like(%arg0 : tensor<*xf32>) -> tensor<*xi1> {
|
||||||
// CHECK-SCF-DAG: %[[EQ21:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG1]]
|
// CHECK-SCF-DAG: %[[EQ21:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG1]]
|
||||||
// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = and %[[EQ20]], %[[EQ21]]
|
// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = and %[[EQ20]], %[[EQ21]]
|
||||||
// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]]
|
// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]]
|
||||||
// CHECK-SCF-DAG: %[[S20:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]]
|
// CHECK-SCF-DAG: %[[ANY_SHAPE:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||||
// CHECK-SCF-DAG: %[[S201:.*]] = shape.any %[[S20]], %[[SHAPE_ARG1]]
|
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[ANY_SHAPE]]
|
||||||
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S201]]
|
|
||||||
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||||
// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
|
// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
|
||||||
// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
|
// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
|
||||||
|
@ -584,3 +583,35 @@ func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
|
||||||
}) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
|
}) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
|
||||||
return %2 : tensor<*xf64>
|
return %2 : tensor<*xf64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @all_equal_shapes_inferrable
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>)
|
||||||
|
func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
|
||||||
|
-> tensor<*xf64> {
|
||||||
|
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
|
||||||
|
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>)
|
||||||
|
// CHECK: %[[INNER_RES:.*]] = mhlo.add %[[ARG0_]], %[[ARG1_]]
|
||||||
|
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]])
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
%0 = "mhlo.add"(%arg0, %arg1)
|
||||||
|
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||||
|
return %0 : tensor<*xf64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-SCF-LABEL: @all_equal_shapes_inferrable
|
||||||
|
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>)
|
||||||
|
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S0]], %[[S1]]
|
||||||
|
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]]
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT_S:.*]] = tensor.from_elements %[[N]]
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_S]])
|
||||||
|
// CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_S]])
|
||||||
|
// CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]]
|
||||||
|
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
|
||||||
|
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
|
||||||
|
// CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9
|
||||||
|
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]])
|
||||||
|
// CHECK-SCF: return %[[RES]]
|
||||||
|
|
Loading…
Reference in New Issue