[NFC] Extract common attributes of MHLO and LMHLO ConvOp
- Create a common class that hold ConvOp attributes that are common to MHLO and LHLO Dialects and use Tablegen DAG concat to append these common attributes to dialect specific arguments in ODS for ConvOp. PiperOrigin-RevId: 336114003
This commit is contained in:
parent
3a99598b59
commit
41436ea0d9
|
@ -912,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(hinsu): Make this struct dialect independent so that it can be shared
|
|
||||||
// between HLO and LHLO dialect.
|
|
||||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
|
|
||||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
||||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
||||||
|
|
||||||
let description = "Structure of dimension information for conv op";
|
|
||||||
}
|
|
||||||
|
|
||||||
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
||||||
let arguments = (ins
|
let arguments = !con(
|
||||||
HLO_Tensor:$lhs,
|
(ins
|
||||||
HLO_Tensor:$rhs,
|
HLO_Tensor:$lhs,
|
||||||
// Default value: one for each of the spatial dimension.
|
HLO_Tensor:$rhs),
|
||||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
ConvolutionAttributes<HLO_Dialect>.attributes);
|
||||||
// Default value: zero for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$padding,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
|
||||||
ConvDimensionNumbers:$dimension_numbers,
|
|
||||||
I64Attr:$feature_group_count,
|
|
||||||
I64Attr:$batch_group_count,
|
|
||||||
HLO_PrecisionConfigAttr:$precision_config
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1007,6 +1007,42 @@ class BASE_HLO_ConcatenateOp {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Common convolution attributes
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class ConvDimensionNumbersBase<Dialect dialect>
|
||||||
|
: StructAttr<"ConvDimensionNumbers", dialect, [
|
||||||
|
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||||
|
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||||
|
|
||||||
|
let description = "Structure of dimension information for conv op";
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConvolutionAttributes<Dialect dialect> {
|
||||||
|
dag attributes = (ins
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||||
|
// Default value: zero for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$padding,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||||
|
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
|
||||||
|
I64Attr:$feature_group_count,
|
||||||
|
I64Attr:$batch_group_count,
|
||||||
|
HLO_PrecisionConfigAttr:$precision_config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
class BASE_HLO_ConvOp {
|
class BASE_HLO_ConvOp {
|
||||||
string summary = "Convolution operator";
|
string summary = "Convolution operator";
|
||||||
|
|
||||||
|
|
|
@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/CopyOpInterface.td"
|
include "mlir/Interfaces/CopyOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||||
|
|
||||||
def LHLO_Dialect : Dialect {
|
def LHLO_Dialect : Dialect {
|
||||||
let name = "lmhlo";
|
let name = "lmhlo";
|
||||||
let cppNamespace = "::mlir::lmhlo";
|
let cppNamespace = "::mlir::lmhlo";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// LMHLO type definitions.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// Any integer tensor types
|
|
||||||
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
|
||||||
|
|
||||||
// Any floating-point tensor types
|
|
||||||
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
|
||||||
|
|
||||||
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
|
||||||
|
|
||||||
// Any integer or floating-point tensor types
|
|
||||||
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
|
||||||
|
|
||||||
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
|
||||||
|
|
||||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LMHLO nullary op definitions.
|
// LMHLO nullary op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -593,40 +568,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(bondhugula): Make this struct dialect independent so that it can be
|
|
||||||
// shared between the HLO and LHLO dialects.
|
|
||||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
|
|
||||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
||||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
||||||
|
|
||||||
let description = "Structure of dimension information for conv op";
|
|
||||||
}
|
|
||||||
|
|
||||||
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||||
let arguments = (ins
|
let arguments = !con(
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
(ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
// Default value: one for each of the spatial dimension.
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
|
||||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
ConvolutionAttributes<LHLO_Dialect>.attributes);
|
||||||
// Default value: zero for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$padding,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
|
||||||
ConvDimensionNumbers:$dimension_numbers,
|
|
||||||
I64Attr:$feature_group_count,
|
|
||||||
I64Attr:$batch_group_count,
|
|
||||||
HLO_PrecisionConfigAttr:$precision_config
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef LHLO_OPS_BASE
|
||||||
|
#define LHLO_OPS_BASE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// LMHLO type definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Any integer tensor types
|
||||||
|
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
||||||
|
|
||||||
|
// Any floating-point tensor types
|
||||||
|
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
||||||
|
|
||||||
|
// Any integer or floating-point tensor types
|
||||||
|
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
||||||
|
|
||||||
|
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
||||||
|
|
||||||
|
#endif // LHLO_OPS_BASE
|
Loading…
Reference in New Issue