[NFC] Extract common attributes of MHLO and LMHLO ConvOp

- Create a common class that hold ConvOp attributes that are common to MHLO and
  LHLO Dialects and use Tablegen DAG concat to append these common attributes to
  dialect specific arguments in ODS for ConvOp.

PiperOrigin-RevId: 336114003
This commit is contained in:
Rahul Joshi 2020-10-08 10:28:02 -07:00 committed by TensorFlow MLIR Team
parent 3a99598b59
commit 41436ea0d9
4 changed files with 95 additions and 91 deletions

View File

@ -912,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
// TODO(hinsu): Make this struct dialect independent so that it can be shared
// between HLO and LHLO dialect.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
let arguments = (ins let arguments = !con(
HLO_Tensor:$lhs, (ins
HLO_Tensor:$rhs, HLO_Tensor:$lhs,
// Default value: one for each of the spatial dimension. HLO_Tensor:$rhs),
OptionalAttr<I64ElementsAttr>:$window_strides, ConvolutionAttributes<HLO_Dialect>.attributes);
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }

View File

@ -1007,6 +1007,42 @@ class BASE_HLO_ConcatenateOp {
}]; }];
} }
//===----------------------------------------------------------------------===//
// Common convolution attributes
//===----------------------------------------------------------------------===//
class ConvDimensionNumbersBase<Dialect dialect>
: StructAttr<"ConvDimensionNumbers", dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
class ConvolutionAttributes<Dialect dialect> {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
}
class BASE_HLO_ConvOp { class BASE_HLO_ConvOp {
string summary = "Convolution operator"; string summary = "Convolution operator";

View File

@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
def LHLO_Dialect : Dialect { def LHLO_Dialect : Dialect {
let name = "lmhlo"; let name = "lmhlo";
let cppNamespace = "::mlir::lmhlo"; let cppNamespace = "::mlir::lmhlo";
} }
//===----------------------------------------------------------------------===//
// LMHLO type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LMHLO nullary op definitions. // LMHLO nullary op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -593,40 +568,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
); );
} }
// TODO(bondhugula): Make this struct dialect independent so that it can be
// shared between the HLO and LHLO dialects.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
let arguments = (ins let arguments = !con(
Arg<LHLO_Buffer, "", [MemRead]>:$lhs, (ins
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
// Default value: one for each of the spatial dimension. Arg<LHLO_Buffer, "", [MemWrite]>:$output),
OptionalAttr<I64ElementsAttr>:$window_strides, ConvolutionAttributes<LHLO_Dialect>.attributes);
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
} }
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_OPS_BASE
#define LHLO_OPS_BASE
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
//===----------------------------------------------------------------------===//
// LMHLO type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
#endif // LHLO_OPS_BASE