[MLIR][HLO] Add `rank_specialization_cluster` op to CHLO
The operation will be used to cluster compatible operations that can be rank- specialized collectively. PiperOrigin-RevId: 373128557
This commit is contained in:
parent
86b7eb434c
commit
96a47345cc
1
BUILD
1
BUILD
|
@ -426,6 +426,7 @@ cc_library(
|
|||
":infer_fusibility_op_interface",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
#define CHLO_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
|
@ -754,4 +755,54 @@ def HLOClient_MinimumBroadcastShapesOp :
|
|||
|
||||
}
|
||||
|
||||
def HLOClient_RankSpecializationClusterOp
|
||||
: HLOClient_Op<"rank_specialization_cluster", [
|
||||
SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">,
|
||||
RecursiveSideEffects]> {
|
||||
|
||||
let summary = "Cluster of operations that will be rank-specialized together.";
|
||||
|
||||
let description = [{
|
||||
Groups compatible element-wise operatons together so that they can be
|
||||
rank-specialized together. The operation takes and yields a variadic number
|
||||
of (unranked) tensor operands. Its body region holds one block with one
|
||||
block argument per input tensor of the same type. All operations in this
|
||||
block must only operate on these block arguments. Results are returned
|
||||
through the `rank_specialization_cluster_yield` operation.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
|
||||
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
|
||||
%1 = chlo.broadcast_multiply %arg0_, %arg1_
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%2 = chlo.broadcast_add %1, %arg2_
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<HLO_Tensor>:$operands);
|
||||
let results = (outs Variadic<HLO_Tensor>:$results);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def HLOClient_RankSpecializationClusterYieldOp
|
||||
: HLOClient_Op<"rank_specialization_cluster_yield", [NoSideEffect,
|
||||
ReturnLike, Terminator, HasParent<"RankSpecializationClusterOp">]> {
|
||||
|
||||
let summary = "Yield operation for `rank_specialization_cluster`";
|
||||
let description = [{
|
||||
This operation yields the results from within the
|
||||
`chlo.rank_specialization_cluster` operation's region. The operation takes
|
||||
an arbitrary number of operands and produces no results. The operand number
|
||||
and types must match the number and types of the parent
|
||||
`rank_specialization_cluster` operation's results.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<HLO_Tensor>:$results);
|
||||
}
|
||||
|
||||
#endif // CHLO_OPS
|
||||
|
|
|
@ -418,6 +418,31 @@ LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RankSpecializationClusterOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(RankSpecializationClusterOp op) {
|
||||
if (op.body().getArgumentTypes() != op.getOperandTypes())
|
||||
return op.emitOpError() << "block argument types must match operand types";
|
||||
|
||||
// All operands of nested ops must be defined in the body or declared by the
|
||||
// cluster.
|
||||
Block* body = op.getBody();
|
||||
for (Operation& nested : body->without_terminator()) {
|
||||
if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
|
||||
Operation* def = operand.get().getDefiningOp();
|
||||
if (def != nullptr && def->getBlock() == body) return true;
|
||||
return llvm::is_contained(body->getArguments(), operand.get());
|
||||
})) {
|
||||
return op.emitOpError()
|
||||
<< "nested ops must not depend on implicit operands";
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -24,3 +24,45 @@ func @minimum_broadcast_shapes_one_operand(%arg: tensor<?xindex>) {
|
|||
%0 = chlo.minimum_broadcast_shapes %arg : tensor<?xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
|
||||
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
|
||||
%1 = chlo.broadcast_multiply %arg0_, %arg1_
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%2 = chlo.broadcast_add %1, %arg2_
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// expected-error @+1{{block argument types must match operand types}}
|
||||
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
|
||||
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>):
|
||||
"chlo.rank_specialization_cluster_yield"(%arg0_) : (tensor<*xf32>) -> ()
|
||||
}) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// expected-error @+1{{nested ops must not depend on implicit operands}}
|
||||
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
|
||||
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
|
||||
%1 = chlo.broadcast_multiply %arg0_, %arg1_
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%2 = chlo.broadcast_add %1, %arg2
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue