[MLIR][HLO] Add `rank_specialization_cluster` op to CHLO
The operation will be used to cluster compatible operations that can be rank- specialized collectively. PiperOrigin-RevId: 373128557
This commit is contained in:
parent
86b7eb434c
commit
96a47345cc
1
BUILD
1
BUILD
|
@ -426,6 +426,7 @@ cc_library(
|
||||||
":infer_fusibility_op_interface",
|
":infer_fusibility_op_interface",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Analysis",
|
"@llvm-project//mlir:Analysis",
|
||||||
|
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:InferTypeOpInterface",
|
"@llvm-project//mlir:InferTypeOpInterface",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
#define CHLO_OPS
|
#define CHLO_OPS
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
@ -754,4 +755,54 @@ def HLOClient_MinimumBroadcastShapesOp :
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_RankSpecializationClusterOp
|
||||||
|
: HLOClient_Op<"rank_specialization_cluster", [
|
||||||
|
SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">,
|
||||||
|
RecursiveSideEffects]> {
|
||||||
|
|
||||||
|
let summary = "Cluster of operations that will be rank-specialized together.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Groups compatible element-wise operatons together so that they can be
|
||||||
|
rank-specialized together. The operation takes and yields a variadic number
|
||||||
|
of (unranked) tensor operands. Its body region holds one block with one
|
||||||
|
block argument per input tensor of the same type. All operations in this
|
||||||
|
block must only operate on these block arguments. Results are returned
|
||||||
|
through the `rank_specialization_cluster_yield` operation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
|
||||||
|
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
|
||||||
|
%1 = chlo.broadcast_multiply %arg0_, %arg1_
|
||||||
|
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%2 = chlo.broadcast_add %1, %arg2_
|
||||||
|
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
|
||||||
|
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<HLO_Tensor>:$operands);
|
||||||
|
let results = (outs Variadic<HLO_Tensor>:$results);
|
||||||
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_RankSpecializationClusterYieldOp
|
||||||
|
: HLOClient_Op<"rank_specialization_cluster_yield", [NoSideEffect,
|
||||||
|
ReturnLike, Terminator, HasParent<"RankSpecializationClusterOp">]> {
|
||||||
|
|
||||||
|
let summary = "Yield operation for `rank_specialization_cluster`";
|
||||||
|
let description = [{
|
||||||
|
This operation yields the results from within the
|
||||||
|
`chlo.rank_specialization_cluster` operation's region. The operation takes
|
||||||
|
an arbitrary number of operands and produces no results. The operand number
|
||||||
|
and types must match the number and types of the parent
|
||||||
|
`rank_specialization_cluster` operation's results.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<HLO_Tensor>:$results);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // CHLO_OPS
|
#endif // CHLO_OPS
|
||||||
|
|
|
@ -418,6 +418,31 @@ LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// RankSpecializationClusterOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult Verify(RankSpecializationClusterOp op) {
|
||||||
|
if (op.body().getArgumentTypes() != op.getOperandTypes())
|
||||||
|
return op.emitOpError() << "block argument types must match operand types";
|
||||||
|
|
||||||
|
// All operands of nested ops must be defined in the body or declared by the
|
||||||
|
// cluster.
|
||||||
|
Block* body = op.getBody();
|
||||||
|
for (Operation& nested : body->without_terminator()) {
|
||||||
|
if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
|
||||||
|
Operation* def = operand.get().getDefiningOp();
|
||||||
|
if (def != nullptr && def->getBlock() == body) return true;
|
||||||
|
return llvm::is_contained(body->getArguments(), operand.get());
|
||||||
|
})) {
|
||||||
|
return op.emitOpError()
|
||||||
|
<< "nested ops must not depend on implicit operands";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace chlo
|
} // namespace chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -24,3 +24,45 @@ func @minimum_broadcast_shapes_one_operand(%arg: tensor<?xindex>) {
|
||||||
%0 = chlo.minimum_broadcast_shapes %arg : tensor<?xindex> -> tensor<?xindex>
|
%0 = chlo.minimum_broadcast_shapes %arg : tensor<?xindex> -> tensor<?xindex>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
|
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
|
||||||
|
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
|
||||||
|
%1 = chlo.broadcast_multiply %arg0_, %arg1_
|
||||||
|
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%2 = chlo.broadcast_add %1, %arg2_
|
||||||
|
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
|
||||||
|
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// expected-error @+1{{block argument types must match operand types}}
|
||||||
|
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
|
||||||
|
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>):
|
||||||
|
"chlo.rank_specialization_cluster_yield"(%arg0_) : (tensor<*xf32>) -> ()
|
||||||
|
}) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
|
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// expected-error @+1{{nested ops must not depend on implicit operands}}
|
||||||
|
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
|
||||||
|
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
|
||||||
|
%1 = chlo.broadcast_multiply %arg0_, %arg1_
|
||||||
|
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%2 = chlo.broadcast_add %1, %arg2
|
||||||
|
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
|
||||||
|
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue