[MLIR][HLO] Find shape equivalences and use them for better rank specialization

Find shape equivalence classes among the operands and use them for better rank
specialization. If all operands are known to be of the same shape, we can
flatten them to rank one. If there are two shape equivalence classes, we can
generalize the scalar rank specialization cases.

PiperOrigin-RevId: 378844575
This commit is contained in:
A. Unique TensorFlower 2021-06-11 03:59:02 -07:00 committed by TensorFlow MLIR Team
parent 5cca8a14e3
commit bd5752f0bf
2 changed files with 169 additions and 86 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
@ -38,7 +39,8 @@ limitations under the License.
namespace mlir {
/// Needed to build `llvm::SmallSet`s of `mlir::Value`s.
/// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of
/// `mlir::Value`s.
static bool operator<(const Value &lhs, const Value &rhs) {
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
}
@ -308,18 +310,6 @@ Type DeriveUnrankedTensorTypes(Type ty) {
return ty;
}
Optional<Value> FindUniqueNonScalar(ValueRange values) {
Value unique_non_scalar;
for (Value v : values) {
if (!IsScalarTensorType(v.getType())) {
if (unique_non_scalar) return llvm::None;
unique_non_scalar = v;
}
}
if (!unique_non_scalar) return llvm::None;
return unique_non_scalar;
}
SmallVector<Value, 8> MaterializeRankedOperations(
OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
chlo::RankSpecializationClusterOp op) {
@ -375,20 +365,35 @@ SmallVector<Value, 8> MaterializeFinalReshape(
}));
}
Value MaterializeFlatShape(OpBuilder &b, Location loc, ValueRange same_shapes) {
assert(!same_shapes.empty() && "Expected at least one shape.");
Value shape = same_shapes.size() == 1
? same_shapes.front()
: b.create<shape::AnyOp>(loc, same_shapes.front().getType(),
same_shapes);
return b.create<tensor::FromElementsOp>(
loc,
b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape).result());
}
Value MaterializeScalarRankSpecializationCase(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes, int64_t non_scalar_idx,
const SmallVector<Value, 8> &shapes, ValueRange non_scalars_of_same_shape,
function_ref<void(OpBuilder &, Location)> else_builder_fn) {
// Materialize predicate: All operands except one are scalars.
// Materialize predicate: All operands are scalars, except the expected
// non-scalars.
Value one = b.create<ConstantIndexOp>(loc, 1);
Value all_others_are_scalar;
for (auto it : llvm::enumerate(shapes)) {
if (it.index() == non_scalar_idx) continue;
// For statically known scalars, there is no need to test.
if (IsScalarTensorType(op.getOperand(it.index()).getType())) continue;
for (auto it : llvm::zip(op.operands(), shapes)) {
Value operand, shape;
std::tie(operand, shape) = it;
if (llvm::is_contained(non_scalars_of_same_shape, operand) ||
IsScalarTensorType(operand.getType())) {
continue;
}
auto literal =
b.create<CmpIOp>(loc, CmpIPredicate::eq,
b.create<shape::NumElementsOp>(loc, it.value()), one);
b.create<shape::NumElementsOp>(loc, shape), one);
all_others_are_scalar =
all_others_are_scalar
? b.create<mlir::AndOp>(loc, all_others_are_scalar, literal)
@ -399,24 +404,31 @@ Value MaterializeScalarRankSpecializationCase(
auto if_op = b.create<scf::IfOp>(
loc, op->getResultTypes(), all_others_are_scalar,
[&](OpBuilder &b, Location loc) {
// Flatten the non-scalar operand.
Value flat_shape = b.create<tensor::FromElementsOp>(
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(),
shapes[non_scalar_idx])
.result());
Value non_scalar_operand = op.operands()[non_scalar_idx];
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
loc,
DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
non_scalar_operand, flat_shape);
// Compute flat non-scalar shape.
SmallVector<Value, 4> non_scalar_shapes;
for (auto it : llvm::zip(op.operands(), shapes)) {
Value operand, shape;
std::tie(operand, shape) = it;
if (llvm::is_contained(non_scalars_of_same_shape, operand))
non_scalar_shapes.push_back(shape);
}
Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
// Derive ranked operands.
auto ranked_operands =
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
if (v == non_scalar_operand) return flat_non_scalar_operand;
if (IsScalarTensorType(v.getType())) return v;
if (!llvm::is_contained(non_scalars_of_same_shape, v)) {
return b
.create<mhlo::ReshapeOp>(
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v)
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0),
v)
.getResult();
}
return b
.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
flat_shape)
.getResult();
}));
@ -464,14 +476,7 @@ Value MaterializeEqualShapesRankSpecializationCase(
loc, op->getResultTypes(), all_shapes_eq_or_scalar,
[&](OpBuilder &b, Location loc) {
// Flatten non-scalar operands.
Value shape = non_scalar_shapes.front();
for (Value s : llvm::drop_begin(non_scalar_shapes)) {
shape = b.create<shape::AnyOp>(loc, shape.getType(),
ValueRange{shape, s});
}
Value flat_shape = b.create<tensor::FromElementsOp>(
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape)
.result());
Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
auto flat_operands =
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
if (IsScalarTensorType(v.getType())) return v;
@ -631,18 +636,15 @@ Value MaterializeDefaultRankSpecializationCases(
});
}
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
SmallVector<Value, 8>
MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
Value non_scalar_operand) {
// Flatten the non-scalar operand.
Value flat_shape = b.create<tensor::FromElementsOp>(
loc, b.create<shape::NumElementsOp>(
loc, b.getIndexType(),
b.create<shape::ShapeOfOp>(loc, non_scalar_operand))
.result());
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
non_scalar_operand, flat_shape);
ValueRange non_scalars_of_same_shape) {
// Compute flat operand shape.
auto non_scalar_shapes = llvm::to_vector<4>(llvm::map_range(
non_scalars_of_same_shape,
[&](Value v) { return b.create<shape::ShapeOfOp>(loc, v).result(); }));
Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
// Materialize ranked variants for the element-wise operations.
BlockAndValueMapping bvm;
@ -650,8 +652,14 @@ SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
Value operand;
Value bb_arg;
std::tie(bb_arg, operand) = it;
bvm.map(bb_arg,
operand == non_scalar_operand ? flat_non_scalar_operand : operand);
if (!IsScalarTensorType(operand.getType())) {
assert(llvm::is_contained(non_scalars_of_same_shape, operand) &&
"Expected all non-scalars in the same shape equivalence class.");
operand = b.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand,
flat_shape);
}
bvm.map(bb_arg, operand);
}
SmallVector<Value, 8> unshaped_results =
MaterializeRankedOperations(b, loc, bvm, op);
@ -660,24 +668,24 @@ SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
return MaterializeFinalReshape(b, loc, op, unshaped_results);
}
Value MaterializeRankSpecializationForTwoNonScalarOperands(
Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
ValueRange non_scalar_operands, int64_t max_target_rank) {
assert(non_scalar_operands.size() == 2);
SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs,
int64_t max_target_rank) {
assert(non_scalar_eqs.size() == 2 &&
"Expect two non-scalar equivalence classes.");
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result();
}));
auto non_scalar_lhs = llvm::find(op.operands(), non_scalar_operands[0]);
auto non_scalar_rhs = llvm::find(op.operands(), non_scalar_operands[1]);
ValueRange lhs_non_scalar_eqs = non_scalar_eqs[0];
ValueRange rhs_non_scalar_eqs = non_scalar_eqs[1];
// Materialize all the different cases.
Value unshaped_result = MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, non_scalar_rhs.getIndex(),
[&](OpBuilder &b, Location loc) {
b, loc, op, shapes, rhs_non_scalar_eqs, [&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, non_scalar_lhs.getIndex(),
b, loc, op, shapes, lhs_non_scalar_eqs,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeDefaultRankSpecializationCases(
@ -705,6 +713,54 @@ Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc,
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
}
// This is a very limited form of shape inference. It is correct but incomplete.
SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
chlo::RankSpecializationClusterOp op) {
llvm::EquivalenceClasses<Value> eqs;
// Bridge the equivalences between operands and block arguments.
for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments()))
eqs.unionSets(std::get<0>(it), std::get<1>(it));
// Find equalities through `SameOperandsAndResultShape` trait.
auto union_sets = [&](ValueRange vs) {
if (vs.empty()) return;
Value repr = vs.front();
for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
};
for (Operation &nested_op : op.getBody()->without_terminator()) {
if (nested_op.hasTrait<OpTrait::SameOperandsAndResultShape>()) {
union_sets(nested_op.getOperands());
union_sets(nested_op.getResults());
if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())
eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0));
}
// TODO(frgossen): Replace this with a check for the appropriate trait when
// that is available.
if (auto select_op = llvm::dyn_cast<mhlo::SelectOp>(nested_op)) {
union_sets(
{select_op.on_true(), select_op.on_false(), select_op.getResult()});
}
}
// Convert to a list-like equivalence class representation.
SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs;
for (Value v : op.operands()) {
if (IsScalarTensorType(v.getType())) continue;
bool inserted = false;
for (auto &eq_class : non_scalar_eqs) {
if (eqs.isEquivalent(eq_class.front(), v)) {
eq_class.push_back(v);
inserted = true;
break;
}
}
if (!inserted) non_scalar_eqs.push_back(SmallVector<Value, 4>({v}));
}
return non_scalar_eqs;
}
struct LowerRankSpecializationClusterPattern
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
LowerRankSpecializationClusterPattern(MLIRContext *ctx,
@ -719,31 +775,27 @@ struct LowerRankSpecializationClusterPattern
// all operands.
if (op.getNumResults() != 1) return failure();
// If there is only a single non-scalar operand, we can flatten that operand
// completely.
// If there is only a single non-scalar shape equivalence class, we can
// flatten that operands completely.
SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs =
FindNonScalarShapeEquivalences(op);
Location loc = op.getLoc();
auto non_scalar_operands =
llvm::to_vector<2>(llvm::make_filter_range(op.operands(), [](Value v) {
return !IsScalarTensorType(v.getType());
}));
if (non_scalar_operands.size() == 1) {
rewriter.replaceOp(op,
MaterializeRankSpecializationForSingleNonScalarOperand(
rewriter, loc, op, non_scalar_operands.front()));
if (non_scalar_eqs.size() == 1) {
rewriter.replaceOp(
op,
MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
rewriter, loc, op, non_scalar_eqs.front()));
return success();
}
// If there are exactly two unranked operands and all others are known to be
// scalars, we can consider two extra cases: If either of the unranked
// operands turns out to be a scalar at runtime, we can, again, apply the
// trick for a single non-scalar operand.
if (non_scalar_operands.size() == 2 &&
llvm::all_of(non_scalar_operands, [](Value v) {
return v.getType().isa<UnrankedTensorType>();
})) {
// If there are exactly two non-scalar shape equivalence classes, we can
// consider two extra cases: If either of the operand classes turns out to
// be all-scalars at runtime, we can, again, flatten all operands.
if (non_scalar_eqs.size() == 2) {
rewriter.replaceOp(
op, MaterializeRankSpecializationForTwoNonScalarOperands(
rewriter, loc, op, non_scalar_operands, max_target_rank));
op,
MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
rewriter, loc, op, non_scalar_eqs, max_target_rank));
return success();
}

View File

@ -51,9 +51,8 @@ func @compare_const_like(%arg0 : tensor<*xf32>) -> tensor<*xi1> {
// CHECK-SCF-DAG: %[[EQ21:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = and %[[EQ20]], %[[EQ21]]
// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]]
// CHECK-SCF-DAG: %[[S20:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[S201:.*]] = shape.any %[[S20]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S201]]
// CHECK-SCF-DAG: %[[ANY_SHAPE:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[ANY_SHAPE]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
@ -584,3 +583,35 @@ func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
}) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
return %2 : tensor<*xf64>
}
// -----
// CHECK-LABEL: @all_equal_shapes_inferrable
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>)
func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
-> tensor<*xf64> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>)
// CHECK: %[[INNER_RES:.*]] = mhlo.add %[[ARG0_]], %[[ARG1_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]])
// CHECK: return %[[RES]]
%0 = "mhlo.add"(%arg0, %arg1)
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
return %0 : tensor<*xf64>
}
// CHECK-SCF-LABEL: @all_equal_shapes_inferrable
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>)
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S0]], %[[S1]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]]
// CHECK-SCF-DAG: %[[FLAT_S:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_S]])
// CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_S]])
// CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]]
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]])
// CHECK-SCF: return %[[RES]]