[MLIR:HLO] Generate enum decls for HLO and LHLO GPU dialects.

- Split out enum definitions in hlo dialect into a separate .td file (similar to structs)
  and generate enum decl/defs for these enums.
- Also split out the LHLO GPU enums into a separate .td file and generate enum
  decl/defs for these enums as well.
- Remove unused dialect from ConvolutionAttributes and generate lhlo_gpu enums.
- Add appropriate namespace for all the enums.

PiperOrigin-RevId: 345277240
This commit is contained in:
Rahul Joshi 2020-12-02 11:38:26 -08:00 committed by TensorFlow MLIR Team
parent d7bd5233ab
commit dbbdfea95b
15 changed files with 274 additions and 114 deletions

View File

@ -32,6 +32,8 @@ mlir_tablegen(hlo_ops.h.inc -gen-op-decls)
mlir_tablegen(hlo_ops.cc.inc -gen-op-defs) mlir_tablegen(hlo_ops.cc.inc -gen-op-defs)
mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls) mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs) mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs)
mlir_tablegen(hlo_ops_base_enums.h.inc -gen-enum-decls)
mlir_tablegen(hlo_ops_base_enums.cc.inc -gen-enum-defs)
add_public_tablegen_target(MLIRhlo_opsIncGen) add_public_tablegen_target(MLIRhlo_opsIncGen)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td) set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td)
@ -40,6 +42,9 @@ mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td) set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td)
mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls) mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs) mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_enums.td)
mlir_tablegen(lhlo_gpu_ops_enums.h.inc -gen-enum-decls)
mlir_tablegen(lhlo_gpu_ops_enums.cc.inc -gen-enum-defs)
add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen) add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen)
add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen) add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen)

View File

@ -34,6 +34,7 @@ limitations under the License.
// clang-format off // clang-format off
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
// clang-format on // clang-format on

View File

@ -41,7 +41,9 @@ def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">;
def HLO_CUSTOM_FUSION : StrEnumAttrCase<"kCustom">; def HLO_CUSTOM_FUSION : StrEnumAttrCase<"kCustom">;
def HLO_FusionKindAttr : StrEnumAttr<"FusionKind", "fusion kind", [ def HLO_FusionKindAttr : StrEnumAttr<"FusionKind", "fusion kind", [
HLO_LOOP_FUSION, HLO_INPUT_FUSION, HLO_OUTPUT_FUSION, HLO_CUSTOM_FUSION HLO_LOOP_FUSION, HLO_INPUT_FUSION, HLO_OUTPUT_FUSION, HLO_CUSTOM_FUSION
]>; ]> {
let cppNamespace = "::mlir::mhlo";
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MHLO nullary op definitions. // MHLO nullary op definitions.
@ -896,7 +898,7 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
(ins (ins
HLO_Tensor:$lhs, HLO_Tensor:$lhs,
HLO_Tensor:$rhs), HLO_Tensor:$rhs),
ConvolutionAttributes<HLO_Dialect>.attributes); ConvolutionAttributes.attributes);
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }

View File

@ -23,6 +23,7 @@ def HLO_Dialect : Dialect {
let cppNamespace = "::mlir::mhlo"; let cppNamespace = "::mlir::mhlo";
} }
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">; def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
@ -692,77 +693,7 @@ class BASE_HLO_TupleOp {
}]; }];
} }
//===----------------------------------------------------------------------===//
// Precision Config enum definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA PrecisionConfig proto enum.
def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">;
def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">;
def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">;
def HLO_PrecisionAttr : StrEnumAttr<"Precision",
"XLA precision for an operand. Has backend specific meaning.",
[HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]>;
// TODO(b/129153247) See if it's possible to also validate the size.
def HLO_PrecisionConfigAttr:
OptionalAttr<
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
//===----------------------------------------------------------------------===//
// Fast Fourier Transform Type enum definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA FftType proto enum.
def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">;
def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">;
def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">;
def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">;
def HLO_FftTypeAttr : StrEnumAttr<"FftType",
"XLA fast fourier transform type.",
[HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT,
HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]>;
//===----------------------------------------------------------------------===//
// Comparison op definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA ComparisonDirection enum.
def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">;
def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">;
def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">;
def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">;
def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">;
def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">;
def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
"Which comparison operation to perform.",
[
HLO_COMPARISON_DIRECTION_EQ,
HLO_COMPARISON_DIRECTION_NE,
HLO_COMPARISON_DIRECTION_GE,
HLO_COMPARISON_DIRECTION_GT,
HLO_COMPARISON_DIRECTION_LE,
HLO_COMPARISON_DIRECTION_LT
]>;
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
"Which comparison type to use.",
[
HLO_COMPARISON_TYPE_FLOAT,
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
HLO_COMPARISON_TYPE_SIGNED,
HLO_COMPARISON_TYPE_UNSIGNED
]>;
class BASE_HLO_CompareOp { class BASE_HLO_CompareOp {
@ -783,13 +714,6 @@ class BASE_HLO_CompareOp {
// Quantize op definitions. // Quantize op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// These mirror the XLA ComparisonDirection enum.
def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">;
def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode",
"Dequantization mode. Only MIN_COMBINED is supported.",
[HLO_MIN_COMBINED]>;
class BASE_HLO_DequantizeOp { class BASE_HLO_DequantizeOp {
string summary = "Dequantize operator"; string summary = "Dequantize operator";
@ -1029,7 +953,12 @@ class BASE_HLO_ConcatenateOp {
// Common convolution attributes // Common convolution attributes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class ConvolutionAttributes<Dialect dialect> { // TODO(b/129153247) See if it's possible to also validate the size.
def HLO_PrecisionConfigAttr:
OptionalAttr<
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
def ConvolutionAttributes {
dag attributes = (ins dag attributes = (ins
// Default value: one for each of the spatial dimension. // Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides, OptionalAttr<I64ElementsAttr>:$window_strides,
@ -1270,21 +1199,6 @@ class BASE_HLO_TransposeOp {
}]; }];
} }
// These mirror the XLA Transpose enum in Triangular Solve options.
def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">;
def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">;
def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">;
def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">;
def HLO_TransposeAttr : StrEnumAttr<"Transpose",
"Transpose options",
[
HLO_TRANSPOSE_INVALID,
HLO_NO_TRANSPOSE,
HLO_TRANSPOSE,
HLO_ADJOINT
]>;
class BASE_HLO_TriangularSolveOp { class BASE_HLO_TriangularSolveOp {
string summary = "TriangularSolve operator"; string summary = "TriangularSolve operator";

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines enums used in MHLO and LMHLO.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
// Order matters, this .inc header is not self-contained, and relies on the
// #includes above.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_

View File

@ -0,0 +1,119 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef HLO_OPS_BASE_ENUMS
#define HLO_OPS_BASE_ENUMS
//===----------------------------------------------------------------------===//
// Precision Config enum definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA PrecisionConfig proto enum.
def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">;
def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">;
def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">;
def HLO_PrecisionAttr : StrEnumAttr<"Precision",
"XLA precision for an operand. Has backend specific meaning.",
[HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]> {
let cppNamespace = "::mlir::mhlo";
}
//===----------------------------------------------------------------------===//
// Fast Fourier Transform Type enum definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA FftType proto enum.
def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">;
def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">;
def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">;
def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">;
def HLO_FftTypeAttr : StrEnumAttr<"FftType",
"XLA fast fourier transform type.",
[HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT,
HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]> {
let cppNamespace = "::mlir::mhlo";
}
//===----------------------------------------------------------------------===//
// Comparison op definitions.
//===----------------------------------------------------------------------===//
// These mirror the XLA ComparisonDirection enum.
def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">;
def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">;
def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">;
def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">;
def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">;
def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">;
def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
"Which comparison operation to perform.",
[
HLO_COMPARISON_DIRECTION_EQ,
HLO_COMPARISON_DIRECTION_NE,
HLO_COMPARISON_DIRECTION_GE,
HLO_COMPARISON_DIRECTION_GT,
HLO_COMPARISON_DIRECTION_LE,
HLO_COMPARISON_DIRECTION_LT
]> {
let cppNamespace = "::mlir::mhlo";
}
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
"Which comparison type to use.",
[
HLO_COMPARISON_TYPE_FLOAT,
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
HLO_COMPARISON_TYPE_SIGNED,
HLO_COMPARISON_TYPE_UNSIGNED
]> {
let cppNamespace = "::mlir::mhlo";
}
// These mirror the XLA Dequantize mode string enum.
def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">;
def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode",
"Dequantization mode. Only MIN_COMBINED is supported.",
[HLO_MIN_COMBINED]> {
let cppNamespace = "::mlir::mhlo";
}
// These mirror the XLA Transpose enum in Triangular Solve options.
def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">;
def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">;
def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">;
def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">;
def HLO_TransposeAttr : StrEnumAttr<"Transpose",
"Transpose options",
[
HLO_TRANSPOSE_INVALID,
HLO_NO_TRANSPOSE,
HLO_TRANSPOSE,
HLO_ADJOINT
]> {
let cppNamespace = "::mlir::mhlo";
}
#endif // HLO_OPS_BASE_ENUMS

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"

View File

@ -23,9 +23,9 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td"
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> : class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
Op<LHLO_GPU_Dialect, mnemonic, Op<LHLO_GPU_Dialect, mnemonic,
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>; !listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
@ -92,30 +92,16 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
// LMHLO ops representing convolution library functions. // LMHLO ops representing convolution library functions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ActivationModeNone : StrEnumAttrCase<"None">;
def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">;
def ActivationModeTanh : StrEnumAttrCase<"Relu">;
def ActivationModeRelu : StrEnumAttrCase<"Relu">;
def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">;
def ActivationModeReluX : StrEnumAttrCase<"ReluX">;
def ActivationModeBandPass : StrEnumAttrCase<"BandPass">;
def ActivationAttr : StrEnumAttr<"Activation",
"Activation applied with fused convolution",
[ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh,
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
ActivationModeBandPass]>;
def GpuConvolutionAttributes { def GpuConvolutionAttributes {
dag attributes = !con( dag attributes = !con(
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes, ConvolutionAttributes.attributes,
(ins F64Attr:$result_scale), (ins F64Attr:$result_scale),
(ins ConvolutionBackendConfigAttr:$backend_config)); (ins ConvolutionBackendConfigAttr:$backend_config));
} }
def GpuFusedConvolutionAttributes { def GpuFusedConvolutionAttributes {
dag attributes = !con( dag attributes = !con(
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes, ConvolutionAttributes.attributes,
(ins F64Attr:$result_scale, (ins F64Attr:$result_scale,
ActivationAttr:$activation_mode, ActivationAttr:$activation_mode,
F64Attr:$side_input_scale), F64Attr:$side_input_scale),

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ==============================================================================*/
// This file defines enums used in the LMHLO_GPU dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
// Order matters, this .inc header is not self-contained, and relies on the
// #includes above.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_GPU_OPS_ENUMS
#define LHLO_GPU_OPS_ENUMS
include "mlir/IR/OpBase.td"
def ActivationModeNone : StrEnumAttrCase<"None">;
def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">;
def ActivationModeTanh : StrEnumAttrCase<"Tanh">;
def ActivationModeRelu : StrEnumAttrCase<"Relu">;
def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">;
def ActivationModeReluX : StrEnumAttrCase<"ReluX">;
def ActivationModeBandPass : StrEnumAttrCase<"BandPass">;
def ActivationAttr : StrEnumAttr<"Activation",
"Activation applied with fused convolution",
[ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh,
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
ActivationModeBandPass]> {
let cppNamespace = "::mlir::lmhlo_gpu";
}
#endif // LHLO_GPU_OPS_ENUMS

View File

@ -1,4 +1,3 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -415,7 +415,7 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
Arg<LHLO_Buffer, "", [MemRead]>:$lhs, Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output), Arg<LHLO_Buffer, "", [MemWrite]>:$output),
ConvolutionAttributes<LHLO_Dialect>.attributes); ConvolutionAttributes.attributes);
} }
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {

View File

@ -44,6 +44,7 @@ add_mlir_library(MhloInferFusibilityOpInterface
add_mlir_dialect_library(MhloDialect add_mlir_dialect_library(MhloDialect
hlo_ops.cc hlo_ops.cc
hlo_ops_base_structs.cc hlo_ops_base_structs.cc
hlo_ops_base_enums.cc
DEPENDS DEPENDS
MLIRhlo_opsIncGen MLIRhlo_opsIncGen
@ -70,6 +71,7 @@ target_link_libraries(LmhloDialect PUBLIC MLIRIR)
add_mlir_dialect_library(LmhloGPUDialect add_mlir_dialect_library(LmhloGPUDialect
lhlo_gpu_ops.cc lhlo_gpu_ops.cc
lhlo_gpu_ops_structs.cc lhlo_gpu_ops_structs.cc
lhlo_gpu_ops_enums.cc
DEPENDS DEPENDS
MLIRlhlo_gpu_opsIncGen MLIRlhlo_gpu_opsIncGen

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.cc.inc"

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc.inc"