From 5fbdac34a96d61df1e0dfb45c5b894973b11b425 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 15 Jun 2021 03:54:24 -0700 Subject: [PATCH] [XLA:GPU] Add AllReduce{Start,Done} to MLIR LHLO dialect. PiperOrigin-RevId: 379455720 --- BUILD | 2 + .../mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td | 27 +++++ include/mlir-hlo/utils/lhlo_utils.h | 100 ++++++++++++++++++ lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc | 9 ++ lib/Dialect/mhlo/IR/lhlo_ops.cc | 60 +---------- 5 files changed, 140 insertions(+), 58 deletions(-) create mode 100644 include/mlir-hlo/utils/lhlo_utils.h diff --git a/BUILD b/BUILD index 65b049c..ede0238 100644 --- a/BUILD +++ b/BUILD @@ -546,6 +546,7 @@ cc_library( hdrs = [ "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h", + "include/mlir-hlo/utils/lhlo_utils.h", ], includes = ["include"], deps = [ @@ -590,6 +591,7 @@ cc_library( ":hlo_ops_base_structs", ":hlo_ops_common", ":infer_fusibility_op_interface", + ":lhlo", ":lhlo_gpu_ops_enums", ":lhlo_gpu_ops_inc_gen", ":lhlo_gpu_ops_structs", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td index f087a99..2c0b18d 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -230,4 +230,31 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { BoolAttr:$is_lower); } +def LHLOGPU_AllReduceStartOp : + LHLOGPU_Op<"all_reduce_start", [SameOperandsElementType, SameVariadicOperandSize]> { + let summary = "AllReduceStart operator"; + let description = [{ + Performs an asynchronous custom reduction across replicas. + }]; + let arguments = (ins + Arg, "", [MemRead]>:$operands, + Arg, "", [MemWrite]>:$results, + I64ElementsAttr:$replica_groups, + DefaultValuedAttr:$constrain_layout, + OptionalAttr:$channel_id, + DefaultValuedAttr:$use_global_device_ids + ); + let regions = (region SizedRegion<1>:$computation); + let verifier = [{ return Verify(*this); }]; +} + +def LHLOGPU_AllReduceDoneOp: + LHLOGPU_Op<"all_reduce_done", [SameVariadicOperandSize]> { + let summary = "AllReduceDone operator"; + let arguments = (ins + Arg, "", [MemRead]>:$operands, + Arg, "", [MemWrite]>:$results + ); +} + #endif // LHLO_GPU_OPS diff --git a/include/mlir-hlo/utils/lhlo_utils.h b/include/mlir-hlo/utils/lhlo_utils.h new file mode 100644 index 0000000..0dbbb18 --- /dev/null +++ b/include/mlir-hlo/utils/lhlo_utils.h @@ -0,0 +1,100 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_ + +#include "llvm/ADT/SmallSet.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace lmhlo { + +// Verifies replica groups attached to collective communication operations. +// If the attribute is not empty, it must be a rank 2 tensor, and each replica +// should appear exactly once. If `is_uniform_sized` is true, then we also check +// that each group is of the same size. If the operation has +// `use_global_device_ids` set, then replica group cannot be empty. +template +LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) { + DenseIntElementsAttr attr = op.replica_groups(); + auto replica_group_type = attr.getType().dyn_cast(); + if (!replica_group_type || replica_group_type.getRank() != 2 || + !replica_group_type.getElementType().isInteger(/*width=*/64)) + return op.emitOpError( + "replica groups should be a rank 2 tensor of 64 bit integers"); + + if (replica_group_type.getShape().equals(ArrayRef{0, 0})) { + if (op.use_global_device_ids()) { + return op.emitOpError( + "if `use_global_device_ids` is set, the replica groups cannot be " + "empty"); + } + return success(); + } + + int64_t max_replica_id_seen = 0; + llvm::SmallSet replica_seen; + for (int64_t id : attr.getValues()) { + // Replica groups are stored in a 2D tensor. If the op supports non-uniform + // groups, null replica IDs are stored as -1. + if (id == -1) { + if (is_uniform_sized) { + return op.emitOpError("Invalid replica id -1"); + } + continue; + } + + if (!replica_seen.insert(id).second) { + return op.emitOpError("replica id #") << id << " seen more than once"; + } + max_replica_id_seen = std::max(max_replica_id_seen, id); + } + + for (int64_t id = 0; id <= max_replica_id_seen; id++) { + if (!replica_seen.contains(id)) { + return op.emitOpError("replica id #") + << id << " not seen in replica groups"; + } + } + return success(); +} + +template +static LogicalResult VerifyAllReduce(OpT op) { + if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false))) + return failure(); + + // AllReduce has variadic operands and results that have the same size. + // Each member of the operand should have the same type as the corresponding + // member of the result. + for (auto it : llvm::enumerate( + llvm::zip(op.operands().getTypes(), op.results().getTypes()))) { + Type operandType = std::get<0>(it.value()); + Type resultType = std::get<1>(it.value()); + if (operandType != resultType) + return op.emitOpError("requires operand #") + << it.index() << " (type: " << operandType << ") and result #" + << it.index() << " (type: " << resultType << ") to have same type"; + } + return success(); +} + +} // namespace lmhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_LHLO_UTILS_H_ diff --git a/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc b/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc index 42c97ac..4f6d407 100644 --- a/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" +#include "mlir-hlo/utils/lhlo_utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -61,6 +62,14 @@ LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context) using mlir::hlo::parseWindowAttributes; using mlir::hlo::printWindowAttributes; +//===----------------------------------------------------------------------===// +// AllReduceStartOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AllReduceStartOp op) { + return lmhlo::VerifyAllReduce(op); +} + } // namespace lmhlo_gpu } // namespace mlir diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 9e00c91..9fc472b 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -33,6 +33,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir-hlo/utils/lhlo_utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" @@ -86,46 +87,6 @@ static LogicalResult Verify(AbsOp op) { // AllToAllOp //===----------------------------------------------------------------------===// -// Verifies replica groups attached to collective communication operations. -// If the attribute is not empty, it must be a rank 2 tensor, and each replica -// should appear exactly once. If `is_uniform_sized` is true, then we also check -// that each group is of the same size. If the operation has -// `use_global_device_id` set, then replica group cannot be empty. -template -LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) { - DenseIntElementsAttr attr = op.replica_groups(); - auto replica_group_type = attr.getType().dyn_cast(); - if (!replica_group_type || replica_group_type.getRank() != 2 || - !replica_group_type.getElementType().isInteger(/*width=*/64)) - return op.emitOpError( - "replica groups should be a rank 2 tensor of 64 bit integers"); - - if (replica_group_type.getShape().equals(ArrayRef{0, 0})) - return success(); - - int64_t max_replica_id_seen = 0; - llvm::SmallSet replica_seen; - for (int64_t id : attr.getValues()) { - if (is_uniform_sized && id == -1) { - return op.emitOpError("Invalid replica id -1"); - } - if (id != -1) { - if (!replica_seen.insert(id).second) { - return op.emitOpError("replica id #") << id << " seen more than once"; - } - max_replica_id_seen = std::max(max_replica_id_seen, id); - } - } - - for (int64_t id = 0; id <= max_replica_id_seen; id++) { - if (!replica_seen.contains(id)) { - return op.emitOpError("replica id #") - << id << " not seen in replica groups"; - } - } - return success(); -} - // TODO(jurahul): Add verification for output shape. static LogicalResult Verify(AllGatherOp op) { return VerifyReplicaGroups(op, /*is_uniform_sized=*/true); @@ -140,24 +101,7 @@ static LogicalResult Verify(AllToAllOp op) { // AllReduceOp //===----------------------------------------------------------------------===// -static LogicalResult Verify(AllReduceOp op) { - if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false))) - return failure(); - - // AllReduce has variadic operands and results that have the same size. - // Each member of the operand should have the same type as the corresponding - // member of the result. - for (auto it : llvm::enumerate( - llvm::zip(op.operands().getTypes(), op.results().getTypes()))) { - Type operandType = std::get<0>(it.value()); - Type resultType = std::get<1>(it.value()); - if (operandType != resultType) - return op.emitOpError("requires operand #") - << it.index() << " (type: " << operandType << ") and result #" - << it.index() << " (type: " << resultType << ") to have same type"; - } - return success(); -} +static LogicalResult Verify(AllReduceOp op) { return VerifyAllReduce(op); } //===----------------------------------------------------------------------===// // AllReduceScatterOp