[HLO] Add AllReduceScatter to MHLO and LMHLO dialects.
PiperOrigin-RevId: 379296198
This commit is contained in:
parent
dbfa4b1537
commit
a6011d0279
|
@ -954,6 +954,26 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLO_AllReduceScatterOp : HLO_Op<"all_reduce_scatter",
|
||||||
|
[SameOperandsAndResultElementType]> {
|
||||||
|
let summary = "AllReduceScatter operator";
|
||||||
|
let description = [{
|
||||||
|
Performs all_reduce followed by a scatter.
|
||||||
|
|
||||||
|
See https://www.tensorflow.org/xla/operation_semantics#allreducescatter
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_Tensor:$operand,
|
||||||
|
I64Attr:$scatter_dimension,
|
||||||
|
I64ElementsAttr:$replica_groups,
|
||||||
|
OptionalAttr<ChannelHandle>:$channel_handle
|
||||||
|
);
|
||||||
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_AllToAllOp : HLO_Op<"all_to_all",
|
def HLO_AllToAllOp : HLO_Op<"all_to_all",
|
||||||
[NoSideEffect, SameOperandsElementType, SameOperandsShape]> {
|
[NoSideEffect, SameOperandsElementType, SameOperandsShape]> {
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,10 @@ namespace hlo {
|
||||||
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
||||||
Operation* op, DenseIntElementsAttr attr);
|
Operation* op, DenseIntElementsAttr attr);
|
||||||
|
|
||||||
|
LogicalResult VerifyAllReduceScatter(Operation* op, TypeRange operand_types,
|
||||||
|
TypeRange result_types,
|
||||||
|
uint64_t scatter_dimension);
|
||||||
|
|
||||||
// Custom formatting for convolution window attributes.
|
// Custom formatting for convolution window attributes.
|
||||||
void printWindowAttributes(OpAsmPrinter& p, Operation* op,
|
void printWindowAttributes(OpAsmPrinter& p, Operation* op,
|
||||||
llvm::Optional<DenseIntElementsAttr> window_strides,
|
llvm::Optional<DenseIntElementsAttr> window_strides,
|
||||||
|
|
|
@ -1105,6 +1105,19 @@ def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperand
|
||||||
let regions = (region SizedRegion<1>:$computation);
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LHLO_AllReduceScatterOp : LHLO_CollectiveCommunicationOp<"all_reduce_scatter", [SameOperandsElementType]> {
|
||||||
|
let summary = "AllReduceScatter operator";
|
||||||
|
let description = [{
|
||||||
|
Performs all_reduce followed by a scatter.
|
||||||
|
|
||||||
|
See https://www.tensorflow.org/xla/operation_semantics#allreducescatter
|
||||||
|
}];
|
||||||
|
let arguments = !con(
|
||||||
|
arguments_base,
|
||||||
|
(ins I64Attr:$scatter_dimension));
|
||||||
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
|
}
|
||||||
|
|
||||||
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]> {
|
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]> {
|
||||||
let arguments = !con(
|
let arguments = !con(
|
||||||
arguments_base,
|
arguments_base,
|
||||||
|
|
|
@ -196,6 +196,18 @@ Value MaybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AllReduceScatterOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult Verify(AllReduceScatterOp op) {
|
||||||
|
return mlir::hlo::VerifyAllReduceScatter(
|
||||||
|
op,
|
||||||
|
/*operand_types=*/{op.operand().getType()},
|
||||||
|
/*result_types=*/{op.getType()},
|
||||||
|
/*scatter_dimension=*/op.scatter_dimension());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConstOp
|
// ConstOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -53,6 +53,51 @@ LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult VerifyAllReduceScatter(Operation *op, TypeRange operand_types,
|
||||||
|
TypeRange result_types,
|
||||||
|
uint64_t scatter_dimension) {
|
||||||
|
// If operand and result are both ranked, then the size of the scatter
|
||||||
|
// dimension in the operand should be a multiple of the size of the scatter
|
||||||
|
// dimension in the result.
|
||||||
|
for (auto it : llvm::zip(operand_types, result_types)) {
|
||||||
|
auto operand_type = std::get<0>(it).cast<ShapedType>();
|
||||||
|
auto result_type = std::get<1>(it).cast<ShapedType>();
|
||||||
|
if (!operand_type.hasRank() || !result_type.hasRank()) continue;
|
||||||
|
if (operand_type.getRank() != result_type.getRank())
|
||||||
|
return op->emitOpError() << "operand and result should have same rank";
|
||||||
|
if (scatter_dimension >= operand_type.getRank())
|
||||||
|
return op->emitOpError()
|
||||||
|
<< "scatter dim should be less than operand/result rank";
|
||||||
|
if (operand_type.isDynamicDim(scatter_dimension) ||
|
||||||
|
result_type.isDynamicDim(scatter_dimension))
|
||||||
|
continue;
|
||||||
|
if (operand_type.getDimSize(scatter_dimension) == 0)
|
||||||
|
return op->emitOpError() << "operand scatter dimension cannot be zero";
|
||||||
|
if (result_type.getDimSize(scatter_dimension) == 0)
|
||||||
|
return op->emitOpError() << "result scatter dimension cannot be zero";
|
||||||
|
if ((operand_type.getDimSize(scatter_dimension) %
|
||||||
|
result_type.getDimSize(scatter_dimension)) != 0)
|
||||||
|
return op->emitOpError()
|
||||||
|
<< "operand scatter dimension has size "
|
||||||
|
<< operand_type.getDimSize(scatter_dimension)
|
||||||
|
<< ", expected to be a multiple of result scatter dimension size "
|
||||||
|
<< result_type.getDimSize(scatter_dimension);
|
||||||
|
|
||||||
|
// Non scatter dimensions should be equal.
|
||||||
|
for (uint64_t index : llvm::seq<uint64_t>(0, operand_type.getRank())) {
|
||||||
|
if (index == scatter_dimension || operand_type.isDynamicDim(index) ||
|
||||||
|
result_type.isDynamicDim(index))
|
||||||
|
continue;
|
||||||
|
if (operand_type.getDimSize(index) != result_type.getDimSize(index))
|
||||||
|
return op->emitOpError()
|
||||||
|
<< "non scatter dimensions should be same for operand ("
|
||||||
|
<< operand_type.getDimSize(index) << ") and result ("
|
||||||
|
<< result_type.getDimSize(index) << ")";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Custom formatting for convolution window attributes.
|
// Custom formatting for convolution window attributes.
|
||||||
void printWindowAttribute(OpAsmPrinter &p, DenseElementsAttr attribute) {
|
void printWindowAttribute(OpAsmPrinter &p, DenseElementsAttr attribute) {
|
||||||
|
|
|
@ -159,6 +159,21 @@ static LogicalResult Verify(AllReduceOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AllReduceScatterOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult Verify(AllReduceScatterOp op) {
|
||||||
|
if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/true)))
|
||||||
|
return failure();
|
||||||
|
if (failed(mlir::hlo::VerifyAllReduceScatter(
|
||||||
|
op, /*operand_types=*/op.operands().getTypes(),
|
||||||
|
/*result_types=*/op.results().getTypes(),
|
||||||
|
/*scatter_dimension=*/op.scatter_dimension())))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// CaseOp
|
// CaseOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -32,6 +32,19 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @reduce_scatter
|
||||||
|
func @reduce_scatter(%data: memref<4x16xf32>, %result:memref<4x4xf32>) {
|
||||||
|
"lmhlo.all_reduce_scatter"(%data, %result) ( {
|
||||||
|
// reduction computation
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
|
||||||
|
scatter_dimension = 1 : i64} : (memref<4x16xf32>, memref<4x4xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @mixed_types_allgather
|
// CHECK-LABEL: func @mixed_types_allgather
|
||||||
func @mixed_types_allgather(%a0: memref<1x1xf32>, %a1:memref<1x1xi32>) {
|
func @mixed_types_allgather(%a0: memref<1x1xf32>, %a1:memref<1x1xi32>) {
|
||||||
"lmhlo.all_gather"(%a0, %a1, %a0, %a1) {all_gather_dimension = 0 : i64,
|
"lmhlo.all_gather"(%a0, %a1, %a0, %a1) {all_gather_dimension = 0 : i64,
|
||||||
|
|
|
@ -13,6 +13,104 @@ func private @invalid_type() -> !mhlo.foobar
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @reduce_scatter
|
||||||
|
func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> {
|
||||||
|
%0 = "mhlo.all_reduce_scatter"(%data) ( {
|
||||||
|
// reduction computation
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
|
||||||
|
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32>
|
||||||
|
return %0 : tensor<4x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x5xf32> {
|
||||||
|
// expected-error@+1 {{operand scatter dimension has size 16, expected to be a multiple of result scatter dimension size 5}}
|
||||||
|
%0 = "mhlo.all_reduce_scatter"(%data) ( {
|
||||||
|
// reduction computation
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
|
||||||
|
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x5xf32>
|
||||||
|
return %0 : tensor<4x5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_reduce_scatter(%data: tensor<4x0xf32>) -> tensor<4x4xf32> {
|
||||||
|
// expected-error@+1 {{operand scatter dimension cannot be zero}}
|
||||||
|
%0 = "mhlo.all_reduce_scatter"(%data) ( {
|
||||||
|
// reduction computation
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
|
||||||
|
scatter_dimension = 1 : i64} : (tensor<4x0xf32>) -> tensor<4x4xf32>
|
||||||
|
return %0 : tensor<4x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x0xf32> {
|
||||||
|
// expected-error@+1 {{result scatter dimension cannot be zero}}
|
||||||
|
%0 = "mhlo.all_reduce_scatter"(%data) ( {
|
||||||
|
// reduction computation
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
|
||||||
|
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x0xf32>
|
||||||
|
return %0 : tensor<4x0xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4xf32> {
|
||||||
|
// expected-error@+1 {{operand and result should have same rank}}
|
||||||
|
%0 = "mhlo.all_reduce_scatter"(%data) ( {
|
||||||
|
// reduction computation
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
|
||||||
|
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4xf32>
|
||||||
|
return %0 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> {
|
||||||
|
// expected-error@+1 {{scatter dim should be less than operand/result rank}}
|
||||||
|
%0 = "mhlo.all_reduce_scatter"(%data) ( {
|
||||||
|
// reduction computation
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
|
||||||
|
scatter_dimension = 4 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32>
|
||||||
|
return %0 : tensor<4x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<3x4xf32> {
|
||||||
|
// expected-error@+1 {{non scatter dimensions should be same for operand (4) and result (3)}}
|
||||||
|
%0 = "mhlo.all_reduce_scatter"(%data) ( {
|
||||||
|
// reduction computation
|
||||||
|
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||||
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||||
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||||
|
}) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>,
|
||||||
|
scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<3x4xf32>
|
||||||
|
return %0 : tensor<3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @alltoall
|
// CHECK-LABEL: func @alltoall
|
||||||
func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> {
|
func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> {
|
||||||
%0 = "mhlo.all_to_all"(%data) {
|
%0 = "mhlo.all_to_all"(%data) {
|
||||||
|
|
Loading…
Reference in New Issue