[MLIR:LHLO] Add optional call target arg mapping to LMHLO CustomCall operations.
- XLA:HLO -> LMHLO conversion drops all token arguments and return values, however custom calls that users write still expect to get buffer pointers for these token types. - To be able to support this, add an optional call target argument mapping attribute to LMHLO custom calls. When this attribute is present, it indicates the number of arguments and returns that the custom call expects and also indicates which LMHLO arg() or output() maps to which arg or result number of the custom call. PiperOrigin-RevId: 358826664
This commit is contained in:
parent
a9cc1dcfa0
commit
5adb7c6e12
34
BUILD
34
BUILD
|
@ -201,14 +201,34 @@ gentbl(
|
|||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "lhlo_ops_structs_inc_gen",
|
||||
strip_include_prefix = "include",
|
||||
tbl_outs = [
|
||||
("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"),
|
||||
("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td",
|
||||
td_includes = [
|
||||
"external/mlir-hlo/include",
|
||||
"include",
|
||||
],
|
||||
td_relative_includes = [
|
||||
"include",
|
||||
],
|
||||
td_srcs = [
|
||||
":hlo_ops_td_files",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "lhlo_ops_inc_gen",
|
||||
strip_include_prefix = "include",
|
||||
tbl_outs = [
|
||||
("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"),
|
||||
("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"),
|
||||
("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"),
|
||||
("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
||||
|
@ -219,7 +239,11 @@ gentbl(
|
|||
td_relative_includes = [
|
||||
"include",
|
||||
],
|
||||
td_srcs = [":hlo_ops_td_files"],
|
||||
td_srcs = [
|
||||
":hlo_ops_td_files",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
|
@ -476,7 +500,11 @@ cc_library(
|
|||
srcs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc",
|
||||
"lib/Dialect/mhlo/IR/lhlo_ops.cc",
|
||||
"lib/Dialect/mhlo/IR/lhlo_ops_structs.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h",
|
||||
|
|
|
@ -25,7 +25,6 @@ function(add_mlir_hlo_dialect dialect dialect_namespace)
|
|||
endfunction()
|
||||
|
||||
add_mlir_hlo_dialect(chlo_ops chlo)
|
||||
add_mlir_hlo_dialect(lhlo_ops lmhlo)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS hlo_ops.td)
|
||||
mlir_tablegen(hlo_ops.h.inc -gen-op-decls)
|
||||
|
@ -36,6 +35,15 @@ mlir_tablegen(hlo_ops_base_enums.h.inc -gen-enum-decls)
|
|||
mlir_tablegen(hlo_ops_base_enums.cc.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRhlo_opsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS lhlo_ops.td)
|
||||
mlir_tablegen(lhlo_ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(lhlo_ops.cc.inc -gen-op-defs)
|
||||
set(LLVM_TARGET_DEFINITIONS lhlo_ops_structs.td)
|
||||
mlir_tablegen(lhlo_ops_structs.h.inc -gen-struct-attr-decls)
|
||||
mlir_tablegen(lhlo_ops_structs.cc.inc -gen-struct-attr-defs)
|
||||
add_public_tablegen_target(MLIRlhlo_opsIncGen)
|
||||
add_dependencies(mlir-headers MLIRlhlo_opsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td)
|
||||
mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef LHLO_DIALECT
|
||||
#define LHLO_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
// We define the dialect here so that both structs and ops can refer to it.
|
||||
def LHLO_Dialect : Dialect {
|
||||
let name = "lmhlo";
|
||||
let cppNamespace = "::mlir::lmhlo";
|
||||
}
|
||||
|
||||
#endif // LHLO_DIALECT
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
|
|
@ -37,12 +37,9 @@ include "mlir/IR/OpBase.td"
|
|||
include "mlir/Interfaces/CopyOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||
|
||||
def LHLO_Dialect : Dialect {
|
||||
let name = "lmhlo";
|
||||
let cppNamespace = "::mlir::lmhlo";
|
||||
}
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LMHLO nullary op definitions.
|
||||
|
@ -274,8 +271,10 @@ def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>,
|
|||
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
|
||||
StrAttr:$call_target_name,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||
DefaultValuedAttr<StrAttr, "">:$backend_config,
|
||||
OptionalAttr<CustomCallTargetArgMapping>:$target_arg_mapping
|
||||
);
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines structures used in LMHLO dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
// Order matters, this .inc header is not self-contained, and relies on the
|
||||
// #includes above.
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef LHLO_OPS_STRUCTS
|
||||
#define LHLO_OPS_STRUCTS
|
||||
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td"
|
||||
|
||||
// This structure defines information about how arguments to the LHLO custom
|
||||
// call operation relate to the arguments of the target function. In most cases
|
||||
// the mapping will be 1:1, but in certain cases, it may not be. As an example,
|
||||
// tokens are not represented in the LHLO dialect, but the custom call target
|
||||
// might still expect to see buffer arguments corresponding to tokens, in which
|
||||
// case the mapping will not be 1:1.
|
||||
def CustomCallTargetArgMapping : StructAttr<"CustomCallTargetArgMapping",
|
||||
LHLO_Dialect, [
|
||||
// number of buffer expected by the target for arguments.
|
||||
StructFieldAttr<"num_args", I64Attr>,
|
||||
// number of buffer expected by the target for results.
|
||||
StructFieldAttr<"num_results", I64Attr>,
|
||||
// map each custom call op arg to its position in target args.
|
||||
StructFieldAttr<"args_to_target_args", I64ArrayAttr>,
|
||||
// map each custom call op arg to its position in target results.
|
||||
StructFieldAttr<"results_to_target_results", I64ArrayAttr>]> {
|
||||
let summary = "Custom call operands to target argument mapping info";
|
||||
}
|
||||
|
||||
#endif // LHLO_OPS_STRUCTS
|
|
@ -62,6 +62,7 @@ target_link_libraries(MhloDialect
|
|||
|
||||
add_mlir_dialect_library(LmhloDialect
|
||||
lhlo_ops.cc
|
||||
lhlo_ops_structs.cc
|
||||
|
||||
DEPENDS
|
||||
MLIRlhlo_opsIncGen
|
||||
|
|
|
@ -21,6 +21,8 @@ limitations under the License.
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
@ -162,6 +164,56 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
|||
results.insert<EraseConstOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CustomCallOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(CustomCallOp op) {
|
||||
if (op.target_arg_mapping()) {
|
||||
CustomCallTargetArgMapping mapping = *op.target_arg_mapping();
|
||||
auto verify_mapping = [&](int64_t target_num, size_t op_num,
|
||||
ArrayAttr mapping,
|
||||
StringRef kind) -> LogicalResult {
|
||||
if (target_num < op_num)
|
||||
return op.emitOpError("number of target " + kind + " (")
|
||||
<< target_num << ") cannot be less than the number of " << kind
|
||||
<< "(" << op_num << ") for the operation";
|
||||
|
||||
if (mapping.size() != op_num)
|
||||
return op.emitOpError("number of entries in the mapping for " + kind +
|
||||
" (")
|
||||
<< mapping.size() << ") should match the number of " << kind
|
||||
<< " for the operation (" << op_num << ")";
|
||||
|
||||
std::unordered_set<int64_t> entries;
|
||||
// Each entry in the mapping should be < target_num and an entry cannot
|
||||
// appear more than once.
|
||||
for (Attribute entry : mapping) {
|
||||
int64_t int_entry = entry.cast<IntegerAttr>().getInt();
|
||||
// ODS verification will ensure that these entries are integers.
|
||||
if (!entries.insert(int_entry).second)
|
||||
return op.emitOpError("entry ")
|
||||
<< int_entry
|
||||
<< " cannot appear more than once in the mapping for " << kind;
|
||||
if (int_entry < 0 || int_entry >= target_num)
|
||||
return op.emitOpError(
|
||||
"entries in mapping for " + kind +
|
||||
" must be >= 0 and less than target's number of " + kind +
|
||||
" (")
|
||||
<< target_num << ")";
|
||||
}
|
||||
return success();
|
||||
};
|
||||
if (failed(verify_mapping(mapping.num_args().getInt(), op.args().size(),
|
||||
mapping.args_to_target_args(), "args")) ||
|
||||
failed(verify_mapping(mapping.num_results().getInt(),
|
||||
op.output().size(),
|
||||
mapping.results_to_target_results(), "results")))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -13,5 +13,6 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"
|
||||
|
|
|
@ -989,3 +989,136 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
|||
}) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @valid_custom_call
|
||||
func @valid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
|
||||
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
|
||||
backend_config = "",
|
||||
call_target_name = "foo",
|
||||
has_side_effects = false,
|
||||
operand_segment_sizes = dense<2> : vector<2xi32>,
|
||||
target_arg_mapping = {
|
||||
num_args = 4 : i64,
|
||||
num_results = 3 : i64,
|
||||
args_to_target_args = [0,3],
|
||||
results_to_target_results = [1,2]
|
||||
}
|
||||
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{number of entries in the mapping for args (1) should match the number of args for the operation (2)}}
|
||||
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
|
||||
backend_config = "",
|
||||
call_target_name = "foo",
|
||||
has_side_effects = false,
|
||||
operand_segment_sizes = dense<2> : vector<2xi32>,
|
||||
target_arg_mapping = {
|
||||
num_args = 4 : i64,
|
||||
num_results = 3 : i64,
|
||||
args_to_target_args = [0],
|
||||
results_to_target_results = [1,2]
|
||||
}
|
||||
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{number of entries in the mapping for results (1) should match the number of results for the operation (2)}}
|
||||
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
|
||||
backend_config = "",
|
||||
call_target_name = "foo",
|
||||
has_side_effects = false,
|
||||
operand_segment_sizes = dense<2> : vector<2xi32>,
|
||||
target_arg_mapping = {
|
||||
num_args = 4 : i64,
|
||||
num_results = 3 : i64,
|
||||
args_to_target_args = [0, 3],
|
||||
results_to_target_results = [1]
|
||||
}
|
||||
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{entry 0 cannot appear more than once in the mapping for args}}
|
||||
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
|
||||
backend_config = "",
|
||||
call_target_name = "foo",
|
||||
has_side_effects = false,
|
||||
operand_segment_sizes = dense<2> : vector<2xi32>,
|
||||
target_arg_mapping = {
|
||||
num_args = 4 : i64,
|
||||
num_results = 3 : i64,
|
||||
args_to_target_args = [0, 0],
|
||||
results_to_target_results = [1, 2]
|
||||
}
|
||||
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{entry 1 cannot appear more than once in the mapping for results}}
|
||||
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
|
||||
backend_config = "",
|
||||
call_target_name = "foo",
|
||||
has_side_effects = false,
|
||||
operand_segment_sizes = dense<2> : vector<2xi32>,
|
||||
target_arg_mapping = {
|
||||
num_args = 4 : i64,
|
||||
num_results = 3 : i64,
|
||||
args_to_target_args = [0, 1],
|
||||
results_to_target_results = [1, 1]
|
||||
}
|
||||
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{entries in mapping for args must be >= 0 and less than target's number of args (4)}}
|
||||
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
|
||||
backend_config = "",
|
||||
call_target_name = "foo",
|
||||
has_side_effects = false,
|
||||
operand_segment_sizes = dense<2> : vector<2xi32>,
|
||||
target_arg_mapping = {
|
||||
num_args = 4 : i64,
|
||||
num_results = 3 : i64,
|
||||
args_to_target_args = [0, 6],
|
||||
results_to_target_results = [1, 2]
|
||||
}
|
||||
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{entries in mapping for results must be >= 0 and less than target's number of results (3)}}
|
||||
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
|
||||
backend_config = "",
|
||||
call_target_name = "foo",
|
||||
has_side_effects = false,
|
||||
operand_segment_sizes = dense<2> : vector<2xi32>,
|
||||
target_arg_mapping = {
|
||||
num_args = 4 : i64,
|
||||
num_results = 3 : i64,
|
||||
args_to_target_args = [0, 1],
|
||||
results_to_target_results = [1, 3]
|
||||
}
|
||||
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue