Introduce CustomCall operation in LHLO Dialect
- And add conversion from MHLO CustomCall to LHLO CustomCall - According to XLA documentation, the called function should not be side effecting, so marking the argument MemRefs as MemRead. PiperOrigin-RevId: 334737196
This commit is contained in:
parent
e2ffba3f61
commit
bce128b070
|
@ -289,6 +289,16 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
|
||||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
StrAttr:$call_target_name,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||||
|
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LMHLO tuple op definitions.
|
// LMHLO tuple op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -50,6 +50,7 @@ MAP_HLO_TO_LHLO(ConvOp);
|
||||||
MAP_HLO_TO_LHLO(ConvertOp);
|
MAP_HLO_TO_LHLO(ConvertOp);
|
||||||
MAP_HLO_TO_LHLO(CopyOp);
|
MAP_HLO_TO_LHLO(CopyOp);
|
||||||
MAP_HLO_TO_LHLO(CosOp);
|
MAP_HLO_TO_LHLO(CosOp);
|
||||||
|
MAP_HLO_TO_LHLO(CustomCallOp);
|
||||||
MAP_HLO_TO_LHLO(DivOp);
|
MAP_HLO_TO_LHLO(DivOp);
|
||||||
MAP_HLO_TO_LHLO(DotOp);
|
MAP_HLO_TO_LHLO(DotOp);
|
||||||
MAP_HLO_TO_LHLO(ExpOp);
|
MAP_HLO_TO_LHLO(ExpOp);
|
||||||
|
|
|
@ -498,6 +498,7 @@ void populateHLOToLHLOConversionPattern(
|
||||||
HloToLhloOpConverter<mhlo::ConvertOp>,
|
HloToLhloOpConverter<mhlo::ConvertOp>,
|
||||||
HloToLhloOpConverter<mhlo::CopyOp>,
|
HloToLhloOpConverter<mhlo::CopyOp>,
|
||||||
HloToLhloOpConverter<mhlo::CosOp>,
|
HloToLhloOpConverter<mhlo::CosOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::CustomCallOp>,
|
||||||
HloToLhloOpConverter<mhlo::DivOp>,
|
HloToLhloOpConverter<mhlo::DivOp>,
|
||||||
HloToLhloOpConverter<mhlo::DotOp>,
|
HloToLhloOpConverter<mhlo::DotOp>,
|
||||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||||
|
|
|
@ -586,3 +586,18 @@ func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// BOTH-LABEL: func @custom_call
|
||||||
|
// BOTH-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>)
|
||||||
|
func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) {
|
||||||
|
%arg0_tensor = tensor_load %arg0 : memref<2x2xf32>
|
||||||
|
%arg1_tensor = tensor_load %arg1 : memref<2x3xf32>
|
||||||
|
// BOTH: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||||
|
%result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor)
|
||||||
|
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
||||||
|
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
|
||||||
|
tensor_store %result_tensor, %result: memref<4x4xf16>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue