From f6b4e6758a0373f6203d81ab5f26d5ebae1809c8 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 14 Oct 2020 11:23:08 -0700 Subject: [PATCH] Add GPU specific LMHLO level ops - Introduce operations in a new lmhlo_gpu dialect that map to GPU library function calls in the XLA:GPU backend. - Add basic unit tests as well. PiperOrigin-RevId: 337132166 --- .../mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt | 1 + .../mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h | 55 +++++ .../mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td | 230 ++++++++++++++++++ include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h | 2 +- lib/Dialect/mhlo/IR/CMakeLists.txt | 10 + lib/Dialect/mhlo/IR/init.cc | 4 +- lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc | 66 +++++ tests/lhlo_gpu_ops.mlir | 99 ++++++++ tools/mlir-hlo-opt/mlir-hlo-opt.cpp | 2 + 9 files changed, 467 insertions(+), 2 deletions(-) create mode 100644 include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h create mode 100644 include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td create mode 100644 lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc create mode 100644 tests/lhlo_gpu_ops.mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt index 09bdca8..79c04b3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt +++ b/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -27,5 +27,6 @@ endfunction() add_mlir_hlo_dialect(chlo_ops chlo) add_mlir_hlo_dialect(hlo_ops mhlo) add_mlir_hlo_dialect(lhlo_ops lmhlo) +add_mlir_hlo_dialect(lhlo_gpu_ops lmhlo_gpu) add_mlir_interface(infer_fusibility_op_interface) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h new file mode 100644 index 0000000..68ab6ce --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the LHLO dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +namespace mlir { +class OpBuilder; +} // namespace mlir + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc" + +namespace mlir { +namespace lmhlo_gpu { + +class LmhloGpuDialect : public Dialect { + public: + explicit LmhloGpuDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "lmhlo_gpu"; } +}; + +} // namespace lmhlo_gpu +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td new file mode 100644 index 0000000..ac5830c --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -0,0 +1,230 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for LHMLO level GPU operations. +// Because these are LMHLO level operations, they operate on memrefs. + +#ifndef LHLO_GPU_OPS +#define LHLO_GPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" + + +def LHLO_GPU_Dialect : Dialect { + let name = "lmhlo_gpu"; + let cppNamespace = "::mlir::lmhlo_gpu"; +} + +class LHLOGPU_Op traits = []> : + Op], traits)>; + +// Type for scratch buffers used by GPU library calls (memref) +def UntypedBuffer : MemRefRankOf<[I8], [1]>; + +// Cholesky info output buffer type. +def I32Buffer : MemRefOf<[I32]>; + +//===----------------------------------------------------------------------===// +// LMHLO ops representing batch norm library functions. +//===----------------------------------------------------------------------===// + +// Note: these are semantically different from similar LHLO as the GPU library +// calls generate or consume standard deviation, whereas LHLO ops generate or +// consume variance (= std-dev ^ 2). + +def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, + BASE_HLO_BatchNormGradOp { + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$mean, + Arg:$stddev, + Arg:$grad_output, + Arg:$grad_operand, // gradient of $operand. + Arg:$grad_scale, + Arg:$grad_offset, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">, + BASE_HLO_BatchNormInferenceOp { + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$mean, + Arg:$stddev, + Arg:$output, + F32Attr:$epsilon, + I64Attr:$feature_index); +} + +def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">, + BASE_HLO_BatchNormTrainingOp { + + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$output, + Arg:$batch_mean, + Arg:$batch_stddev, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +//===----------------------------------------------------------------------===// +// LMHLO ops representing convolution library functions. +//===----------------------------------------------------------------------===// + +def ActivationModeNone : StrEnumAttrCase<"None">; +def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">; +def ActivationModeTanh : StrEnumAttrCase<"Relu">; +def ActivationModeRelu : StrEnumAttrCase<"Relu">; +def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">; +def ActivationModeReluX : StrEnumAttrCase<"ReluX">; +def ActivationModeBandPass : StrEnumAttrCase<"BandPass">; + +def ActivationAttr : StrEnumAttr<"Activation", + "Activation applied with fused convolution", + [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, + ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, + ActivationModeBandPass]>; + +def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig", + LHLO_GPU_Dialect, [ + StructFieldAttr<"algorithm", I64Attr>, + StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> { + let description = "GPU Convolution backend configuration"; +} + +def GpuConvolutionAttributes { + dag attributes = !con( + ConvolutionAttributes.attributes, + (ins F64Attr:$result_scale), + (ins ConvolutionBackendConfigAttr:$backend_config)); +} + +def GpuFusedConvolutionAttributes { + dag attributes = !con( + ConvolutionAttributes.attributes, + (ins F64Attr:$result_scale, + ActivationAttr:$activation_mode, + F64Attr:$side_input_scale), + (ins ConvolutionBackendConfigAttr:$backend_config)); +} + +def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$output, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { + let arguments = !con( + (ins + Arg:$d_output, + Arg:$filter, + Arg:$d_input, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$d_output, + Arg:$d_filter, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +// output = activation(result_scale * conv(input, filter) + +// side_input * side_input_scale + +// bias) +def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$bias, + Arg:$side_input, + Arg:$output, + Arg:$scratch), + GpuFusedConvolutionAttributes.attributes); +} + +//===----------------------------------------------------------------------===// +// LMHLO ops representing other library functions. +//===----------------------------------------------------------------------===// + +// TODO(jurahul): Share this with the MHLO dialect. +def DotDimensionNumbersAttr : StructAttr<"DotDimensionNumbers", LHLO_GPU_Dialect, [ + StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> + ]> { + let description = "Structure of dimension information for dot product"; +} + +// output = alpha * (lhs * rhs) +// Verify: beta = 0.0 +def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output, + DotDimensionNumbersAttr:$dot_dimension_numbers, + F64Attr:$alpha, + I64Attr:$batch_size, + I64Attr:$algorithm); +} + +// output = alpha(lhs * rhs) + beta * bias +def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$bias, + Arg:$output, + DotDimensionNumbersAttr:$dot_dimension_numbers, + F64Attr:$alpha, + F64Attr:$beta, + I64Attr:$batch_size, + I64Attr:$algorithm); +} + +def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { + let arguments = (ins + Arg:$input, + Arg:$output, + Arg:$scratch, + Arg:$info, + BoolAttr:$is_upper); +} + +#endif // LHLO_GPU_OPS diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index cc24e17..c0c0494 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the LXLA dialect. +// This file defines the operations used in the LHLO dialect. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ diff --git a/lib/Dialect/mhlo/IR/CMakeLists.txt b/lib/Dialect/mhlo/IR/CMakeLists.txt index d7bb505..7b5da34 100644 --- a/lib/Dialect/mhlo/IR/CMakeLists.txt +++ b/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -66,6 +66,14 @@ add_mlir_dialect_library(LmhloDialect ) target_link_libraries(LmhloDialect PUBLIC MLIRIR) +add_mlir_dialect_library(LmhloGPUDialect + lhlo_gpu_ops.cc + + DEPENDS + MLIRlhlo_gpu_opsIncGen +) +target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR) + add_mlir_dialect_library(MhloRegisterDialects init.cc @@ -73,10 +81,12 @@ DEPENDS MLIRchlo_opsIncGen MLIRhlo_opsIncGen MLIRlhlo_opsIncGen + MLIRlhlo_gpu_opsIncGen ) target_link_libraries(MhloRegisterDialects PUBLIC ChloDialect MhloDialect LmhloDialect + LmhloGPUDialect ) diff --git a/lib/Dialect/mhlo/IR/init.cc b/lib/Dialect/mhlo/IR/init.cc index 503b100..ca8c6a8 100644 --- a/lib/Dialect/mhlo/IR/init.cc +++ b/lib/Dialect/mhlo/IR/init.cc @@ -15,13 +15,15 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/register.h" void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry ®istry) { // clang-format off registry.insert(); + mlir::lmhlo_gpu::LmhloGpuDialect>(); // clang-format on } diff --git a/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc b/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc new file mode 100644 index 0000000..33f4e2c --- /dev/null +++ b/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the LMHLO GPU dialect. + +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +namespace mlir { +namespace lmhlo_gpu { + +LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc" + >(); +} + +// TODO(jurahul): Add verification for operand shapes and ranks. + +} // namespace lmhlo_gpu +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc" diff --git a/tests/lhlo_gpu_ops.mlir b/tests/lhlo_gpu_ops.mlir new file mode 100644 index 0000000..9e5ce67 --- /dev/null +++ b/tests/lhlo_gpu_ops.mlir @@ -0,0 +1,99 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s + +// CHECK-LABEL: func @batch_norm_grad_memrefs +func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>, + %grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>, + %grad_offset: memref<8xf32>) -> () { + "lmhlo_gpu.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, + memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () + return +} + +// CHECK-LABEL: func @batch_norm_inference_memrefs +func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () { + "lmhlo_gpu.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> () + return +} + +// CHECK-LABEL: func @batch_norm_training_memrefs +func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>, + %batch_var: memref<8xf32>) -> () { + "lmhlo_gpu.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () + return +} + +// CHECK-LABEL: func @conv_forward +func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { + %scratch = alloc() : memref<32xi8> + // This defined a 2D convolution over a 8x8 single channel input using a 2x2 + // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) + "lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch) + { dimension_numbers = {input_batch_dimension = 0 : i64, + input_feature_dimension = 1 : i64, + input_spatial_dimensions = dense<[2,3]> : tensor<2xi64>, + kernel_input_feature_dimension = 0 : i64, + kernel_output_feature_dimension = 1 : i64, + kernel_spatial_dimensions = dense<[2,3]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 1 : i64, + output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>}, + window_strides = dense<[1, 1]> : tensor<2xi64>, + padding = dense<[0,0]> : tensor<2xi64>, + lhs_dilation = dense<[1,1]> : tensor<2xi64>, + rhs_dilation = dense<[1,1]> : tensor<2xi64>, + feature_group_count = 1, + batch_group_count = 1, + result_scale = 1.0, + backend_config = {algorithm=0, tensor_ops_enabled = true } + } + : (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @gemm +func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) { + "lmhlo_gpu.gemm"(%lhs, %rhs, %output) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, + alpha = 0.5, + batch_size = 1, + algorithm = 0} + : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () + return +} + + +// CHECK-LABEL: func @gemm_bias +func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, + %bias: memref<5x5xf32>, %output:memref<5x5xf32>) { + "lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, + lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, + rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, + alpha = 0.5, + beta = 1.0, + batch_size = 1, + algorithm = 0} + : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> () + return +} + +// CHECK-LABEL: func @cholesky +func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { + %scratch = alloc() : memref<32xi8> + %info = alloc() : memref<32xi32> + "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true } + : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () + return +} diff --git a/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp index d0c0e3c..ed96dd5 100644 --- a/tools/mlir-hlo-opt/mlir-hlo-opt.cpp +++ b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -15,6 +15,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" #include "mlir/InitAllDialects.h" @@ -31,6 +32,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); return failed( mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry));