diff --git a/BUILD b/BUILD index 041ed1c..77d153f 100644 --- a/BUILD +++ b/BUILD @@ -232,6 +232,17 @@ gentbl( deps = [":hlo_ops_td_files"], ) +cc_library( + name = "hlo_ops_common", + srcs = ["lib/Dialect/mhlo/IR/hlo_ops_common.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"], + includes = ["include"], + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "lhlo_gpu_ops_structs", srcs = [ @@ -399,14 +410,15 @@ cc_library( ], includes = ["include"], deps = [ - "hlo_ops_pattern_gen", ":canonicalize_inc_gen", ":chlo_ops_inc_gen", ":convert_op_folder", ":hlo_ops_base_enums", ":hlo_ops_base_inc_gen", ":hlo_ops_base_structs", + ":hlo_ops_common", ":hlo_ops_inc_gen", + ":hlo_ops_pattern_gen", ":infer_fusibility_op_interface", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -443,6 +455,7 @@ cc_library( ":hlo_ops_base_enums", ":hlo_ops_base_inc_gen", ":hlo_ops_base_structs", + ":hlo_ops_common", ":lhlo_ops_inc_gen", ":lhlo_ops_structs_inc_gen", "@llvm-project//llvm:Support", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h new file mode 100644 index 0000000..e5b4477 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_ + +// This file defines functionality shared between chlo/mhlo/lhlo dialects. + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace hlo { + +// Verifies the source target pairs attached to collective permute. +LogicalResult VerifyCollectivePermuteSourceTargetPairs( + Operation *op, DenseIntElementsAttr attr); + +} // namespace hlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 64bc6a2..d1aa3fe 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -584,6 +584,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, I64ElementsAttr:$source_target_pairs, OptionalAttr:$channel_id ); + let verifier = [{ return Verify(*this); }]; } def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp { diff --git a/lib/Dialect/mhlo/IR/CMakeLists.txt b/lib/Dialect/mhlo/IR/CMakeLists.txt index 575c578..35019fd 100644 --- a/lib/Dialect/mhlo/IR/CMakeLists.txt +++ b/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -40,6 +40,9 @@ add_mlir_library(MhloInferFusibilityOpInterface MLIRinfer_fusibility_op_interfaceIncGen ) +add_mlir_library(HloOpsCommon + hlo_ops_common.cc +) add_mlir_dialect_library(MhloDialect hlo_ops.cc @@ -57,6 +60,7 @@ target_link_libraries(MhloDialect MLIRIR MhloInferFusibilityOpInterface MLIRMhloUtils + HloOpsCommon ) @@ -67,7 +71,11 @@ add_mlir_dialect_library(LmhloDialect DEPENDS MLIRlhlo_opsIncGen ) -target_link_libraries(LmhloDialect PUBLIC MLIRIR) +target_link_libraries(LmhloDialect + PUBLIC + MLIRIR + HloOpsCommon +) add_mlir_dialect_library(LmhloGPUDialect lhlo_gpu_ops.cc diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 401dc16..c2a0fa8 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -36,6 +36,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" #include "mlir-hlo/utils/convert_op_folder.h" #include "mlir-hlo/utils/hlo_utils.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -479,32 +480,8 @@ LogicalResult AbsOp::inferReturnTypes( //===----------------------------------------------------------------------===// static LogicalResult Verify(CollectivePermuteOp op) { - // Check that source target pair is Nx2 tensor. - auto type = op.source_target_pairs().getType().dyn_cast(); - if (type.getRank() != 2) - return op.emitError() << "expect source_target_pairs attribute to be of " - "rank 2, but got rank " - << type.getRank(); - if (type.getShape()[1] != 2) - return op.emitError() - << "expect source_target_pairs attribute of shape (N, 2), but got (" - << type.getShape() << ")"; - // Check source target pairs for duplicate sources or targets - llvm::DenseSet sources; - llvm::DenseSet targets; - for (auto i = op.source_target_pairs().begin(), - e = op.source_target_pairs().end(); - i != e; ++i) { - auto val = (*i).getSExtValue(); - if (i.getIndex() % 2 == 0) { - bool is_unique = sources.insert(val).second; - if (!is_unique) return op.emitError() << "duplicate sources not allowed."; - } else { - bool is_unique = targets.insert(val).second; - if (!is_unique) return op.emitError() << "duplicate targets not allowed."; - } - } - return success(); + return mlir::hlo::VerifyCollectivePermuteSourceTargetPairs( + op, op.source_target_pairs()); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/IR/hlo_ops_common.cc b/lib/Dialect/mhlo/IR/hlo_ops_common.cc new file mode 100644 index 0000000..06bb29e --- /dev/null +++ b/lib/Dialect/mhlo/IR/hlo_ops_common.cc @@ -0,0 +1,54 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" + +#include "mlir/IR/BuiltinTypes.h" + +namespace mlir { +namespace hlo { + +// Verifies the source target pairs attached to collective permute. +LogicalResult VerifyCollectivePermuteSourceTargetPairs( + Operation *op, DenseIntElementsAttr attr) { + auto type = attr.getType().dyn_cast(); + if (type.getRank() != 2) + return op->emitError() << "expect source_target_pairs attribute to be of " + "rank 2, but got rank " + << type.getRank(); + if (type.getShape()[1] != 2) + return op->emitError() + << "expect source_target_pairs attribute of shape (N, 2), but got (" + << type.getShape() << ")"; + // Check source target pairs for duplicate sources or targets. + llvm::DenseSet sources; + llvm::DenseSet targets; + for (auto i = attr.begin(), e = attr.end(); i != e; ++i) { + auto val = (*i).getSExtValue(); + if (i.getIndex() % 2 == 0) { + bool is_unique = sources.insert(val).second; + if (!is_unique) + return op->emitError() << "duplicate sources not allowed."; + } else { + bool is_unique = targets.insert(val).second; + if (!is_unique) + return op->emitError() << "duplicate targets not allowed."; + } + } + return success(); +} + +} // namespace hlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 8614bd5..960e643 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" @@ -132,6 +133,15 @@ static LogicalResult Verify(AllReduceOp op) { return success(); } +//===----------------------------------------------------------------------===// +// CollectivePermuteOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CollectivePermuteOp op) { + return mlir::hlo::VerifyCollectivePermuteSourceTargetPairs( + op, op.source_target_pairs()); +} + //===----------------------------------------------------------------------===// // ConstOp. //===----------------------------------------------------------------------===// diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 5d9113c..a7fc702 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -788,6 +788,36 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128 // ----- +func @invalid_collective_permute(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () { + // expected-error@+1{{expect source_target_pairs attribute of shape (N, 2), but got (1, 3)}} + "lmhlo.collective_permute"(%arg0, %arg_out) { + source_target_pairs = dense<[[2, 3, 4]]> : tensor<1x3xi64> + } : (memref<128x32xf32>, memref<128x32xf32>) -> () + return +} + +// ----- + +func @invalid_collective_permute(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () { + // expected-error@+1{{duplicate sources not allowed.}} + "lmhlo.collective_permute"(%arg0, %arg_out) { + source_target_pairs = dense<[[1,2], [1,3]]> : tensor<2x2xi64> + } : (memref<128x32xf32>, memref<128x32xf32>) -> () + return +} + +// ----- + +func @invalid_collective_permute(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () { + // expected-error@+1{{duplicate targets not allowed.}} + "lmhlo.collective_permute"(%arg0, %arg_out) { + source_target_pairs = dense<[[1,2], [0,2]]> : tensor<2x2xi64> + } : (memref<128x32xf32>, memref<128x32xf32>) -> () + return +} + +// ----- + // CHECK-LABEL: func @fft_memrefs func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex>) -> () { "lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex>) -> ()