[HLO] Add AllReduceScatter to MHLO and LMHLO dialects.

PiperOrigin-RevId: 379296198
This commit is contained in:
Rahul Joshi 2021-06-14 09:36:23 -07:00 committed by TensorFlow MLIR Team
parent dbfa4b1537
commit a6011d0279
8 changed files with 220 additions and 0 deletions

View File

@ -954,6 +954,26 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
let hasCustomHLOConverter = 1;
}
def HLO_AllReduceScatterOp : HLO_Op<"all_reduce_scatter",
[SameOperandsAndResultElementType]> {
let summary = "AllReduceScatter operator";
let description = [{
Performs all_reduce followed by a scatter.
See https://www.tensorflow.org/xla/operation_semantics#allreducescatter
}];
let arguments = (ins
HLO_Tensor:$operand,
I64Attr:$scatter_dimension,
I64ElementsAttr:$replica_groups,
OptionalAttr<ChannelHandle>:$channel_handle
);
let regions = (region SizedRegion<1>:$computation);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
def HLO_AllToAllOp : HLO_Op<"all_to_all",
[NoSideEffect, SameOperandsElementType, SameOperandsShape]> {

View File

@ -30,6 +30,10 @@ namespace hlo {
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
Operation* op, DenseIntElementsAttr attr);
LogicalResult VerifyAllReduceScatter(Operation* op, TypeRange operand_types,
TypeRange result_types,
uint64_t scatter_dimension);
// Custom formatting for convolution window attributes.
void printWindowAttributes(OpAsmPrinter& p, Operation* op,
llvm::Optional<DenseIntElementsAttr> window_strides,

View File

@ -1105,6 +1105,19 @@ def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperand
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_AllReduceScatterOp : LHLO_CollectiveCommunicationOp<"all_reduce_scatter", [SameOperandsElementType]> {
let summary = "AllReduceScatter operator";
let description = [{
Performs all_reduce followed by a scatter.
See https://www.tensorflow.org/xla/operation_semantics#allreducescatter
}];
let arguments = !con(
arguments_base,
(ins I64Attr:$scatter_dimension));
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]> {
let arguments = !con(
arguments_base,

View File

@ -196,6 +196,18 @@ Value MaybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
} // namespace
//===----------------------------------------------------------------------===//
// AllReduceScatterOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(AllReduceScatterOp op) {
return mlir::hlo::VerifyAllReduceScatter(
op,
/*operand_types=*/{op.operand().getType()},
/*result_types=*/{op.getType()},
/*scatter_dimension=*/op.scatter_dimension());
}
//===----------------------------------------------------------------------===//
// ConstOp
//===----------------------------------------------------------------------===//

View File

@ -53,6 +53,51 @@ LogicalResult VerifyCollectivePermuteSourceTargetPairs(
return success();
}
LogicalResult VerifyAllReduceScatter(Operation *op, TypeRange operand_types,
TypeRange result_types,
uint64_t scatter_dimension) {
// If operand and result are both ranked, then the size of the scatter
// dimension in the operand should be a multiple of the size of the scatter
// dimension in the result.
for (auto it : llvm::zip(operand_types, result_types)) {
auto operand_type = std::get<0>(it).cast<ShapedType>();
auto result_type = std::get<1>(it).cast<ShapedType>();
if (!operand_type.hasRank() || !result_type.hasRank()) continue;
if (operand_type.getRank() != result_type.getRank())
return op->emitOpError() << "operand and result should have same rank";
if (scatter_dimension >= operand_type.getRank())
return op->emitOpError()
<< "scatter dim should be less than operand/result rank";
if (operand_type.isDynamicDim(scatter_dimension) ||
result_type.isDynamicDim(scatter_dimension))
continue;
if (operand_type.getDimSize(scatter_dimension) == 0)
return op->emitOpError() << "operand scatter dimension cannot be zero";
if (result_type.getDimSize(scatter_dimension) == 0)
return op->emitOpError() << "result scatter dimension cannot be zero";
if ((operand_type.getDimSize(scatter_dimension) %
result_type.getDimSize(scatter_dimension)) != 0)
return op->emitOpError()
<< "operand scatter dimension has size "
<< operand_type.getDimSize(scatter_dimension)
<< ", expected to be a multiple of result scatter dimension size "
<< result_type.getDimSize(scatter_dimension);
// Non scatter dimensions should be equal.
for (uint64_t index : llvm::seq<uint64_t>(0, operand_type.getRank())) {
if (index == scatter_dimension || operand_type.isDynamicDim(index) ||
result_type.isDynamicDim(index))
continue;
if (operand_type.getDimSize(index) != result_type.getDimSize(index))
return op->emitOpError()
<< "non scatter dimensions should be same for operand ("
<< operand_type.getDimSize(index) << ") and result ("
<< result_type.getDimSize(index) << ")";
}
}
return success();
}
namespace {
// Custom formatting for convolution window attributes.
void printWindowAttribute(OpAsmPrinter &p, DenseElementsAttr attribute) {

View File

@ -159,6 +159,21 @@ static LogicalResult Verify(AllReduceOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// AllReduceScatterOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(AllReduceScatterOp op) {
if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/true)))
return failure();
if (failed(mlir::hlo::VerifyAllReduceScatter(
op, /*operand_types=*/op.operands().getTypes(),
/*result_types=*/op.results().getTypes(),
/*scatter_dimension=*/op.scatter_dimension())))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// CaseOp
//===----------------------------------------------------------------------===//

View File

@ -32,6 +32,19 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) {
// -----
// CHECK-LABEL: func @reduce_scatter
func @reduce_scatter(%data: memref<4x16xf32>, %result:memref<4x4xf32>) {
"lmhlo.all_reduce_scatter"(%data, %result) ( {
// reduction computation
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64} : (memref<4x16xf32>, memref<4x4xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @mixed_types_allgather
func @mixed_types_allgather(%a0: memref<1x1xf32>, %a1:memref<1x1xi32>) {
"lmhlo.all_gather"(%a0, %a1, %a0, %a1) {all_gather_dimension = 0 : i64,

View File

@ -13,6 +13,104 @@ func private @invalid_type() -> !mhlo.foobar
// -----
// CHECK-LABEL: func @reduce_scatter
func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> {
%0 = "mhlo.all_reduce_scatter"(%data) ( {
// reduction computation
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
// -----
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x5xf32> {
// expected-error@+1 {{operand scatter dimension has size 16, expected to be a multiple of result scatter dimension size 5}}
%0 = "mhlo.all_reduce_scatter"(%data) ( {
// reduction computation
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x5xf32>
return %0 : tensor<4x5xf32>
}
// -----
func @invalid_reduce_scatter(%data: tensor<4x0xf32>) -> tensor<4x4xf32> {
// expected-error@+1 {{operand scatter dimension cannot be zero}}
%0 = "mhlo.all_reduce_scatter"(%data) ( {
// reduction computation
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64} : (tensor<4x0xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
// -----
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x0xf32> {
// expected-error@+1 {{result scatter dimension cannot be zero}}
%0 = "mhlo.all_reduce_scatter"(%data) ( {
// reduction computation
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x0xf32>
return %0 : tensor<4x0xf32>
}
// -----
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4xf32> {
// expected-error@+1 {{operand and result should have same rank}}
%0 = "mhlo.all_reduce_scatter"(%data) ( {
// reduction computation
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> {
// expected-error@+1 {{scatter dim should be less than operand/result rank}}
%0 = "mhlo.all_reduce_scatter"(%data) ( {
// reduction computation
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 4 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
// -----
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<3x4xf32> {
// expected-error@+1 {{non scatter dimensions should be same for operand (4) and result (3)}}
%0 = "mhlo.all_reduce_scatter"(%data) ( {
// reduction computation
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
// -----
// CHECK-LABEL: func @alltoall
func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> {
%0 = "mhlo.all_to_all"(%data) {