[MLIR][HLO] Extend rank specialization clustering pass
Also cluster operations that operate on same shape operands. These implicitly satisfy the broadcasting semantics requirement. Also, add test cases for some cases that appear in the current MLIR-generated kernels. PiperOrigin-RevId: 374191950
This commit is contained in:
parent
b82bbf4dd1
commit
c514c73390
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
|
@ -52,35 +53,36 @@ namespace {
|
|||
/// original shape afterwards.
|
||||
/// - Broadcasting semantics: All operations must implement broadcasting
|
||||
/// semantics. Most importantly, this allows extending operand shapes such
|
||||
/// that they match in rank.
|
||||
/// that they match in rank. Operations that require all their operands to
|
||||
/// be of the same shape also fulfill this requirement.
|
||||
/// - Shape reification: All operations must implement
|
||||
/// `InferShapedTypeOpInterface`. This is later needed to compute and to
|
||||
/// restore the desired result shape.
|
||||
|
||||
bool IsClusterable(Operation *op) {
|
||||
if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
|
||||
unsigned int num_operands = op->getNumOperands();
|
||||
if (num_operands == 0) return false;
|
||||
if (num_operands == 1) return op->hasTrait<OpTrait::Elementwise>();
|
||||
return op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
|
||||
op->hasTrait<chlo::OpTrait::Broadcasting>();
|
||||
if (op->getNumOperands() == 0) return false;
|
||||
return (op->hasTrait<OpTrait::Elementwise>() &&
|
||||
op->hasTrait<OpTrait::SameOperandsAndResultShape>()) ||
|
||||
(op->hasTrait<chlo::OpTrait::BroadcastingElementwise>() &&
|
||||
op->hasTrait<chlo::OpTrait::Broadcasting>());
|
||||
}
|
||||
|
||||
struct RankSpecializationClusterPattern : public RewritePattern {
|
||||
explicit RankSpecializationClusterPattern(MLIRContext *ctx)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *root_op,
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only apply to operations that have not been clustered yet.
|
||||
if (root_op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
|
||||
if (op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Only cluster when rank specialization is needed.
|
||||
if (!IsClusterable(root_op) ||
|
||||
!llvm::any_of(root_op->getOperandTypes(),
|
||||
[](Type ty) { return ty.isa<UnrankedTensorType>(); })) {
|
||||
if (!IsClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) {
|
||||
return ty.isa<UnrankedTensorType>();
|
||||
})) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -88,20 +90,26 @@ struct RankSpecializationClusterPattern : public RewritePattern {
|
|||
SmallVector<Operation *, 16> cluster;
|
||||
llvm::SmallSet<Value, 16> operand_set;
|
||||
llvm::SmallSet<Value, 16> result_set;
|
||||
Operation *new_op = root_op;
|
||||
while (new_op != nullptr && IsClusterable(new_op)) {
|
||||
|
||||
Operation *root_op = op;
|
||||
while (root_op->getNextNode() != nullptr &&
|
||||
IsClusterable(root_op->getNextNode()))
|
||||
root_op = root_op->getNextNode();
|
||||
|
||||
Operation *it = root_op;
|
||||
while (it != nullptr && IsClusterable(it)) {
|
||||
// Find results that escape the cluster.
|
||||
for (OpOperand &use : new_op->getUses()) {
|
||||
for (OpOperand &use : it->getUses()) {
|
||||
if (!llvm::is_contained(cluster, use.getOwner()))
|
||||
result_set.insert(use.get());
|
||||
}
|
||||
|
||||
// Update cluster operands.
|
||||
for (OpResult v : new_op->getResults()) operand_set.erase(Value(v));
|
||||
for (OpOperand &v : new_op->getOpOperands()) operand_set.insert(v.get());
|
||||
for (OpResult v : it->getResults()) operand_set.erase(Value(v));
|
||||
for (OpOperand &v : it->getOpOperands()) operand_set.insert(v.get());
|
||||
|
||||
cluster.push_back(new_op);
|
||||
new_op = new_op->getPrevNode();
|
||||
cluster.push_back(it);
|
||||
it = it->getPrevNode();
|
||||
}
|
||||
|
||||
// Create `RankSpecializationClusterOp`.
|
||||
|
@ -109,7 +117,7 @@ struct RankSpecializationClusterPattern : public RewritePattern {
|
|||
auto results = llvm::to_vector<16>(result_set);
|
||||
auto result_types = llvm::to_vector<16>(
|
||||
llvm::map_range(result_set, [](Value v) { return v.getType(); }));
|
||||
Location loc = root_op->getLoc();
|
||||
Location loc = op->getLoc();
|
||||
auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>(
|
||||
loc, result_types, operands);
|
||||
|
||||
|
@ -141,7 +149,7 @@ struct RankSpecializationClusterPattern : public RewritePattern {
|
|||
}
|
||||
auto replacements = llvm::to_vector<16>(llvm::map_range(
|
||||
it->getResults(), [&](Value v) { return bvm.lookup(v); }));
|
||||
rewriter.replaceOp(root_op, replacements);
|
||||
rewriter.replaceOp(it, replacements);
|
||||
}
|
||||
|
||||
return success();
|
||||
|
|
|
@ -126,3 +126,66 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>)
|
|||
%5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32>
|
||||
return %5 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Constant cluster operand.
|
||||
// CHECK-LABEL: @relu
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||
func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]], %[[C0]])
|
||||
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>, %[[C0_:.*]]: tensor<f32>):
|
||||
// CHECK: %[[TMP:.*]] = chlo.broadcast_maximum %[[ARG_]], %[[C0_]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = chlo.broadcast_maximum %0, %arg
|
||||
: (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Cluster with binary non-broadcasting operation.
|
||||
// CHECK-LABEL: @angle
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xcomplex<f32>>)
|
||||
func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]])
|
||||
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xcomplex<f32>>):
|
||||
// CHECK: %[[IMAG:.*]] = "mhlo.imag"(%[[ARG_]])
|
||||
// CHECK: %[[REAL:.*]] = "mhlo.real"(%[[ARG_]])
|
||||
// CHECK: %[[TMP:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = "mhlo.imag"(%arg) : (tensor<*xcomplex<f32>>) -> tensor<*xf32>
|
||||
%1 = "mhlo.real"(%arg) : (tensor<*xcomplex<f32>>) -> tensor<*xf32>
|
||||
%2 = mhlo.atan2 %0, %1 : tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @xlogy
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
|
||||
func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[C0]], %[[ARG0]], %[[ARG1]])
|
||||
// CHECK: ^bb0(%[[C0_:.*]]: tensor<f32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
|
||||
// CHECK: %[[TMP0:.*]] = chlo.broadcast_compare %[[ARG0_]], %[[C0_]] {comparison_direction = "EQ"}
|
||||
// CHECK: %[[TMP1:.*]] = "mhlo.log"(%[[ARG1_]])
|
||||
// CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[TMP1]]
|
||||
// CHECK: %[[TMP3:.*]] = chlo.broadcast_select %[[TMP0]], %[[C0_]], %[[TMP2]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP3]])
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = tensor.cast %0 : tensor<f32> to tensor<f32>
|
||||
%2 = chlo.broadcast_compare %arg0, %1 {comparison_direction = "EQ"}
|
||||
: (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
|
||||
%3 = "mhlo.log"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%4 = chlo.broadcast_multiply %arg0, %3
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%5 = chlo.broadcast_select %2, %1, %4
|
||||
: (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %5 : tensor<*xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue