diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 37f11de..507f7c1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -912,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", let results = (outs HLO_Tensor); } -// TODO(hinsu): Make this struct dialect independent so that it can be shared -// between HLO and LHLO dialect. -def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { - let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, - // Default value: zero for each of the spatial dimension. - OptionalAttr:$padding, - // Default value: one for each of the spatial dimension. - OptionalAttr:$lhs_dilation, - // Default value: one for each of the spatial dimension. - OptionalAttr:$rhs_dilation, - ConvDimensionNumbers:$dimension_numbers, - I64Attr:$feature_group_count, - I64Attr:$batch_group_count, - HLO_PrecisionConfigAttr:$precision_config - ); + let arguments = !con( + (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs), + ConvolutionAttributes.attributes); let results = (outs HLO_Tensor); } diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 6386972..cba2dc3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -1007,6 +1007,42 @@ class BASE_HLO_ConcatenateOp { }]; } +//===----------------------------------------------------------------------===// +// Common convolution attributes +//===----------------------------------------------------------------------===// + +class ConvDimensionNumbersBase + : StructAttr<"ConvDimensionNumbers", dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + +class ConvolutionAttributes { + dag attributes = (ins + // Default value: one for each of the spatial dimension. + OptionalAttr:$window_strides, + // Default value: zero for each of the spatial dimension. + OptionalAttr:$padding, + // Default value: one for each of the spatial dimension. + OptionalAttr:$lhs_dilation, + // Default value: one for each of the spatial dimension. + OptionalAttr:$rhs_dilation, + ConvDimensionNumbersBase:$dimension_numbers, + I64Attr:$feature_group_count, + I64Attr:$batch_group_count, + HLO_PrecisionConfigAttr:$precision_config + ); +} + class BASE_HLO_ConvOp { string summary = "Convolution operator"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index eddbd95..4aa2f3d 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" def LHLO_Dialect : Dialect { let name = "lmhlo"; let cppNamespace = "::mlir::lmhlo"; } -//===----------------------------------------------------------------------===// -// LMHLO type definitions. -//===----------------------------------------------------------------------===// - -// Any integer tensor types -def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; - -// Any floating-point tensor types -def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; - -def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; - -def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; - -def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; - -// Any integer or floating-point tensor types -def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; - -def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; - -def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; - -def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; - //===----------------------------------------------------------------------===// // LMHLO nullary op definitions. //===----------------------------------------------------------------------===// @@ -593,40 +568,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { ); } -// TODO(bondhugula): Make this struct dialect independent so that it can be -// shared between the HLO and LHLO dialects. -def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { - let arguments = (ins - Arg:$lhs, - Arg:$rhs, - Arg:$output, - // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, - // Default value: zero for each of the spatial dimension. - OptionalAttr:$padding, - // Default value: one for each of the spatial dimension. - OptionalAttr:$lhs_dilation, - // Default value: one for each of the spatial dimension. - OptionalAttr:$rhs_dilation, - ConvDimensionNumbers:$dimension_numbers, - I64Attr:$feature_group_count, - I64Attr:$batch_group_count, - HLO_PrecisionConfigAttr:$precision_config - ); + let arguments = !con( + (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output), + ConvolutionAttributes.attributes); } def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td new file mode 100644 index 0000000..9cd7741 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_OPS_BASE +#define LHLO_OPS_BASE + +include "mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" + +//===----------------------------------------------------------------------===// +// LMHLO type definitions. +//===----------------------------------------------------------------------===// + +// Any integer tensor types +def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; + +// Any floating-point tensor types +def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; + +def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; + +def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; + +def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; + +// Any integer or floating-point tensor types +def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; + +def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; + +def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; + +def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; + +#endif // LHLO_OPS_BASE