[MLIR][HLO] Add `rank_specialization_cluster` op to CHLO

The operation will be used to cluster compatible operations that can be rank-
specialized collectively.

PiperOrigin-RevId: 373128557
This commit is contained in:
A. Unique TensorFlower 2021-05-11 05:17:01 -07:00 committed by TensorFlow MLIR Team
parent 86b7eb434c
commit 96a47345cc
5 changed files with 120 additions and 0 deletions

1
BUILD
View File

@ -426,6 +426,7 @@ cc_library(
":infer_fusibility_op_interface",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Pass",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -30,6 +30,7 @@ limitations under the License.
#define CHLO_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
@ -754,4 +755,54 @@ def HLOClient_MinimumBroadcastShapesOp :
}
def HLOClient_RankSpecializationClusterOp
: HLOClient_Op<"rank_specialization_cluster", [
SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">,
RecursiveSideEffects]> {
let summary = "Cluster of operations that will be rank-specialized together.";
let description = [{
Groups compatible element-wise operatons together so that they can be
rank-specialized together. The operation takes and yields a variadic number
of (unranked) tensor operands. Its body region holds one block with one
block argument per input tensor of the same type. All operations in this
block must only operate on these block arguments. Results are returned
through the `rank_specialization_cluster_yield` operation.
Example:
```
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
%1 = chlo.broadcast_multiply %arg0_, %arg1_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = chlo.broadcast_add %1, %arg2_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
```
}];
let arguments = (ins Variadic<HLO_Tensor>:$operands);
let results = (outs Variadic<HLO_Tensor>:$results);
let regions = (region SizedRegion<1>:$body);
}
def HLOClient_RankSpecializationClusterYieldOp
: HLOClient_Op<"rank_specialization_cluster_yield", [NoSideEffect,
ReturnLike, Terminator, HasParent<"RankSpecializationClusterOp">]> {
let summary = "Yield operation for `rank_specialization_cluster`";
let description = [{
This operation yields the results from within the
`chlo.rank_specialization_cluster` operation's region. The operation takes
an arbitrary number of operands and produces no results. The operand number
and types must match the number and types of the parent
`rank_specialization_cluster` operation's results.
}];
let arguments = (ins Variadic<HLO_Tensor>:$results);
}
#endif // CHLO_OPS

View File

@ -418,6 +418,31 @@ LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
return success();
}
//===----------------------------------------------------------------------===//
// RankSpecializationClusterOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(RankSpecializationClusterOp op) {
if (op.body().getArgumentTypes() != op.getOperandTypes())
return op.emitOpError() << "block argument types must match operand types";
// All operands of nested ops must be defined in the body or declared by the
// cluster.
Block* body = op.getBody();
for (Operation& nested : body->without_terminator()) {
if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
Operation* def = operand.get().getDefiningOp();
if (def != nullptr && def->getBlock() == body) return true;
return llvm::is_contained(body->getArguments(), operand.get());
})) {
return op.emitOpError()
<< "nested ops must not depend on implicit operands";
}
}
return success();
}
} // namespace chlo
} // namespace mlir

View File

@ -24,3 +24,45 @@ func @minimum_broadcast_shapes_one_operand(%arg: tensor<?xindex>) {
%0 = chlo.minimum_broadcast_shapes %arg : tensor<?xindex> -> tensor<?xindex>
return
}
// -----
func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
%1 = chlo.broadcast_multiply %arg0_, %arg1_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = chlo.broadcast_add %1, %arg2_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1{{block argument types must match operand types}}
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>):
"chlo.rank_specialization_cluster_yield"(%arg0_) : (tensor<*xf32>) -> ()
}) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1{{nested ops must not depend on implicit operands}}
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
%1 = chlo.broadcast_multiply %arg0_, %arg1_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = chlo.broadcast_add %1, %arg2
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}