[MLIR:LHLO_GPU] Add fused convolution operation without any side inputs.
- Add a variant of the fused convolution that does not need a side input and side input scale. - Rename the existing one to `ConvForwardFusedSideInputOp`. - Update tests to exercise all variants of the convolution ops in the GPU dialect. - Eliminate unused `LHLO_ExtentBuffer` and changed LHLO_Buffer to allow any integer element type to match what XLA can generate sometimes for scratch buffers. PiperOrigin-RevId: 345701569
This commit is contained in:
parent
3691e39f62
commit
e48881af81
|
@ -92,19 +92,11 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">,
|
||||||
// LMHLO ops representing convolution library functions.
|
// LMHLO ops representing convolution library functions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def GpuConvolutionAttributes {
|
class GpuConvolutionAttributes<dag extraAttribs> {
|
||||||
dag attributes = !con(
|
dag attributes = !con(
|
||||||
ConvolutionAttributes.attributes,
|
ConvolutionAttributes.attributes,
|
||||||
(ins F64Attr:$result_scale),
|
(ins F64Attr:$result_scale),
|
||||||
(ins ConvolutionBackendConfigAttr:$backend_config));
|
extraAttribs,
|
||||||
}
|
|
||||||
|
|
||||||
def GpuFusedConvolutionAttributes {
|
|
||||||
dag attributes = !con(
|
|
||||||
ConvolutionAttributes.attributes,
|
|
||||||
(ins F64Attr:$result_scale,
|
|
||||||
ActivationAttr:$activation_mode,
|
|
||||||
F64Attr:$side_input_scale),
|
|
||||||
(ins ConvolutionBackendConfigAttr:$backend_config));
|
(ins ConvolutionBackendConfigAttr:$backend_config));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,8 +106,8 @@ def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
|
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
|
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
|
||||||
GpuConvolutionAttributes.attributes);
|
GpuConvolutionAttributes<(ins)>.attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
|
def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
|
||||||
|
@ -124,8 +116,8 @@ def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
|
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
|
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$d_input,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$d_input,
|
||||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
|
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
|
||||||
GpuConvolutionAttributes.attributes);
|
GpuConvolutionAttributes<(ins)>.attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
|
def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
|
||||||
|
@ -134,14 +126,27 @@ def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
|
Arg<LHLO_Buffer, "", [MemRead]>:$d_output,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$d_filter,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$d_filter,
|
||||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
|
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
|
||||||
GpuConvolutionAttributes.attributes);
|
GpuConvolutionAttributes<(ins)>.attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// output = activation(result_scale * conv(input, filter) + bias)
|
||||||
|
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
|
||||||
|
let arguments = !con(
|
||||||
|
(ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$filter,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
|
||||||
|
GpuConvolutionAttributes<(ins
|
||||||
|
ActivationAttr:$activation_mode)>.attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
// output = activation(result_scale * conv(input, filter) +
|
// output = activation(result_scale * conv(input, filter) +
|
||||||
// side_input * side_input_scale +
|
// side_input * side_input_scale +
|
||||||
// bias)
|
// bias)
|
||||||
def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
|
def LHLOGPU_ConvForwardFusedSideInputOp : LHLOGPU_Op<"conv_forward_fused_with_side_input"> {
|
||||||
let arguments = !con(
|
let arguments = !con(
|
||||||
(ins
|
(ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||||
|
@ -149,8 +154,10 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$side_input,
|
Arg<LHLO_Buffer, "", [MemRead]>:$side_input,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch),
|
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch),
|
||||||
GpuFusedConvolutionAttributes.attributes);
|
GpuConvolutionAttributes<(ins
|
||||||
|
ActivationAttr:$activation_mode,
|
||||||
|
F64Attr:$side_input_scale)>.attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -40,8 +40,6 @@ def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||||
|
|
||||||
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
||||||
|
|
||||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger, AnyComplex]>;
|
||||||
|
|
||||||
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
|
||||||
|
|
||||||
#endif // LHLO_OPS_BASE
|
#endif // LHLO_OPS_BASE
|
||||||
|
|
|
@ -56,7 +56,112 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// CHECK-LABEL: func @conv_backfilter
|
||||||
|
func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) {
|
||||||
|
%scratch = alloc() : memref<23328xui8>
|
||||||
|
"lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch)
|
||||||
|
{ backend_config = {algorithm = 1 : i64, tensor_ops_enabled = false},
|
||||||
|
batch_group_count = 1 : i64,
|
||||||
|
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||||
|
input_feature_dimension = 3 : i64,
|
||||||
|
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
kernel_input_feature_dimension = 2 : i64,
|
||||||
|
kernel_output_feature_dimension = 3 : i64,
|
||||||
|
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
output_batch_dimension = 0 : i64,
|
||||||
|
output_feature_dimension = 3 : i64,
|
||||||
|
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||||
|
feature_group_count = 1 : i64,
|
||||||
|
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
padding = dense<0> : tensor<2xi64>,
|
||||||
|
precision_config = [],
|
||||||
|
result_scale = 1.000000e+00 : f64,
|
||||||
|
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
window_strides = dense<1> : tensor<2xi64>}
|
||||||
|
: (memref<3x56x56x16xf64>, memref<3x3x3x64xf64>, memref<54x54x16x64xf64>, memref<23328xui8>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @conv_backinput
|
||||||
|
func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) {
|
||||||
|
%scratch = alloc() : memref<32xui8>
|
||||||
|
"lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch)
|
||||||
|
{ backend_config = {algorithm = 1 : i64, tensor_ops_enabled = false},
|
||||||
|
batch_group_count = 1 : i64,
|
||||||
|
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||||
|
input_feature_dimension = 1 : i64,
|
||||||
|
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
|
||||||
|
kernel_input_feature_dimension = 1 : i64,
|
||||||
|
kernel_output_feature_dimension = 0 : i64,
|
||||||
|
kernel_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
|
||||||
|
output_batch_dimension = 0 : i64,
|
||||||
|
output_feature_dimension = 1 : i64,
|
||||||
|
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
|
||||||
|
feature_group_count = 1 : i64,
|
||||||
|
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
padding = dense<3> : tensor<2xi64>,
|
||||||
|
precision_config = [],
|
||||||
|
result_scale = 1.000000e+00 : f64,
|
||||||
|
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
window_strides = dense<1> : tensor<2xi64>}
|
||||||
|
: (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @conv_fused
|
||||||
|
func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
||||||
|
%scratch = alloc() : memref<32xui8>
|
||||||
|
"lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch)
|
||||||
|
{activation_mode = "Relu",
|
||||||
|
backend_config = {algorithm = 0 : i64, tensor_ops_enabled = false},
|
||||||
|
batch_group_count = 1 : i64,
|
||||||
|
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||||
|
input_feature_dimension = 1 : i64,
|
||||||
|
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
|
||||||
|
kernel_input_feature_dimension = 2 : i64,
|
||||||
|
kernel_output_feature_dimension = 3 : i64,
|
||||||
|
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
output_batch_dimension = 0 : i64,
|
||||||
|
output_feature_dimension = 1 : i64,
|
||||||
|
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
|
||||||
|
feature_group_count = 1 : i64,
|
||||||
|
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
padding = dense<1> : tensor<2xi64>,
|
||||||
|
precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"],
|
||||||
|
result_scale = 1.000000e+00 : f64,
|
||||||
|
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
window_strides = dense<1> : tensor<2xi64>}
|
||||||
|
: (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<32xui8>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @conv_fused_side_input
|
||||||
|
func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) {
|
||||||
|
%scratch = alloc() : memref<0xui8>
|
||||||
|
"lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch)
|
||||||
|
{activation_mode = "Relu",
|
||||||
|
backend_config = {algorithm = 0 : i64, tensor_ops_enabled = false},
|
||||||
|
batch_group_count = 1 : i64,
|
||||||
|
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||||
|
input_feature_dimension = 1 : i64,
|
||||||
|
input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>,
|
||||||
|
kernel_input_feature_dimension = 2 : i64,
|
||||||
|
kernel_output_feature_dimension = 3 : i64,
|
||||||
|
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
output_batch_dimension = 0 : i64,
|
||||||
|
output_feature_dimension = 1 : i64,
|
||||||
|
output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>},
|
||||||
|
feature_group_count = 1 : i64,
|
||||||
|
lhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
padding = dense<1> : tensor<2xi64>,
|
||||||
|
precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"],
|
||||||
|
result_scale = 1.000000e+00 : f64,
|
||||||
|
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
|
side_input_scale = 1.000000e+00 : f64,
|
||||||
|
window_strides = dense<1> : tensor<2xi64>}
|
||||||
|
: (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<0xui8>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @gemm
|
// CHECK-LABEL: func @gemm
|
||||||
func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) {
|
func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) {
|
||||||
|
|
Loading…
Reference in New Issue