diff --git a/BUILD b/BUILD index b4dfa36..6fad5e9 100644 --- a/BUILD +++ b/BUILD @@ -201,14 +201,34 @@ gentbl( ], ) +gentbl( + name = "lhlo_ops_structs_inc_gen", + strip_include_prefix = "include", + tbl_outs = [ + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td", + td_includes = [ + "external/mlir-hlo/include", + "include", + ], + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td", + ], +) + gentbl( name = "lhlo_ops_inc_gen", strip_include_prefix = "include", tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"), ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"), - ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"), - ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", @@ -219,7 +239,11 @@ gentbl( td_relative_includes = [ "include", ], - td_srcs = [":hlo_ops_td_files"], + td_srcs = [ + ":hlo_ops_td_files", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td", + ], ) gentbl( @@ -476,7 +500,11 @@ cc_library( srcs = [ "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc", "lib/Dialect/mhlo/IR/lhlo_ops.cc", + "lib/Dialect/mhlo/IR/lhlo_ops_structs.cc", ], hdrs = [ "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt index d64f7cc..1e412f4 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt +++ b/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -25,7 +25,6 @@ function(add_mlir_hlo_dialect dialect dialect_namespace) endfunction() add_mlir_hlo_dialect(chlo_ops chlo) -add_mlir_hlo_dialect(lhlo_ops lmhlo) set(LLVM_TARGET_DEFINITIONS hlo_ops.td) mlir_tablegen(hlo_ops.h.inc -gen-op-decls) @@ -36,6 +35,15 @@ mlir_tablegen(hlo_ops_base_enums.h.inc -gen-enum-decls) mlir_tablegen(hlo_ops_base_enums.cc.inc -gen-enum-defs) add_public_tablegen_target(MLIRhlo_opsIncGen) +set(LLVM_TARGET_DEFINITIONS lhlo_ops.td) +mlir_tablegen(lhlo_ops.h.inc -gen-op-decls) +mlir_tablegen(lhlo_ops.cc.inc -gen-op-defs) +set(LLVM_TARGET_DEFINITIONS lhlo_ops_structs.td) +mlir_tablegen(lhlo_ops_structs.h.inc -gen-struct-attr-decls) +mlir_tablegen(lhlo_ops_structs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRlhlo_opsIncGen) +add_dependencies(mlir-headers MLIRlhlo_opsIncGen) + set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td) mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls) mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td new file mode 100644 index 0000000..7cddf4e --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_DIALECT +#define LHLO_DIALECT + +include "mlir/IR/OpBase.td" + +// We define the dialect here so that both structs and ops can refer to it. +def LHLO_Dialect : Dialect { + let name = "lmhlo"; + let cppNamespace = "::mlir::lmhlo"; +} + +#endif // LHLO_DIALECT diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index 7dfbfd6..deee865 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index d723aad..e2f835a 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -37,12 +37,9 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" - -def LHLO_Dialect : Dialect { - let name = "lmhlo"; - let cppNamespace = "::mlir::lmhlo"; -} +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td" //===----------------------------------------------------------------------===// // LMHLO nullary op definitions. @@ -274,8 +271,10 @@ def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>, Arg, "", [MemWrite]>:$output, StrAttr:$call_target_name, DefaultValuedAttr:$has_side_effect, - DefaultValuedAttr:$backend_config + DefaultValuedAttr:$backend_config, + OptionalAttr:$target_arg_mapping ); + let verifier = [{ return Verify(*this); }]; } //===----------------------------------------------------------------------===// diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h new file mode 100644 index 0000000..8b14843 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines structures used in LMHLO dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_ + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td new file mode 100644 index 0000000..d9ae1ca --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td @@ -0,0 +1,40 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_OPS_STRUCTS +#define LHLO_OPS_STRUCTS + +include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td" + +// This structure defines information about how arguments to the LHLO custom +// call operation relate to the arguments of the target function. In most cases +// the mapping will be 1:1, but in certain cases, it may not be. As an example, +// tokens are not represented in the LHLO dialect, but the custom call target +// might still expect to see buffer arguments corresponding to tokens, in which +// case the mapping will not be 1:1. +def CustomCallTargetArgMapping : StructAttr<"CustomCallTargetArgMapping", + LHLO_Dialect, [ + // number of buffer expected by the target for arguments. + StructFieldAttr<"num_args", I64Attr>, + // number of buffer expected by the target for results. + StructFieldAttr<"num_results", I64Attr>, + // map each custom call op arg to its position in target args. + StructFieldAttr<"args_to_target_args", I64ArrayAttr>, + // map each custom call op arg to its position in target results. + StructFieldAttr<"results_to_target_results", I64ArrayAttr>]> { + let summary = "Custom call operands to target argument mapping info"; +} + +#endif // LHLO_OPS_STRUCTS diff --git a/lib/Dialect/mhlo/IR/CMakeLists.txt b/lib/Dialect/mhlo/IR/CMakeLists.txt index 98983e7..575c578 100644 --- a/lib/Dialect/mhlo/IR/CMakeLists.txt +++ b/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -62,6 +62,7 @@ target_link_libraries(MhloDialect add_mlir_dialect_library(LmhloDialect lhlo_ops.cc + lhlo_ops_structs.cc DEPENDS MLIRlhlo_opsIncGen diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 048623a..8614bd5 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include + #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -162,6 +164,56 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results, results.insert(context); } +//===----------------------------------------------------------------------===// +// CustomCallOp. +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CustomCallOp op) { + if (op.target_arg_mapping()) { + CustomCallTargetArgMapping mapping = *op.target_arg_mapping(); + auto verify_mapping = [&](int64_t target_num, size_t op_num, + ArrayAttr mapping, + StringRef kind) -> LogicalResult { + if (target_num < op_num) + return op.emitOpError("number of target " + kind + " (") + << target_num << ") cannot be less than the number of " << kind + << "(" << op_num << ") for the operation"; + + if (mapping.size() != op_num) + return op.emitOpError("number of entries in the mapping for " + kind + + " (") + << mapping.size() << ") should match the number of " << kind + << " for the operation (" << op_num << ")"; + + std::unordered_set entries; + // Each entry in the mapping should be < target_num and an entry cannot + // appear more than once. + for (Attribute entry : mapping) { + int64_t int_entry = entry.cast().getInt(); + // ODS verification will ensure that these entries are integers. + if (!entries.insert(int_entry).second) + return op.emitOpError("entry ") + << int_entry + << " cannot appear more than once in the mapping for " << kind; + if (int_entry < 0 || int_entry >= target_num) + return op.emitOpError( + "entries in mapping for " + kind + + " must be >= 0 and less than target's number of " + kind + + " (") + << target_num << ")"; + } + return success(); + }; + if (failed(verify_mapping(mapping.num_args().getInt(), op.args().size(), + mapping.args_to_target_args(), "args")) || + failed(verify_mapping(mapping.num_results().getInt(), + op.output().size(), + mapping.results_to_target_results(), "results"))) + return failure(); + } + return success(); +} + } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc b/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc index 83dd4e6..b1b453b 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc @@ -13,5 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h" + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc" diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 021ea31..5d9113c 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -989,3 +989,136 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, }) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () return } + +// ----- + +// CHECK-LABEL: func @valid_custom_call +func @valid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + backend_config = "", + call_target_name = "foo", + has_side_effects = false, + operand_segment_sizes = dense<2> : vector<2xi32>, + target_arg_mapping = { + num_args = 4 : i64, + num_results = 3 : i64, + args_to_target_args = [0,3], + results_to_target_results = [1,2] + } + } : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { + // expected-error @+1 {{number of entries in the mapping for args (1) should match the number of args for the operation (2)}} + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + backend_config = "", + call_target_name = "foo", + has_side_effects = false, + operand_segment_sizes = dense<2> : vector<2xi32>, + target_arg_mapping = { + num_args = 4 : i64, + num_results = 3 : i64, + args_to_target_args = [0], + results_to_target_results = [1,2] + } + } : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { + // expected-error @+1 {{number of entries in the mapping for results (1) should match the number of results for the operation (2)}} + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + backend_config = "", + call_target_name = "foo", + has_side_effects = false, + operand_segment_sizes = dense<2> : vector<2xi32>, + target_arg_mapping = { + num_args = 4 : i64, + num_results = 3 : i64, + args_to_target_args = [0, 3], + results_to_target_results = [1] + } + } : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { + // expected-error @+1 {{entry 0 cannot appear more than once in the mapping for args}} + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + backend_config = "", + call_target_name = "foo", + has_side_effects = false, + operand_segment_sizes = dense<2> : vector<2xi32>, + target_arg_mapping = { + num_args = 4 : i64, + num_results = 3 : i64, + args_to_target_args = [0, 0], + results_to_target_results = [1, 2] + } + } : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { + // expected-error @+1 {{entry 1 cannot appear more than once in the mapping for results}} + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + backend_config = "", + call_target_name = "foo", + has_side_effects = false, + operand_segment_sizes = dense<2> : vector<2xi32>, + target_arg_mapping = { + num_args = 4 : i64, + num_results = 3 : i64, + args_to_target_args = [0, 1], + results_to_target_results = [1, 1] + } + } : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { + // expected-error @+1 {{entries in mapping for args must be >= 0 and less than target's number of args (4)}} + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + backend_config = "", + call_target_name = "foo", + has_side_effects = false, + operand_segment_sizes = dense<2> : vector<2xi32>, + target_arg_mapping = { + num_args = 4 : i64, + num_results = 3 : i64, + args_to_target_args = [0, 6], + results_to_target_results = [1, 2] + } + } : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { + // expected-error @+1 {{entries in mapping for results must be >= 0 and less than target's number of results (3)}} + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + backend_config = "", + call_target_name = "foo", + has_side_effects = false, + operand_segment_sizes = dense<2> : vector<2xi32>, + target_arg_mapping = { + num_args = 4 : i64, + num_results = 3 : i64, + args_to_target_args = [0, 1], + results_to_target_results = [1, 3] + } + } : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +}