[HLO] Add LMHLO CollectivePermute verification.

- Extract verification of source target pairs attached to collective permute into a common
  helper function and use that to verify both MHLO and LMHLO variants.
- Change MlirGpuTestBase::ParseMlirModule to allow returning back a failure, and use
  that to update the mlir_gpu_compile_test to check the new behavior.

PiperOrigin-RevId: 362156962
This commit is contained in:
Rahul Joshi 2021-03-10 15:36:22 -08:00 committed by TensorFlow MLIR Team
parent 4f16b10ce2
commit 9902e6ee32
8 changed files with 155 additions and 28 deletions

15
BUILD
View File

@ -232,6 +232,17 @@ gentbl(
deps = [":hlo_ops_td_files"], deps = [":hlo_ops_td_files"],
) )
cc_library(
name = "hlo_ops_common",
srcs = ["lib/Dialect/mhlo/IR/hlo_ops_common.cc"],
hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"],
includes = ["include"],
deps = [
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
cc_library( cc_library(
name = "lhlo_gpu_ops_structs", name = "lhlo_gpu_ops_structs",
srcs = [ srcs = [
@ -399,14 +410,15 @@ cc_library(
], ],
includes = ["include"], includes = ["include"],
deps = [ deps = [
"hlo_ops_pattern_gen",
":canonicalize_inc_gen", ":canonicalize_inc_gen",
":chlo_ops_inc_gen", ":chlo_ops_inc_gen",
":convert_op_folder", ":convert_op_folder",
":hlo_ops_base_enums", ":hlo_ops_base_enums",
":hlo_ops_base_inc_gen", ":hlo_ops_base_inc_gen",
":hlo_ops_base_structs", ":hlo_ops_base_structs",
":hlo_ops_common",
":hlo_ops_inc_gen", ":hlo_ops_inc_gen",
":hlo_ops_pattern_gen",
":infer_fusibility_op_interface", ":infer_fusibility_op_interface",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
@ -443,6 +455,7 @@ cc_library(
":hlo_ops_base_enums", ":hlo_ops_base_enums",
":hlo_ops_base_inc_gen", ":hlo_ops_base_inc_gen",
":hlo_ops_base_structs", ":hlo_ops_base_structs",
":hlo_ops_common",
":lhlo_ops_inc_gen", ":lhlo_ops_inc_gen",
":lhlo_ops_structs_inc_gen", ":lhlo_ops_structs_inc_gen",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",

View File

@ -0,0 +1,34 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_
// This file defines functionality shared between chlo/mhlo/lhlo dialects.
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
namespace mlir {
namespace hlo {
// Verifies the source target pairs attached to collective permute.
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
Operation *op, DenseIntElementsAttr attr);
} // namespace hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_

View File

@ -584,6 +584,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
I64ElementsAttr:$source_target_pairs, I64ElementsAttr:$source_target_pairs,
OptionalAttr<ChannelHandle>:$channel_id OptionalAttr<ChannelHandle>:$channel_id
); );
let verifier = [{ return Verify(*this); }];
} }
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp { def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {

View File

@ -40,6 +40,9 @@ add_mlir_library(MhloInferFusibilityOpInterface
MLIRinfer_fusibility_op_interfaceIncGen MLIRinfer_fusibility_op_interfaceIncGen
) )
add_mlir_library(HloOpsCommon
hlo_ops_common.cc
)
add_mlir_dialect_library(MhloDialect add_mlir_dialect_library(MhloDialect
hlo_ops.cc hlo_ops.cc
@ -57,6 +60,7 @@ target_link_libraries(MhloDialect
MLIRIR MLIRIR
MhloInferFusibilityOpInterface MhloInferFusibilityOpInterface
MLIRMhloUtils MLIRMhloUtils
HloOpsCommon
) )
@ -67,7 +71,11 @@ add_mlir_dialect_library(LmhloDialect
DEPENDS DEPENDS
MLIRlhlo_opsIncGen MLIRlhlo_opsIncGen
) )
target_link_libraries(LmhloDialect PUBLIC MLIRIR) target_link_libraries(LmhloDialect
PUBLIC
MLIRIR
HloOpsCommon
)
add_mlir_dialect_library(LmhloGPUDialect add_mlir_dialect_library(LmhloGPUDialect
lhlo_gpu_ops.cc lhlo_gpu_ops.cc

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h" #include "llvm/Support/MathExtras.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir-hlo/utils/convert_op_folder.h" #include "mlir-hlo/utils/convert_op_folder.h"
#include "mlir-hlo/utils/hlo_utils.h" #include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
@ -479,32 +480,8 @@ LogicalResult AbsOp::inferReturnTypes(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(CollectivePermuteOp op) { static LogicalResult Verify(CollectivePermuteOp op) {
// Check that source target pair is Nx2 tensor. return mlir::hlo::VerifyCollectivePermuteSourceTargetPairs(
auto type = op.source_target_pairs().getType().dyn_cast<RankedTensorType>(); op, op.source_target_pairs());
if (type.getRank() != 2)
return op.emitError() << "expect source_target_pairs attribute to be of "
"rank 2, but got rank "
<< type.getRank();
if (type.getShape()[1] != 2)
return op.emitError()
<< "expect source_target_pairs attribute of shape (N, 2), but got ("
<< type.getShape() << ")";
// Check source target pairs for duplicate sources or targets
llvm::DenseSet<int64_t> sources;
llvm::DenseSet<int64_t> targets;
for (auto i = op.source_target_pairs().begin(),
e = op.source_target_pairs().end();
i != e; ++i) {
auto val = (*i).getSExtValue();
if (i.getIndex() % 2 == 0) {
bool is_unique = sources.insert(val).second;
if (!is_unique) return op.emitError() << "duplicate sources not allowed.";
} else {
bool is_unique = targets.insert(val).second;
if (!is_unique) return op.emitError() << "duplicate targets not allowed.";
}
}
return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,54 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir/IR/BuiltinTypes.h"
namespace mlir {
namespace hlo {
// Verifies the source target pairs attached to collective permute.
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
Operation *op, DenseIntElementsAttr attr) {
auto type = attr.getType().dyn_cast<RankedTensorType>();
if (type.getRank() != 2)
return op->emitError() << "expect source_target_pairs attribute to be of "
"rank 2, but got rank "
<< type.getRank();
if (type.getShape()[1] != 2)
return op->emitError()
<< "expect source_target_pairs attribute of shape (N, 2), but got ("
<< type.getShape() << ")";
// Check source target pairs for duplicate sources or targets.
llvm::DenseSet<int64_t> sources;
llvm::DenseSet<int64_t> targets;
for (auto i = attr.begin(), e = attr.end(); i != e; ++i) {
auto val = (*i).getSExtValue();
if (i.getIndex() % 2 == 0) {
bool is_unique = sources.insert(val).second;
if (!is_unique)
return op->emitError() << "duplicate sources not allowed.";
} else {
bool is_unique = targets.insert(val).second;
if (!is_unique)
return op->emitError() << "duplicate targets not allowed.";
}
}
return success();
}
} // namespace hlo
} // namespace mlir

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
@ -132,6 +133,15 @@ static LogicalResult Verify(AllReduceOp op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// CollectivePermuteOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(CollectivePermuteOp op) {
return mlir::hlo::VerifyCollectivePermuteSourceTargetPairs(
op, op.source_target_pairs());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConstOp. // ConstOp.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -788,6 +788,36 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128
// ----- // -----
func @invalid_collective_permute(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
// expected-error@+1{{expect source_target_pairs attribute of shape (N, 2), but got (1, 3)}}
"lmhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[2, 3, 4]]> : tensor<1x3xi64>
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
return
}
// -----
func @invalid_collective_permute(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
// expected-error@+1{{duplicate sources not allowed.}}
"lmhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[1,2], [1,3]]> : tensor<2x2xi64>
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
return
}
// -----
func @invalid_collective_permute(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
// expected-error@+1{{duplicate targets not allowed.}}
"lmhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[1,2], [0,2]]> : tensor<2x2xi64>
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @fft_memrefs // CHECK-LABEL: func @fft_memrefs
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () { func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
"lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> () "lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()