[MLIR:HLO] Extend CustomCall to support multiple outputs.

- Extend MHLO CustomCall to have multiple tensors as results.
- Extend LHLO CustomCall to have multiple memrefs for output operands.
- Fix HLO->LHLO and XLA HLO->LHLO mapping for CustomCall to setup the
  operand_segment_sizes attribute correctly.

PiperOrigin-RevId: 342067762
This commit is contained in:
Rahul Joshi 2020-11-12 09:45:39 -08:00 committed by TensorFlow MLIR Team
parent af1914a174
commit 1958f228ec
5 changed files with 57 additions and 5 deletions

View File

@ -925,7 +925,7 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedAttr<StrAttr, "">:$backend_config
);
let results = (outs HLO_Tensor);
let results = (outs Variadic<HLO_Tensor>);
let hasCustomHLOConverter = 1;
}

View File

@ -264,10 +264,11 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>,
BASE_HLO_CustomCallOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
StrAttr:$call_target_name,
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedAttr<StrAttr, "">:$backend_config

View File

@ -165,6 +165,32 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
}
};
struct HloToLhloCustomCallOpConverter
: public BaseOpConversion<mhlo::CustomCallOp> {
public:
using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::CustomCallOp hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
op->getLoc(), llvm::None, buffer_args, op->getAttrs());
// Setup AttrSizedOperandSegments attribute to indicate number of operands
// for args and outputs.
const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
static_cast<int32_t>(op->getNumResults())};
lhloOp.setAttr(lhloOp.getOperandSegmentSizeAttr(),
rewriter.getI32VectorAttr(segments));
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success();
}
};
struct HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public:
@ -572,6 +598,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
HloToLhloCustomCallOpConverter,
HloToLhloDotGeneralOpConverter,
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicReshapeConverter,
@ -588,7 +615,6 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::ConvertOp>,
HloToLhloOpConverter<mhlo::CopyOp>,
HloToLhloOpConverter<mhlo::CosOp>,
HloToLhloOpConverter<mhlo::CustomCallOp>,
HloToLhloOpConverter<mhlo::DivOp>,
HloToLhloOpConverter<mhlo::DotOp>,
HloToLhloOpConverter<mhlo::ExpOp>,

View File

@ -588,7 +588,7 @@ func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<[2, 1]> : vector<2xi32>}
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
{backend_config = "", call_target_name = "foo", has_side_effect = false}
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
@ -598,6 +598,22 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre
// ----
// CHECK-LABEL: func @custom_call_multiout
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
func @custom_call_multiout(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>}
%temp:2 = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
{backend_config = "", call_target_name = "foo", has_side_effect = false}
: (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>)
%result_tensor = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16>
tensor_store %result_tensor, %result: memref<4x4xf16>
return
}
// ----
// CHECK-LABEL: func @isfinite
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>

View File

@ -1281,3 +1281,12 @@ func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
return %result : tensor<1x128x512xf32>
}
// -----
// CHECK: func @custom_call_multiple_outputs
func @custom_call_multiple_outputs(%x: tensor<2xf32>) -> tensor<2xf32> {
%0:2 = "mhlo.custom_call"(%x) {backend_config="", call_target_name = "foo", has_side_effect = false} : (tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
%1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
}