Make LMHLO's Dot have the same power as MHLO's DotGeneral.
PiperOrigin-RevId: 337391565
This commit is contained in:
parent
8fe5bc89bb
commit
51cd4200b6
|
@ -25,8 +25,22 @@ function(add_mlir_hlo_dialect dialect dialect_namespace)
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
add_mlir_hlo_dialect(chlo_ops chlo)
|
add_mlir_hlo_dialect(chlo_ops chlo)
|
||||||
add_mlir_hlo_dialect(hlo_ops mhlo)
|
|
||||||
add_mlir_hlo_dialect(lhlo_ops lmhlo)
|
add_mlir_hlo_dialect(lhlo_ops lmhlo)
|
||||||
add_mlir_hlo_dialect(lhlo_gpu_ops lmhlo_gpu)
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS hlo_ops.td)
|
||||||
|
mlir_tablegen(hlo_ops.h.inc -gen-op-decls)
|
||||||
|
mlir_tablegen(hlo_ops.cc.inc -gen-op-defs)
|
||||||
|
mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls)
|
||||||
|
mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs)
|
||||||
|
add_public_tablegen_target(MLIRhlo_opsIncGen)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td)
|
||||||
|
mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls)
|
||||||
|
mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs)
|
||||||
|
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td)
|
||||||
|
mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls)
|
||||||
|
mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs)
|
||||||
|
add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen)
|
||||||
|
add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen)
|
||||||
|
|
||||||
add_mlir_interface(infer_fusibility_op_interface)
|
add_mlir_interface(infer_fusibility_op_interface)
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
@ -32,7 +33,7 @@ limitations under the License.
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
|
|
@ -25,11 +25,6 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
||||||
|
|
||||||
def HLO_Dialect : Dialect {
|
|
||||||
let name = "mhlo";
|
|
||||||
let cppNamespace = "::mlir::mhlo";
|
|
||||||
}
|
|
||||||
|
|
||||||
class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
||||||
Op<HLO_Dialect, mnemonic, traits> {
|
Op<HLO_Dialect, mnemonic, traits> {
|
||||||
// Whether this operation has a custom conversion to HLO or not.
|
// Whether this operation has a custom conversion to HLO or not.
|
||||||
|
@ -136,8 +131,8 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
||||||
}
|
}
|
||||||
LogicalResult reifyReturnTypeShapes(
|
LogicalResult reifyReturnTypeShapes(
|
||||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||||
&reifiedReturnShapes);
|
&reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
bool inferInputOutputShapeEquality(int input, int output) {
|
bool inferInputOutputShapeEquality(int input, int output) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -153,7 +148,7 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
||||||
[NoSideEffect, SameOperandsAndResultShape],
|
[NoSideEffect, SameOperandsAndResultShape],
|
||||||
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
|
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"OpBuilder &builder, OperationState &result, Value operand"
|
"Value operand"
|
||||||
>];
|
>];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,8 +163,7 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp<
|
||||||
BASE_HLO_ConvertOp {
|
BASE_HLO_ConvertOp {
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value operand, "
|
"Value operand, Type result_element_ty"
|
||||||
"Type result_element_ty"
|
|
||||||
>];
|
>];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
@ -293,8 +287,8 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||||
}
|
}
|
||||||
LogicalResult reifyReturnTypeShapes(
|
LogicalResult reifyReturnTypeShapes(
|
||||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||||
&reifiedReturnShapes);
|
&reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
bool inferInputsShapeEquality(int lhs, int rhs) {
|
bool inferInputsShapeEquality(int lhs, int rhs) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -458,7 +452,7 @@ def HLO_SendOp : HLO_Op<"send", []> {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_TensorOrTuple:$operand,
|
HLO_TensorOrTuple:$operand,
|
||||||
HLO_Token:$token,
|
HLO_Token:$token,
|
||||||
ChannelHandle<HLO_Dialect>:$channel_id,
|
ChannelHandle:$channel_id,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -483,7 +477,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Token:$token,
|
HLO_Token:$token,
|
||||||
ChannelHandle<HLO_Dialect>:$channel_id,
|
ChannelHandle:$channel_id,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -587,7 +581,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
I64ElementsAttr:$replica_groups,
|
I64ElementsAttr:$replica_groups,
|
||||||
OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id
|
OptionalAttr<ChannelHandle>:$channel_id
|
||||||
);
|
);
|
||||||
let regions = (region SizedRegion<1>:$computation);
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
|
@ -959,15 +953,6 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp {
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
|
|
||||||
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
|
|
||||||
]> {
|
|
||||||
let description = "Structure of dimension information for dot product";
|
|
||||||
}
|
|
||||||
|
|
||||||
def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp {
|
def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$lhs,
|
HLO_Tensor:$lhs,
|
||||||
|
@ -1029,14 +1014,6 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
|
|
||||||
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"start_index_map", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
|
||||||
let description = "Structure of dimension information for gather";
|
|
||||||
}
|
|
||||||
|
|
||||||
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp {
|
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
|
@ -1114,7 +1091,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
|
||||||
HLO_Tensor:$operand,
|
HLO_Tensor:$operand,
|
||||||
HLO_Tensor:$scatter_indices,
|
HLO_Tensor:$scatter_indices,
|
||||||
HLO_Tensor:$updates,
|
HLO_Tensor:$updates,
|
||||||
ScatterDimensionNumbers<HLO_Dialect>:$scatter_dimension_numbers,
|
ScatterDimensionNumbers:$scatter_dimension_numbers,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
||||||
);
|
);
|
||||||
|
|
|
@ -18,6 +18,13 @@ limitations under the License.
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def HLO_Dialect : Dialect {
|
||||||
|
let name = "mhlo";
|
||||||
|
let cppNamespace = "::mlir::mhlo";
|
||||||
|
}
|
||||||
|
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
|
||||||
|
|
||||||
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
||||||
|
|
||||||
// TODO(hinsu): Use signed integers instead of signless integer which is being
|
// TODO(hinsu): Use signed integers instead of signless integer which is being
|
||||||
|
@ -614,15 +621,6 @@ class BASE_HLO_CaseOp {
|
||||||
// XLA parallelism related op definitions.
|
// XLA parallelism related op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Represents a unique identifier for each Send/Recv instruction pair or
|
|
||||||
// optionally for collective instructions (AllReduce, CollectivePermute,
|
|
||||||
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
|
|
||||||
class ChannelHandle<Dialect dialect> : StructAttr<"ChannelHandle", dialect, [
|
|
||||||
StructFieldAttr<"handle", I64Attr>,
|
|
||||||
StructFieldAttr<"type", I64Attr>]> {
|
|
||||||
let description = "two 64-bit integers 'handle' and 'type'";
|
|
||||||
}
|
|
||||||
|
|
||||||
class BASE_HLO_ReplicaIdOp {
|
class BASE_HLO_ReplicaIdOp {
|
||||||
string summary = "ReplicaId operator";
|
string summary = "ReplicaId operator";
|
||||||
|
|
||||||
|
@ -712,6 +710,7 @@ def HLO_PrecisionConfigAttr:
|
||||||
OptionalAttr<
|
OptionalAttr<
|
||||||
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
|
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Fast Fourier Transform Type enum definitions.
|
// Fast Fourier Transform Type enum definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1011,21 +1010,6 @@ class BASE_HLO_ConcatenateOp {
|
||||||
// Common convolution attributes
|
// Common convolution attributes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class ConvDimensionNumbersBase<Dialect dialect>
|
|
||||||
: StructAttr<"ConvDimensionNumbers", dialect, [
|
|
||||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
||||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
||||||
|
|
||||||
let description = "Structure of dimension information for conv op";
|
|
||||||
}
|
|
||||||
|
|
||||||
class ConvolutionAttributes<Dialect dialect> {
|
class ConvolutionAttributes<Dialect dialect> {
|
||||||
dag attributes = (ins
|
dag attributes = (ins
|
||||||
// Default value: one for each of the spatial dimension.
|
// Default value: one for each of the spatial dimension.
|
||||||
|
@ -1036,7 +1020,7 @@ class ConvolutionAttributes<Dialect dialect> {
|
||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||||
// Default value: one for each of the spatial dimension.
|
// Default value: one for each of the spatial dimension.
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||||
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
|
ConvDimensionNumbers:$dimension_numbers,
|
||||||
I64Attr:$feature_group_count,
|
I64Attr:$feature_group_count,
|
||||||
I64Attr:$batch_group_count,
|
I64Attr:$batch_group_count,
|
||||||
HLO_PrecisionConfigAttr:$precision_config
|
HLO_PrecisionConfigAttr:$precision_config
|
||||||
|
@ -1164,15 +1148,6 @@ class BASE_HLO_ReshapeOp {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
class ScatterDimensionNumbers<Dialect dialect> : StructAttr<
|
|
||||||
"ScatterDimensionNumbers", dialect, [
|
|
||||||
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
|
||||||
let description = "Structure of dimension information for scatter";
|
|
||||||
}
|
|
||||||
|
|
||||||
class BASE_HLO_ScatterOp {
|
class BASE_HLO_ScatterOp {
|
||||||
string summary = "Scatter operator";
|
string summary = "Scatter operator";
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file defines structures used in MHLO and LMHLO.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/Identifier.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
|
// Order matters, this .inc header is not self-contained, and relies on the
|
||||||
|
// #includes above.
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
|
|
@ -0,0 +1,73 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef HLO_OPS_BASE_STRUCTS
|
||||||
|
#define HLO_OPS_BASE_STRUCTS
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dot dimensions enum definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
|
||||||
|
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
|
||||||
|
]> {
|
||||||
|
let description = "Structure of dimension information for dot product";
|
||||||
|
}
|
||||||
|
|
||||||
|
def ScatterDimensionNumbers : StructAttr<
|
||||||
|
"ScatterDimensionNumbers", HLO_Dialect, [
|
||||||
|
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||||
|
let description = "Structure of dimension information for scatter";
|
||||||
|
}
|
||||||
|
|
||||||
|
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
|
||||||
|
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||||
|
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||||
|
|
||||||
|
let description = "Structure of dimension information for conv op";
|
||||||
|
}
|
||||||
|
|
||||||
|
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
|
||||||
|
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"start_index_map", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||||
|
let description = "Structure of dimension information for gather";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Represents a unique identifier for each Send/Recv instruction pair or
|
||||||
|
// optionally for collective instructions (AllReduce, CollectivePermute,
|
||||||
|
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
|
||||||
|
def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
|
||||||
|
StructFieldAttr<"handle", I64Attr>,
|
||||||
|
StructFieldAttr<"type", I64Attr>]> {
|
||||||
|
let description = "two 64-bit integers 'handle' and 'type'";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // HLO_OPS_BASE_STRUCTS
|
|
@ -19,6 +19,10 @@ limitations under the License.
|
||||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
|
@ -28,6 +32,7 @@ limitations under the License.
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/Interfaces/CopyOpInterface.h"
|
#include "mlir/Interfaces/CopyOpInterface.h"
|
||||||
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
|
||||||
|
@ -35,7 +40,6 @@ namespace mlir {
|
||||||
class OpBuilder;
|
class OpBuilder;
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace lmhlo_gpu {
|
namespace lmhlo_gpu {
|
||||||
|
|
|
@ -22,13 +22,10 @@ limitations under the License.
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td"
|
||||||
|
|
||||||
|
|
||||||
def LHLO_GPU_Dialect : Dialect {
|
|
||||||
let name = "lmhlo_gpu";
|
|
||||||
let cppNamespace = "::mlir::lmhlo_gpu";
|
|
||||||
}
|
|
||||||
|
|
||||||
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
|
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
Op<LHLO_GPU_Dialect, mnemonic,
|
Op<LHLO_GPU_Dialect, mnemonic,
|
||||||
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
|
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
|
||||||
|
@ -109,13 +106,6 @@ def ActivationAttr : StrEnumAttr<"Activation",
|
||||||
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
|
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
|
||||||
ActivationModeBandPass]>;
|
ActivationModeBandPass]>;
|
||||||
|
|
||||||
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
|
|
||||||
LHLO_GPU_Dialect, [
|
|
||||||
StructFieldAttr<"algorithm", I64Attr>,
|
|
||||||
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
|
|
||||||
let description = "GPU Convolution backend configuration";
|
|
||||||
}
|
|
||||||
|
|
||||||
def GpuConvolutionAttributes {
|
def GpuConvolutionAttributes {
|
||||||
dag attributes = !con(
|
dag attributes = !con(
|
||||||
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
|
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
|
||||||
|
@ -181,16 +171,6 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
|
||||||
// LMHLO ops representing other library functions.
|
// LMHLO ops representing other library functions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// TODO(jurahul): Share this with the MHLO dialect.
|
|
||||||
def DotDimensionNumbersAttr : StructAttr<"DotDimensionNumbers", LHLO_GPU_Dialect, [
|
|
||||||
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
|
|
||||||
]> {
|
|
||||||
let description = "Structure of dimension information for dot product";
|
|
||||||
}
|
|
||||||
|
|
||||||
// output = alpha * (lhs * rhs)
|
// output = alpha * (lhs * rhs)
|
||||||
// Verify: beta = 0.0
|
// Verify: beta = 0.0
|
||||||
def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
||||||
|
@ -198,7 +178,7 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||||
DotDimensionNumbersAttr:$dot_dimension_numbers,
|
DotDimensionNumbers:$dot_dimension_numbers,
|
||||||
F64Attr:$alpha,
|
F64Attr:$alpha,
|
||||||
I64Attr:$batch_size,
|
I64Attr:$batch_size,
|
||||||
I64Attr:$algorithm);
|
I64Attr:$algorithm);
|
||||||
|
@ -211,7 +191,7 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||||
DotDimensionNumbersAttr:$dot_dimension_numbers,
|
DotDimensionNumbers:$dot_dimension_numbers,
|
||||||
F64Attr:$alpha,
|
F64Attr:$alpha,
|
||||||
F64Attr:$beta,
|
F64Attr:$beta,
|
||||||
I64Attr:$batch_size,
|
I64Attr:$batch_size,
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// We define the dialect here so that both structs and ops can refer to it.
|
||||||
|
|
||||||
|
#ifndef LHLO_GPU_OPS_BASE
|
||||||
|
#define LHLO_GPU_OPS_BASE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def LHLO_GPU_Dialect : Dialect {
|
||||||
|
let name = "lmhlo_gpu";
|
||||||
|
let cppNamespace = "::mlir::lmhlo_gpu";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // LHLO_GPU_OPS_BASE
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
* ==============================================================================*/
|
||||||
|
|
||||||
|
// This file defines structures used in the LMHLO_GPU dialect.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/Identifier.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
|
// Order matters, this .inc header is not self-contained, and relies on the
|
||||||
|
// #includes above.
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
|
|
@ -0,0 +1,29 @@
|
||||||
|
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef LHLO_GPU_OPS_STRUCTS
|
||||||
|
#define LHLO_GPU_OPS_STRUCTS
|
||||||
|
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
|
||||||
|
|
||||||
|
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
|
||||||
|
LHLO_GPU_Dialect, [
|
||||||
|
StructFieldAttr<"algorithm", I64Attr>,
|
||||||
|
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
|
||||||
|
let description = "GPU Convolution backend configuration";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // LHLO_GPU_OPS_STRUCTS
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
|
@ -33,11 +34,6 @@ limitations under the License.
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class OpBuilder;
|
class OpBuilder;
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace lmhlo {
|
namespace lmhlo {
|
||||||
|
|
||||||
class LmhloDialect : public Dialect {
|
class LmhloDialect : public Dialect {
|
||||||
|
|
|
@ -592,6 +592,7 @@ def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
|
DotDimensionNumbers:$dot_dimension_numbers,
|
||||||
HLO_PrecisionConfigAttr:$precision_config,
|
HLO_PrecisionConfigAttr:$precision_config,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
);
|
);
|
||||||
|
@ -623,7 +624,7 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
|
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
|
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
ScatterDimensionNumbers<LHLO_Dialect>:$scatter_dimension_numbers,
|
ScatterDimensionNumbers:$scatter_dimension_numbers,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
||||||
);
|
);
|
||||||
|
@ -699,7 +700,7 @@ def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
I64ElementsAttr:$replica_groups,
|
I64ElementsAttr:$replica_groups,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
|
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
|
||||||
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
|
OptionalAttr<ChannelHandle>:$channel_id,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
|
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
|
||||||
);
|
);
|
||||||
let regions = (region SizedRegion<1>:$computation);
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
|
@ -712,7 +713,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
I64ElementsAttr:$source_target_pairs,
|
I64ElementsAttr:$source_target_pairs,
|
||||||
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
|
OptionalAttr<ChannelHandle>:$channel_id
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,7 @@ add_mlir_library(MhloInferFusibilityOpInterface
|
||||||
|
|
||||||
add_mlir_dialect_library(MhloDialect
|
add_mlir_dialect_library(MhloDialect
|
||||||
hlo_ops.cc
|
hlo_ops.cc
|
||||||
|
hlo_ops_base_structs.cc
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRhlo_opsIncGen
|
MLIRhlo_opsIncGen
|
||||||
|
@ -68,6 +69,7 @@ target_link_libraries(LmhloDialect PUBLIC MLIRIR)
|
||||||
|
|
||||||
add_mlir_dialect_library(LmhloGPUDialect
|
add_mlir_dialect_library(LmhloGPUDialect
|
||||||
lhlo_gpu_ops.cc
|
lhlo_gpu_ops.cc
|
||||||
|
lhlo_gpu_ops_structs.cc
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRlhlo_gpu_opsIncGen
|
MLIRlhlo_gpu_opsIncGen
|
||||||
|
|
|
@ -63,8 +63,6 @@ namespace mlir {
|
||||||
#include "hlo_patterns.cc.inc"
|
#include "hlo_patterns.cc.inc"
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"
|
|
@ -28,8 +28,6 @@ limitations under the License.
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
|
|
@ -29,7 +29,6 @@ limitations under the License.
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
|
|
@ -126,6 +126,60 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This specialization exists so that LMHLO's Dot can be given a specific set of
|
||||||
|
// dimension numbers, when lowering from MHLO's Dot, which does not have
|
||||||
|
// dimension numbers (it uses DotGeneral for this generalized notion of dot
|
||||||
|
// products). When these two dialects are in sync with respect to the
|
||||||
|
// Dot/DotGeneral issue, this specialization should be deleted.
|
||||||
|
template <>
|
||||||
|
class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::DotOp hloOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
Operation* op = hloOp.getOperation();
|
||||||
|
const auto& original_results = op->getResults();
|
||||||
|
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
||||||
|
for (auto result : llvm::enumerate(original_results)) {
|
||||||
|
RankedTensorType resultType =
|
||||||
|
result.value().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!resultType) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (resultType.hasStaticShape()) {
|
||||||
|
buffer_args.push_back(
|
||||||
|
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||||
|
} else {
|
||||||
|
SmallVector<Value, 1> results_shape;
|
||||||
|
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||||
|
if (!shape_type_op) return failure();
|
||||||
|
if (failed(
|
||||||
|
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||||
|
return failure();
|
||||||
|
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||||
|
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(silvasean): Move this helper to MLIR core.
|
||||||
|
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
|
||||||
|
auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
|
||||||
|
rewriter.getIntegerType(64));
|
||||||
|
return DenseIntElementsAttr::get(type, integers);
|
||||||
|
};
|
||||||
|
auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
|
||||||
|
buffer_args, op->getAttrs());
|
||||||
|
// MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
|
||||||
|
auto dimension_numbers = mhlo::DotDimensionNumbers::get(
|
||||||
|
make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}),
|
||||||
|
make_elements_attr({0}), rewriter.getContext());
|
||||||
|
dotOp.dot_dimension_numbersAttr(dimension_numbers);
|
||||||
|
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -236,6 +290,43 @@ struct HloToLhloDynamicReshapeConverter
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct HloToLhloDotGeneralOpConverter
|
||||||
|
: public BaseOpConversion<mhlo::DotGeneralOp> {
|
||||||
|
using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::DotGeneralOp dotGeneralOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
Operation* op = dotGeneralOp.getOperation();
|
||||||
|
|
||||||
|
if (op->getResults().empty()) return failure();
|
||||||
|
OpResult result = op->getResults()[0];
|
||||||
|
RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!resultType) return failure();
|
||||||
|
|
||||||
|
// The third buffer argument will be filled with what used to be the return
|
||||||
|
// type of the DotGeneral.
|
||||||
|
if (operands.size() != 2) return failure();
|
||||||
|
std::array<Value, 3> bufferArgs = {operands[0], operands[1], {}};
|
||||||
|
|
||||||
|
if (resultType.hasStaticShape()) {
|
||||||
|
bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter);
|
||||||
|
} else {
|
||||||
|
SmallVector<Value, 1> results_shape;
|
||||||
|
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||||
|
if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bufferArgs[2] = InsertDynamicAllocAndDealloc(
|
||||||
|
op->getLoc(), result, results_shape.front(), &rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
|
||||||
|
op->getAttrs());
|
||||||
|
rewriter.replaceOp(op, bufferArgs[2]);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||||
public:
|
public:
|
||||||
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
||||||
|
@ -485,6 +576,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
|
HloToLhloDotGeneralOpConverter,
|
||||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||||
HloToLhloDynamicReshapeConverter,
|
HloToLhloDynamicReshapeConverter,
|
||||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||||
|
|
|
@ -192,7 +192,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
||||||
lmhlo::ConvOp op, ArrayRef<Value> args,
|
lmhlo::ConvOp op, ArrayRef<Value> args,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
// Check validity of dimension information.
|
// Check validity of dimension information.
|
||||||
if (const lmhlo::ConvDimensionNumbers& dimensionNumbers =
|
if (const mhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||||
op.dimension_numbers()) {
|
op.dimension_numbers()) {
|
||||||
const int inputSpatialRank =
|
const int inputSpatialRank =
|
||||||
llvm::size(dimensionNumbers.input_spatial_dimensions());
|
llvm::size(dimensionNumbers.input_spatial_dimensions());
|
||||||
|
|
|
@ -59,6 +59,20 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We don't currently support batching dimensions, or multiple contraction
|
||||||
|
// dimensions.
|
||||||
|
mhlo::DotDimensionNumbers dot_dimension_numbers =
|
||||||
|
op.dot_dimension_numbers();
|
||||||
|
if (dot_dimension_numbers.lhs_batching_dimensions().size() > 0 ||
|
||||||
|
dot_dimension_numbers.rhs_batching_dimensions().size() > 0)
|
||||||
|
return failure();
|
||||||
|
if (dot_dimension_numbers.lhs_contracting_dimensions().size() != 1 ||
|
||||||
|
*dot_dimension_numbers.lhs_contracting_dimensions().begin() != 1 ||
|
||||||
|
dot_dimension_numbers.rhs_contracting_dimensions().size() != 1 ||
|
||||||
|
*dot_dimension_numbers.rhs_contracting_dimensions().begin() != 0) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult map_status = success();
|
LogicalResult map_status = success();
|
||||||
auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
|
auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
|
||||||
SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},
|
SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},
|
||||||
|
|
|
@ -511,7 +511,13 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||||
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||||
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||||
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
||||||
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
|
||||||
|
// dot_dimension_numbers = {
|
||||||
|
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||||
|
// lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
|
||||||
|
// rhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||||
|
// rhs_contracting_dimensions = dense<0> : tensor<1xi64>}}
|
||||||
|
// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||||
%dot = "mhlo.dot"(%arg0, %arg0)
|
%dot = "mhlo.dot"(%arg0, %arg0)
|
||||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||||
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||||
|
@ -632,4 +638,4 @@ func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
|
||||||
shape.assuming_yield %7 : tensor<?xf16>
|
shape.assuming_yield %7 : tensor<?xf16>
|
||||||
}
|
}
|
||||||
return %2 : tensor<?xf16>
|
return %2 : tensor<?xf16>
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,7 +158,14 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
|
||||||
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
|
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
|
||||||
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
|
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
"lmhlo.dot"(%lhs, %rhs, %result) :
|
"lmhlo.dot"(%lhs, %rhs, %result) {
|
||||||
|
dot_dimension_numbers = {
|
||||||
|
lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||||
|
rhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||||
|
lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
|
||||||
|
rhs_contracting_dimensions = dense<0> : tensor<1xi64>
|
||||||
|
}
|
||||||
|
} :
|
||||||
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
|
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -175,7 +182,14 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
|
||||||
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
|
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
|
||||||
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
|
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
"lmhlo.dot"(%lhs, %rhs, %result) :
|
"lmhlo.dot"(%lhs, %rhs, %result) {
|
||||||
|
dot_dimension_numbers = {
|
||||||
|
lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||||
|
rhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||||
|
lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
|
||||||
|
rhs_contracting_dimensions = dense<0> : tensor<1xi64>
|
||||||
|
}
|
||||||
|
} :
|
||||||
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
|
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue