diff --git a/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt index 79c04b3..3fa2b90 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt +++ b/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -25,8 +25,22 @@ function(add_mlir_hlo_dialect dialect dialect_namespace) endfunction() add_mlir_hlo_dialect(chlo_ops chlo) -add_mlir_hlo_dialect(hlo_ops mhlo) add_mlir_hlo_dialect(lhlo_ops lmhlo) -add_mlir_hlo_dialect(lhlo_gpu_ops lmhlo_gpu) + +set(LLVM_TARGET_DEFINITIONS hlo_ops.td) +mlir_tablegen(hlo_ops.h.inc -gen-op-decls) +mlir_tablegen(hlo_ops.cc.inc -gen-op-defs) +mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls) +mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRhlo_opsIncGen) + +set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td) +mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls) +mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs) +set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td) +mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls) +mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen) +add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen) add_mlir_interface(infer_fusibility_op_interface) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 60ee4e6..b354189 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -32,7 +33,7 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // clang-format off -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" // clang-format on diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index cb431bd..3defb65 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -25,11 +25,6 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" -def HLO_Dialect : Dialect { - let name = "mhlo"; - let cppNamespace = "::mlir::mhlo"; -} - class HLO_Op traits> : Op { // Whether this operation has a custom conversion to HLO or not. @@ -136,8 +131,8 @@ class HLO_UnaryElementwiseOp traits, } LogicalResult reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); } bool inferInputOutputShapeEquality(int input, int output) { return true; @@ -153,7 +148,7 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultShape], TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value operand" + "Value operand" >]; } @@ -168,8 +163,7 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp< BASE_HLO_ConvertOp { let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value operand, " - "Type result_element_ty" + "Value operand, Type result_element_ty" >]; let hasFolder = 1; @@ -293,8 +287,8 @@ class HLO_BinaryElementwiseOp traits> : } LogicalResult reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); } bool inferInputsShapeEquality(int lhs, int rhs) { return true; @@ -458,7 +452,7 @@ def HLO_SendOp : HLO_Op<"send", []> { let arguments = (ins HLO_TensorOrTuple:$operand, HLO_Token:$token, - ChannelHandle:$channel_id, + ChannelHandle:$channel_id, DefaultValuedAttr:$is_host_transfer ); @@ -483,7 +477,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> { let arguments = (ins HLO_Token:$token, - ChannelHandle:$channel_id, + ChannelHandle:$channel_id, DefaultValuedAttr:$is_host_transfer ); @@ -587,7 +581,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce", let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$replica_groups, - OptionalAttr>:$channel_id + OptionalAttr:$channel_id ); let regions = (region SizedRegion<1>:$computation); let results = (outs HLO_Tensor); @@ -959,15 +953,6 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { let results = (outs HLO_Tensor); } -def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ - StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> - ]> { - let description = "Structure of dimension information for dot product"; -} - def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp { let arguments = (ins HLO_Tensor:$lhs, @@ -1029,14 +1014,6 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { let results = (outs HLO_Tensor); } -def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, - [StructFieldAttr<"offset_dims", I64ElementsAttr>, - StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, - StructFieldAttr<"start_index_map", I64ElementsAttr>, - StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for gather"; -} - def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { let arguments = (ins HLO_Tensor:$operand, @@ -1114,7 +1091,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, HLO_Tensor:$operand, HLO_Tensor:$scatter_indices, HLO_Tensor:$updates, - ScatterDimensionNumbers:$scatter_dimension_numbers, + ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedAttr:$indices_are_sorted, DefaultValuedAttr:$unique_indices ); diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index cba2dc3..da8c921 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -18,6 +18,13 @@ limitations under the License. include "mlir/IR/OpBase.td" +def HLO_Dialect : Dialect { + let name = "mhlo"; + let cppNamespace = "::mlir::mhlo"; +} + +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td" + def HLO_Pred : TypeAlias; // TODO(hinsu): Use signed integers instead of signless integer which is being @@ -614,15 +621,6 @@ class BASE_HLO_CaseOp { // XLA parallelism related op definitions. //===----------------------------------------------------------------------===// -// Represents a unique identifier for each Send/Recv instruction pair or -// optionally for collective instructions (AllReduce, CollectivePermute, -// AllToAll). Non-positive channel_id handle is equivalent to no channel id. -class ChannelHandle : StructAttr<"ChannelHandle", dialect, [ - StructFieldAttr<"handle", I64Attr>, - StructFieldAttr<"type", I64Attr>]> { - let description = "two 64-bit integers 'handle' and 'type'"; -} - class BASE_HLO_ReplicaIdOp { string summary = "ReplicaId operator"; @@ -712,6 +710,7 @@ def HLO_PrecisionConfigAttr: OptionalAttr< TypedArrayAttrBase>; + //===----------------------------------------------------------------------===// // Fast Fourier Transform Type enum definitions. //===----------------------------------------------------------------------===// @@ -1011,21 +1010,6 @@ class BASE_HLO_ConcatenateOp { // Common convolution attributes //===----------------------------------------------------------------------===// -class ConvDimensionNumbersBase - : StructAttr<"ConvDimensionNumbers", dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - class ConvolutionAttributes { dag attributes = (ins // Default value: one for each of the spatial dimension. @@ -1036,7 +1020,7 @@ class ConvolutionAttributes { OptionalAttr:$lhs_dilation, // Default value: one for each of the spatial dimension. OptionalAttr:$rhs_dilation, - ConvDimensionNumbersBase:$dimension_numbers, + ConvDimensionNumbers:$dimension_numbers, I64Attr:$feature_group_count, I64Attr:$batch_group_count, HLO_PrecisionConfigAttr:$precision_config @@ -1164,15 +1148,6 @@ class BASE_HLO_ReshapeOp { }]; } -class ScatterDimensionNumbers : StructAttr< - "ScatterDimensionNumbers", dialect, [ - StructFieldAttr<"update_window_dims", I64ElementsAttr>, - StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, - StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, - StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for scatter"; -} - class BASE_HLO_ScatterOp { string summary = "Scatter operator"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h new file mode 100644 index 0000000..3b78ff8 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines structures used in MHLO and LMHLO. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td new file mode 100644 index 0000000..d25eb51 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef HLO_OPS_BASE_STRUCTS +#define HLO_OPS_BASE_STRUCTS + +//===----------------------------------------------------------------------===// +// Dot dimensions enum definitions. +//===----------------------------------------------------------------------===// + +def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> + ]> { + let description = "Structure of dimension information for dot product"; +} + +def ScatterDimensionNumbers : StructAttr< + "ScatterDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"update_window_dims", I64ElementsAttr>, + StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, + StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for scatter"; +} + +def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + +def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, + [StructFieldAttr<"offset_dims", I64ElementsAttr>, + StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, + StructFieldAttr<"start_index_map", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for gather"; +} + + +// Represents a unique identifier for each Send/Recv instruction pair or +// optionally for collective instructions (AllReduce, CollectivePermute, +// AllToAll). Non-positive channel_id handle is equivalent to no channel id. +def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [ + StructFieldAttr<"handle", I64Attr>, + StructFieldAttr<"type", I64Attr>]> { + let description = "two 64-bit integers 'handle' and 'type'"; +} + +#endif // HLO_OPS_BASE_STRUCTS diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h index 68ab6ce..effa9ec 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h @@ -19,6 +19,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ #include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" @@ -28,6 +32,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" @@ -35,7 +40,6 @@ namespace mlir { class OpBuilder; } // namespace mlir -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc" namespace mlir { namespace lmhlo_gpu { diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td index ac5830c..b3708bf 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -22,13 +22,10 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td" -def LHLO_GPU_Dialect : Dialect { - let name = "lmhlo_gpu"; - let cppNamespace = "::mlir::lmhlo_gpu"; -} - class LHLOGPU_Op traits = []> : Op], traits)>; @@ -109,13 +106,6 @@ def ActivationAttr : StrEnumAttr<"Activation", ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, ActivationModeBandPass]>; -def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig", - LHLO_GPU_Dialect, [ - StructFieldAttr<"algorithm", I64Attr>, - StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> { - let description = "GPU Convolution backend configuration"; -} - def GpuConvolutionAttributes { dag attributes = !con( ConvolutionAttributes.attributes, @@ -181,16 +171,6 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { // LMHLO ops representing other library functions. //===----------------------------------------------------------------------===// -// TODO(jurahul): Share this with the MHLO dialect. -def DotDimensionNumbersAttr : StructAttr<"DotDimensionNumbers", LHLO_GPU_Dialect, [ - StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> - ]> { - let description = "Structure of dimension information for dot product"; -} - // output = alpha * (lhs * rhs) // Verify: beta = 0.0 def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { @@ -198,7 +178,7 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { Arg:$lhs, Arg:$rhs, Arg:$output, - DotDimensionNumbersAttr:$dot_dimension_numbers, + DotDimensionNumbers:$dot_dimension_numbers, F64Attr:$alpha, I64Attr:$batch_size, I64Attr:$algorithm); @@ -211,7 +191,7 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { Arg:$rhs, Arg:$bias, Arg:$output, - DotDimensionNumbersAttr:$dot_dimension_numbers, + DotDimensionNumbers:$dot_dimension_numbers, F64Attr:$alpha, F64Attr:$beta, I64Attr:$batch_size, diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td new file mode 100644 index 0000000..820e4ce --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// We define the dialect here so that both structs and ops can refer to it. + +#ifndef LHLO_GPU_OPS_BASE +#define LHLO_GPU_OPS_BASE + +include "mlir/IR/OpBase.td" + +def LHLO_GPU_Dialect : Dialect { + let name = "lmhlo_gpu"; + let cppNamespace = "::mlir::lmhlo_gpu"; +} + +#endif // LHLO_GPU_OPS_BASE diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h new file mode 100644 index 0000000..ff642b8 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ + +// This file defines structures used in the LMHLO_GPU dialect. + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc" + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td new file mode 100644 index 0000000..2236fc3 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td @@ -0,0 +1,29 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_GPU_OPS_STRUCTS +#define LHLO_GPU_OPS_STRUCTS + +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" + +def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig", + LHLO_GPU_Dialect, [ + StructFieldAttr<"algorithm", I64Attr>, + StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> { + let description = "GPU Convolution backend configuration"; +} + +#endif // LHLO_GPU_OPS_STRUCTS diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index c0c0494..9dc6d7a 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" @@ -33,11 +34,6 @@ limitations under the License. namespace mlir { class OpBuilder; -} // namespace mlir - -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc" - -namespace mlir { namespace lmhlo { class LmhloDialect : public Dialect { diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index c013939..25d5e50 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -592,6 +592,7 @@ def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { let arguments = (ins Arg:$lhs, Arg:$rhs, + DotDimensionNumbers:$dot_dimension_numbers, HLO_PrecisionConfigAttr:$precision_config, Arg:$output ); @@ -623,7 +624,7 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp { Arg:$scatter_indices, Arg:$updates, Arg:$output, - ScatterDimensionNumbers:$scatter_dimension_numbers, + ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedAttr:$indices_are_sorted, DefaultValuedAttr:$unique_indices ); @@ -699,7 +700,7 @@ def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>, Arg:$output, I64ElementsAttr:$replica_groups, DefaultValuedAttr:$constrain_layout, - OptionalAttr>:$channel_id, + OptionalAttr:$channel_id, DefaultValuedAttr:$use_global_device_ids ); let regions = (region SizedRegion<1>:$computation); @@ -712,7 +713,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, Arg:$operand, Arg:$output, I64ElementsAttr:$source_target_pairs, - OptionalAttr>:$channel_id + OptionalAttr:$channel_id ); } diff --git a/lib/Dialect/mhlo/IR/CMakeLists.txt b/lib/Dialect/mhlo/IR/CMakeLists.txt index 7b5da34..7c0c11b 100644 --- a/lib/Dialect/mhlo/IR/CMakeLists.txt +++ b/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -43,6 +43,7 @@ add_mlir_library(MhloInferFusibilityOpInterface add_mlir_dialect_library(MhloDialect hlo_ops.cc + hlo_ops_base_structs.cc DEPENDS MLIRhlo_opsIncGen @@ -68,6 +69,7 @@ target_link_libraries(LmhloDialect PUBLIC MLIRIR) add_mlir_dialect_library(LmhloGPUDialect lhlo_gpu_ops.cc + lhlo_gpu_ops_structs.cc DEPENDS MLIRlhlo_gpu_opsIncGen diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 86f048f..f8a92cc 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -63,8 +63,6 @@ namespace mlir { #include "hlo_patterns.cc.inc" } // namespace mlir -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc" - namespace mlir { namespace mhlo { diff --git a/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc b/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc new file mode 100644 index 0000000..90da125 --- /dev/null +++ b/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc @@ -0,0 +1,18 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc" diff --git a/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc b/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc index 33f4e2c..10c5c0c 100644 --- a/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc @@ -28,8 +28,6 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc" -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" diff --git a/lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc b/lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc new file mode 100644 index 0000000..cd2cfc5 --- /dev/null +++ b/lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc @@ -0,0 +1,18 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index cba0d3b..4524cf3 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" diff --git a/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc b/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc new file mode 100644 index 0000000..83dd4e6 --- /dev/null +++ b/lib/Dialect/mhlo/IR/lhlo_ops_structs.cc @@ -0,0 +1,17 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 2659c58..7b401d5 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -126,6 +126,60 @@ class HloToLhloOpConverter : public BaseOpConversion { } }; +// This specialization exists so that LMHLO's Dot can be given a specific set of +// dimension numbers, when lowering from MHLO's Dot, which does not have +// dimension numbers (it uses DotGeneral for this generalized notion of dot +// products). When these two dialects are in sync with respect to the +// Dot/DotGeneral issue, this specialization should be deleted. +template <> +class HloToLhloOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + LogicalResult matchAndRewrite( + mhlo::DotOp hloOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Operation* op = hloOp.getOperation(); + const auto& original_results = op->getResults(); + SmallVector buffer_args(operands.begin(), operands.end()); + for (auto result : llvm::enumerate(original_results)) { + RankedTensorType resultType = + result.value().getType().dyn_cast(); + if (!resultType) { + return failure(); + } + if (resultType.hasStaticShape()) { + buffer_args.push_back( + InsertAlloc(op->getLoc(), result.value(), &rewriter)); + } else { + SmallVector results_shape; + auto shape_type_op = dyn_cast(op); + if (!shape_type_op) return failure(); + if (failed( + shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) + return failure(); + buffer_args.push_back(InsertDynamicAllocAndDealloc( + op->getLoc(), result.value(), results_shape.front(), &rewriter)); + } + } + + // TODO(silvasean): Move this helper to MLIR core. + auto make_elements_attr = [&rewriter](ArrayRef integers) { + auto type = RankedTensorType::get({static_cast(integers.size())}, + rewriter.getIntegerType(64)); + return DenseIntElementsAttr::get(type, integers); + }; + auto dotOp = rewriter.create(op->getLoc(), llvm::None, + buffer_args, op->getAttrs()); + // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O]. + auto dimension_numbers = mhlo::DotDimensionNumbers::get( + make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}), + make_elements_attr({0}), rewriter.getContext()); + dotOp.dot_dimension_numbersAttr(dimension_numbers); + rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); + return success(); + } +}; + struct HloToLhloDynamicBroadcastInDimOpConverter : public BaseOpConversion { public: @@ -236,6 +290,43 @@ struct HloToLhloDynamicReshapeConverter } }; +struct HloToLhloDotGeneralOpConverter + : public BaseOpConversion { + using BaseOpConversion::BaseOpConversion; + LogicalResult matchAndRewrite( + mhlo::DotGeneralOp dotGeneralOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Operation* op = dotGeneralOp.getOperation(); + + if (op->getResults().empty()) return failure(); + OpResult result = op->getResults()[0]; + RankedTensorType resultType = result.getType().dyn_cast(); + if (!resultType) return failure(); + + // The third buffer argument will be filled with what used to be the return + // type of the DotGeneral. + if (operands.size() != 2) return failure(); + std::array bufferArgs = {operands[0], operands[1], {}}; + + if (resultType.hasStaticShape()) { + bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter); + } else { + SmallVector results_shape; + auto shape_type_op = dyn_cast(op); + if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) + return failure(); + + bufferArgs[2] = InsertDynamicAllocAndDealloc( + op->getLoc(), result, results_shape.front(), &rewriter); + } + + rewriter.create(op->getLoc(), llvm::None, bufferArgs, + op->getAttrs()); + rewriter.replaceOp(op, bufferArgs[2]); + return success(); + } +}; + struct HloToLhloReduceOpConverter : public BaseOpConversion { public: using BaseOpConversion::BaseOpConversion; @@ -485,6 +576,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< + HloToLhloDotGeneralOpConverter, HloToLhloDynamicBroadcastInDimOpConverter, HloToLhloDynamicReshapeConverter, HloToLhloOpConverter, diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 57859b6..b64d662 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -192,7 +192,7 @@ struct ConvToLinalgConverter : public OpConversionPattern { lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // Check validity of dimension information. - if (const lmhlo::ConvDimensionNumbers& dimensionNumbers = + if (const mhlo::ConvDimensionNumbers& dimensionNumbers = op.dimension_numbers()) { const int inputSpatialRank = llvm::size(dimensionNumbers.input_spatial_dimensions()); diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 2771afc..2041d22 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -59,6 +59,20 @@ struct DotOpConverter : public OpRewritePattern { return failure(); } + // We don't currently support batching dimensions, or multiple contraction + // dimensions. + mhlo::DotDimensionNumbers dot_dimension_numbers = + op.dot_dimension_numbers(); + if (dot_dimension_numbers.lhs_batching_dimensions().size() > 0 || + dot_dimension_numbers.rhs_batching_dimensions().size() > 0) + return failure(); + if (dot_dimension_numbers.lhs_contracting_dimensions().size() != 1 || + *dot_dimension_numbers.lhs_contracting_dimensions().begin() != 1 || + dot_dimension_numbers.rhs_contracting_dimensions().size() != 1 || + *dot_dimension_numbers.rhs_contracting_dimensions().begin() != 0) { + return failure(); + } + LogicalResult map_status = success(); auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) { SmallVector lhs_indices{ivs[0], ivs[2]}, diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 3caa4f0..f6fdc44 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -511,7 +511,13 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // BOTH-NEXT: %[[ALLOC:.*]] = alloc -// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) { +// dot_dimension_numbers = { +// lhs_batching_dimensions = dense<> : tensor<0xi64>, +// lhs_contracting_dimensions = dense<1> : tensor<1xi64>, +// rhs_batching_dimensions = dense<> : tensor<0xi64>, +// rhs_contracting_dimensions = dense<0> : tensor<1xi64>}} +// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () %dot = "mhlo.dot"(%arg0, %arg0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> // PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]]) @@ -632,4 +638,4 @@ func @shape_assuming_memref(%arg0: tensor) -> tensor { shape.assuming_yield %7 : tensor } return %2 : tensor -} \ No newline at end of file +} diff --git a/tests/lhlo-legalize-to-affine.mlir b/tests/lhlo-legalize-to-affine.mlir index 8781804..d020f7a 100644 --- a/tests/lhlo-legalize-to-affine.mlir +++ b/tests/lhlo-legalize-to-affine.mlir @@ -158,7 +158,14 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> // CHECK: return - "lmhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<> : tensor<0xi64>, + rhs_batching_dimensions = dense<> : tensor<0xi64>, + lhs_contracting_dimensions = dense<1> : tensor<1xi64>, + rhs_contracting_dimensions = dense<0> : tensor<1xi64> + } + } : (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () return } @@ -175,7 +182,14 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> // CHECK: return - "lmhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) { + dot_dimension_numbers = { + lhs_batching_dimensions = dense<> : tensor<0xi64>, + rhs_batching_dimensions = dense<> : tensor<0xi64>, + lhs_contracting_dimensions = dense<1> : tensor<1xi64>, + rhs_contracting_dimensions = dense<0> : tensor<1xi64> + } + } : (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () return }