[HLO] Adopt custom syntax for convolution dims and window attributes for LMHLO_GPU

PiperOrigin-RevId: 374889917
This commit is contained in:
Rahul Joshi 2021-05-20 09:40:28 -07:00 committed by TensorFlow MLIR Team
parent 57aeb5ab16
commit fc88cf1ff4
5 changed files with 85 additions and 72 deletions

1
BUILD
View File

@ -592,6 +592,7 @@ cc_library(
":hlo",
":hlo_ops_base_enums",
":hlo_ops_base_structs",
":hlo_ops_common",
":infer_fusibility_op_interface",
":lhlo_gpu_ops_enums",
":lhlo_gpu_ops_inc_gen",

View File

@ -115,7 +115,19 @@ class GpuConvolutionAttributes<dag extraAttribs> {
(ins ConvolutionBackendConfigAttr:$backend_config));
}
def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> {
// Provide a custom assembly format for all LHLO_GPU convolution operations.
class LHLOGPU_ConvBaseOp<string mnemonic> : LHLOGPU_Op<mnemonic> {
let assemblyFormat = [{
`(`operands`)`
`dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,`
`window` `=` `{` custom<WindowAttributes>($window_strides, $padding,
$lhs_dilation, $rhs_dilation,
$window_reversal) `}`
attr-dict `:` functional-type(operands, results)
}];
}
def LHLOGPU_ConvForwardOp : LHLOGPU_ConvBaseOp<"conv_forward"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
@ -125,7 +137,7 @@ def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> {
GpuConvolutionAttributes<(ins)>.attributes);
}
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_ConvBaseOp<"conv_backwardinput"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
@ -135,7 +147,7 @@ def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
GpuConvolutionAttributes<(ins)>.attributes);
}
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_ConvBaseOp<"conv_backwardfilter"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
@ -146,7 +158,7 @@ def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
}
// output = activation(result_scale * conv(input, filter) + bias)
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_ConvBaseOp<"conv_forward_fused"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
@ -161,7 +173,8 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
// output = activation(result_scale * conv(input, filter) +
// side_input * side_input_scale +
// bias)
def LHLOGPU_ConvForwardFusedSideInputOp : LHLOGPU_Op<"conv_forward_fused_with_side_input"> {
def LHLOGPU_ConvForwardFusedSideInputOp :
LHLOGPU_ConvBaseOp<"conv_forward_fused_with_side_input"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,

View File

@ -85,8 +85,11 @@ add_mlir_dialect_library(LmhloGPUDialect
DEPENDS
MLIRlhlo_gpu_opsIncGen
)
target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR)
target_link_libraries(LmhloGPUDialect
PUBLIC
MLIRIR
HloOpsCommon
)
add_mlir_dialect_library(MhloRegisterDialects
init.cc

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@ -57,6 +58,9 @@ LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context)
// TODO(jurahul): Add verification for operand shapes and ranks.
using mlir::hlo::parseWindowAttributes;
using mlir::hlo::printWindowAttributes;
} // namespace lmhlo_gpu
} // namespace mlir

View File

@ -1,4 +1,5 @@
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt
/// | FileCheck %s
// CHECK-LABEL: func @batch_norm_grad_memrefs
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
@ -28,8 +29,11 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
return
}
// CHECK-LABEL: func @conv_forward
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
// CHECK-LABEL: func @conv_forward_generic
// CHECK: lmhlo_gpu.conv_forward
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
func @conv_forward_generic(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
%scratch = memref.alloc() : memref<32xi8>
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
@ -44,7 +48,7 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
output_feature_dimension = 1 : i64,
output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>},
window_strides = dense<[1, 1]> : tensor<2xi64>,
padding = dense<[0,0]> : tensor<2xi64>,
padding = dense<[[0, 0], [1, 0]]> : tensor<2x2xi64>,
lhs_dilation = dense<[1,1]> : tensor<2xi64>,
rhs_dilation = dense<[1,1]> : tensor<2xi64>,
feature_group_count = 1,
@ -59,71 +63,80 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
return
}
// CHECK-LABEL: func @conv_forward
// CHECK: lmhlo_gpu.conv_forward
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) {
%scratch = memref.alloc() : memref<32xi8>
// This defined a 2D convolution over a 8x8 single channel input using a 2x2
// filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W)
lmhlo_gpu.conv_forward(%input, %filter, %output, %scratch)
dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1],
window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{ feature_group_count = 1, batch_group_count = 1, result_scale = 1.0,
backend_config = {algorithm=0,
operand_0_layout = [3,2,1,0],
operand_1_layout = [3,2,1,0],
result_layout = [3,2,1,0],
tensor_ops_enabled = true}}
: (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> ()
return
}
// CHECK-LABEL: func @conv_backfilter
// CHECK: lmhlo_gpu.conv_backwardfilter
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
%scratch = memref.alloc() : memref<23328xui8>
"lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
lmhlo_gpu.conv_backwardfilter(%input, %filter, %output, %scratch)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{ backend_config = {algorithm = 1 : i64,
operand_0_layout = [3,2,1,0],
operand_1_layout = [3,2,1,0],
result_layout = [3,2,1,0],
tensor_ops_enabled = false},
batch_group_count = 1 : i64,
dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<0> : tensor<2xi64>,
precision_config = [],
result_scale = 1.000000e+00 : f64,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>}
result_scale = 1.000000e+00 : f64}
: (memref<3x56x56x16xf64>, memref<3x3x3x64xf64>, memref<54x54x16x64xf64>, memref<23328xui8>) -> ()
return
}
// CHECK-LABEL: func @conv_backinput
// CHECK: lmhlo_gpu.conv_backwardinput
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[3, 0], [1, 5]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]}
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
%scratch = memref.alloc() : memref<32xui8>
"lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
lmhlo_gpu.conv_backwardinput(%input, %filter, %output, %scratch)
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1],
window = {stride = [1, 1], pad = [[3, 0], [1, 5]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]}
{ backend_config = {algorithm = 1 : i64,
operand_0_layout = [3,2,1,0],
operand_1_layout = [3,2,1,0],
result_layout = [3,2,1,0],
tensor_ops_enabled = false},
batch_group_count = 1 : i64,
dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 1 : i64,
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
kernel_input_feature_dimension = 1 : i64,
kernel_output_feature_dimension = 0 : i64,
kernel_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 1 : i64,
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<3> : tensor<2xi64>,
precision_config = [],
result_scale = 1.000000e+00 : f64,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>,
window_reversal = dense<true>: tensor<2xi1>}
result_scale = 1.000000e+00 : f64}
: (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> ()
return
}
// CHECK-LABEL: func @conv_fused
// CHECK: lmhlo_gpu.conv_forward_fused
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
%scratch = memref.alloc() : memref<32xui8>
"lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
lmhlo_gpu.conv_forward_fused(%input, %filter, %bias, %output, %scratch)
dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{activation_mode = "Relu",
backend_config = {algorithm = 1 : i64,
operand_0_layout = [3,2,1,0],
@ -131,30 +144,22 @@ func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>,
result_layout = [3,2,1,0],
tensor_ops_enabled = false},
batch_group_count = 1 : i64,
dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 1 : i64,
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 1 : i64,
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<1> : tensor<2xi64>,
precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"],
result_scale = 1.000000e+00 : f64,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>}
result_scale = 1.000000e+00 : f64}
: (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<32xui8>) -> ()
return
}
// CHECK-LABEL: func @conv_fused_side_input
// CHECK: lmhlo_gpu.conv_forward_fused_with_side_input
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) {
%scratch = memref.alloc() : memref<0xui8>
"lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
lmhlo_gpu.conv_forward_fused_with_side_input(%input, %filter, %bias, %side_input, %output, %scratch)
dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{activation_mode = "Relu",
backend_config = {algorithm = 1 : i64,
operand_0_layout = [3,2,1,0],
@ -162,23 +167,10 @@ func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x
result_layout = [3,2,1,0],
tensor_ops_enabled = false},
batch_group_count = 1 : i64,
dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 1 : i64,
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 1 : i64,
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<1> : tensor<2xi64>,
precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"],
result_scale = 1.000000e+00 : f64,
rhs_dilation = dense<1> : tensor<2xi64>,
side_input_scale = 1.000000e+00 : f64,
window_strides = dense<1> : tensor<2xi64>}
side_input_scale = 1.000000e+00 : f64}
: (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<0xui8>) -> ()
return
}