Add `mhlo.all_gather` op to MHLO dialect.

Adds import/export/verifier support as well.
Also makes `channel_handle` uniform across mhlo.all_reduce and mhlo.all-gather.

PiperOrigin-RevId: 377323468
This commit is contained in:
A. Unique TensorFlower 2021-06-03 10:44:36 -07:00 committed by TensorFlow MLIR Team
parent 4fc2e87a42
commit aba16adfa5
3 changed files with 85 additions and 1 deletions

View File

@ -909,6 +909,26 @@ def HLO_WhileOp: HLO_Op<"while", [
let hasCustomHLOConverter = 1;
}
def HLO_AllGatherOp : HLO_Op<"all_gather", [SameOperandsAndResultElementType]> {
string summary = "AllGather operator";
string description = [{
Performs concatenation across replicas.
See https://www.tensorflow.org/xla/operation_semantics#allgather
}];
let arguments = (ins
HLO_Tensor:$operand,
I64Attr:$all_gather_dim,
I64ElementsAttr:$replica_groups,
OptionalAttr<ChannelHandle>:$channel_handle
);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
def HLO_AllReduceOp : HLO_Op<"all_reduce",
[SameOperandsAndResultType]> {
let summary = "AllReduce operator";
@ -921,7 +941,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
let arguments = (ins
HLO_Tensor:$operand,
I64ElementsAttr:$replica_groups,
OptionalAttr<ChannelHandle>:$channel_id
OptionalAttr<ChannelHandle>:$channel_handle
);
let regions = (region SizedRegion<1>:$computation);
let results = (outs HLO_Tensor);

View File

@ -725,6 +725,34 @@ static LogicalResult Verify(AllToAllOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// AllGatherOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(AllGatherOp op) {
// If operand and result are both ranked, then the size of the gather
// dimension in the result should be a multiple of the size of the gather
// dimension in the operand.
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
auto resultType = op.getType().dyn_cast<RankedTensorType>();
uint64_t allGatherDimIndex = op.all_gather_dim();
if (!operandType || !resultType ||
operandType.isDynamicDim(allGatherDimIndex) ||
resultType.isDynamicDim(allGatherDimIndex))
return success();
if (operandType.getDimSize(allGatherDimIndex) == 0)
return op.emitOpError() << "operand gather dimension cannot be zero.";
if ((resultType.getDimSize(allGatherDimIndex) %
operandType.getDimSize(allGatherDimIndex)) != 0)
return op.emitOpError()
<< "result gather dimension has size "
<< resultType.getDimSize(allGatherDimIndex)
<< ", expected to be a multiple of operand gather dimension size "
<< operandType.getDimSize(allGatherDimIndex);
return success();
}
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

View File

@ -52,6 +52,42 @@ func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf3
// -----
func @allgather_incompatible_types(%arg0: tensor<128x32xf32>) -> tensor<128x100xf32> {
// expected-error@+1 {{result gather dimension has size 100, expected to be a multiple of operand gather dimension size 32}}
%0 = "mhlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
channel_handle = {handle = 1 : i64, type = 0 : i64},
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
} : (tensor<128x32xf32>) -> tensor<128x100xf32>
return %0 : tensor<128x100xf32>
}
// -----
func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0x32xf32>) -> tensor<128x100xf32> {
// expected-error@+1 {{operand gather dimension cannot be zero}}
%0 = "mhlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
channel_handle = {handle = 1 : i64, type = 0 : i64},
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
} : (tensor<128x0x32xf32>) -> tensor<128x100xf32>
return %0 : tensor<128x100xf32>
}
// -----
// CHECK-LABEL: func @allgather_dynamic_gather_dim
func @allgather_dynamic_gather_dim(%arg0: tensor<128x32xf32>) -> tensor<128x?xf32> {
%0 = "mhlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
channel_handle = {handle = 1 : i64, type = 0 : i64},
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
} : (tensor<128x32xf32>) -> tensor<128x?xf32>
return %0 : tensor<128x?xf32>
}
// -----
// CHECK-LABEL: func @broadcast
func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
%0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>