[MLIR][HLO] Extend rank specialization clustering pass

Also cluster operations that operate on same shape operands. These implicitly
satisfy the broadcasting semantics requirement. Also, add test cases for some
cases that appear in the current MLIR-generated kernels.

PiperOrigin-RevId: 374191950
This commit is contained in:
A. Unique TensorFlower 2021-05-17 07:30:45 -07:00 committed by TensorFlow MLIR Team
parent b82bbf4dd1
commit c514c73390
2 changed files with 91 additions and 20 deletions

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
@ -52,35 +53,36 @@ namespace {
/// original shape afterwards. /// original shape afterwards.
/// - Broadcasting semantics: All operations must implement broadcasting /// - Broadcasting semantics: All operations must implement broadcasting
/// semantics. Most importantly, this allows extending operand shapes such /// semantics. Most importantly, this allows extending operand shapes such
/// that they match in rank. /// that they match in rank. Operations that require all their operands to
/// be of the same shape also fulfill this requirement.
/// - Shape reification: All operations must implement /// - Shape reification: All operations must implement
/// `InferShapedTypeOpInterface`. This is later needed to compute and to /// `InferShapedTypeOpInterface`. This is later needed to compute and to
/// restore the desired result shape. /// restore the desired result shape.
bool IsClusterable(Operation *op) { bool IsClusterable(Operation *op) {
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false; if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
unsigned int num_operands = op->getNumOperands(); if (op->getNumOperands() == 0) return false;
if (num_operands == 0) return false; return (op->hasTrait<OpTrait::Elementwise>() &&
if (num_operands == 1) return op->hasTrait<OpTrait::Elementwise>(); op->hasTrait<OpTrait::SameOperandsAndResultShape>()) ||
return op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() && (op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
op->hasTrait<chlo::OpTrait::Broadcasting>(); op->hasTrait<chlo::OpTrait::Broadcasting>());
} }
struct RankSpecializationClusterPattern : public RewritePattern { struct RankSpecializationClusterPattern : public RewritePattern {
explicit RankSpecializationClusterPattern(MLIRContext *ctx) explicit RankSpecializationClusterPattern(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult matchAndRewrite(Operation *root_op, LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// Only apply to operations that have not been clustered yet. // Only apply to operations that have not been clustered yet.
if (root_op->getParentOfType<chlo::RankSpecializationClusterOp>()) { if (op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
return failure(); return failure();
} }
// Only cluster when rank specialization is needed. // Only cluster when rank specialization is needed.
if (!IsClusterable(root_op) || if (!IsClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) {
!llvm::any_of(root_op->getOperandTypes(), return ty.isa<UnrankedTensorType>();
[](Type ty) { return ty.isa<UnrankedTensorType>(); })) { })) {
return failure(); return failure();
} }
@ -88,20 +90,26 @@ struct RankSpecializationClusterPattern : public RewritePattern {
SmallVector<Operation *, 16> cluster; SmallVector<Operation *, 16> cluster;
llvm::SmallSet<Value, 16> operand_set; llvm::SmallSet<Value, 16> operand_set;
llvm::SmallSet<Value, 16> result_set; llvm::SmallSet<Value, 16> result_set;
Operation *new_op = root_op;
while (new_op != nullptr && IsClusterable(new_op)) { Operation *root_op = op;
while (root_op->getNextNode() != nullptr &&
IsClusterable(root_op->getNextNode()))
root_op = root_op->getNextNode();
Operation *it = root_op;
while (it != nullptr && IsClusterable(it)) {
// Find results that escape the cluster. // Find results that escape the cluster.
for (OpOperand &use : new_op->getUses()) { for (OpOperand &use : it->getUses()) {
if (!llvm::is_contained(cluster, use.getOwner())) if (!llvm::is_contained(cluster, use.getOwner()))
result_set.insert(use.get()); result_set.insert(use.get());
} }
// Update cluster operands. // Update cluster operands.
for (OpResult v : new_op->getResults()) operand_set.erase(Value(v)); for (OpResult v : it->getResults()) operand_set.erase(Value(v));
for (OpOperand &v : new_op->getOpOperands()) operand_set.insert(v.get()); for (OpOperand &v : it->getOpOperands()) operand_set.insert(v.get());
cluster.push_back(new_op); cluster.push_back(it);
new_op = new_op->getPrevNode(); it = it->getPrevNode();
} }
// Create `RankSpecializationClusterOp`. // Create `RankSpecializationClusterOp`.
@ -109,7 +117,7 @@ struct RankSpecializationClusterPattern : public RewritePattern {
auto results = llvm::to_vector<16>(result_set); auto results = llvm::to_vector<16>(result_set);
auto result_types = llvm::to_vector<16>( auto result_types = llvm::to_vector<16>(
llvm::map_range(result_set, [](Value v) { return v.getType(); })); llvm::map_range(result_set, [](Value v) { return v.getType(); }));
Location loc = root_op->getLoc(); Location loc = op->getLoc();
auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>( auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>(
loc, result_types, operands); loc, result_types, operands);
@ -141,7 +149,7 @@ struct RankSpecializationClusterPattern : public RewritePattern {
} }
auto replacements = llvm::to_vector<16>(llvm::map_range( auto replacements = llvm::to_vector<16>(llvm::map_range(
it->getResults(), [&](Value v) { return bvm.lookup(v); })); it->getResults(), [&](Value v) { return bvm.lookup(v); }));
rewriter.replaceOp(root_op, replacements); rewriter.replaceOp(it, replacements);
} }
return success(); return success();

View File

@ -126,3 +126,66 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>)
%5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32> %5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32>
return %5 : tensor<*xf32> return %5 : tensor<*xf32>
} }
// -----
// Constant cluster operand.
// CHECK-LABEL: @relu
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]], %[[C0]])
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>, %[[C0_:.*]]: tensor<f32>):
// CHECK: %[[TMP:.*]] = chlo.broadcast_maximum %[[ARG_]], %[[C0_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
// CHECK: return %[[RES]]
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = chlo.broadcast_maximum %0, %arg
: (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
// -----
// Cluster with binary non-broadcasting operation.
// CHECK-LABEL: @angle
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xcomplex<f32>>)
func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]])
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xcomplex<f32>>):
// CHECK: %[[IMAG:.*]] = "mhlo.imag"(%[[ARG_]])
// CHECK: %[[REAL:.*]] = "mhlo.real"(%[[ARG_]])
// CHECK: %[[TMP:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
// CHECK: return %[[RES]]
%0 = "mhlo.imag"(%arg) : (tensor<*xcomplex<f32>>) -> tensor<*xf32>
%1 = "mhlo.real"(%arg) : (tensor<*xcomplex<f32>>) -> tensor<*xf32>
%2 = mhlo.atan2 %0, %1 : tensor<*xf32>
return %2 : tensor<*xf32>
}
// -----
// CHECK-LABEL: @xlogy
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[C0]], %[[ARG0]], %[[ARG1]])
// CHECK: ^bb0(%[[C0_:.*]]: tensor<f32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
// CHECK: %[[TMP0:.*]] = chlo.broadcast_compare %[[ARG0_]], %[[C0_]] {comparison_direction = "EQ"}
// CHECK: %[[TMP1:.*]] = "mhlo.log"(%[[ARG1_]])
// CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[TMP1]]
// CHECK: %[[TMP3:.*]] = chlo.broadcast_select %[[TMP0]], %[[C0_]], %[[TMP2]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP3]])
// CHECK: return %[[RES]]
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = tensor.cast %0 : tensor<f32> to tensor<f32>
%2 = chlo.broadcast_compare %arg0, %1 {comparison_direction = "EQ"}
: (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
%3 = "mhlo.log"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
%4 = chlo.broadcast_multiply %arg0, %3
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%5 = chlo.broadcast_select %2, %1, %4
: (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
return %5 : tensor<*xf32>
}