[MLIR:HLO] Extend CustomCall to support multiple outputs.
- Extend MHLO CustomCall to have multiple tensors as results. - Extend LHLO CustomCall to have multiple memrefs for output operands. - Fix HLO->LHLO and XLA HLO->LHLO mapping for CustomCall to setup the operand_segment_sizes attribute correctly. PiperOrigin-RevId: 342067762
This commit is contained in:
parent
af1914a174
commit
1958f228ec
|
@ -925,7 +925,7 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||||
DefaultValuedAttr<StrAttr, "">:$backend_config
|
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||||
);
|
);
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs Variadic<HLO_Tensor>);
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -264,10 +264,11 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
|
||||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>,
|
||||||
|
BASE_HLO_CustomCallOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
|
||||||
StrAttr:$call_target_name,
|
StrAttr:$call_target_name,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||||
DefaultValuedAttr<StrAttr, "">:$backend_config
|
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||||
|
|
|
@ -165,6 +165,32 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct HloToLhloCustomCallOpConverter
|
||||||
|
: public BaseOpConversion<mhlo::CustomCallOp> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::CustomCallOp hloOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
Operation* op = hloOp.getOperation();
|
||||||
|
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
||||||
|
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||||
|
|
||||||
|
auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
|
||||||
|
op->getLoc(), llvm::None, buffer_args, op->getAttrs());
|
||||||
|
// Setup AttrSizedOperandSegments attribute to indicate number of operands
|
||||||
|
// for args and outputs.
|
||||||
|
const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
|
||||||
|
static_cast<int32_t>(op->getNumResults())};
|
||||||
|
lhloOp.setAttr(lhloOp.getOperandSegmentSizeAttr(),
|
||||||
|
rewriter.getI32VectorAttr(segments));
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -572,6 +598,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
|
HloToLhloCustomCallOpConverter,
|
||||||
HloToLhloDotGeneralOpConverter,
|
HloToLhloDotGeneralOpConverter,
|
||||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||||
HloToLhloDynamicReshapeConverter,
|
HloToLhloDynamicReshapeConverter,
|
||||||
|
@ -588,7 +615,6 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
HloToLhloOpConverter<mhlo::ConvertOp>,
|
HloToLhloOpConverter<mhlo::ConvertOp>,
|
||||||
HloToLhloOpConverter<mhlo::CopyOp>,
|
HloToLhloOpConverter<mhlo::CopyOp>,
|
||||||
HloToLhloOpConverter<mhlo::CosOp>,
|
HloToLhloOpConverter<mhlo::CosOp>,
|
||||||
HloToLhloOpConverter<mhlo::CustomCallOp>,
|
|
||||||
HloToLhloOpConverter<mhlo::DivOp>,
|
HloToLhloOpConverter<mhlo::DivOp>,
|
||||||
HloToLhloOpConverter<mhlo::DotOp>,
|
HloToLhloOpConverter<mhlo::DotOp>,
|
||||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||||
|
|
|
@ -588,7 +588,7 @@ func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
|
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
|
||||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||||
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
|
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
|
||||||
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
|
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<[2, 1]> : vector<2xi32>}
|
||||||
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
|
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
|
||||||
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||||
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
|
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
|
||||||
|
@ -598,6 +598,22 @@ func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memre
|
||||||
|
|
||||||
// ----
|
// ----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @custom_call_multiout
|
||||||
|
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
|
||||||
|
func @custom_call_multiout(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
|
||||||
|
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||||
|
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
|
||||||
|
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>}
|
||||||
|
%temp:2 = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
|
||||||
|
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||||
|
: (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>)
|
||||||
|
%result_tensor = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16>
|
||||||
|
tensor_store %result_tensor, %result: memref<4x4xf16>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
// CHECK-LABEL: func @isfinite
|
// CHECK-LABEL: func @isfinite
|
||||||
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||||
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||||
|
|
|
@ -1281,3 +1281,12 @@ func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> {
|
||||||
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
|
%result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 3 : i64} : (tensor<1x128x512xf32>, tensor<i32>) -> tensor<1x128x512xf32>
|
||||||
return %result : tensor<1x128x512xf32>
|
return %result : tensor<1x128x512xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: func @custom_call_multiple_outputs
|
||||||
|
func @custom_call_multiple_outputs(%x: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
%0:2 = "mhlo.custom_call"(%x) {backend_config="", call_target_name = "foo", has_side_effect = false} : (tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||||
|
%1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
return %1 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue