Add GPU specific LMHLO level ops

- Introduce operations in a new lmhlo_gpu dialect that map to GPU library function calls
  in the XLA:GPU backend.
- Add basic unit tests as well.

PiperOrigin-RevId: 337132166
This commit is contained in:
Rahul Joshi 2020-10-14 11:23:08 -07:00 committed by TensorFlow MLIR Team
parent 8506f1f26a
commit f6b4e6758a
9 changed files with 467 additions and 2 deletions

View File

@ -27,5 +27,6 @@ endfunction()
add_mlir_hlo_dialect(chlo_ops chlo) add_mlir_hlo_dialect(chlo_ops chlo)
add_mlir_hlo_dialect(hlo_ops mhlo) add_mlir_hlo_dialect(hlo_ops mhlo)
add_mlir_hlo_dialect(lhlo_ops lmhlo) add_mlir_hlo_dialect(lhlo_ops lmhlo)
add_mlir_hlo_dialect(lhlo_gpu_ops lmhlo_gpu)
add_mlir_interface(infer_fusibility_op_interface) add_mlir_interface(infer_fusibility_op_interface)

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the LHLO dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
class OpBuilder;
} // namespace mlir
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"
namespace mlir {
namespace lmhlo_gpu {
class LmhloGpuDialect : public Dialect {
public:
explicit LmhloGpuDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "lmhlo_gpu"; }
};
} // namespace lmhlo_gpu
} // end namespace mlir
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_

View File

@ -0,0 +1,230 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for LHMLO level GPU operations.
// Because these are LMHLO level operations, they operate on memrefs.
#ifndef LHLO_GPU_OPS
#define LHLO_GPU_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
def LHLO_GPU_Dialect : Dialect {
let name = "lmhlo_gpu";
let cppNamespace = "::mlir::lmhlo_gpu";
}
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
Op<LHLO_GPU_Dialect, mnemonic,
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
// Type for scratch buffers used by GPU library calls (memref<?xi8>)
def UntypedBuffer : MemRefRankOf<[I8], [1]>;
// Cholesky info output buffer type.
def I32Buffer : MemRefOf<[I32]>;
//===----------------------------------------------------------------------===//
// LMHLO ops representing batch norm library functions.
//===----------------------------------------------------------------------===//
// Note: these are semantically different from similar LHLO as the GPU library
// calls generate or consume standard deviation, whereas LHLO ops generate or
// consume variance (= std-dev ^ 2).
def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">,
BASE_HLO_BatchNormGradOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">,
BASE_HLO_BatchNormInferenceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index);
}
def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
BASE_HLO_BatchNormTrainingOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_stddev,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
//===----------------------------------------------------------------------===//
// LMHLO ops representing convolution library functions.
//===----------------------------------------------------------------------===//
def ActivationModeNone : StrEnumAttrCase<"None">;
def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">;
def ActivationModeTanh : StrEnumAttrCase<"Relu">;
def ActivationModeRelu : StrEnumAttrCase<"Relu">;
def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">;
def ActivationModeReluX : StrEnumAttrCase<"ReluX">;
def ActivationModeBandPass : StrEnumAttrCase<"BandPass">;
def ActivationAttr : StrEnumAttr<"Activation",
"Activation applied with fused convolution",
[ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh,
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
ActivationModeBandPass]>;
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
LHLO_GPU_Dialect, [
StructFieldAttr<"algorithm", I64Attr>,
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
let description = "GPU Convolution backend configuration";
}
def GpuConvolutionAttributes {
dag attributes = !con(
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
(ins F64Attr:$result_scale),
(ins ConvolutionBackendConfigAttr:$backend_config));
}
def GpuFusedConvolutionAttributes {
dag attributes = !con(
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
(ins F64Attr:$result_scale,
ActivationAttr:$activation_mode,
F64Attr:$side_input_scale),
(ins ConvolutionBackendConfigAttr:$backend_config));
}
def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes.attributes);
}
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_input,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes.attributes);
}
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_filter,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes.attributes);
}
// output = activation(result_scale * conv(input, filter) +
// side_input * side_input_scale +
// bias)
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemRead]>:$side_input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuFusedConvolutionAttributes.attributes);
}
//===----------------------------------------------------------------------===//
// LMHLO ops representing other library functions.
//===----------------------------------------------------------------------===//
// TODO(jurahul): Share this with the MHLO dialect.
def DotDimensionNumbersAttr : StructAttr<"DotDimensionNumbers", LHLO_GPU_Dialect, [
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
]> {
let description = "Structure of dimension information for dot product";
}
// output = alpha * (lhs * rhs)
// Verify: beta = 0.0
def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbersAttr:$dot_dimension_numbers,
F64Attr:$alpha,
I64Attr:$batch_size,
I64Attr:$algorithm);
}
// output = alpha(lhs * rhs) + beta * bias
def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbersAttr:$dot_dimension_numbers,
F64Attr:$alpha,
F64Attr:$beta,
I64Attr:$batch_size,
I64Attr:$algorithm);
}
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch,
Arg<I32Buffer, "", [MemWrite]>:$info,
BoolAttr:$is_upper);
}
#endif // LHLO_GPU_OPS

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This file defines the operations used in the LXLA dialect. // This file defines the operations used in the LHLO dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_

View File

@ -66,6 +66,14 @@ add_mlir_dialect_library(LmhloDialect
) )
target_link_libraries(LmhloDialect PUBLIC MLIRIR) target_link_libraries(LmhloDialect PUBLIC MLIRIR)
add_mlir_dialect_library(LmhloGPUDialect
lhlo_gpu_ops.cc
DEPENDS
MLIRlhlo_gpu_opsIncGen
)
target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR)
add_mlir_dialect_library(MhloRegisterDialects add_mlir_dialect_library(MhloRegisterDialects
init.cc init.cc
@ -73,10 +81,12 @@ DEPENDS
MLIRchlo_opsIncGen MLIRchlo_opsIncGen
MLIRhlo_opsIncGen MLIRhlo_opsIncGen
MLIRlhlo_opsIncGen MLIRlhlo_opsIncGen
MLIRlhlo_gpu_opsIncGen
) )
target_link_libraries(MhloRegisterDialects target_link_libraries(MhloRegisterDialects
PUBLIC PUBLIC
ChloDialect ChloDialect
MhloDialect MhloDialect
LmhloDialect LmhloDialect
LmhloGPUDialect
) )

View File

@ -15,13 +15,15 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/register.h" #include "mlir-hlo/Dialect/mhlo/IR/register.h"
void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry &registry) { void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry &registry) {
// clang-format off // clang-format off
registry.insert<mlir::chlo::HloClientDialect, registry.insert<mlir::chlo::HloClientDialect,
mlir::mhlo::MhloDialect,
mlir::lmhlo::LmhloDialect, mlir::lmhlo::LmhloDialect,
mlir::mhlo::MhloDialect>(); mlir::lmhlo_gpu::LmhloGpuDialect>();
// clang-format on // clang-format on
} }

View File

@ -0,0 +1,66 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the LMHLO GPU dialect.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace mlir {
namespace lmhlo_gpu {
LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<LmhloGpuDialect>()) {
addOperations<
#define GET_OP_LIST
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"
>();
}
// TODO(jurahul): Add verification for operand shapes and ranks.
} // namespace lmhlo_gpu
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"

99
tests/lhlo_gpu_ops.mlir Normal file
View File

@ -0,0 +1,99 @@
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
// CHECK-LABEL: func @batch_norm_grad_memrefs
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
%grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>,
%grad_offset: memref<8xf32>) -> () {
"lmhlo_gpu.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return
}
// CHECK-LABEL: func @batch_norm_inference_memrefs
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
"lmhlo_gpu.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
return
}
// CHECK-LABEL: func @batch_norm_training_memrefs
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>,
%batch_var: memref<8xf32>) -> () {
"lmhlo_gpu.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return
}
// CHECK-LABEL: func @conv_forward
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
%scratch = alloc() : memref<32xi8>
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
"lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch)
{ dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 1 : i64,
input_spatial_dimensions = dense<[2,3]> : tensor<2xi64>,
kernel_input_feature_dimension = 0 : i64,
kernel_output_feature_dimension = 1 : i64,
kernel_spatial_dimensions = dense<[2,3]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 1 : i64,
output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>},
window_strides = dense<[1, 1]> : tensor<2xi64>,
padding = dense<[0,0]> : tensor<2xi64>,
lhs_dilation = dense<[1,1]> : tensor<2xi64>,
rhs_dilation = dense<[1,1]> : tensor<2xi64>,
feature_group_count = 1,
batch_group_count = 1,
result_scale = 1.0,
backend_config = {algorithm=0, tensor_ops_enabled = true }
}
: (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> ()
return
}
// -----
// CHECK-LABEL: func @gemm
func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) {
"lmhlo_gpu.gemm"(%lhs, %rhs, %output) { dot_dimension_numbers = {
lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5,
batch_size = 1,
algorithm = 0}
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
return
}
// CHECK-LABEL: func @gemm_bias
func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
%bias: memref<5x5xf32>, %output:memref<5x5xf32>) {
"lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { dot_dimension_numbers = {
lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5,
beta = 1.0,
batch_size = 1,
algorithm = 0}
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> ()
return
}
// CHECK-LABEL: func @cholesky
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
%scratch = alloc() : memref<32xi8>
%info = alloc() : memref<32xi32>
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true }
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
return
}

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/InitAllDialects.h" #include "mlir/InitAllDialects.h"
@ -31,6 +32,7 @@ int main(int argc, char **argv) {
registry.insert<mlir::mhlo::MhloDialect>(); registry.insert<mlir::mhlo::MhloDialect>();
registry.insert<mlir::chlo::HloClientDialect>(); registry.insert<mlir::chlo::HloClientDialect>();
registry.insert<mlir::lmhlo::LmhloDialect>(); registry.insert<mlir::lmhlo::LmhloDialect>();
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
return failed( return failed(
mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry));