[MLIR:LHLO_GPU] Add fused convolution operation without any side inputs.

- Add a variant of the fused convolution that does not need a side input and side input scale.
- Rename the existing one to `ConvForwardFusedSideInputOp`.
- Update tests to exercise all variants of the convolution ops in the GPU dialect.
- Eliminate unused `LHLO_ExtentBuffer` and changed LHLO_Buffer to allow any integer element
  type to match what XLA can generate sometimes for scratch buffers.

PiperOrigin-RevId: 345701569
This commit is contained in:
Rahul Joshi 2020-12-04 10:08:37 -08:00 committed by TensorFlow MLIR Team
parent 3691e39f62
commit e48881af81
3 changed files with 133 additions and 23 deletions

View File

@ -92,19 +92,11 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
// LMHLO ops representing convolution library functions.
//===----------------------------------------------------------------------===//
def GpuConvolutionAttributes {
class GpuConvolutionAttributes<dag extraAttribs> {
dag attributes = !con(
ConvolutionAttributes.attributes,
(ins F64Attr:$result_scale),
(ins ConvolutionBackendConfigAttr:$backend_config));
}
def GpuFusedConvolutionAttributes {
dag attributes = !con(
ConvolutionAttributes.attributes,
(ins F64Attr:$result_scale,
ActivationAttr:$activation_mode,
F64Attr:$side_input_scale),
extraAttribs,
(ins ConvolutionBackendConfigAttr:$backend_config));
}
@ -114,8 +106,8 @@ def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> {
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes.attributes);
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes<(ins)>.attributes);
}
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
@ -124,8 +116,8 @@ def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_input,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes.attributes);
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes<(ins)>.attributes);
}
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
@ -134,14 +126,27 @@ def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$d_filter,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes.attributes);
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes<(ins)>.attributes);
}
// output = activation(result_scale * conv(input, filter) + bias)
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes<(ins
ActivationAttr:$activation_mode)>.attributes);
}
// output = activation(result_scale * conv(input, filter) +
// side_input * side_input_scale +
// bias)
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
def LHLOGPU_ConvForwardFusedSideInputOp : LHLOGPU_Op<"conv_forward_fused_with_side_input"> {
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
@ -149,8 +154,10 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemRead]>:$side_input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
GpuFusedConvolutionAttributes.attributes);
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
GpuConvolutionAttributes<(ins
ActivationAttr:$activation_mode,
F64Attr:$side_input_scale)>.attributes);
}
//===----------------------------------------------------------------------===//

View File

@ -40,8 +40,6 @@ def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger, AnyComplex]>;
#endif // LHLO_OPS_BASE

View File

@ -56,7 +56,112 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
return
}
// -----
// CHECK-LABEL: func @conv_backfilter
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
%scratch = alloc() : memref<23328xui8>
"lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
{ backend_config = {algorithm = 1 : i64, tensor_ops_enabled = false},
batch_group_count = 1 : i64,
dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<0> : tensor<2xi64>,
precision_config = [],
result_scale = 1.000000e+00 : f64,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>}
: (memref<3x56x56x16xf64>, memref<3x3x3x64xf64>, memref<54x54x16x64xf64>, memref<23328xui8>) -> ()
return
}
// CHECK-LABEL: func @conv_backinput
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
%scratch = alloc() : memref<32xui8>
"lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
{ backend_config = {algorithm = 1 : i64, tensor_ops_enabled = false},
batch_group_count = 1 : i64,
dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 1 : i64,
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
kernel_input_feature_dimension = 1 : i64,
kernel_output_feature_dimension = 0 : i64,
kernel_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 1 : i64,
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<3> : tensor<2xi64>,
precision_config = [],
result_scale = 1.000000e+00 : f64,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>}
: (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> ()
return
}
// CHECK-LABEL: func @conv_fused
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
%scratch = alloc() : memref<32xui8>
"lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
{activation_mode = "Relu",
backend_config = {algorithm = 0 : i64, tensor_ops_enabled = false},
batch_group_count = 1 : i64,
dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 1 : i64,
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 1 : i64,
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<1> : tensor<2xi64>,
precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"],
result_scale = 1.000000e+00 : f64,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>}
: (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<32xui8>) -> ()
return
}
// CHECK-LABEL: func @conv_fused_side_input
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) {
%scratch = alloc() : memref<0xui8>
"lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
{activation_mode = "Relu",
backend_config = {algorithm = 0 : i64, tensor_ops_enabled = false},
batch_group_count = 1 : i64,
dimension_numbers = {input_batch_dimension = 0 : i64,
input_feature_dimension = 1 : i64,
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 1 : i64,
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
feature_group_count = 1 : i64,
lhs_dilation = dense<1> : tensor<2xi64>,
padding = dense<1> : tensor<2xi64>,
precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"],
result_scale = 1.000000e+00 : f64,
rhs_dilation = dense<1> : tensor<2xi64>,
side_input_scale = 1.000000e+00 : f64,
window_strides = dense<1> : tensor<2xi64>}
: (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<0xui8>) -> ()
return
}
// CHECK-LABEL: func @gemm
func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) {