[XLA:GPU] Add AllReduce{Start,Done} to MLIR LHLO dialect.

PiperOrigin-RevId: 379455720
This commit is contained in:
Chris Jones 2021-06-15 03:54:24 -07:00 committed by TensorFlow MLIR Team
parent 399dae666d
commit 5fbdac34a9
5 changed files with 140 additions and 58 deletions

2
BUILD
View File

@ -546,6 +546,7 @@ cc_library(
hdrs = [ hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h",
"include/mlir-hlo/utils/lhlo_utils.h",
], ],
includes = ["include"], includes = ["include"],
deps = [ deps = [
@ -590,6 +591,7 @@ cc_library(
":hlo_ops_base_structs", ":hlo_ops_base_structs",
":hlo_ops_common", ":hlo_ops_common",
":infer_fusibility_op_interface", ":infer_fusibility_op_interface",
":lhlo",
":lhlo_gpu_ops_enums", ":lhlo_gpu_ops_enums",
":lhlo_gpu_ops_inc_gen", ":lhlo_gpu_ops_inc_gen",
":lhlo_gpu_ops_structs", ":lhlo_gpu_ops_structs",

View File

@ -230,4 +230,31 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
BoolAttr:$is_lower); BoolAttr:$is_lower);
} }
def LHLOGPU_AllReduceStartOp :
LHLOGPU_Op<"all_reduce_start", [SameOperandsElementType, SameVariadicOperandSize]> {
let summary = "AllReduceStart operator";
let description = [{
Performs an asynchronous custom reduction across replicas.
}];
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
I64ElementsAttr:$replica_groups,
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
OptionalAttr<ChannelHandle>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
);
let regions = (region SizedRegion<1>:$computation);
let verifier = [{ return Verify(*this); }];
}
def LHLOGPU_AllReduceDoneOp:
LHLOGPU_Op<"all_reduce_done", [SameVariadicOperandSize]> {
let summary = "AllReduceDone operator";
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results
);
}
#endif // LHLO_GPU_OPS #endif // LHLO_GPU_OPS

View File

