diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 42d8e19..7f286f9 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -52,35 +53,36 @@ namespace { /// original shape afterwards. /// - Broadcasting semantics: All operations must implement broadcasting /// semantics. Most importantly, this allows extending operand shapes such -/// that they match in rank. +/// that they match in rank. Operations that require all their operands to +/// be of the same shape also fulfill this requirement. /// - Shape reification: All operations must implement /// `InferShapedTypeOpInterface`. This is later needed to compute and to /// restore the desired result shape. bool IsClusterable(Operation *op) { if (!llvm::isa(op)) return false; - unsigned int num_operands = op->getNumOperands(); - if (num_operands == 0) return false; - if (num_operands == 1) return op->hasTrait(); - return op->hasTrait() && - op->hasTrait(); + if (op->getNumOperands() == 0) return false; + return (op->hasTrait() && + op->hasTrait()) || + (op->hasTrait() && + op->hasTrait()); } struct RankSpecializationClusterPattern : public RewritePattern { explicit RankSpecializationClusterPattern(MLIRContext *ctx) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} - LogicalResult matchAndRewrite(Operation *root_op, + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // Only apply to operations that have not been clustered yet. - if (root_op->getParentOfType()) { + if (op->getParentOfType()) { return failure(); } // Only cluster when rank specialization is needed. - if (!IsClusterable(root_op) || - !llvm::any_of(root_op->getOperandTypes(), - [](Type ty) { return ty.isa(); })) { + if (!IsClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) { + return ty.isa(); + })) { return failure(); } @@ -88,20 +90,26 @@ struct RankSpecializationClusterPattern : public RewritePattern { SmallVector cluster; llvm::SmallSet operand_set; llvm::SmallSet result_set; - Operation *new_op = root_op; - while (new_op != nullptr && IsClusterable(new_op)) { + + Operation *root_op = op; + while (root_op->getNextNode() != nullptr && + IsClusterable(root_op->getNextNode())) + root_op = root_op->getNextNode(); + + Operation *it = root_op; + while (it != nullptr && IsClusterable(it)) { // Find results that escape the cluster. - for (OpOperand &use : new_op->getUses()) { + for (OpOperand &use : it->getUses()) { if (!llvm::is_contained(cluster, use.getOwner())) result_set.insert(use.get()); } // Update cluster operands. - for (OpResult v : new_op->getResults()) operand_set.erase(Value(v)); - for (OpOperand &v : new_op->getOpOperands()) operand_set.insert(v.get()); + for (OpResult v : it->getResults()) operand_set.erase(Value(v)); + for (OpOperand &v : it->getOpOperands()) operand_set.insert(v.get()); - cluster.push_back(new_op); - new_op = new_op->getPrevNode(); + cluster.push_back(it); + it = it->getPrevNode(); } // Create `RankSpecializationClusterOp`. @@ -109,7 +117,7 @@ struct RankSpecializationClusterPattern : public RewritePattern { auto results = llvm::to_vector<16>(result_set); auto result_types = llvm::to_vector<16>( llvm::map_range(result_set, [](Value v) { return v.getType(); })); - Location loc = root_op->getLoc(); + Location loc = op->getLoc(); auto cluster_op = rewriter.create( loc, result_types, operands); @@ -141,7 +149,7 @@ struct RankSpecializationClusterPattern : public RewritePattern { } auto replacements = llvm::to_vector<16>(llvm::map_range( it->getResults(), [&](Value v) { return bvm.lookup(v); })); - rewriter.replaceOp(root_op, replacements); + rewriter.replaceOp(it, replacements); } return success(); diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 32c3ad2..5850062 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -126,3 +126,66 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>) %5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32> return %5 : tensor<*xf32> } + +// ----- + +// Constant cluster operand. +// CHECK-LABEL: @relu +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +func @relu(%arg : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]], %[[C0]]) + // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>, %[[C0_:.*]]: tensor): + // CHECK: %[[TMP:.*]] = chlo.broadcast_maximum %[[ARG_]], %[[C0_]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) + // CHECK: return %[[RES]] + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = chlo.broadcast_maximum %0, %arg + : (tensor, tensor<*xf32>) -> tensor<*xf32> + return %1 : tensor<*xf32> +} + +// ----- + +// Cluster with binary non-broadcasting operation. +// CHECK-LABEL: @angle +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xcomplex>) +func @angle(%arg : tensor<*xcomplex>) -> tensor<*xf32> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) + // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xcomplex>): + // CHECK: %[[IMAG:.*]] = "mhlo.imag"(%[[ARG_]]) + // CHECK: %[[REAL:.*]] = "mhlo.real"(%[[ARG_]]) + // CHECK: %[[TMP:.*]] = mhlo.atan2 %[[IMAG]], %[[REAL]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) + // CHECK: return %[[RES]] + %0 = "mhlo.imag"(%arg) : (tensor<*xcomplex>) -> tensor<*xf32> + %1 = "mhlo.real"(%arg) : (tensor<*xcomplex>) -> tensor<*xf32> + %2 = mhlo.atan2 %0, %1 : tensor<*xf32> + return %2 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @xlogy +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) +func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[C0]], %[[ARG0]], %[[ARG1]]) + // CHECK: ^bb0(%[[C0_:.*]]: tensor, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>): + // CHECK: %[[TMP0:.*]] = chlo.broadcast_compare %[[ARG0_]], %[[C0_]] {comparison_direction = "EQ"} + // CHECK: %[[TMP1:.*]] = "mhlo.log"(%[[ARG1_]]) + // CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[TMP1]] + // CHECK: %[[TMP3:.*]] = chlo.broadcast_select %[[TMP0]], %[[C0_]], %[[TMP2]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP3]]) + // CHECK: return %[[RES]] + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = tensor.cast %0 : tensor to tensor + %2 = chlo.broadcast_compare %arg0, %1 {comparison_direction = "EQ"} + : (tensor<*xf32>, tensor) -> tensor<*xi1> + %3 = "mhlo.log"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + %4 = chlo.broadcast_multiply %arg0, %3 + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %5 = chlo.broadcast_select %2, %1, %4 + : (tensor<*xi1>, tensor, tensor<*xf32>) -> tensor<*xf32> + return %5 : tensor<*xf32> +}