[MLIR:HLO] Extend AllReduce to support multiple inputs and results (to model tuples).

- Instead of SameTypeOperands, add custom verification to check if operands and
  results pairwise have the same type.

PiperOrigin-RevId: 353986341
This commit is contained in:
Rahul Joshi 2021-01-26 17:23:49 -08:00 committed by TensorFlow MLIR Team
parent 471fc63c11
commit 44deae2aa1
3 changed files with 72 additions and 6 deletions

View File

@ -544,17 +544,18 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
);
}
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameVariadicOperandSize]>,
BASE_HLO_AllReduceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
I64ElementsAttr:$replica_groups,
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
OptionalAttr<ChannelHandle>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
);
let regions = (region SizedRegion<1>:$computation);
let verifier = [{ return Verify(*this); }];
}
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,

View File

@ -48,7 +48,7 @@ limitations under the License.
namespace mlir {
namespace lmhlo {
LmhloDialect::LmhloDialect(MLIRContext *context)
LmhloDialect::LmhloDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context, TypeID::get<LmhloDialect>()) {
addOperations<
#define GET_OP_LIST
@ -56,6 +56,41 @@ LmhloDialect::LmhloDialect(MLIRContext *context)
>();
}
//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(AllReduceOp op) {
// AllReduce had variadic operands and results that have the same size.
// Each memeber of the operand should have the same type as the corresponding
// member of the result.
for (auto it : llvm::enumerate(
llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
Type operandType = std::get<0>(it.value());
Type resultType = std::get<1>(it.value());
if (operandType != resultType)
return op.emitOpError("requires operand #")
<< it.index() << " (type: " << operandType << ") and result #"
<< it.index() << " (type: " << resultType << ") to have same type";
}
// Since AllReduce has a single reduction computation attached to it (which is
// applied over all the operands and results), they all need to have the same
// element type. Since we already check that each operand and corresponding
// result has the same type, its sufficient to check just the memref element
// type for each operands.
Type elementType =
op.operands().front().getType().cast<MemRefType>().getElementType();
bool allMatch = llvm::all_of(
op.operands().drop_front().getType(), [elementType](Type type) {
return type.cast<MemRefType>().getElementType() == elementType;
});
if (!allMatch)
return op.emitOpError("requires all operands to have same element type");
return success();
}
//===----------------------------------------------------------------------===//
// ConstOp.
//===----------------------------------------------------------------------===//
@ -99,10 +134,10 @@ namespace lmhlo {
// TODO(cheshire): Support folding, reuse code from hlo_ops.cc.
void FusionOp::build(OpBuilder &builder, OperationState &result,
void FusionOp::build(OpBuilder& builder, OperationState& result,
ArrayRef<NamedAttribute> attributes) {
result.addAttributes(attributes);
Region *bodyRegion = result.addRegion();
Region* bodyRegion = result.addRegion();
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
}

View File

@ -2,6 +2,36 @@
// -----
func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf32>) {
// expected-error@+1 {{requires operand #1 (type: 'memref<3xf32>') and result #1 (type: 'memref<2xf32>') to have same type}}
"lmhlo.all_reduce"(%input0, %input1, %input0, %input0) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%add = mhlo.add %arg0, %arg1 : tensor<f32>
"mhlo.return"(%add) : (tensor<f32>) -> ()
})
{channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false,
replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>,
use_global_device_ids = false} : (memref<2xf32>, memref<3xf32>, memref<2xf32>, memref<2xf32>) -> ()
return
}
// -----
func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) {
// expected-error@+1 {{requires all operands to have same element type}}
"lmhlo.all_reduce"(%input0, %input1, %input0, %input1) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%add = mhlo.add %arg0, %arg1 : tensor<f32>
"mhlo.return"(%add) : (tensor<f32>) -> ()
})
{channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false,
replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>,
use_global_device_ids = false} : (memref<2xf32>, memref<3xf16>, memref<2xf32>, memref<3xf16>) -> ()
return
}
// -----
// CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()