From 41436ea0d907668e7fb05b8c398f0fb70744c331 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 8 Oct 2020 10:28:02 -0700 Subject: [PATCH] [NFC] Extract common attributes of MHLO and LMHLO ConvOp - Create a common class that hold ConvOp attributes that are common to MHLO and LHLO Dialects and use Tablegen DAG concat to append these common attributes to dialect specific arguments in ODS for ConvOp. PiperOrigin-RevId: 336114003 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 37 ++--------- .../mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td | 36 ++++++++++ include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 66 ++----------------- .../mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td | 47 +++++++++++++ 4 files changed, 95 insertions(+), 91 deletions(-) create mode 100644 include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 37f11de..507f7c1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -912,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", let results = (outs HLO_Tensor); } -// TODO(hinsu): Make this struct dialect independent so that it can be shared -// between HLO and LHLO dialect. -def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { - let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, - // Default value: zero for each of the spatial dimension. - OptionalAttr:$padding, - // Default value: one for each of the spatial dimension. - OptionalAttr:$lhs_dilation, - // Default value: one for each of the spatial dimension. - OptionalAttr:$rhs_dilation, - ConvDimensionNumbers:$dimension_numbers, - I64Attr:$feature_group_count, - I64Attr:$batch_group_count, - HLO_PrecisionConfigAttr:$precision_config - ); + let arguments = !con( + (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs), + ConvolutionAttributes.attributes); let results = (outs HLO_Tensor); } diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 6386972..cba2dc3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -1007,6 +1007,42 @@ class BASE_HLO_ConcatenateOp { }]; } +//===----------------------------------------------------------------------===// +// Common convolution attributes +//===----------------------------------------------------------------------===// + +class ConvDimensionNumbersBase + : StructAttr<"ConvDimensionNumbers", dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + +class ConvolutionAttributes { + dag attributes = (ins + // Default value: one for each of the spatial dimension. + OptionalAttr:$window_strides, + // Default value: zero for each of the spatial dimension. + OptionalAttr:$padding, + // Default value: one for each of the spatial dimension. + OptionalAttr:$lhs_dilation, + // Default value: one for each of the spatial dimension. + OptionalAttr:$rhs_dilation, + ConvDimensionNumbersBase:$dimension_numbers, + I64Attr:$feature_group_count, + I64Attr:$batch_group_count, + HLO_PrecisionConfigAttr:$precision_config + ); +} + class BASE_HLO_ConvOp { string summary = "Convolution operator"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index eddbd95..4aa2f3d 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" def LHLO_Dialect : Dialect { let name = "lmhlo"; let cppNamespace = "::mlir::lmhlo"; } -//===----------------------------------------------------------------------===// -// LMHLO type definitions. -//===----------------------------------------------------------------------===// - -// Any integer tensor types -def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; - -// Any floating-point tensor types -def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; - -def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; - -def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; - -def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; - -// Any integer or floating-point tensor types -def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; - -def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; - -def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; - -def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; - //===----------------------------------------------------------------------===// // LMHLO nullary op definitions. //===----------------------------------------------------------------------===// @@ -593,40 +568,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { ); } -// TODO(bondhugula): Make this struct dialect independent so that it can be -// shared between the HLO and LHLO dialects. -def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { - let arguments = (ins - Arg:$lhs, - Arg:$rhs, - Arg:$output, - // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, - // Default value: zero for each of the spatial dimension. - OptionalAttr:$padding, - // Default value: one for each of the spatial dimension. - OptionalAttr:$lhs_dilation, - // Default value: one for each of the spatial dimension. - OptionalAttr:$rhs_dilation, - ConvDimensionNumbers:$dimension_numbers, - I64Attr:$feature_group_count, - I64Attr:$batch_group_count, - HLO_PrecisionConfigAttr:$precision_config - ); + let arguments = !con( + (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output), + ConvolutionAttributes.attributes); } def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td new file mode 100644 index 0000000..9cd7741 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_OPS_BASE +#define LHLO_OPS_BASE + +include "mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" + +//===----------------------------------------------------------------------===// +// LMHLO type definitions. +//===----------------------------------------------------------------------===// + +// Any integer tensor types +def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; + +// Any floating-point tensor types +def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; + +def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; + +def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; + +def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; + +// Any integer or floating-point tensor types +def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; + +def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; + +def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; + +def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; + +#endif // LHLO_OPS_BASE