[MLIR:HLO] Extend CustomCall to support multiple outputs.
- Extend MHLO CustomCall to have multiple tensors as results. - Extend LHLO CustomCall to have multiple memrefs for output operands. - Fix HLO->LHLO and XLA HLO->LHLO mapping for CustomCall to setup the operand_segment_sizes attribute correctly. PiperOrigin-RevId: 342067762
This commit is contained in:
parent
af1914a174
commit
1958f228ec
|
@ -925,7 +925,7 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
|||
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||
);
|
||||
let results = (outs HLO_Tensor);
|
||||
let results = (outs Variadic<HLO_Tensor>);
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -264,10 +264,11 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
|
|||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
||||
def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>,
|
||||
BASE_HLO_CustomCallOp {
|
||||
let arguments = (ins
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
|
||||
StrAttr:$call_target_name,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||
|
|
|
@ -165,6 +165,32 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct HloToLhloCustomCallOpConverter
|
||||
: public BaseOpConversion<mhlo::CustomCallOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::CustomCallOp hloOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = hloOp.getOperation();
|
||||
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
||||
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||
|
||||
auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
|
||||
op->getLoc(), llvm::None, buffer_args, op->getAttrs());
|
||||
// Setup AttrSizedOperandSegments attribute to indicate number of operands
|
||||
// for args and outputs.
|
||||
const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
|
||||
static_cast<int32_t>(op->getNumResults())};
|
||||
lhloOp.setAttr(lhloOp.getOperandSegmentSizeAttr(),
|
||||
rewriter.getI32VectorAttr(segments));
|
||||
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||
public:
|
||||
|
@ -572,6 +598,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLhloCustomCallOpConverter,
|
||||
HloToLhloDotGeneralOpConverter,
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloDynamicReshapeConverter,
|
||||
|
@ -588,7 +615,6 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||
HloToLhloOpConverter<mhlo::ConvertOp>,
|
||||
HloToLhloOpConverter<mhlo::CopyOp>,
|
||||
HloToLhloOpConverter<mhlo::CosOp>,
|
||||
HloToLhloOpConverter<mhlo::CustomCallOp>,
|
||||
HloToLhloOpConverter<mhlo::DivOp>,
|
||||
HloToLhloOpConverter<mhlo::DotOp>,
|
||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||
|
|
|
@ -588,7 +588,7 @@ func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
|
||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
|
||||
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<[2, 1]> : vector<2xi32>}
|
||||
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
|
||||
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
|
||||
|
@ -598,6 +598,22 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre
|
|||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: func @custom_call_multiout
|
||||
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
|
||||
func @custom_call_multiout(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
|
||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
|
||||
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>}
|
||||
%temp:2 = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
|
||||
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||
: (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>)
|
||||
%result_tensor = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16>
|
||||
tensor_store %result_tensor, %result: memref<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: func @isfinite
|
||||
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||
|
|
|
@ -1281,3 +1281,12 @@ func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
|
|||
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
|
||||
return %result : tensor<1x128x512xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @custom_call_multiple_outputs
|
||||
func @custom_call_multiple_outputs(%x: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0:2 = "mhlo.custom_call"(%x) {backend_config="", call_target_name = "foo", has_side_effect = false} : (tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
%1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue