Make LMHLO's Dot have the same power as MHLO's DotGeneral.

PiperOrigin-RevId: 337391565
This commit is contained in:
A. Unique TensorFlower 2020-10-15 15:08:30 -07:00 committed by TensorFlow MLIR Team
parent 8fe5bc89bb
commit 51cd4200b6
25 changed files with 427 additions and 113 deletions

View File

@ -25,8 +25,22 @@ function(add_mlir_hlo_dialect dialect dialect_namespace)
endfunction() endfunction()
add_mlir_hlo_dialect(chlo_ops chlo) add_mlir_hlo_dialect(chlo_ops chlo)
add_mlir_hlo_dialect(hlo_ops mhlo)
add_mlir_hlo_dialect(lhlo_ops lmhlo) add_mlir_hlo_dialect(lhlo_ops lmhlo)
add_mlir_hlo_dialect(lhlo_gpu_ops lmhlo_gpu)
set(LLVM_TARGET_DEFINITIONS hlo_ops.td)
mlir_tablegen(hlo_ops.h.inc -gen-op-decls)
mlir_tablegen(hlo_ops.cc.inc -gen-op-defs)
mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRhlo_opsIncGen)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td)
mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls)
mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs)
set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td)
mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen)
add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen)
add_mlir_interface(infer_fusibility_op_interface) add_mlir_interface(infer_fusibility_op_interface)

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
@ -32,7 +33,7 @@ limitations under the License.
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
// clang-format off // clang-format off
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
// clang-format on // clang-format on

View File

@ -25,11 +25,6 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
def HLO_Dialect : Dialect {
let name = "mhlo";
let cppNamespace = "::mlir::mhlo";
}
class HLO_Op<string mnemonic, list<OpTrait> traits> : class HLO_Op<string mnemonic, list<OpTrait> traits> :
Op<HLO_Dialect, mnemonic, traits> { Op<HLO_Dialect, mnemonic, traits> {
// Whether this operation has a custom conversion to HLO or not. // Whether this operation has a custom conversion to HLO or not.
@ -136,8 +131,8 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
} }
LogicalResult reifyReturnTypeShapes( LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(), return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes); &reifiedReturnShapes);
} }
bool inferInputOutputShapeEquality(int input, int output) { bool inferInputOutputShapeEquality(int input, int output) {
return true; return true;
@ -153,7 +148,7 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
[NoSideEffect, SameOperandsAndResultShape], [NoSideEffect, SameOperandsAndResultShape],
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
let builders = [OpBuilder< let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value operand" "Value operand"
>]; >];
} }
@ -168,8 +163,7 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp<
BASE_HLO_ConvertOp { BASE_HLO_ConvertOp {
let builders = [OpBuilder< let builders = [OpBuilder<
"OpBuilder &, OperationState &tblgen_state, Value operand, " "Value operand, Type result_element_ty"
"Type result_element_ty"
>]; >];
let hasFolder = 1; let hasFolder = 1;
@ -293,8 +287,8 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
} }
LogicalResult reifyReturnTypeShapes( LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return deriveShapeFromFirstOperand(&builder, getOperation(), return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(),
&reifiedReturnShapes); &reifiedReturnShapes);
} }
bool inferInputsShapeEquality(int lhs, int rhs) { bool inferInputsShapeEquality(int lhs, int rhs) {
return true; return true;
@ -458,7 +452,7 @@ def HLO_SendOp : HLO_Op<"send", []> {
let arguments = (ins let arguments = (ins
HLO_TensorOrTuple:$operand, HLO_TensorOrTuple:$operand,
HLO_Token:$token, HLO_Token:$token,
ChannelHandle<HLO_Dialect>:$channel_id, ChannelHandle:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
); );
@ -483,7 +477,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
let arguments = (ins let arguments = (ins
HLO_Token:$token, HLO_Token:$token,
ChannelHandle<HLO_Dialect>:$channel_id, ChannelHandle:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
); );
@ -587,7 +581,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
I64ElementsAttr:$replica_groups, I64ElementsAttr:$replica_groups,
OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id OptionalAttr<ChannelHandle>:$channel_id
); );
let regions = (region SizedRegion<1>:$computation); let regions = (region SizedRegion<1>:$computation);
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
@ -959,15 +953,6 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp {
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
]> {
let description = "Structure of dimension information for dot product";
}
def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp { def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp {
let arguments = (ins let arguments = (ins
HLO_Tensor:$lhs, HLO_Tensor:$lhs,
@ -1029,14 +1014,6 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp {
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
StructFieldAttr<"start_index_map", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for gather";
}
def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp {
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
@ -1114,7 +1091,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
HLO_Tensor:$operand, HLO_Tensor:$operand,
HLO_Tensor:$scatter_indices, HLO_Tensor:$scatter_indices,
HLO_Tensor:$updates, HLO_Tensor:$updates,
ScatterDimensionNumbers<HLO_Dialect>:$scatter_dimension_numbers, ScatterDimensionNumbers:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted, DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices DefaultValuedAttr<BoolAttr, "false">:$unique_indices
); );

View File

@ -18,6 +18,13 @@ limitations under the License.
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def HLO_Dialect : Dialect {
let name = "mhlo";
let cppNamespace = "::mlir::mhlo";
}
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td"
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">; def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
// TODO(hinsu): Use signed integers instead of signless integer which is being // TODO(hinsu): Use signed integers instead of signless integer which is being
@ -614,15 +621,6 @@ class BASE_HLO_CaseOp {
// XLA parallelism related op definitions. // XLA parallelism related op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
class ChannelHandle<Dialect dialect> : StructAttr<"ChannelHandle", dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
class BASE_HLO_ReplicaIdOp { class BASE_HLO_ReplicaIdOp {
string summary = "ReplicaId operator"; string summary = "ReplicaId operator";
@ -712,6 +710,7 @@ def HLO_PrecisionConfigAttr:
OptionalAttr< OptionalAttr<
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>; TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Fast Fourier Transform Type enum definitions. // Fast Fourier Transform Type enum definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1011,21 +1010,6 @@ class BASE_HLO_ConcatenateOp {
// Common convolution attributes // Common convolution attributes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class ConvDimensionNumbersBase<Dialect dialect>
: StructAttr<"ConvDimensionNumbers", dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
class ConvolutionAttributes<Dialect dialect> { class ConvolutionAttributes<Dialect dialect> {
dag attributes = (ins dag attributes = (ins
// Default value: one for each of the spatial dimension. // Default value: one for each of the spatial dimension.
@ -1036,7 +1020,7 @@ class ConvolutionAttributes<Dialect dialect> {
OptionalAttr<I64ElementsAttr>:$lhs_dilation, OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension. // Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation, OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbersBase<dialect>:$dimension_numbers, ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count, I64Attr:$feature_group_count,
I64Attr:$batch_group_count, I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config HLO_PrecisionConfigAttr:$precision_config
@ -1164,15 +1148,6 @@ class BASE_HLO_ReshapeOp {
}]; }];
} }
class ScatterDimensionNumbers<Dialect dialect> : StructAttr<
"ScatterDimensionNumbers", dialect, [
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for scatter";
}
class BASE_HLO_ScatterOp { class BASE_HLO_ScatterOp {
string summary = "Scatter operator"; string summary = "Scatter operator";

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines structures used in MHLO and LMHLO.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
// Order matters, this .inc header is not self-contained, and relies on the
// #includes above.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef HLO_OPS_BASE_STRUCTS
#define HLO_OPS_BASE_STRUCTS
//===----------------------------------------------------------------------===//
// Dot dimensions enum definitions.
//===----------------------------------------------------------------------===//
def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
]> {
let description = "Structure of dimension information for dot product";
}
def ScatterDimensionNumbers : StructAttr<
"ScatterDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for scatter";
}
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
[StructFieldAttr<"offset_dims", I64ElementsAttr>,
StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>,
StructFieldAttr<"start_index_map", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for gather";
}
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
#endif // HLO_OPS_BASE_STRUCTS

View File

@ -19,6 +19,10 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
@ -28,6 +32,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h"
@ -35,7 +40,6 @@ namespace mlir {
class OpBuilder; class OpBuilder;
} // namespace mlir } // namespace mlir
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"
namespace mlir { namespace mlir {
namespace lmhlo_gpu { namespace lmhlo_gpu {

View File

@ -22,13 +22,10 @@ limitations under the License.
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td"
def LHLO_GPU_Dialect : Dialect {
let name = "lmhlo_gpu";
let cppNamespace = "::mlir::lmhlo_gpu";
}
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> : class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
Op<LHLO_GPU_Dialect, mnemonic, Op<LHLO_GPU_Dialect, mnemonic,
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>; !listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
@ -109,13 +106,6 @@ def ActivationAttr : StrEnumAttr<"Activation",
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
ActivationModeBandPass]>; ActivationModeBandPass]>;
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
LHLO_GPU_Dialect, [
StructFieldAttr<"algorithm", I64Attr>,
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
let description = "GPU Convolution backend configuration";
}
def GpuConvolutionAttributes { def GpuConvolutionAttributes {
dag attributes = !con( dag attributes = !con(
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes, ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
@ -181,16 +171,6 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
// LMHLO ops representing other library functions. // LMHLO ops representing other library functions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TODO(jurahul): Share this with the MHLO dialect.
def DotDimensionNumbersAttr : StructAttr<"DotDimensionNumbers", LHLO_GPU_Dialect, [
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
]> {
let description = "Structure of dimension information for dot product";
}
// output = alpha * (lhs * rhs) // output = alpha * (lhs * rhs)
// Verify: beta = 0.0 // Verify: beta = 0.0
def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
@ -198,7 +178,7 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
Arg<LHLO_Buffer, "", [MemRead]>:$lhs, Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$output, Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbersAttr:$dot_dimension_numbers, DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha, F64Attr:$alpha,
I64Attr:$batch_size, I64Attr:$batch_size,
I64Attr:$algorithm); I64Attr:$algorithm);
@ -211,7 +191,7 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$bias, Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemRead]>:$output, Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbersAttr:$dot_dimension_numbers, DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha, F64Attr:$alpha,
F64Attr:$beta, F64Attr:$beta,
I64Attr:$batch_size, I64Attr:$batch_size,

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// We define the dialect here so that both structs and ops can refer to it.
#ifndef LHLO_GPU_OPS_BASE
#define LHLO_GPU_OPS_BASE
include "mlir/IR/OpBase.td"
def LHLO_GPU_Dialect : Dialect {
let name = "lmhlo_gpu";
let cppNamespace = "::mlir::lmhlo_gpu";
}
#endif // LHLO_GPU_OPS_BASE

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ==============================================================================*/
// This file defines structures used in the LMHLO_GPU dialect.
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
// Order matters, this .inc header is not self-contained, and relies on the
// #includes above.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_GPU_OPS_STRUCTS
#define LHLO_GPU_OPS_STRUCTS
include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td"
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
LHLO_GPU_Dialect, [
StructFieldAttr<"algorithm", I64Attr>,
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
let description = "GPU Convolution backend configuration";
}
#endif // LHLO_GPU_OPS_STRUCTS

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
@ -33,11 +34,6 @@ limitations under the License.
namespace mlir { namespace mlir {
class OpBuilder; class OpBuilder;
} // namespace mlir
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"
namespace mlir {
namespace lmhlo { namespace lmhlo {
class LmhloDialect : public Dialect { class LmhloDialect : public Dialect {

View File

@ -592,6 +592,7 @@ def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs, Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
DotDimensionNumbers:$dot_dimension_numbers,
HLO_PrecisionConfigAttr:$precision_config, HLO_PrecisionConfigAttr:$precision_config,
Arg<LHLO_Buffer, "", [MemWrite]>:$output Arg<LHLO_Buffer, "", [MemWrite]>:$output
); );
@ -623,7 +624,7 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices, Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
Arg<LHLO_Buffer, "", [MemRead]>:$updates, Arg<LHLO_Buffer, "", [MemRead]>:$updates,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemWrite]>:$output,
ScatterDimensionNumbers<LHLO_Dialect>:$scatter_dimension_numbers, ScatterDimensionNumbers:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted, DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices DefaultValuedAttr<BoolAttr, "false">:$unique_indices
); );
@ -699,7 +700,7 @@ def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$replica_groups, I64ElementsAttr:$replica_groups,
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout, DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id, OptionalAttr<ChannelHandle>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
); );
let regions = (region SizedRegion<1>:$computation); let regions = (region SizedRegion<1>:$computation);
@ -712,7 +713,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
Arg<LHLO_Buffer, "", [MemRead]>:$operand, Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$source_target_pairs, I64ElementsAttr:$source_target_pairs,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id OptionalAttr<ChannelHandle>:$channel_id
); );
} }

View File

@ -43,6 +43,7 @@ add_mlir_library(MhloInferFusibilityOpInterface
add_mlir_dialect_library(MhloDialect add_mlir_dialect_library(MhloDialect
hlo_ops.cc hlo_ops.cc
hlo_ops_base_structs.cc
DEPENDS DEPENDS
MLIRhlo_opsIncGen MLIRhlo_opsIncGen
@ -68,6 +69,7 @@ target_link_libraries(LmhloDialect PUBLIC MLIRIR)
add_mlir_dialect_library(LmhloGPUDialect add_mlir_dialect_library(LmhloGPUDialect
lhlo_gpu_ops.cc lhlo_gpu_ops.cc
lhlo_gpu_ops_structs.cc
DEPENDS DEPENDS
MLIRlhlo_gpu_opsIncGen MLIRlhlo_gpu_opsIncGen

View File

@ -63,8 +63,6 @@ namespace mlir {
#include "hlo_patterns.cc.inc" #include "hlo_patterns.cc.inc"
} // namespace mlir } // namespace mlir
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"

View File

@ -28,8 +28,6 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"

View File

@ -29,7 +29,6 @@ limitations under the License.
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"

View File

@ -0,0 +1,17 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h"

View File

@ -126,6 +126,60 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
} }
}; };
// This specialization exists so that LMHLO's Dot can be given a specific set of
// dimension numbers, when lowering from MHLO's Dot, which does not have
// dimension numbers (it uses DotGeneral for this generalized notion of dot
// products). When these two dialects are in sync with respect to the
// Dot/DotGeneral issue, this specialization should be deleted.
template <>
class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
public:
using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DotOp hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
if (failed(
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
// TODO(silvasean): Move this helper to MLIR core.
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
rewriter.getIntegerType(64));
return DenseIntElementsAttr::get(type, integers);
};
auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
// MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
auto dimension_numbers = mhlo::DotDimensionNumbers::get(
make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}),
make_elements_attr({0}), rewriter.getContext());
dotOp.dot_dimension_numbersAttr(dimension_numbers);
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success();
}
};
struct HloToLhloDynamicBroadcastInDimOpConverter struct HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> { : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public: public:
@ -236,6 +290,43 @@ struct HloToLhloDynamicReshapeConverter
} }
}; };
struct HloToLhloDotGeneralOpConverter
: public BaseOpConversion<mhlo::DotGeneralOp> {
using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DotGeneralOp dotGeneralOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = dotGeneralOp.getOperation();
if (op->getResults().empty()) return failure();
OpResult result = op->getResults()[0];
RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
if (!resultType) return failure();
// The third buffer argument will be filled with what used to be the return
// type of the DotGeneral.
if (operands.size() != 2) return failure();
std::array<Value, 3> bufferArgs = {operands[0], operands[1], {}};
if (resultType.hasStaticShape()) {
bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter);
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
bufferArgs[2] = InsertDynamicAllocAndDealloc(
op->getLoc(), result, results_shape.front(), &rewriter);
}
rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
op->getAttrs());
rewriter.replaceOp(op, bufferArgs[2]);
return success();
}
};
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> { struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
public: public:
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion; using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
@ -485,6 +576,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
HloToLhloDotGeneralOpConverter,
HloToLhloDynamicBroadcastInDimOpConverter, HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicReshapeConverter, HloToLhloDynamicReshapeConverter,
HloToLhloOpConverter<mhlo::AbsOp>, HloToLhloOpConverter<mhlo::AbsOp>,

View File

@ -192,7 +192,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
lmhlo::ConvOp op, ArrayRef<Value> args, lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information. // Check validity of dimension information.
if (const lmhlo::ConvDimensionNumbers& dimensionNumbers = if (const mhlo::ConvDimensionNumbers& dimensionNumbers =
op.dimension_numbers()) { op.dimension_numbers()) {
const int inputSpatialRank = const int inputSpatialRank =
llvm::size(dimensionNumbers.input_spatial_dimensions()); llvm::size(dimensionNumbers.input_spatial_dimensions());

View File

@ -59,6 +59,20 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
return failure(); return failure();
} }
// We don't currently support batching dimensions, or multiple contraction
// dimensions.
mhlo::DotDimensionNumbers dot_dimension_numbers =
op.dot_dimension_numbers();
if (dot_dimension_numbers.lhs_batching_dimensions().size() > 0 ||
dot_dimension_numbers.rhs_batching_dimensions().size() > 0)
return failure();
if (dot_dimension_numbers.lhs_contracting_dimensions().size() != 1 ||
*dot_dimension_numbers.lhs_contracting_dimensions().begin() != 1 ||
dot_dimension_numbers.rhs_contracting_dimensions().size() != 1 ||
*dot_dimension_numbers.rhs_contracting_dimensions().begin() != 0) {
return failure();
}
LogicalResult map_status = success(); LogicalResult map_status = success();
auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) { auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]}, SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},

View File

@ -511,7 +511,13 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc // BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () // BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
// dot_dimension_numbers = {
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
// lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
// rhs_batching_dimensions = dense<> : tensor<0xi64>,
// rhs_contracting_dimensions = dense<0> : tensor<1xi64>}}
// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "mhlo.dot"(%arg0, %arg0) %dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]]) // PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])

View File

@ -158,7 +158,14 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
// CHECK: return // CHECK: return
"lmhlo.dot"(%lhs, %rhs, %result) : "lmhlo.dot"(%lhs, %rhs, %result) {
dot_dimension_numbers = {
lhs_batching_dimensions = dense<> : tensor<0xi64>,
rhs_batching_dimensions = dense<> : tensor<0xi64>,
lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
rhs_contracting_dimensions = dense<0> : tensor<1xi64>
}
} :
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
return return
} }
@ -175,7 +182,14 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
// CHECK: return // CHECK: return
"lmhlo.dot"(%lhs, %rhs, %result) : "lmhlo.dot"(%lhs, %rhs, %result) {
dot_dimension_numbers = {
lhs_batching_dimensions = dense<> : tensor<0xi64>,
rhs_batching_dimensions = dense<> : tensor<0xi64>,
lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
rhs_contracting_dimensions = dense<0> : tensor<1xi64>
}
} :
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
return return
} }