Add GPU specific LMHLO level ops
- Introduce operations in a new lmhlo_gpu dialect that map to GPU library function calls in the XLA:GPU backend. - Add basic unit tests as well. PiperOrigin-RevId: 337132166
This commit is contained in:
		
							parent
							
								
									8506f1f26a
								
							
						
					
					
						commit
						f6b4e6758a
					
				|  | @ -27,5 +27,6 @@ endfunction() | |||
| add_mlir_hlo_dialect(chlo_ops chlo) | ||||
| add_mlir_hlo_dialect(hlo_ops mhlo) | ||||
| add_mlir_hlo_dialect(lhlo_ops lmhlo) | ||||
| add_mlir_hlo_dialect(lhlo_gpu_ops lmhlo_gpu) | ||||
| 
 | ||||
| add_mlir_interface(infer_fusibility_op_interface) | ||||
|  |  | |||
|  | @ -0,0 +1,55 @@ | |||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| // This file defines the operations used in the LHLO dialect.
 | ||||
| 
 | ||||
| #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ | ||||
| #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ | ||||
| 
 | ||||
| #include "llvm/ADT/StringRef.h" | ||||
| #include "mlir/IR/Attributes.h" | ||||
| #include "mlir/IR/Dialect.h" | ||||
| #include "mlir/IR/Location.h" | ||||
| #include "mlir/IR/MLIRContext.h" | ||||
| #include "mlir/IR/OpDefinition.h" | ||||
| #include "mlir/IR/Operation.h" | ||||
| #include "mlir/IR/StandardTypes.h" | ||||
| #include "mlir/IR/Types.h" | ||||
| #include "mlir/Interfaces/CopyOpInterface.h" | ||||
| #include "mlir/Interfaces/SideEffectInterfaces.h" | ||||
| #include "mlir/Interfaces/ViewLikeInterface.h" | ||||
| 
 | ||||
| namespace mlir { | ||||
| class OpBuilder; | ||||
| }  // namespace mlir
 | ||||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc" | ||||
| 
 | ||||
| namespace mlir { | ||||
| namespace lmhlo_gpu { | ||||
| 
 | ||||
| class LmhloGpuDialect : public Dialect { | ||||
|  public: | ||||
|   explicit LmhloGpuDialect(MLIRContext *context); | ||||
|   static StringRef getDialectNamespace() { return "lmhlo_gpu"; } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace lmhlo_gpu
 | ||||
| }  // end namespace mlir
 | ||||
| 
 | ||||
| #define GET_OP_CLASSES | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc" | ||||
| 
 | ||||
| #endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
 | ||||
|  | @ -0,0 +1,230 @@ | |||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| // This is the operation definition file for LHMLO level GPU operations. | ||||
| // Because these are LMHLO level operations, they operate on memrefs. | ||||
| 
 | ||||
| #ifndef LHLO_GPU_OPS | ||||
| #define LHLO_GPU_OPS | ||||
| 
 | ||||
| include "mlir/IR/OpBase.td" | ||||
| include "mlir/Interfaces/SideEffectInterfaces.td" | ||||
| include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" | ||||
| 
 | ||||
| 
 | ||||
| def LHLO_GPU_Dialect : Dialect { | ||||
|   let name = "lmhlo_gpu"; | ||||
|   let cppNamespace = "::mlir::lmhlo_gpu"; | ||||
| } | ||||
| 
 | ||||
| class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> : | ||||
|   Op<LHLO_GPU_Dialect, mnemonic, | ||||
|     !listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>; | ||||
| 
 | ||||
| // Type for scratch buffers used by GPU library calls (memref<?xi8>) | ||||
| def UntypedBuffer : MemRefRankOf<[I8], [1]>; | ||||
| 
 | ||||
| // Cholesky info output buffer type. | ||||
| def I32Buffer : MemRefOf<[I32]>; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // LMHLO ops representing batch norm library functions. | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| // Note: these are semantically different from similar LHLO as the GPU library | ||||
| // calls generate or consume standard deviation, whereas LHLO ops generate or | ||||
| // consume variance (= std-dev ^ 2). | ||||
| 
 | ||||
| def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, | ||||
|     BASE_HLO_BatchNormGradOp { | ||||
|   let arguments = (ins | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$operand, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$scale, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$mean, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$stddev, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$grad_output, | ||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand,  // gradient of $operand. | ||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale, | ||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset, | ||||
|     F32Attr:$epsilon, | ||||
|     I64Attr:$feature_index | ||||
|   ); | ||||
| } | ||||
| 
 | ||||
| def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">, | ||||
|     BASE_HLO_BatchNormInferenceOp { | ||||
|   let arguments = (ins | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$operand, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$scale, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$offset, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$mean, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$stddev, | ||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$output, | ||||
|     F32Attr:$epsilon, | ||||
|     I64Attr:$feature_index); | ||||
| } | ||||
| 
 | ||||
| def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">, | ||||
|     BASE_HLO_BatchNormTrainingOp { | ||||
| 
 | ||||
|   let arguments = (ins | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$operand, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$scale, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$offset, | ||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$output, | ||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean, | ||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$batch_stddev, | ||||
|     F32Attr:$epsilon, | ||||
|     I64Attr:$feature_index | ||||
|   ); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // LMHLO ops representing convolution library functions. | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| def ActivationModeNone : StrEnumAttrCase<"None">; | ||||
| def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">; | ||||
| def ActivationModeTanh : StrEnumAttrCase<"Relu">; | ||||
| def ActivationModeRelu : StrEnumAttrCase<"Relu">; | ||||
| def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">; | ||||
| def ActivationModeReluX : StrEnumAttrCase<"ReluX">; | ||||
| def ActivationModeBandPass : StrEnumAttrCase<"BandPass">; | ||||
| 
 | ||||
| def ActivationAttr : StrEnumAttr<"Activation", | ||||
|     "Activation applied with fused convolution", | ||||
|     [ActivationModeNone,  ActivationModeSigmoid, ActivationModeTanh, | ||||
|      ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, | ||||
|      ActivationModeBandPass]>; | ||||
| 
 | ||||
| def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig", | ||||
|                                           LHLO_GPU_Dialect, [ | ||||
|    StructFieldAttr<"algorithm", I64Attr>, | ||||
|    StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> { | ||||
|    let description = "GPU Convolution backend configuration"; | ||||
| } | ||||
| 
 | ||||
| def GpuConvolutionAttributes { | ||||
|   dag attributes = !con( | ||||
|     ConvolutionAttributes<LHLO_GPU_Dialect>.attributes, | ||||
|     (ins F64Attr:$result_scale), | ||||
|     (ins ConvolutionBackendConfigAttr:$backend_config)); | ||||
| } | ||||
| 
 | ||||
| def GpuFusedConvolutionAttributes { | ||||
|   dag attributes = !con( | ||||
|     ConvolutionAttributes<LHLO_GPU_Dialect>.attributes, | ||||
|     (ins F64Attr:$result_scale, | ||||
|          ActivationAttr:$activation_mode, | ||||
|          F64Attr:$side_input_scale), | ||||
|     (ins ConvolutionBackendConfigAttr:$backend_config)); | ||||
| } | ||||
| 
 | ||||
| def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> { | ||||
|   let arguments = !con( | ||||
|     (ins | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$input, | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$filter, | ||||
|        Arg<LHLO_Buffer, "", [MemWrite]>:$output, | ||||
|        Arg<UntypedBuffer, "", [MemWrite]>:$scratch), | ||||
|      GpuConvolutionAttributes.attributes); | ||||
| } | ||||
| 
 | ||||
| def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { | ||||
|   let arguments = !con( | ||||
|     (ins | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$d_output, | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$filter, | ||||
|        Arg<LHLO_Buffer, "", [MemWrite]>:$d_input, | ||||
|        Arg<UntypedBuffer, "", [MemWrite]>:$scratch), | ||||
|      GpuConvolutionAttributes.attributes); | ||||
| } | ||||
| 
 | ||||
| def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { | ||||
|   let arguments = !con( | ||||
|     (ins | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$input, | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$d_output, | ||||
|        Arg<LHLO_Buffer, "", [MemWrite]>:$d_filter, | ||||
|        Arg<UntypedBuffer, "", [MemWrite]>:$scratch), | ||||
|      GpuConvolutionAttributes.attributes); | ||||
| } | ||||
| 
 | ||||
| // output = activation(result_scale * conv(input, filter) + | ||||
| //                     side_input * side_input_scale + | ||||
| //                     bias) | ||||
| def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { | ||||
|   let arguments = !con( | ||||
|     (ins | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$input, | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$filter, | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$bias, | ||||
|        Arg<LHLO_Buffer, "", [MemRead]>:$side_input, | ||||
|        Arg<LHLO_Buffer, "", [MemWrite]>:$output, | ||||
|        Arg<UntypedBuffer, "", [MemWrite]>:$scratch), | ||||
|      GpuFusedConvolutionAttributes.attributes); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // LMHLO ops representing other library functions. | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| // TODO(jurahul): Share this with the MHLO dialect. | ||||
| def DotDimensionNumbersAttr : StructAttr<"DotDimensionNumbers", LHLO_GPU_Dialect, [ | ||||
|                 StructFieldAttr<"lhs_batching_dimensions",   I64ElementsAttr>, | ||||
|                 StructFieldAttr<"rhs_batching_dimensions",   I64ElementsAttr>, | ||||
|                 StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, | ||||
|                 StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> | ||||
|   ]> { | ||||
|   let description = "Structure of dimension information for dot product"; | ||||
| } | ||||
| 
 | ||||
| // output = alpha * (lhs * rhs) | ||||
| // Verify: beta = 0.0 | ||||
| def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { | ||||
|   let arguments = (ins | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$lhs, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$rhs, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$output, | ||||
|     DotDimensionNumbersAttr:$dot_dimension_numbers, | ||||
|     F64Attr:$alpha, | ||||
|     I64Attr:$batch_size, | ||||
|     I64Attr:$algorithm); | ||||
| } | ||||
| 
 | ||||
| // output = alpha(lhs * rhs) + beta * bias | ||||
| def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { | ||||
|   let arguments = (ins | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$lhs, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$rhs, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$bias, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$output, | ||||
|     DotDimensionNumbersAttr:$dot_dimension_numbers, | ||||
|     F64Attr:$alpha, | ||||
|     F64Attr:$beta, | ||||
|     I64Attr:$batch_size, | ||||
|     I64Attr:$algorithm); | ||||
| } | ||||
| 
 | ||||
| def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { | ||||
|   let arguments = (ins | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$input, | ||||
|     Arg<LHLO_Buffer, "", [MemWrite]>:$output, | ||||
|     Arg<UntypedBuffer, "", [MemWrite]>:$scratch, | ||||
|     Arg<I32Buffer, "", [MemWrite]>:$info, | ||||
|     BoolAttr:$is_upper); | ||||
| } | ||||
| 
 | ||||
| #endif // LHLO_GPU_OPS | ||||
|  | @ -13,7 +13,7 @@ See the License for the specific language governing permissions and | |||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| // This file defines the operations used in the LXLA dialect.
 | ||||
| // This file defines the operations used in the LHLO dialect.
 | ||||
| 
 | ||||
| #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ | ||||
| #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ | ||||
|  |  | |||
|  | @ -66,6 +66,14 @@ add_mlir_dialect_library(LmhloDialect | |||
| ) | ||||
| target_link_libraries(LmhloDialect PUBLIC MLIRIR) | ||||
| 
 | ||||
| add_mlir_dialect_library(LmhloGPUDialect | ||||
|   lhlo_gpu_ops.cc | ||||
| 
 | ||||
|   DEPENDS | ||||
|   MLIRlhlo_gpu_opsIncGen | ||||
| ) | ||||
| target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR) | ||||
| 
 | ||||
| 
 | ||||
| add_mlir_dialect_library(MhloRegisterDialects | ||||
|   init.cc | ||||
|  | @ -73,10 +81,12 @@ DEPENDS | |||
|   MLIRchlo_opsIncGen | ||||
|   MLIRhlo_opsIncGen | ||||
|   MLIRlhlo_opsIncGen | ||||
|   MLIRlhlo_gpu_opsIncGen | ||||
| ) | ||||
| target_link_libraries(MhloRegisterDialects | ||||
|   PUBLIC | ||||
|   ChloDialect | ||||
|   MhloDialect | ||||
|   LmhloDialect | ||||
|   LmhloGPUDialect | ||||
| ) | ||||
|  |  | |||
|  | @ -15,13 +15,15 @@ limitations under the License. | |||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/register.h" | ||||
| 
 | ||||
| void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry ®istry) { | ||||
|   // clang-format off
 | ||||
|   registry.insert<mlir::chlo::HloClientDialect, | ||||
|                   mlir::mhlo::MhloDialect, | ||||
|                   mlir::lmhlo::LmhloDialect, | ||||
|                   mlir::mhlo::MhloDialect>(); | ||||
|                   mlir::lmhlo_gpu::LmhloGpuDialect>(); | ||||
|   // clang-format on
 | ||||
| } | ||||
|  |  | |||
|  | @ -0,0 +1,66 @@ | |||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| // This file defines the operations used in the LMHLO GPU dialect.
 | ||||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" | ||||
| 
 | ||||
| #include <assert.h> | ||||
| #include <stddef.h> | ||||
| #include <stdint.h> | ||||
| 
 | ||||
| #include "llvm/ADT/APFloat.h" | ||||
| #include "llvm/ADT/APInt.h" | ||||
| #include "llvm/ADT/ArrayRef.h" | ||||
| #include "llvm/ADT/STLExtras.h" | ||||
| #include "llvm/ADT/SmallVector.h" | ||||
| #include "llvm/ADT/StringRef.h" | ||||
| #include "llvm/Support/FormatVariadic.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc" | ||||
| #include "mlir/Dialect/StandardOps/IR/Ops.h" | ||||
| #include "mlir/IR/Attributes.h" | ||||
| #include "mlir/IR/Builders.h" | ||||
| #include "mlir/IR/Dialect.h" | ||||
| #include "mlir/IR/Location.h" | ||||
| #include "mlir/IR/MLIRContext.h" | ||||
| #include "mlir/IR/OpDefinition.h" | ||||
| #include "mlir/IR/OpImplementation.h" | ||||
| #include "mlir/IR/Operation.h" | ||||
| #include "mlir/IR/OperationSupport.h" | ||||
| #include "mlir/IR/PatternMatch.h" | ||||
| #include "mlir/IR/StandardTypes.h" | ||||
| #include "mlir/IR/TypeUtilities.h" | ||||
| #include "mlir/IR/Types.h" | ||||
| #include "mlir/IR/Value.h" | ||||
| 
 | ||||
| namespace mlir { | ||||
| namespace lmhlo_gpu { | ||||
| 
 | ||||
| LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context) | ||||
|     : Dialect(getDialectNamespace(), context, TypeID::get<LmhloGpuDialect>()) { | ||||
|   addOperations< | ||||
| #define GET_OP_LIST | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc" | ||||
|       >(); | ||||
| } | ||||
| 
 | ||||
| // TODO(jurahul): Add verification for operand shapes and ranks.
 | ||||
| 
 | ||||
| }  // namespace lmhlo_gpu
 | ||||
| }  // namespace mlir
 | ||||
| 
 | ||||
| #define GET_OP_CLASSES | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc" | ||||
|  | @ -0,0 +1,99 @@ | |||
| // RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s | ||||
| 
 | ||||
| // CHECK-LABEL: func @batch_norm_grad_memrefs | ||||
| func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, | ||||
|                               %arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>, | ||||
|                               %grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>, | ||||
|                               %grad_offset: memref<8xf32>) -> () { | ||||
|   "lmhlo_gpu.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} | ||||
|       : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, | ||||
|          memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @batch_norm_inference_memrefs | ||||
| func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, | ||||
|                                    %arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () { | ||||
|   "lmhlo_gpu.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} | ||||
|       : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> () | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @batch_norm_training_memrefs | ||||
| func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, | ||||
|                                   %output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>, | ||||
|                                   %batch_var: memref<8xf32>) -> () { | ||||
|   "lmhlo_gpu.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} | ||||
|       : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @conv_forward | ||||
| func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { | ||||
|   %scratch = alloc() : memref<32xi8> | ||||
|   // This defined a 2D convolution over a 8x8 single channel input using a 2x2 | ||||
|   // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) | ||||
|   "lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch) | ||||
|     { dimension_numbers = {input_batch_dimension = 0 : i64, | ||||
|                            input_feature_dimension = 1 : i64, | ||||
|                            input_spatial_dimensions = dense<[2,3]> : tensor<2xi64>, | ||||
|                            kernel_input_feature_dimension = 0 : i64, | ||||
|                            kernel_output_feature_dimension = 1 : i64, | ||||
|                            kernel_spatial_dimensions = dense<[2,3]> : tensor<2xi64>, | ||||
|                            output_batch_dimension = 0 : i64, | ||||
|                            output_feature_dimension = 1 : i64, | ||||
|                            output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>}, | ||||
|       window_strides = dense<[1, 1]> : tensor<2xi64>, | ||||
|       padding = dense<[0,0]> : tensor<2xi64>, | ||||
|       lhs_dilation = dense<[1,1]> : tensor<2xi64>, | ||||
|       rhs_dilation = dense<[1,1]> : tensor<2xi64>, | ||||
|       feature_group_count = 1, | ||||
|       batch_group_count = 1, | ||||
|       result_scale = 1.0, | ||||
|       backend_config = {algorithm=0, tensor_ops_enabled = true } | ||||
|     } | ||||
|     : (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> () | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @gemm | ||||
| func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) { | ||||
|   "lmhlo_gpu.gemm"(%lhs, %rhs, %output) { dot_dimension_numbers = { | ||||
|        lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, | ||||
|        alpha = 0.5, | ||||
|        batch_size = 1, | ||||
|        algorithm = 0} | ||||
|     : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // CHECK-LABEL: func @gemm_bias | ||||
| func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, | ||||
|                 %bias: memref<5x5xf32>, %output:memref<5x5xf32>) { | ||||
|   "lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { dot_dimension_numbers = { | ||||
|        lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, | ||||
|        alpha = 0.5, | ||||
|        beta = 1.0, | ||||
|        batch_size = 1, | ||||
|        algorithm = 0} | ||||
|     : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> () | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @cholesky | ||||
| func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { | ||||
|   %scratch = alloc() : memref<32xi8> | ||||
|   %info = alloc() : memref<32xi32> | ||||
|   "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true } | ||||
|       : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () | ||||
|   return | ||||
| } | ||||
|  | @ -15,6 +15,7 @@ limitations under the License. | |||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" | ||||
| #include "mlir/InitAllDialects.h" | ||||
|  | @ -31,6 +32,7 @@ int main(int argc, char **argv) { | |||
|   registry.insert<mlir::mhlo::MhloDialect>(); | ||||
|   registry.insert<mlir::chlo::HloClientDialect>(); | ||||
|   registry.insert<mlir::lmhlo::LmhloDialect>(); | ||||
|   registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>(); | ||||
| 
 | ||||
|   return failed( | ||||
|       mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue