Add GPU specific LMHLO level ops
- Introduce operations in a new lmhlo_gpu dialect that map to GPU library function calls in the XLA:GPU backend. - Add basic unit tests as well. PiperOrigin-RevId: 337132166
This commit is contained in:
parent
8506f1f26a
commit
f6b4e6758a
|
@ -27,5 +27,6 @@ endfunction()
|
|||
add_mlir_hlo_dialect(chlo_ops chlo)
|
||||
add_mlir_hlo_dialect(hlo_ops mhlo)
|
||||
add_mlir_hlo_dialect(lhlo_ops lmhlo)
|
||||
add_mlir_hlo_dialect(lhlo_gpu_ops lmhlo_gpu)
|
||||
|
||||
add_mlir_interface(infer_fusibility_op_interface)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the operations used in the LHLO dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/CopyOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
class OpBuilder;
|
||||
} // namespace mlir
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo_gpu {
|
||||
|
||||
class LmhloGpuDialect : public Dialect {
|
||||
public:
|
||||
explicit LmhloGpuDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "lmhlo_gpu"; }
|
||||
};
|
||||
|
||||
} // namespace lmhlo_gpu
|
||||
} // end namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_
|
|
@ -0,0 +1,230 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the operation definition file for LHMLO level GPU operations.
|
||||
// Because these are LMHLO level operations, they operate on memrefs.
|
||||
|
||||
#ifndef LHLO_GPU_OPS
|
||||
#define LHLO_GPU_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||
|
||||
|
||||
def LHLO_GPU_Dialect : Dialect {
|
||||
let name = "lmhlo_gpu";
|
||||
let cppNamespace = "::mlir::lmhlo_gpu";
|
||||
}
|
||||
|
||||
class LHLOGPU_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<LHLO_GPU_Dialect, mnemonic,
|
||||
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
|
||||
|
||||
// Type for scratch buffers used by GPU library calls (memref<?xi8>)
|
||||
def UntypedBuffer : MemRefRankOf<[I8], [1]>;
|
||||
|
||||
// Cholesky info output buffer type.
|
||||
def I32Buffer : MemRefOf<[I32]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LMHLO ops representing batch norm library functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Note: these are semantically different from similar LHLO as the GPU library
|
||||
// calls generate or consume standard deviation, whereas LHLO ops generate or
|
||||
// consume variance (= std-dev ^ 2).
|
||||
|
||||
def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">,
|
||||
BASE_HLO_BatchNormGradOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index
|
||||
);
|
||||
}
|
||||
|
||||
def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">,
|
||||
BASE_HLO_BatchNormInferenceOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$stddev,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index);
|
||||
}
|
||||
|
||||
def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
|
||||
BASE_HLO_BatchNormTrainingOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_stddev,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LMHLO ops representing convolution library functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ActivationModeNone : StrEnumAttrCase<"None">;
|
||||
def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">;
|
||||
def ActivationModeTanh : StrEnumAttrCase<"Relu">;
|
||||
def ActivationModeRelu : StrEnumAttrCase<"Relu">;
|
||||
def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">;
|
||||
def ActivationModeReluX : StrEnumAttrCase<"ReluX">;
|
||||
def ActivationModeBandPass : StrEnumAttrCase<"BandPass">;
|
||||
|
||||
def ActivationAttr : StrEnumAttr<"Activation",
|
||||
"Activation applied with fused convolution",
|
||||
[ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh,
|
||||
ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX,
|
||||
ActivationModeBandPass]>;
|
||||
|
||||
def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig",
|
||||
LHLO_GPU_Dialect, [
|
||||
StructFieldAttr<"algorithm", I64Attr>,
|
||||
StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> {
|
||||
let description = "GPU Convolution backend configuration";
|
||||
}
|
||||
|
||||
def GpuConvolutionAttributes {
|
||||
dag attributes = !con(
|
||||
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
|
||||
(ins F64Attr:$result_scale),
|
||||
(ins ConvolutionBackendConfigAttr:$backend_config));
|
||||
}
|
||||
|
||||
def GpuFusedConvolutionAttributes {
|
||||
dag attributes = !con(
|
||||
ConvolutionAttributes<LHLO_GPU_Dialect>.attributes,
|
||||
(ins F64Attr:$result_scale,
|
||||
ActivationAttr:$activation_mode,
|
||||
F64Attr:$side_input_scale),
|
||||
(ins ConvolutionBackendConfigAttr:$backend_config));
|
||||
}
|
||||
|
||||
def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> {
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
|
||||
GpuConvolutionAttributes.attributes);
|
||||
}
|
||||
|
||||
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$d_input,
|
||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
|
||||
GpuConvolutionAttributes.attributes);
|
||||
}
|
||||
|
||||
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$d_filter,
|
||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
|
||||
GpuConvolutionAttributes.attributes);
|
||||
}
|
||||
|
||||
// output = activation(result_scale * conv(input, filter) +
|
||||
// side_input * side_input_scale +
|
||||
// bias)
|
||||
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$side_input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
|
||||
GpuFusedConvolutionAttributes.attributes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LMHLO ops representing other library functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(jurahul): Share this with the MHLO dialect.
|
||||
def DotDimensionNumbersAttr : StructAttr<"DotDimensionNumbers", LHLO_GPU_Dialect, [
|
||||
StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr>
|
||||
]> {
|
||||
let description = "Structure of dimension information for dot product";
|
||||
}
|
||||
|
||||
// output = alpha * (lhs * rhs)
|
||||
// Verify: beta = 0.0
|
||||
def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbersAttr:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
I64Attr:$batch_size,
|
||||
I64Attr:$algorithm);
|
||||
}
|
||||
|
||||
// output = alpha(lhs * rhs) + beta * bias
|
||||
def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbersAttr:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
F64Attr:$beta,
|
||||
I64Attr:$batch_size,
|
||||
I64Attr:$algorithm);
|
||||
}
|
||||
|
||||
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch,
|
||||
Arg<I32Buffer, "", [MemWrite]>:$info,
|
||||
BoolAttr:$is_upper);
|
||||
}
|
||||
|
||||
#endif // LHLO_GPU_OPS
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the operations used in the LXLA dialect.
|
||||
// This file defines the operations used in the LHLO dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||
|
|
|
@ -66,6 +66,14 @@ add_mlir_dialect_library(LmhloDialect
|
|||
)
|
||||
target_link_libraries(LmhloDialect PUBLIC MLIRIR)
|
||||
|
||||
add_mlir_dialect_library(LmhloGPUDialect
|
||||
lhlo_gpu_ops.cc
|
||||
|
||||
DEPENDS
|
||||
MLIRlhlo_gpu_opsIncGen
|
||||
)
|
||||
target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR)
|
||||
|
||||
|
||||
add_mlir_dialect_library(MhloRegisterDialects
|
||||
init.cc
|
||||
|
@ -73,10 +81,12 @@ DEPENDS
|
|||
MLIRchlo_opsIncGen
|
||||
MLIRhlo_opsIncGen
|
||||
MLIRlhlo_opsIncGen
|
||||
MLIRlhlo_gpu_opsIncGen
|
||||
)
|
||||
target_link_libraries(MhloRegisterDialects
|
||||
PUBLIC
|
||||
ChloDialect
|
||||
MhloDialect
|
||||
LmhloDialect
|
||||
LmhloGPUDialect
|
||||
)
|
||||
|
|
|
@ -15,13 +15,15 @@ limitations under the License.
|
|||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
|
||||
|
||||
void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry ®istry) {
|
||||
// clang-format off
|
||||
registry.insert<mlir::chlo::HloClientDialect,
|
||||
mlir::mhlo::MhloDialect,
|
||||
mlir::lmhlo::LmhloDialect,
|
||||
mlir::mhlo::MhloDialect>();
|
||||
mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the operations used in the LMHLO GPU dialect.
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lmhlo_gpu {
|
||||
|
||||
LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context, TypeID::get<LmhloGpuDialect>()) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
// TODO(jurahul): Add verification for operand shapes and ranks.
|
||||
|
||||
} // namespace lmhlo_gpu
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"
|
|
@ -0,0 +1,99 @@
|
|||
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_grad_memrefs
|
||||
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
|
||||
%grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>,
|
||||
%grad_offset: memref<8xf32>) -> () {
|
||||
"lmhlo_gpu.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
|
||||
memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_inference_memrefs
|
||||
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
|
||||
"lmhlo_gpu.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_training_memrefs
|
||||
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>,
|
||||
%batch_var: memref<8xf32>) -> () {
|
||||
"lmhlo_gpu.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv_forward
|
||||
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
|
||||
%scratch = alloc() : memref<32xi8>
|
||||
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
|
||||
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
|
||||
"lmhlo_gpu.conv_forward"(%input, %filter, %output, %scratch)
|
||||
{ dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 1 : i64,
|
||||
input_spatial_dimensions = dense<[2,3]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 0 : i64,
|
||||
kernel_output_feature_dimension = 1 : i64,
|
||||
kernel_spatial_dimensions = dense<[2,3]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 1 : i64,
|
||||
output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>},
|
||||
window_strides = dense<[1, 1]> : tensor<2xi64>,
|
||||
padding = dense<[0,0]> : tensor<2xi64>,
|
||||
lhs_dilation = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_dilation = dense<[1,1]> : tensor<2xi64>,
|
||||
feature_group_count = 1,
|
||||
batch_group_count = 1,
|
||||
result_scale = 1.0,
|
||||
backend_config = {algorithm=0, tensor_ops_enabled = true }
|
||||
}
|
||||
: (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @gemm
|
||||
func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) {
|
||||
"lmhlo_gpu.gemm"(%lhs, %rhs, %output) { dot_dimension_numbers = {
|
||||
lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||
alpha = 0.5,
|
||||
batch_size = 1,
|
||||
algorithm = 0}
|
||||
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @gemm_bias
|
||||
func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
||||
%bias: memref<5x5xf32>, %output:memref<5x5xf32>) {
|
||||
"lmhlo_gpu.gemm_bias"(%lhs, %rhs, %bias, %output) { dot_dimension_numbers = {
|
||||
lhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||
alpha = 0.5,
|
||||
beta = 1.0,
|
||||
batch_size = 1,
|
||||
algorithm = 0}
|
||||
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>, memref<5x5xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cholesky
|
||||
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
||||
%scratch = alloc() : memref<32xi8>
|
||||
%info = alloc() : memref<32xi32>
|
||||
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true }
|
||||
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
||||
return
|
||||
}
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
|
@ -31,6 +32,7 @@ int main(int argc, char **argv) {
|
|||
registry.insert<mlir::mhlo::MhloDialect>();
|
||||
registry.insert<mlir::chlo::HloClientDialect>();
|
||||
registry.insert<mlir::lmhlo::LmhloDialect>();
|
||||
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||
|
||||
return failed(
|
||||
mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry));
|
||||
|
|
Loading…
Reference in New Issue