Make LMHLO's Dot have the same power as MHLO's DotGeneral.
PiperOrigin-RevId: 337391565
This commit is contained in:
parent
8fe5bc89bb
commit
51cd4200b6
|
@ -25,8 +25,22 @@ function(add_mlir_hlo_dialect dialect dialect_namespace)
|
|||
endfunction()
|
||||
|
||||
add_mlir_hlo_dialect(chlo_ops chlo)
|
||||
add_mlir_hlo_dialect(hlo_ops mhlo)
|
||||
add_mlir_hlo_dialect(lhlo_ops lmhlo)
|
||||
add_mlir_hlo_dialect(lhlo_gpu_ops lmhlo_gpu)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS hlo_ops.td)
|
||||
mlir_tablegen(hlo_ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(hlo_ops.cc.inc -gen-op-defs)
|
||||
mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls)
|
||||
mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs)
|
||||
add_public_tablegen_target(MLIRhlo_opsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td)
|
||||
mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs)
|
||||
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td)
|
||||
mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls)
|
||||
mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs)
|
||||
add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen)
|
||||
add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen)
|
||||
|
||||
add_mlir_interface(infer_fusibility_op_interface)
|
||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
@ -32,7 +33,7 @@ limitations under the License.
|
|||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
// clang-format off
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -25,11 +25,6 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
|||
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
||||
|
||||
def HLO_Dialect : Dialect {
|
||||
let name = "mhlo";
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
||||
Op<HLO_Dialect, mnemonic, traits> {
|
||||
// Whether this operation has a custom conversion to HLO or not.
|
||||
|
@ -136,8 +131,8 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
|||
}
|
||||
LogicalResult reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputOutputShapeEquality(int input, int output) {
|
||||
return true;
|
||||
|
@ -153,7 +148,7 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
|||
[NoSideEffect, SameOperandsAndResultShape],
|
||||
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value operand"
|
||||
"Value operand"
|
||||
>];
|
||||
}
|
||||
|
||||
|
@ -168,8 +163,7 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp<
|
|||
BASE_HLO_ConvertOp {
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &tblgen_state, Value operand, "
|
||||
"Type result_element_ty"
|
||||
"Value operand, Type result_element_ty"
|
||||
>];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
@ -293,8 +287,8 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
|||
}
|
||||
LogicalResult reifyReturnTypeShapes(
|
||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||
return deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
bool inferInputsShapeEquality(int lhs, int rhs) {
|
||||
return true;
|
||||
|
@ -458,7 +452,7 @@ def HLO_SendOp : HLO_Op<"send", []> {
|
|||
let arguments = (ins
|
||||
HLO_TensorOrTuple:$operand,
|
||||
HLO_Token:$token,
|
||||
ChannelHandle<HLO_Dialect>:$channel_id,
|
||||
ChannelHandle:$channel_id,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
||||
);
|
||||
|
||||
|
@ -483,7 +477,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
|
|||
|
||||
let arguments = (ins
|
||||
HLO_Token:$token,
|
||||
ChannelHandle<HLO_Dialect>:$channel_id,
|
||||
ChannelHandle:$channel_id,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
||||
);
|
||||
|
||||
|
@ -587,7 +581,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
|
|||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
I64ElementsAttr:$replica_groups,
|
||||
OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id
|
||||
OptionalAttr<ChannelHandle>:$channel_id
|
||||
);
|
||||
let regions = (region SizedRegion<1>:$computation);
|
||||
let results = (outs HLO_Tensor);
|
||||
|
@ -959,15 +953,6 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp {
|
|||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
||||
def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
|
||||
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
|
||||
]> {
|
||||
let description = "Structure of dimension information for dot product";
|
||||
}
|
||||
|
||||
def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
|
@ -1029,14 +1014,6 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
|
|||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
||||
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
|
||||
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"start_index_map", I64ElementsAttr>,
|
||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||
let description = "Structure of dimension information for gather";
|
||||
}
|
||||
|
||||
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
|
@ -1114,7 +1091,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
|
|||
HLO_Tensor:$operand,
|
||||
HLO_Tensor:$scatter_indices,
|
||||
HLO_Tensor:$updates,
|
||||
ScatterDimensionNumbers<HLO_Dialect>:$scatter_dimension_numbers,
|
||||
ScatterDimensionNumbers:$scatter_dimension_numbers,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
||||
);
|
||||
|
|
|
@ -18,6 +18,13 @@ limitations under the License.
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def HLO_Dialect : Dialect {
|
||||
let name = "mhlo";
|
||||
let cppNamespace = "::mlir::mhlo";
|
||||
}
|
||||
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
|
||||
|
||||
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
||||
|
||||
// TODO(hinsu): Use signed integers instead of signless integer which is being
|
||||
|
@ -614,15 +621,6 @@ class BASE_HLO_CaseOp {
|
|||
// XLA parallelism related op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Represents a unique identifier for each Send/Recv instruction pair or
|
||||
// optionally for collective instructions (AllReduce, CollectivePermute,
|
||||
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
|
||||
class ChannelHandle<Dialect dialect> : StructAttr<"ChannelHandle", dialect, [
|
||||
StructFieldAttr<"handle", I64Attr>,
|
||||
StructFieldAttr<"type", I64Attr>]> {
|
||||
let description = "two 64-bit integers 'handle' and 'type'";
|
||||
}
|
||||
|
||||
class BASE_HLO_ReplicaIdOp {
|
||||
string summary = "ReplicaId operator";
|
||||
|
||||
|
@ -712,6 +710,7 @@ def HLO_PrecisionConfigAttr:
|
|||
OptionalAttr<
|
||||
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fast Fourier Transform Type enum definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1011,21 +1010,6 @@ class BASE_HLO_ConcatenateOp {
|
|||
// Common convolution attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class ConvDimensionNumbersBase<Dialect dialect>
|
||||
: StructAttr<"ConvDimensionNumbers", dialect, [
|
||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||
|
||||
let description = "Structure of dimension information for conv op";
|
||||
}
|
||||
|
||||
class ConvolutionAttributes<Dialect dialect> {
|
||||
dag attributes = (ins
|
||||
// Default value: one for each of the spatial dimension.
|
||||
|
@ -1036,7 +1020,7 @@ class ConvolutionAttributes<Dialect dialect> {
|
|||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
|
||||
ConvDimensionNumbers:$dimension_numbers,
|
||||
I64Attr:$feature_group_count,
|
||||
I64Attr:$batch_group_count,
|
||||
HLO_PrecisionConfigAttr:$precision_config
|
||||
|
@ -1164,15 +1148,6 @@ class BASE_HLO_ReshapeOp {
|
|||
}];
|
||||
}
|
||||
|
||||
class ScatterDimensionNumbers<Dialect dialect> : StructAttr<
|
||||
"ScatterDimensionNumbers", dialect, [
|
||||
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||
let description = "Structure of dimension information for scatter";
|
||||
}
|
||||
|
||||
class BASE_HLO_ScatterOp {
|
||||
string summary = "Scatter operator";
|
||||
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines structures used in MHLO and LMHLO.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
// Order matters, this .inc header is not self-contained, and relies on the
|
||||
// #includes above.
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
|
|
@ -0,0 +1,73 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef HLO_OPS_BASE_STRUCTS
|
||||
#define HLO_OPS_BASE_STRUCTS
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dot dimensions enum definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
|
||||
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
|
||||
]> {
|
||||
let description = "Structure of dimension information for dot product";
|
||||
}
|
||||
|
||||
def ScatterDimensionNumbers : StructAttr<
|
||||
"ScatterDimensionNumbers", HLO_Dialect, [
|
||||
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||
let description = "Structure of dimension information for scatter";
|
||||
}
|
||||
|
||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
|
||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||
|
||||
let description = "Structure of dimension information for conv op";
|
||||
}
|
||||
|
||||
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
|
||||
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"start_index_map", I64ElementsAttr>,
|
||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||
let description = "Structure of dimension information for gather";
|
||||
}
|
||||
|
||||
|
||||
// Represents a unique identifier for each Send/Recv instruction pair or
|
||||
// optionally for collective instructions (AllReduce, CollectivePermute,
|
||||
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
|
||||
def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
|
||||
StructFieldAttr<"handle", I64Attr>,
|
||||
StructFieldAttr<"type", I64Attr>]> {
|
||||
let description = "two 64-bit integers 'handle' and 'type'";
|
||||
}
|
||||
|
||||
#endif // HLO_OPS_BASE_STRUCTS
|
|
@ -19,6 +19,10 @@ limitations under the License.
|
|||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
|
@ -28,6 +32,7 @@ limitations under the License.
|
|||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/CopyOpInterface.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
|
||||
|
@ -35,7 +40,6 @@ namespace mlir {
|
|||
class OpBuilder;
|
||||
} // namespace mlir
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo_gpu {
|
||||
|
|
|
@ -22,13 +22,10 @@ limitations under the License.
|
|||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td"
|
||||
|
||||
|
||||
def LHLO_GPU_Dialect : Dialect {
|
||||
let name = "lmhlo_gpu";
|
||||
let cppNamespace = "::mlir::lmhlo_gpu";
|
||||
}
|
||||
|
||||
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<LHLO_GPU_Dialect, mnemonic,
|
||||
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
|
||||
|
@ -109,13 +106,6 @@ def ActivationAttr : StrEnumAttr<"Activation",
|
|||
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
|
||||
ActivationModeBandPass]>;
|
||||
|
||||
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
|
||||
LHLO_GPU_Dialect, [
|
||||
StructFieldAttr<"algorithm", I64Attr>,
|
||||
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
|
||||
let description = "GPU Convolution backend configuration";
|
||||
}
|
||||
|
||||
def GpuConvolutionAttributes {
|
||||
dag attributes = !con(
|
||||
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
|
||||
|
@ -181,16 +171,6 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
|
|||
// LMHLO ops representing other library functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(jurahul): Share this with the MHLO dialect.
|
||||
def DotDimensionNumbersAttr : StructAttr<"DotDimensionNumbers", LHLO_GPU_Dialect, [
|
||||
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
|
||||
]> {
|
||||
let description = "Structure of dimension information for dot product";
|
||||
}
|
||||
|
||||
// output = alpha * (lhs * rhs)
|
||||
// Verify: beta = 0.0
|
||||
def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
||||
|
@ -198,7 +178,7 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
|||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbersAttr:$dot_dimension_numbers,
|
||||
DotDimensionNumbers:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
I64Attr:$batch_size,
|
||||
I64Attr:$algorithm);
|
||||
|
@ -211,7 +191,7 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
|
|||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbersAttr:$dot_dimension_numbers,
|
||||
DotDimensionNumbers:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
F64Attr:$beta,
|
||||
I64Attr:$batch_size,
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// We define the dialect here so that both structs and ops can refer to it.
|
||||
|
||||
#ifndef LHLO_GPU_OPS_BASE
|
||||
#define LHLO_GPU_OPS_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def LHLO_GPU_Dialect : Dialect {
|
||||
let name = "lmhlo_gpu";
|
||||
let cppNamespace = "::mlir::lmhlo_gpu";
|
||||
}
|
||||
|
||||
#endif // LHLO_GPU_OPS_BASE
|
|
@ -0,0 +1,30 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
* ==============================================================================*/
|
||||
|
||||
// This file defines structures used in the LMHLO_GPU dialect.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
// Order matters, this .inc header is not self-contained, and relies on the
|
||||
// #includes above.
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
|
|
@ -0,0 +1,29 @@
|
|||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef LHLO_GPU_OPS_STRUCTS
|
||||
#define LHLO_GPU_OPS_STRUCTS
|
||||
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
|
||||
|
||||
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
|
||||
LHLO_GPU_Dialect, [
|
||||
StructFieldAttr<"algorithm", I64Attr>,
|
||||
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
|
||||
let description = "GPU Convolution backend configuration";
|
||||
}
|
||||
|
||||
#endif // LHLO_GPU_OPS_STRUCTS
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
|
@ -33,11 +34,6 @@ limitations under the License.
|
|||
|
||||
namespace mlir {
|
||||
class OpBuilder;
|
||||
} // namespace mlir
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo {
|
||||
|
||||
class LmhloDialect : public Dialect {
|
||||
|
|
|
@ -592,6 +592,7 @@ def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
|
|||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
DotDimensionNumbers:$dot_dimension_numbers,
|
||||
HLO_PrecisionConfigAttr:$precision_config,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
);
|
||||
|
@ -623,7 +624,7 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
|
|||
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
ScatterDimensionNumbers<LHLO_Dialect>:$scatter_dimension_numbers,
|
||||
ScatterDimensionNumbers:$scatter_dimension_numbers,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
||||
);
|
||||
|
@ -699,7 +700,7 @@ def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
|
|||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
I64ElementsAttr:$replica_groups,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
|
||||
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
|
||||
OptionalAttr<ChannelHandle>:$channel_id,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
|
||||
);
|
||||
let regions = (region SizedRegion<1>:$computation);
|
||||
|
@ -712,7 +713,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
|
|||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
I64ElementsAttr:$source_target_pairs,
|
||||
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
|
||||
OptionalAttr<ChannelHandle>:$channel_id
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ add_mlir_library(MhloInferFusibilityOpInterface
|
|||
|
||||
add_mlir_dialect_library(MhloDialect
|
||||
hlo_ops.cc
|
||||
hlo_ops_base_structs.cc
|
||||
|
||||
DEPENDS
|
||||
MLIRhlo_opsIncGen
|
||||
|
@ -68,6 +69,7 @@ target_link_libraries(LmhloDialect PUBLIC MLIRIR)
|
|||
|
||||
add_mlir_dialect_library(LmhloGPUDialect
|
||||
lhlo_gpu_ops.cc
|
||||
lhlo_gpu_ops_structs.cc
|
||||
|
||||
DEPENDS
|
||||
MLIRlhlo_gpu_opsIncGen
|
||||
|
|
|
@ -63,8 +63,6 @@ namespace mlir {
|
|||
#include "hlo_patterns.cc.inc"
|
||||
} // namespace mlir
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"
|
|
@ -28,8 +28,6 @@ limitations under the License.
|
|||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
|
|
@ -29,7 +29,6 @@ limitations under the License.
|
|||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
|
|
@ -126,6 +126,60 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
|||
}
|
||||
};
|
||||
|
||||
// This specialization exists so that LMHLO's Dot can be given a specific set of
|
||||
// dimension numbers, when lowering from MHLO's Dot, which does not have
|
||||
// dimension numbers (it uses DotGeneral for this generalized notion of dot
|
||||
// products). When these two dialects are in sync with respect to the
|
||||
// Dot/DotGeneral issue, this specialization should be deleted.
|
||||
template <>
|
||||
class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DotOp hloOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = hloOp.getOperation();
|
||||
const auto& original_results = op->getResults();
|
||||
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : llvm::enumerate(original_results)) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) {
|
||||
return failure();
|
||||
}
|
||||
if (resultType.hasStaticShape()) {
|
||||
buffer_args.push_back(
|
||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
} else {
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
if (failed(
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||
return failure();
|
||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(silvasean): Move this helper to MLIR core.
|
||||
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
|
||||
auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
|
||||
rewriter.getIntegerType(64));
|
||||
return DenseIntElementsAttr::get(type, integers);
|
||||
};
|
||||
auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
// MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
|
||||
auto dimension_numbers = mhlo::DotDimensionNumbers::get(
|
||||
make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}),
|
||||
make_elements_attr({0}), rewriter.getContext());
|
||||
dotOp.dot_dimension_numbersAttr(dimension_numbers);
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||
public:
|
||||
|
@ -236,6 +290,43 @@ struct HloToLhloDynamicReshapeConverter
|
|||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDotGeneralOpConverter
|
||||
: public BaseOpConversion<mhlo::DotGeneralOp> {
|
||||
using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DotGeneralOp dotGeneralOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = dotGeneralOp.getOperation();
|
||||
|
||||
if (op->getResults().empty()) return failure();
|
||||
OpResult result = op->getResults()[0];
|
||||
RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) return failure();
|
||||
|
||||
// The third buffer argument will be filled with what used to be the return
|
||||
// type of the DotGeneral.
|
||||
if (operands.size() != 2) return failure();
|
||||
std::array<Value, 3> bufferArgs = {operands[0], operands[1], {}};
|
||||
|
||||
if (resultType.hasStaticShape()) {
|
||||
bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter);
|
||||
} else {
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||
return failure();
|
||||
|
||||
bufferArgs[2] = InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result, results_shape.front(), &rewriter);
|
||||
}
|
||||
|
||||
rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
|
||||
op->getAttrs());
|
||||
rewriter.replaceOp(op, bufferArgs[2]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
||||
|
@ -485,6 +576,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLhloDotGeneralOpConverter,
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloDynamicReshapeConverter,
|
||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||
|
|
|
@ -192,7 +192,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
|||
lmhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// Check validity of dimension information.
|
||||
if (const lmhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||
if (const mhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||
op.dimension_numbers()) {
|
||||
const int inputSpatialRank =
|
||||
llvm::size(dimensionNumbers.input_spatial_dimensions());
|
||||
|
|
|
@ -59,6 +59,20 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
|
|||
return failure();
|
||||
}
|
||||
|
||||
// We don't currently support batching dimensions, or multiple contraction
|
||||
// dimensions.
|
||||
mhlo::DotDimensionNumbers dot_dimension_numbers =
|
||||
op.dot_dimension_numbers();
|
||||
if (dot_dimension_numbers.lhs_batching_dimensions().size() > 0 ||
|
||||
dot_dimension_numbers.rhs_batching_dimensions().size() > 0)
|
||||
return failure();
|
||||
if (dot_dimension_numbers.lhs_contracting_dimensions().size() != 1 ||
|
||||
*dot_dimension_numbers.lhs_contracting_dimensions().begin() != 1 ||
|
||||
dot_dimension_numbers.rhs_contracting_dimensions().size() != 1 ||
|
||||
*dot_dimension_numbers.rhs_contracting_dimensions().begin() != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult map_status = success();
|
||||
auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
|
||||
SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},
|
||||
|
|
|
@ -511,7 +511,13 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
|||
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
|
||||
// dot_dimension_numbers = {
|
||||
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
// lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
|
||||
// rhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
// rhs_contracting_dimensions = dense<0> : tensor<1xi64>}}
|
||||
// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
%dot = "mhlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
|
@ -632,4 +638,4 @@ func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
|
|||
shape.assuming_yield %7 : tensor<?xf16>
|
||||
}
|
||||
return %2 : tensor<?xf16>
|
||||
}
|
||||
}
|
||||
|
|
|
@ -158,7 +158,14 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
|
|||
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
|
||||
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
|
||||
// CHECK: return
|
||||
"lmhlo.dot"(%lhs, %rhs, %result) :
|
||||
"lmhlo.dot"(%lhs, %rhs, %result) {
|
||||
dot_dimension_numbers = {
|
||||
lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
rhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
|
||||
rhs_contracting_dimensions = dense<0> : tensor<1xi64>
|
||||
}
|
||||
} :
|
||||
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -175,7 +182,14 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
|
|||
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
|
||||
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
|
||||
// CHECK: return
|
||||
"lmhlo.dot"(%lhs, %rhs, %result) :
|
||||
"lmhlo.dot"(%lhs, %rhs, %result) {
|
||||
dot_dimension_numbers = {
|
||||
lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
rhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
|
||||
rhs_contracting_dimensions = dense<0> : tensor<1xi64>
|
||||
}
|
||||
} :
|
||||
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue