[HLO] Adopt custom syntax for convolution dims and window attributes for LMHLO_GPU
PiperOrigin-RevId: 374889917
This commit is contained in:
parent
57aeb5ab16
commit
fc88cf1ff4
1
BUILD
1
BUILD
|
@ -592,6 +592,7 @@ cc_library(
|
|||
":hlo",
|
||||
":hlo_ops_base_enums",
|
||||
":hlo_ops_base_structs",
|
||||
":hlo_ops_common",
|
||||
":infer_fusibility_op_interface",
|
||||
":lhlo_gpu_ops_enums",
|
||||
":lhlo_gpu_ops_inc_gen",
|
||||
|
|
|
@ -115,7 +115,19 @@ class GpuConvolutionAttributes<dag extraAttribs> {
|
|||
(ins ConvolutionBackendConfigAttr:$backend_config));
|
||||
}
|
||||
|
||||
def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> {
|
||||
// Provide a custom assembly format for all LHLO_GPU convolution operations.
|
||||
class LHLOGPU_ConvBaseOp<string mnemonic> : LHLOGPU_Op<mnemonic> {
|
||||
let assemblyFormat = [{
|
||||
`(`operands`)`
|
||||
`dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,`
|
||||
`window` `=` `{` custom<WindowAttributes>($window_strides, $padding,
|
||||
$lhs_dilation, $rhs_dilation,
|
||||
$window_reversal) `}`
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def LHLOGPU_ConvForwardOp : LHLOGPU_ConvBaseOp<"conv_forward"> {
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
|
@ -125,7 +137,7 @@ def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> {
|
|||
GpuConvolutionAttributes<(ins)>.attributes);
|
||||
}
|
||||
|
||||
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
|
||||
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_ConvBaseOp<"conv_backwardinput"> {
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
|
||||
|
@ -135,7 +147,7 @@ def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
|
|||
GpuConvolutionAttributes<(ins)>.attributes);
|
||||
}
|
||||
|
||||
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
|
||||
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_ConvBaseOp<"conv_backwardfilter"> {
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
|
@ -146,7 +158,7 @@ def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
|
|||
}
|
||||
|
||||
// output = activation(result_scale * conv(input, filter) + bias)
|
||||
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
|
||||
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_ConvBaseOp<"conv_forward_fused"> {
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
|
@ -161,7 +173,8 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
|
|||
// output = activation(result_scale * conv(input, filter) +
|
||||
// side_input * side_input_scale +
|
||||
// bias)
|
||||
def LHLOGPU_ConvForwardFusedSideInputOp : LHLOGPU_Op<"conv_forward_fused_with_side_input"> {
|
||||
def LHLOGPU_ConvForwardFusedSideInputOp :
|
||||
LHLOGPU_ConvBaseOp<"conv_forward_fused_with_side_input"> {
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
|
|
|
@ -85,8 +85,11 @@ add_mlir_dialect_library(LmhloGPUDialect
|
|||
DEPENDS
|
||||
MLIRlhlo_gpu_opsIncGen
|
||||
)
|
||||
target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR)
|
||||
|
||||
target_link_libraries(LmhloGPUDialect
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
HloOpsCommon
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MhloRegisterDialects
|
||||
init.cc
|
||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -57,6 +58,9 @@ LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context)
|
|||
|
||||
// TODO(jurahul): Add verification for operand shapes and ranks.
|
||||
|
||||
using mlir::hlo::parseWindowAttributes;
|
||||
using mlir::hlo::printWindowAttributes;
|
||||
|
||||
} // namespace lmhlo_gpu
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt
|
||||
/// | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_grad_memrefs
|
||||
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
|
@ -28,8 +29,11 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv_forward
|
||||
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
|
||||
// CHECK-LABEL: func @conv_forward_generic
|
||||
// CHECK: lmhlo_gpu.conv_forward
|
||||
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1]
|
||||
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
|
||||
func @conv_forward_generic(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
|
||||
%scratch = memref.alloc() : memref<32xi8>
|
||||
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
|
||||
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
|
||||
|
@ -44,7 +48,7 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
|
|||
output_feature_dimension = 1 : i64,
|
||||
output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>},
|
||||
window_strides = dense<[1, 1]> : tensor<2xi64>,
|
||||
padding = dense<[0,0]> : tensor<2xi64>,
|
||||
padding = dense<[[0, 0], [1, 0]]> : tensor<2x2xi64>,
|
||||
lhs_dilation = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_dilation = dense<[1,1]> : tensor<2xi64>,
|
||||
feature_group_count = 1,
|
||||
|
@ -59,71 +63,80 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv_forward
|
||||
// CHECK: lmhlo_gpu.conv_forward
|
||||
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1]
|
||||
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
|
||||
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
|
||||
%scratch = memref.alloc() : memref<32xi8>
|
||||
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
|
||||
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
|
||||
lmhlo_gpu.conv_forward(%input, %filter, %output, %scratch)
|
||||
dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1],
|
||||
window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
|
||||
{ feature_group_count = 1, batch_group_count = 1, result_scale = 1.0,
|
||||
backend_config = {algorithm=0,
|
||||
operand_0_layout = [3,2,1,0],
|
||||
operand_1_layout = [3,2,1,0],
|
||||
result_layout = [3,2,1,0],
|
||||
tensor_ops_enabled = true}}
|
||||
: (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv_backfilter
|
||||
// CHECK: lmhlo_gpu.conv_backwardfilter
|
||||
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
|
||||
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
|
||||
%scratch = memref.alloc() : memref<23328xui8>
|
||||
"lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
|
||||
lmhlo_gpu.conv_backwardfilter(%input, %filter, %output, %scratch)
|
||||
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
|
||||
window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
|
||||
{ backend_config = {algorithm = 1 : i64,
|
||||
operand_0_layout = [3,2,1,0],
|
||||
operand_1_layout = [3,2,1,0],
|
||||
result_layout = [3,2,1,0],
|
||||
tensor_ops_enabled = false},
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||
feature_group_count = 1 : i64,
|
||||
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||
padding = dense<0> : tensor<2xi64>,
|
||||
precision_config = [],
|
||||
result_scale = 1.000000e+00 : f64,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>}
|
||||
result_scale = 1.000000e+00 : f64}
|
||||
: (memref<3x56x56x16xf64>, memref<3x3x3x64xf64>, memref<54x54x16x64xf64>, memref<23328xui8>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv_backinput
|
||||
// CHECK: lmhlo_gpu.conv_backwardinput
|
||||
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1]
|
||||
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[3, 0], [1, 5]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]}
|
||||
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
|
||||
%scratch = memref.alloc() : memref<32xui8>
|
||||
"lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
|
||||
lmhlo_gpu.conv_backwardinput(%input, %filter, %output, %scratch)
|
||||
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1],
|
||||
window = {stride = [1, 1], pad = [[3, 0], [1, 5]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]}
|
||||
{ backend_config = {algorithm = 1 : i64,
|
||||
operand_0_layout = [3,2,1,0],
|
||||
operand_1_layout = [3,2,1,0],
|
||||
result_layout = [3,2,1,0],
|
||||
tensor_ops_enabled = false},
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 1 : i64,
|
||||
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 1 : i64,
|
||||
kernel_output_feature_dimension = 0 : i64,
|
||||
kernel_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 1 : i64,
|
||||
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
|
||||
feature_group_count = 1 : i64,
|
||||
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||
padding = dense<3> : tensor<2xi64>,
|
||||
precision_config = [],
|
||||
result_scale = 1.000000e+00 : f64,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>,
|
||||
window_reversal = dense<true>: tensor<2xi1>}
|
||||
result_scale = 1.000000e+00 : f64}
|
||||
: (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv_fused
|
||||
// CHECK: lmhlo_gpu.conv_forward_fused
|
||||
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1]
|
||||
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
|
||||
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
||||
%scratch = memref.alloc() : memref<32xui8>
|
||||
"lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
|
||||
lmhlo_gpu.conv_forward_fused(%input, %filter, %bias, %output, %scratch)
|
||||
dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1],
|
||||
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
|
||||
{activation_mode = "Relu",
|
||||
backend_config = {algorithm = 1 : i64,
|
||||
operand_0_layout = [3,2,1,0],
|
||||
|
@ -131,30 +144,22 @@ func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>,
|
|||
result_layout = [3,2,1,0],
|
||||
tensor_ops_enabled = false},
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 1 : i64,
|
||||
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 1 : i64,
|
||||
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
|
||||
feature_group_count = 1 : i64,
|
||||
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||
padding = dense<1> : tensor<2xi64>,
|
||||
precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"],
|
||||
result_scale = 1.000000e+00 : f64,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>}
|
||||
result_scale = 1.000000e+00 : f64}
|
||||
: (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<32xui8>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv_fused_side_input
|
||||
// CHECK: lmhlo_gpu.conv_forward_fused_with_side_input
|
||||
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1]
|
||||
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
|
||||
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
||||
%scratch = memref.alloc() : memref<0xui8>
|
||||
"lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
|
||||
lmhlo_gpu.conv_forward_fused_with_side_input(%input, %filter, %bias, %side_input, %output, %scratch)
|
||||
dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1],
|
||||
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
|
||||
{activation_mode = "Relu",
|
||||
backend_config = {algorithm = 1 : i64,
|
||||
operand_0_layout = [3,2,1,0],
|
||||
|
@ -162,23 +167,10 @@ func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x
|
|||
result_layout = [3,2,1,0],
|
||||
tensor_ops_enabled = false},
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 1 : i64,
|
||||
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 1 : i64,
|
||||
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
|
||||
feature_group_count = 1 : i64,
|
||||
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||
padding = dense<1> : tensor<2xi64>,
|
||||
precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"],
|
||||
result_scale = 1.000000e+00 : f64,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
side_input_scale = 1.000000e+00 : f64,
|
||||
window_strides = dense<1> : tensor<2xi64>}
|
||||
side_input_scale = 1.000000e+00 : f64}
|
||||
: (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<0xui8>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue