[MLIR:LHLO] Add optional call target arg mapping to LMHLO CustomCall operations.

- XLA:HLO -> LMHLO conversion drops all token arguments and return values, however
  custom calls that users write still expect to get buffer pointers for these token types.
- To be able to support this, add an optional call target argument mapping attribute to
  LMHLO custom calls. When this attribute is present, it indicates the number of
  arguments and returns that the custom call expects and also indicates which LMHLO
  arg() or output() maps to which arg or result number of the custom call.

PiperOrigin-RevId: 358826664
This commit is contained in:
Rahul Joshi 2021-02-22 08:41:59 -08:00 committed by TensorFlow MLIR Team
parent a9cc1dcfa0
commit 5adb7c6e12
11 changed files with 332 additions and 12 deletions

34
BUILD
View File

@ -201,14 +201,34 @@ gentbl(
],
)
gentbl(
name = "lhlo_ops_structs_inc_gen",
strip_include_prefix = "include",
tbl_outs = [
("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"),
("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td",
td_includes = [
"external/mlir-hlo/include",
"include",
],
td_relative_includes = [
"include",
],
td_srcs = [
":hlo_ops_td_files",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td",
],
)
gentbl(
name = "lhlo_ops_inc_gen",
strip_include_prefix = "include",
tbl_outs = [
("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"),
("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"),
("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"),
("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
@ -219,7 +239,11 @@ gentbl(
td_relative_includes = [
"include",
],
td_srcs = [":hlo_ops_td_files"],
td_srcs = [
":hlo_ops_td_files",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td",
],
)
gentbl(
@ -476,7 +500,11 @@ cc_library(
srcs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc",
"lib/Dialect/mhlo/IR/lhlo_ops.cc",
"lib/Dialect/mhlo/IR/lhlo_ops_structs.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h",

View File

@ -25,7 +25,6 @@ function(add_mlir_hlo_dialect dialect dialect_namespace)
endfunction()
add_mlir_hlo_dialect(chlo_ops chlo)
add_mlir_hlo_dialect(lhlo_ops lmhlo)
set(LLVM_TARGET_DEFINITIONS hlo_ops.td)
mlir_tablegen(hlo_ops.h.inc -gen-op-decls)
@ -36,6 +35,15 @@ mlir_tablegen(hlo_ops_base_enums.h.inc -gen-enum-decls)
mlir_tablegen(hlo_ops_base_enums.cc.inc -gen-enum-defs)
add_public_tablegen_target(MLIRhlo_opsIncGen)
set(LLVM_TARGET_DEFINITIONS lhlo_ops.td)
mlir_tablegen(lhlo_ops.h.inc -gen-op-decls)
mlir_tablegen(lhlo_ops.cc.inc -gen-op-defs)
set(LLVM_TARGET_DEFINITIONS lhlo_ops_structs.td)
mlir_tablegen(lhlo_ops_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(lhlo_ops_structs.cc.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRlhlo_opsIncGen)
add_dependencies(mlir-headers MLIRlhlo_opsIncGen)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td)
mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls)
mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs)

View File

@ -0,0 +1,27 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_DIALECT
#define LHLO_DIALECT
include "mlir/IR/OpBase.td"
// We define the dialect here so that both structs and ops can refer to it.
def LHLO_Dialect : Dialect {
let name = "lmhlo";
let cppNamespace = "::mlir::lmhlo";
}
#endif // LHLO_DIALECT

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"

View File

@ -37,12 +37,9 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
def LHLO_Dialect : Dialect {
let name = "lmhlo";
let cppNamespace = "::mlir::lmhlo";
}
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td"
//===----------------------------------------------------------------------===//
// LMHLO nullary op definitions.
@ -274,8 +271,10 @@ def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
StrAttr:$call_target_name,
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedAttr<StrAttr, "">:$backend_config
DefaultValuedAttr<StrAttr, "">:$backend_config,
OptionalAttr<CustomCallTargetArgMapping>:$target_arg_mapping
);
let verifier = [{ return Verify(*this); }];
}
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,30 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines structures used in LMHLO dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Types.h"
// Order matters, this .inc header is not self-contained, and relies on the
// #includes above.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_

View File

@ -0,0 +1,40 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_OPS_STRUCTS
#define LHLO_OPS_STRUCTS
include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td"
// This structure defines information about how arguments to the LHLO custom
// call operation relate to the arguments of the target function. In most cases
// the mapping will be 1:1, but in certain cases, it may not be. As an example,
// tokens are not represented in the LHLO dialect, but the custom call target
// might still expect to see buffer arguments corresponding to tokens, in which
// case the mapping will not be 1:1.
def CustomCallTargetArgMapping : StructAttr<"CustomCallTargetArgMapping",
LHLO_Dialect, [
// number of buffer expected by the target for arguments.
StructFieldAttr<"num_args", I64Attr>,
// number of buffer expected by the target for results.
StructFieldAttr<"num_results", I64Attr>,
// map each custom call op arg to its position in target args.
StructFieldAttr<"args_to_target_args", I64ArrayAttr>,
// map each custom call op arg to its position in target results.
StructFieldAttr<"results_to_target_results", I64ArrayAttr>]> {
let summary = "Custom call operands to target argument mapping info";
}
#endif // LHLO_OPS_STRUCTS

View File

@ -62,6 +62,7 @@ target_link_libraries(MhloDialect
add_mlir_dialect_library(LmhloDialect
lhlo_ops.cc
lhlo_ops_structs.cc
DEPENDS
MLIRlhlo_opsIncGen

View File

@ -21,6 +21,8 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include <unordered_set>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
@ -162,6 +164,56 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
results.insert<EraseConstOp>(context);
}
//===----------------------------------------------------------------------===//
// CustomCallOp.
//===----------------------------------------------------------------------===//
static LogicalResult Verify(CustomCallOp op) {
if (op.target_arg_mapping()) {
CustomCallTargetArgMapping mapping = *op.target_arg_mapping();
auto verify_mapping = [&](int64_t target_num, size_t op_num,
ArrayAttr mapping,
StringRef kind) -> LogicalResult {
if (target_num < op_num)
return op.emitOpError("number of target " + kind + " (")
<< target_num << ") cannot be less than the number of " << kind
<< "(" << op_num << ") for the operation";
if (mapping.size() != op_num)
return op.emitOpError("number of entries in the mapping for " + kind +
" (")
<< mapping.size() << ") should match the number of " << kind
<< " for the operation (" << op_num << ")";
std::unordered_set<int64_t> entries;
// Each entry in the mapping should be < target_num and an entry cannot
// appear more than once.
for (Attribute entry : mapping) {
int64_t int_entry = entry.cast<IntegerAttr>().getInt();
// ODS verification will ensure that these entries are integers.
if (!entries.insert(int_entry).second)
return op.emitOpError("entry ")
<< int_entry
<< " cannot appear more than once in the mapping for " << kind;
if (int_entry < 0 || int_entry >= target_num)
return op.emitOpError(
"entries in mapping for " + kind +
" must be >= 0 and less than target's number of " + kind +
" (")
<< target_num << ")";
}
return success();
};
if (failed(verify_mapping(mapping.num_args().getInt(), op.args().size(),
mapping.args_to_target_args(), "args")) ||
failed(verify_mapping(mapping.num_results().getInt(),
op.output().size(),
mapping.results_to_target_results(), "results")))
return failure();
}
return success();
}
} // namespace lmhlo
} // namespace mlir

View File

@ -13,5 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"

View File

@ -989,3 +989,136 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
}) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return
}
// -----
// CHECK-LABEL: func @valid_custom_call
func @valid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
backend_config = "",
call_target_name = "foo",
has_side_effects = false,
operand_segment_sizes = dense<2> : vector<2xi32>,
target_arg_mapping = {
num_args = 4 : i64,
num_results = 3 : i64,
args_to_target_args = [0,3],
results_to_target_results = [1,2]
}
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
// expected-error @+1 {{number of entries in the mapping for args (1) should match the number of args for the operation (2)}}
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
backend_config = "",
call_target_name = "foo",
has_side_effects = false,
operand_segment_sizes = dense<2> : vector<2xi32>,
target_arg_mapping = {
num_args = 4 : i64,
num_results = 3 : i64,
args_to_target_args = [0],
results_to_target_results = [1,2]
}
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
// expected-error @+1 {{number of entries in the mapping for results (1) should match the number of results for the operation (2)}}
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
backend_config = "",
call_target_name = "foo",
has_side_effects = false,
operand_segment_sizes = dense<2> : vector<2xi32>,
target_arg_mapping = {
num_args = 4 : i64,
num_results = 3 : i64,
args_to_target_args = [0, 3],
results_to_target_results = [1]
}
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
// expected-error @+1 {{entry 0 cannot appear more than once in the mapping for args}}
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
backend_config = "",
call_target_name = "foo",
has_side_effects = false,
operand_segment_sizes = dense<2> : vector<2xi32>,
target_arg_mapping = {
num_args = 4 : i64,
num_results = 3 : i64,
args_to_target_args = [0, 0],
results_to_target_results = [1, 2]
}
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
// expected-error @+1 {{entry 1 cannot appear more than once in the mapping for results}}
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
backend_config = "",
call_target_name = "foo",
has_side_effects = false,
operand_segment_sizes = dense<2> : vector<2xi32>,
target_arg_mapping = {
num_args = 4 : i64,
num_results = 3 : i64,
args_to_target_args = [0, 1],
results_to_target_results = [1, 1]
}
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
// expected-error @+1 {{entries in mapping for args must be >= 0 and less than target's number of args (4)}}
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
backend_config = "",
call_target_name = "foo",
has_side_effects = false,
operand_segment_sizes = dense<2> : vector<2xi32>,
target_arg_mapping = {
num_args = 4 : i64,
num_results = 3 : i64,
args_to_target_args = [0, 6],
results_to_target_results = [1, 2]
}
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () {
// expected-error @+1 {{entries in mapping for results must be >= 0 and less than target's number of results (3)}}
"lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) {
backend_config = "",
call_target_name = "foo",
has_side_effects = false,
operand_segment_sizes = dense<2> : vector<2xi32>,
target_arg_mapping = {
num_args = 4 : i64,
num_results = 3 : i64,
args_to_target_args = [0, 1],
results_to_target_results = [1, 3]
}
} : (memref<1xf32>, memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}