Add `mhlo.all_gather` op to MHLO dialect.
Adds import/export/verifier support as well. Also makes `channel_handle` uniform across mhlo.all_reduce and mhlo.all-gather. PiperOrigin-RevId: 377323468
This commit is contained in:
parent
4fc2e87a42
commit
aba16adfa5
|
@ -909,6 +909,26 @@ def HLO_WhileOp: HLO_Op<"while", [
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLO_AllGatherOp : HLO_Op<"all_gather", [SameOperandsAndResultElementType]> {
|
||||||
|
|
||||||
|
string summary = "AllGather operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Performs concatenation across replicas.
|
||||||
|
|
||||||
|
See https://www.tensorflow.org/xla/operation_semantics#allgather
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_Tensor:$operand,
|
||||||
|
I64Attr:$all_gather_dim,
|
||||||
|
I64ElementsAttr:$replica_groups,
|
||||||
|
OptionalAttr<ChannelHandle>:$channel_handle
|
||||||
|
);
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_AllReduceOp : HLO_Op<"all_reduce",
|
def HLO_AllReduceOp : HLO_Op<"all_reduce",
|
||||||
[SameOperandsAndResultType]> {
|
[SameOperandsAndResultType]> {
|
||||||
let summary = "AllReduce operator";
|
let summary = "AllReduce operator";
|
||||||
|
@ -921,7 +941,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
I64ElementsAttr:$replica_groups,
|
I64ElementsAttr:$replica_groups,
|
||||||
OptionalAttr<ChannelHandle>:$channel_id
|
OptionalAttr<ChannelHandle>:$channel_handle
|
||||||
);
|
);
|
||||||
let regions = (region SizedRegion<1>:$computation);
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
|
|
|
@ -725,6 +725,34 @@ static LogicalResult Verify(AllToAllOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AllGatherOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult Verify(AllGatherOp op) {
|
||||||
|
// If operand and result are both ranked, then the size of the gather
|
||||||
|
// dimension in the result should be a multiple of the size of the gather
|
||||||
|
// dimension in the operand.
|
||||||
|
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
uint64_t allGatherDimIndex = op.all_gather_dim();
|
||||||
|
if (!operandType || !resultType ||
|
||||||
|
operandType.isDynamicDim(allGatherDimIndex) ||
|
||||||
|
resultType.isDynamicDim(allGatherDimIndex))
|
||||||
|
return success();
|
||||||
|
if (operandType.getDimSize(allGatherDimIndex) == 0)
|
||||||
|
return op.emitOpError() << "operand gather dimension cannot be zero.";
|
||||||
|
if ((resultType.getDimSize(allGatherDimIndex) %
|
||||||
|
operandType.getDimSize(allGatherDimIndex)) != 0)
|
||||||
|
return op.emitOpError()
|
||||||
|
<< "result gather dimension has size "
|
||||||
|
<< resultType.getDimSize(allGatherDimIndex)
|
||||||
|
<< ", expected to be a multiple of operand gather dimension size "
|
||||||
|
<< operandType.getDimSize(allGatherDimIndex);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BroadcastOp
|
// BroadcastOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -52,6 +52,42 @@ func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf3
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @allgather_incompatible_types(%arg0: tensor<128x32xf32>) -> tensor<128x100xf32> {
|
||||||
|
// expected-error@+1 {{result gather dimension has size 100, expected to be a multiple of operand gather dimension size 32}}
|
||||||
|
%0 = "mhlo.all_gather"(%arg0) {
|
||||||
|
all_gather_dim = 1 : i64,
|
||||||
|
channel_handle = {handle = 1 : i64, type = 0 : i64},
|
||||||
|
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
|
||||||
|
} : (tensor<128x32xf32>) -> tensor<128x100xf32>
|
||||||
|
return %0 : tensor<128x100xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0x32xf32>) -> tensor<128x100xf32> {
|
||||||
|
// expected-error@+1 {{operand gather dimension cannot be zero}}
|
||||||
|
%0 = "mhlo.all_gather"(%arg0) {
|
||||||
|
all_gather_dim = 1 : i64,
|
||||||
|
channel_handle = {handle = 1 : i64, type = 0 : i64},
|
||||||
|
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
|
||||||
|
} : (tensor<128x0x32xf32>) -> tensor<128x100xf32>
|
||||||
|
return %0 : tensor<128x100xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @allgather_dynamic_gather_dim
|
||||||
|
func @allgather_dynamic_gather_dim(%arg0: tensor<128x32xf32>) -> tensor<128x?xf32> {
|
||||||
|
%0 = "mhlo.all_gather"(%arg0) {
|
||||||
|
all_gather_dim = 1 : i64,
|
||||||
|
channel_handle = {handle = 1 : i64, type = 0 : i64},
|
||||||
|
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
|
||||||
|
} : (tensor<128x32xf32>) -> tensor<128x?xf32>
|
||||||
|
return %0 : tensor<128x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @broadcast
|
// CHECK-LABEL: func @broadcast
|
||||||
func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
|
func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
|
||||||
%0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
|
%0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
|
||||||
|
|
Loading…
Reference in New Issue