@ -0,0 +1,100 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_
#include "llvm/ADT/SmallSet.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
namespace mlir {
namespace lmhlo {
// Verifies replica groups attached to collective communication operations.
// If the attribute is not empty, it must be a rank 2 tensor, and each replica
// should appear exactly once. If `is_uniform_sized` is true, then we also check
// that each group is of the same size. If the operation has
// `use_global_device_ids` set, then replica group cannot be empty.
template <typename OpT>
LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) {
DenseIntElementsAttr attr = op.replica_groups();
auto replica_group_type = attr.getType().dyn_cast<RankedTensorType>();
if (!replica_group_type || replica_group_type.getRank() != 2 ||
!replica_group_type.getElementType().isInteger(/*width=*/64))
return op.emitOpError(
"replica groups should be a rank 2 tensor of 64 bit integers");
if (replica_group_type.getShape().equals(ArrayRef<int64_t>{0, 0})) {
if (op.use_global_device_ids()) {
return op.emitOpError(
"if `use_global_device_ids` is set, the replica groups cannot be "
"empty");
}
return success();
}
int64_t max_replica_id_seen = 0;
llvm::SmallSet<int64_t, 8> replica_seen;
for (int64_t id : attr.getValues<int64_t>()) {
// Replica groups are stored in a 2D tensor. If the op supports non-uniform
// groups, null replica IDs are stored as -1.
if (id == -1) {
if (is_uniform_sized) {
return op.emitOpError("Invalid replica id -1");
}
continue;
}
if (!replica_seen.insert(id).second) {
return op.emitOpError("replica id #") << id << " seen more than once";
}
max_replica_id_seen = std::max(max_replica_id_seen, id);
}
for (int64_t id = 0; id <= max_replica_id_seen; id++) {
if (!replica_seen.contains(id)) {
return op.emitOpError("replica id #")
<< id << " not seen in replica groups";
}
}
return success();
}
template <typename OpT>
static LogicalResult VerifyAllReduce(OpT op) {
if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
return failure();
// AllReduce has variadic operands and results that have the same size.
// Each member of the operand should have the same type as the corresponding
// member of the result.
for (auto it : llvm::enumerate(
llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
Type operandType = std::get<0>(it.value());
Type resultType = std::get<1>(it.value());
if (operandType != resultType)
return op.emitOpError("requires operand #")
<< it.index() << " (type: " << operandType << ") and result #"
<< it.index() << " (type: " << resultType << ") to have same type";
}
return success();
}
} // namespace lmhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir-hlo/utils/lhlo_utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
@ -61,6 +62,14 @@ LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context)
using mlir::hlo::parseWindowAttributes; using mlir::hlo::parseWindowAttributes;
using mlir::hlo::printWindowAttributes; using mlir::hlo::printWindowAttributes;
//===----------------------------------------------------------------------===//
// AllReduceStartOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(AllReduceStartOp op) {
return lmhlo::VerifyAllReduce(op);
}
} // namespace lmhlo_gpu } // namespace lmhlo_gpu
} // namespace mlir } // namespace mlir

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
#include "mlir-hlo/utils/lhlo_utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
@ -86,46 +87,6 @@ static LogicalResult Verify(AbsOp op) {
// AllToAllOp // AllToAllOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Verifies replica groups attached to collective communication operations.
// If the attribute is not empty, it must be a rank 2 tensor, and each replica
// should appear exactly once. If `is_uniform_sized` is true, then we also check
// that each group is of the same size. If the operation has
// `use_global_device_id` set, then replica group cannot be empty.
template <typename OpT>
LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) {
DenseIntElementsAttr attr = op.replica_groups();
auto replica_group_type = attr.getType().dyn_cast<RankedTensorType>();
if (!replica_group_type || replica_group_type.getRank() != 2 ||
!replica_group_type.getElementType().isInteger(/*width=*/64))
return op.emitOpError(
"replica groups should be a rank 2 tensor of 64 bit integers");
if (replica_group_type.getShape().equals(ArrayRef<int64_t>{0, 0}))
return success();
int64_t max_replica_id_seen = 0;
llvm::SmallSet<int64_t, 8> replica_seen;
for (int64_t id : attr.getValues<int64_t>()) {
if (is_uniform_sized && id == -1) {
return op.emitOpError("Invalid replica id -1");
}
if (id != -1) {
if (!replica_seen.insert(id).second) {
return op.emitOpError("replica id #") << id << " seen more than once";
}
max_replica_id_seen = std::max(max_replica_id_seen, id);
}
}
for (int64_t id = 0; id <= max_replica_id_seen; id++) {
if (!replica_seen.contains(id)) {
return op.emitOpError("replica id #")
<< id << " not seen in replica groups";
}
}
return success();
}
// TODO(jurahul): Add verification for output shape. // TODO(jurahul): Add verification for output shape.
static LogicalResult Verify(AllGatherOp op) { static LogicalResult Verify(AllGatherOp op) {
return VerifyReplicaGroups(op, /*is_uniform_sized=*/true); return VerifyReplicaGroups(op, /*is_uniform_sized=*/true);
@ -140,24 +101,7 @@ static LogicalResult Verify(AllToAllOp op) {
// AllReduceOp // AllReduceOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(AllReduceOp op) { static LogicalResult Verify(AllReduceOp op) { return VerifyAllReduce(op); }
if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
return failure();
// AllReduce has variadic operands and results that have the same size.
// Each member of the operand should have the same type as the corresponding
// member of the result.
for (auto it : llvm::enumerate(
llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
Type operandType = std::get<0>(it.value());
Type resultType = std::get<1>(it.value());
if (operandType != resultType)
return op.emitOpError("requires operand #")
<< it.index() << " (type: " << operandType << ") and result #"
<< it.index() << " (type: " << resultType << ") to have same type";
}
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AllReduceScatterOp // AllReduceScatterOp