[MLIR][HLO] Extend rank specialization clustering pass

Also cluster operations that operate on same shape operands. These implicitly
satisfy the broadcasting semantics requirement. Also, add test cases for some
cases that appear in the current MLIR-generated kernels.

PiperOrigin-RevId: 374191950
This commit is contained in:
A. Unique TensorFlower 2021-05-17 07:30:45 -07:00 committed by TensorFlow MLIR Team
parent b82bbf4dd1
commit c514c73390
2 changed files with 91 additions and 20 deletions

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
@ -52,35 +53,36 @@ namespace {
/// original shape afterwards.
/// - Broadcasting semantics: All operations must implement broadcasting
/// semantics. Most importantly, this allows extending operand shapes such
/// that they match in rank.
/// that they match in rank. Operations that require all their operands to
/// be of the same shape also fulfill this requirement.
/// - Shape reification: All operations must implement
/// `InferShapedTypeOpInterface`. This is later needed to compute and to
/// restore the desired result shape.
bool IsClusterable(Operation *op) {
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
unsigned int num_operands = op->getNumOperands();
if (num_operands == 0) return false;
if (num_operands == 1) return op->hasTrait<OpTrait::Elementwise>();
return op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
op->hasTrait<chlo::OpTrait::Broadcasting>();
if (op->getNumOperands() == 0) return false;
return (op->hasTrait<OpTrait::Elementwise>() &&
op->hasTrait<OpTrait::SameOperandsAndResultShape>()) ||
(op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
op->hasTrait<chlo::OpTrait::Broadcasting>());
}
struct RankSpecializationClusterPattern : public RewritePattern {
explicit RankSpecializationClusterPattern(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *root_op,
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Only apply to operations that have not been clustered yet.
if (root_op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
if (op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
return failure();
}
// Only cluster when rank specialization is needed.
if (!IsClusterable(root_op) ||
!llvm::any_of(root_op->getOperandTypes(),
[](Type ty) { return ty.isa<UnrankedTensorType>(); })) {
if (!IsClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) {
return ty.isa<UnrankedTensorType>();
})) {
return failure();
}
@ -88,20 +90,26 @@ struct RankSpecializationClusterPattern : public RewritePattern {
SmallVector<Operation *, 16> cluster;
llvm::SmallSet<Value, 16> operand_set;
llvm::SmallSet<Value, 16> result_set;
Operation *new_op = root_op;
while (new_op != nullptr && IsClusterable(new_op)) {
Operation *root_op = op;
while (root_op->getNextNode() != nullptr &&
IsClusterable(root_op->getNextNode()))
root_op = root_op->getNextNode();
Operation *it = root_op;
while (it != nullptr && IsClusterable(it)) {
// Find results that escape the cluster.
for (OpOperand &use : new_op->getUses()) {
for (OpOperand &use : it->getUses()) {
if (!llvm::is_contained(cluster, use.getOwner()))
result_set.insert(use.get());
}
// Update cluster operands.
for (OpResult v : new_op->getResults()) operand_set.erase(Value(v));
for (OpOperand &v : new_op->getOpOperands()) operand_set.insert(v.get());
for (OpResult v : it->getResults()) operand_set.erase(Value(v));
for (OpOperand &v : it->getOpOperands()) operand_set.insert(v.get());
cluster.push_back(new_op);
new_op = new_op->getPrevNode();
cluster.push_back(it);
it = it->getPrevNode();
}
// Create `RankSpecializationClusterOp`.
@ -109,7 +117,7 @@ struct RankSpecializationClusterPattern : public RewritePattern {
auto results = llvm::to_vector<16>(result_set);
auto result_types = llvm::to_vector<16>(
llvm::map_range(result_set, [](Value v) { return v.getType(); }));
Location loc = root_op->getLoc();
Location loc = op->getLoc();
auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>(
loc, result_types, operands);
@ -141,7 +149,7 @@ struct RankSpecializationClusterPattern : public RewritePattern {
}
auto replacements = llvm::to_vector<16>(llvm::map_range(
it->getResults(), [&](Value v) { return bvm.lookup(v); }));
rewriter.replaceOp(root_op, replacements);
rewriter.replaceOp(it, replacements);
}
return success();

View File

@ -126,3 +126,66 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>)
%5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32>
return %5 : tensor<*xf32>
}
// -----
// Constant cluster operand.
// CHECK-LABEL: @relu
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]], %[[C0]])
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>, %[[C0_:.*]]: tensor<f32>):
// CHECK: %[[TMP:.*]] = chlo.broadcast_maximum %[[ARG_]], %[[C0_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
// CHECK: return %[[RES]]
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = chlo.broadcast_maximum %0, %arg
: (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
// -----
// Cluster with binary non-broadcasting operation.
// CHECK-LABEL: @angle
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xcomplex<f32>>)
func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]])
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xcomplex<f32>>):
// CHECK: %[[IMAG:.*]] = "mhlo.imag"(%[[ARG_]])
// CHECK: %[[REAL:.*]] = "mhlo.real"(%[[ARG_]])
// CHECK: %[[TMP:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
// CHECK: return %[[RES]]
%0 = "mhlo.imag"(%arg) : (tensor<*xcomplex<f32>>) -> tensor<*xf32>
%1 = "mhlo.real"(%arg) : (tensor<*xcomplex<f32>>) -> tensor<*xf32>
%2 = mhlo.atan2 %0, %1 : tensor<*xf32>
return %2 : tensor<*xf32>
}
// -----
// CHECK-LABEL: @xlogy
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[C0]], %[[ARG0]], %[[ARG1]])
// CHECK: ^bb0(%[[C0_:.*]]: tensor<f32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
// CHECK: %[[TMP0:.*]] = chlo.broadcast_compare %[[ARG0_]], %[[C0_]] {comparison_direction = "EQ"}
// CHECK: %[[TMP1:.*]] = "mhlo.log"(%[[ARG1_]])
// CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[TMP1]]
// CHECK: %[[TMP3:.*]] = chlo.broadcast_select %[[TMP0]], %[[C0_]], %[[TMP2]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP3]])
// CHECK: return %[[RES]]
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = tensor.cast %0 : tensor<f32> to tensor<f32>
%2 = chlo.broadcast_compare %arg0, %1 {comparison_direction = "EQ"}
: (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
%3 = "mhlo.log"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
%4 = chlo.broadcast_multiply %arg0, %3
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%5 = chlo.broadcast_select %2, %1, %4
: (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
return %5 : tensor<*xf32>
}