From c514c73390640556f8f242f403bfff57cb23e569 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 May 2021 07:30:45 -0700 Subject: [PATCH] [MLIR][HLO] Extend rank specialization clustering pass Also cluster operations that operate on same shape operands. These implicitly satisfy the broadcasting semantics requirement. Also, add test cases for some cases that appear in the current MLIR-generated kernels. PiperOrigin-RevId: 374191950 --- .../mhlo/transforms/rank_specialization.cc | 48 ++++++++------ tests/rank-specialization.mlir | 63 +++++++++++++++++++ 2 files changed, 91 insertions(+), 20 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 42d8e19..7f286f9 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -52,35 +53,36 @@ namespace { /// original shape afterwards. /// - Broadcasting semantics: All operations must implement broadcasting /// semantics. Most importantly, this allows extending operand shapes such -/// that they match in rank. +/// that they match in rank. Operations that require all their operands to +/// be of the same shape also fulfill this requirement. /// - Shape reification: All operations must implement /// `InferShapedTypeOpInterface`. This is later needed to compute and to /// restore the desired result shape. bool IsClusterable(Operation *op) { if (!llvm::isa(op)) return false; - unsigned int num_operands = op->getNumOperands(); - if (num_operands == 0) return false; - if (num_operands == 1) return op->hasTrait(); - return op->hasTrait() && - op->hasTrait(); + if (op->getNumOperands() == 0) return false; + return (op->hasTrait() && + op->hasTrait()) || + (op->hasTrait() && + op->hasTrait()); } struct RankSpecializationClusterPattern : public RewritePattern { explicit RankSpecializationClusterPattern(MLIRContext *ctx) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} - LogicalResult matchAndRewrite(Operation *root_op, + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // Only apply to operations that have not been clustered yet. - if (root_op->getParentOfType()) { + if (op->getParentOfType()) { return failure(); } // Only cluster when rank specialization is needed. - if (!IsClusterable(root_op) || - !llvm::any_of(root_op->getOperandTypes(), - [](Type ty) { return ty.isa(); })) { + if (!IsClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) { + return ty.isa(); + })) { return failure(); } @@ -88,20 +90,26 @@ struct RankSpecializationClusterPattern : public RewritePattern { SmallVector cluster; llvm::SmallSet operand_set; llvm::SmallSet result_set; - Operation *new_op = root_op; - while (new_op != nullptr && IsClusterable(new_op)) { + + Operation *root_op = op; + while (root_op->getNextNode() != nullptr && + IsClusterable(root_op->getNextNode())) + root_op = root_op->getNextNode(); + + Operation *it = root_op; + while (it != nullptr && IsClusterable(it)) { // Find results that escape the cluster. - for (OpOperand &use : new_op->getUses()) { + for (OpOperand &use : it->getUses()) { if (!llvm::is_contained(cluster, use.getOwner())) result_set.insert(use.get()); } // Update cluster operands. - for (OpResult v : new_op->getResults()) operand_set.erase(Value(v)); - for (OpOperand &v : new_op->getOpOperands()) operand_set.insert(v.get()); + for (OpResult v : it->getResults()) operand_set.erase(Value(v)); + for (OpOperand &v : it->getOpOperands()) operand_set.insert(v.get()); - cluster.push_back(new_op); - new_op = new_op->getPrevNode(); + cluster.push_back(it); + it = it->getPrevNode(); } // Create `RankSpecializationClusterOp`. @@ -109,7 +117,7 @@ struct RankSpecializationClusterPattern : public RewritePattern { auto results = llvm::to_vector<16>(result_set); auto result_types = llvm::to_vector<16>( llvm::map_range(result_set, [](Value v) { return v.getType(); })); - Location loc = root_op->getLoc(); + Location loc = op->getLoc(); auto cluster_op = rewriter.create( loc, result_types, operands); @@ -141,7 +149,7 @@ struct RankSpecializationClusterPattern : public RewritePattern { } auto replacements = llvm::to_vector<16>(llvm::map_range( it->getResults(), [&](Value v) { return bvm.lookup(v); })); - rewriter.replaceOp(root_op, replacements); + rewriter.replaceOp(it, replacements); } return success(); diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 32c3ad2..5850062 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -126,3 +126,66 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>) %5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32> return %5 : tensor<*xf32> } + +// ----- + +// Constant cluster operand. +// CHECK-LABEL: @relu +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]], %[[C0]]) + // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>, %[[C0_:.*]]: tensor): + // CHECK: %[[TMP:.*]] = chlo.broadcast_maximum %[[ARG_]], %[[C0_]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) + // CHECK: return %[[RES]] + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = chlo.broadcast_maximum %0, %arg + : (tensor, tensor<*xf32>) -> tensor<*xf32> + return %1 : tensor<*xf32> +} + +// ----- + +// Cluster with binary non-broadcasting operation. +// CHECK-LABEL: @angle +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xcomplex>) +func @angle(%arg : tensor<*xcomplex>) -> tensor<*xf32> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) + // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xcomplex>): + // CHECK: %[[IMAG:.*]] = "mhlo.imag"(%[[ARG_]]) + // CHECK: %[[REAL:.*]] = "mhlo.real"(%[[ARG_]]) + // CHECK: %[[TMP:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) + // CHECK: return %[[RES]] + %0 = "mhlo.imag"(%arg) : (tensor<*xcomplex>) -> tensor<*xf32> + %1 = "mhlo.real"(%arg) : (tensor<*xcomplex>) -> tensor<*xf32> + %2 = mhlo.atan2 %0, %1 : tensor<*xf32> + return %2 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @xlogy +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) +func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[C0]], %[[ARG0]], %[[ARG1]]) + // CHECK: ^bb0(%[[C0_:.*]]: tensor, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): + // CHECK: %[[TMP0:.*]] = chlo.broadcast_compare %[[ARG0_]], %[[C0_]] {comparison_direction = "EQ"} + // CHECK: %[[TMP1:.*]] = "mhlo.log"(%[[ARG1_]]) + // CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[TMP1]] + // CHECK: %[[TMP3:.*]] = chlo.broadcast_select %[[TMP0]], %[[C0_]], %[[TMP2]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP3]]) + // CHECK: return %[[RES]] + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = tensor.cast %0 : tensor to tensor + %2 = chlo.broadcast_compare %arg0, %1 {comparison_direction = "EQ"} + : (tensor<*xf32>, tensor) -> tensor<*xi1> + %3 = "mhlo.log"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + %4 = chlo.broadcast_multiply %arg0, %3 + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %5 = chlo.broadcast_select %2, %1, %4 + : (tensor<*xi1>, tensor, tensor<*xf32>) -> tensor<*xf32> + return %5 : tensor<*xf32> +}