[HLO] Add LMHLO CollectivePermute verification.
- Extract verification of source target pairs attached to collective permute into a common helper function and use that to verify both MHLO and LMHLO variants. - Change MlirGpuTestBase::ParseMlirModule to allow returning back a failure, and use that to update the mlir_gpu_compile_test to check the new behavior. PiperOrigin-RevId: 362156962
This commit is contained in:
parent
4f16b10ce2
commit
9902e6ee32
15
BUILD
15
BUILD
|
@ -232,6 +232,17 @@ gentbl(
|
|||
deps = [":hlo_ops_td_files"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_ops_common",
|
||||
srcs = ["lib/Dialect/mhlo/IR/hlo_ops_common.cc"],
|
||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lhlo_gpu_ops_structs",
|
||||
srcs = [
|
||||
|
@ -399,14 +410,15 @@ cc_library(
|
|||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
"hlo_ops_pattern_gen",
|
||||
":canonicalize_inc_gen",
|
||||
":chlo_ops_inc_gen",
|
||||
":convert_op_folder",
|
||||
":hlo_ops_base_enums",
|
||||
":hlo_ops_base_inc_gen",
|
||||
":hlo_ops_base_structs",
|
||||
":hlo_ops_common",
|
||||
":hlo_ops_inc_gen",
|
||||
":hlo_ops_pattern_gen",
|
||||
":infer_fusibility_op_interface",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
|
@ -443,6 +455,7 @@ cc_library(
|
|||
":hlo_ops_base_enums",
|
||||
":hlo_ops_base_inc_gen",
|
||||
":hlo_ops_base_structs",
|
||||
":hlo_ops_common",
|
||||
":lhlo_ops_inc_gen",
|
||||
":lhlo_ops_structs_inc_gen",
|
||||
"@llvm-project//llvm:Support",
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_
|
||||
|
||||
// This file defines functionality shared between chlo/mhlo/lhlo dialects.
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
|
||||
// Verifies the source target pairs attached to collective permute.
|
||||
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
||||
Operation *op, DenseIntElementsAttr attr);
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_
|
|
@ -584,6 +584,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
|
|||
I64ElementsAttr:$source_target_pairs,
|
||||
OptionalAttr<ChannelHandle>:$channel_id
|
||||
);
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
|
||||
|
|
|
@ -40,6 +40,9 @@ add_mlir_library(MhloInferFusibilityOpInterface
|
|||
MLIRinfer_fusibility_op_interfaceIncGen
|
||||
)
|
||||
|
||||
add_mlir_library(HloOpsCommon
|
||||
hlo_ops_common.cc
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MhloDialect
|
||||
hlo_ops.cc
|
||||
|
@ -57,6 +60,7 @@ target_link_libraries(MhloDialect
|
|||
MLIRIR
|
||||
MhloInferFusibilityOpInterface
|
||||
MLIRMhloUtils
|
||||
HloOpsCommon
|
||||
)
|
||||
|
||||
|
||||
|
@ -67,7 +71,11 @@ add_mlir_dialect_library(LmhloDialect
|
|||
DEPENDS
|
||||
MLIRlhlo_opsIncGen
|
||||
)
|
||||
target_link_libraries(LmhloDialect PUBLIC MLIRIR)
|
||||
target_link_libraries(LmhloDialect
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
HloOpsCommon
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(LmhloGPUDialect
|
||||
lhlo_gpu_ops.cc
|
||||
|
|
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
|
||||
#include "mlir-hlo/utils/convert_op_folder.h"
|
||||
#include "mlir-hlo/utils/hlo_utils.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
|
@ -479,32 +480,8 @@ LogicalResult AbsOp::inferReturnTypes(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(CollectivePermuteOp op) {
|
||||
// Check that source target pair is Nx2 tensor.
|
||||
auto type = op.source_target_pairs().getType().dyn_cast<RankedTensorType>();
|
||||
if (type.getRank() != 2)
|
||||
return op.emitError() << "expect source_target_pairs attribute to be of "
|
||||
"rank 2, but got rank "
|
||||
<< type.getRank();
|
||||
if (type.getShape()[1] != 2)
|
||||
return op.emitError()
|
||||
<< "expect source_target_pairs attribute of shape (N, 2), but got ("
|
||||
<< type.getShape() << ")";
|
||||
// Check source target pairs for duplicate sources or targets
|
||||
llvm::DenseSet<int64_t> sources;
|
||||
llvm::DenseSet<int64_t> targets;
|
||||
for (auto i = op.source_target_pairs().begin(),
|
||||
e = op.source_target_pairs().end();
|
||||
i != e; ++i) {
|
||||
auto val = (*i).getSExtValue();
|
||||
if (i.getIndex() % 2 == 0) {
|
||||
bool is_unique = sources.insert(val).second;
|
||||
if (!is_unique) return op.emitError() << "duplicate sources not allowed.";
|
||||
} else {
|
||||
bool is_unique = targets.insert(val).second;
|
||||
if (!is_unique) return op.emitError() << "duplicate targets not allowed.";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
return mlir::hlo::VerifyCollectivePermuteSourceTargetPairs(
|
||||
op, op.source_target_pairs());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace hlo {
|
||||
|
||||
// Verifies the source target pairs attached to collective permute.
|
||||
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
||||
Operation *op, DenseIntElementsAttr attr) {
|
||||
auto type = attr.getType().dyn_cast<RankedTensorType>();
|
||||
if (type.getRank() != 2)
|
||||
return op->emitError() << "expect source_target_pairs attribute to be of "
|
||||
"rank 2, but got rank "
|
||||
<< type.getRank();
|
||||
if (type.getShape()[1] != 2)
|
||||
return op->emitError()
|
||||
<< "expect source_target_pairs attribute of shape (N, 2), but got ("
|
||||
<< type.getShape() << ")";
|
||||
// Check source target pairs for duplicate sources or targets.
|
||||
llvm::DenseSet<int64_t> sources;
|
||||
llvm::DenseSet<int64_t> targets;
|
||||
for (auto i = attr.begin(), e = attr.end(); i != e; ++i) {
|
||||
auto val = (*i).getSExtValue();
|
||||
if (i.getIndex() % 2 == 0) {
|
||||
bool is_unique = sources.insert(val).second;
|
||||
if (!is_unique)
|
||||
return op->emitError() << "duplicate sources not allowed.";
|
||||
} else {
|
||||
bool is_unique = targets.insert(val).second;
|
||||
if (!is_unique)
|
||||
return op->emitError() << "duplicate targets not allowed.";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -132,6 +133,15 @@ static LogicalResult Verify(AllReduceOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CollectivePermuteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(CollectivePermuteOp op) {
|
||||
return mlir::hlo::VerifyCollectivePermuteSourceTargetPairs(
|
||||
op, op.source_target_pairs());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -788,6 +788,36 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128
|
|||
|
||||
// -----
|
||||
|
||||
func @invalid_collective_permute(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
|
||||
// expected-error@+1{{expect source_target_pairs attribute of shape (N, 2), but got (1, 3)}}
|
||||
"lmhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[2, 3, 4]]> : tensor<1x3xi64>
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_collective_permute(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
|
||||
// expected-error@+1{{duplicate sources not allowed.}}
|
||||
"lmhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[1,2], [1,3]]> : tensor<2x2xi64>
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_collective_permute(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
|
||||
// expected-error@+1{{duplicate targets not allowed.}}
|
||||
"lmhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[1,2], [0,2]]> : tensor<2x2xi64>
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fft_memrefs
|
||||
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
|
||||
"lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
|
||||
|
|
Loading…
Reference in New Issue