diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td index 21a25fb..d36e634 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -92,19 +92,11 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">, // LMHLO ops representing convolution library functions. //===----------------------------------------------------------------------===// -def GpuConvolutionAttributes { +class GpuConvolutionAttributes { dag attributes = !con( ConvolutionAttributes.attributes, (ins F64Attr:$result_scale), - (ins ConvolutionBackendConfigAttr:$backend_config)); -} - -def GpuFusedConvolutionAttributes { - dag attributes = !con( - ConvolutionAttributes.attributes, - (ins F64Attr:$result_scale, - ActivationAttr:$activation_mode, - F64Attr:$side_input_scale), + extraAttribs, (ins ConvolutionBackendConfigAttr:$backend_config)); } @@ -114,8 +106,8 @@ def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> { Arg:$input, Arg:$filter, Arg:$output, - Arg:$scratch), - GpuConvolutionAttributes.attributes); + Arg:$scratch), + GpuConvolutionAttributes<(ins)>.attributes); } def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { @@ -124,8 +116,8 @@ def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { Arg:$d_output, Arg:$filter, Arg:$d_input, - Arg:$scratch), - GpuConvolutionAttributes.attributes); + Arg:$scratch), + GpuConvolutionAttributes<(ins)>.attributes); } def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { @@ -134,14 +126,27 @@ def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { Arg:$input, Arg:$d_output, Arg:$d_filter, - Arg:$scratch), - GpuConvolutionAttributes.attributes); + Arg:$scratch), + GpuConvolutionAttributes<(ins)>.attributes); +} + +// output = activation(result_scale * conv(input, filter) + bias) +def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$bias, + Arg:$output, + Arg:$scratch), + GpuConvolutionAttributes<(ins + ActivationAttr:$activation_mode)>.attributes); } // output = activation(result_scale * conv(input, filter) + // side_input * side_input_scale + // bias) -def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { +def LHLOGPU_ConvForwardFusedSideInputOp : LHLOGPU_Op<"conv_forward_fused_with_side_input"> { let arguments = !con( (ins Arg:$input, @@ -149,8 +154,10 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { Arg:$bias, Arg:$side_input, Arg:$output, - Arg:$scratch), - GpuFusedConvolutionAttributes.attributes); + Arg:$scratch), + GpuConvolutionAttributes<(ins + ActivationAttr:$activation_mode, + F64Attr:$side_input_scale)>.attributes); } //===----------------------------------------------------------------------===// diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td index 9cd7741..2f5b542 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td @@ -40,8 +40,6 @@ def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; -def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; - -def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; +def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger, AnyComplex]>; #endif // LHLO_OPS_BASE diff --git a/tests/lhlo_gpu_ops.mlir b/tests/lhlo_gpu_ops.mlir index bd5df38..35bf59b 100644 --- a/tests/lhlo_gpu_ops.mlir +++ b/tests/lhlo_gpu_ops.mlir @@ -56,7 +56,112 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, % return } -// ----- +// CHECK-LABEL: func @conv_backfilter +func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) { + %scratch = alloc() : memref<23328xui8> + "lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch) + { backend_config = {algorithm = 1 : i64, tensor_ops_enabled = false}, + batch_group_count = 1 : i64, + dimension_numbers = {input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<0> : tensor<2xi64>, + precision_config = [], + result_scale = 1.000000e+00 : f64, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64>} + : (memref<3x56x56x16xf64>, memref<3x3x3x64xf64>, memref<54x54x16x64xf64>, memref<23328xui8>) -> () + return +} + +// CHECK-LABEL: func @conv_backinput +func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) { + %scratch = alloc() : memref<32xui8> + "lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch) + { backend_config = {algorithm = 1 : i64, tensor_ops_enabled = false}, + batch_group_count = 1 : i64, + dimension_numbers = {input_batch_dimension = 0 : i64, + input_feature_dimension = 1 : i64, + input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>, + kernel_input_feature_dimension = 1 : i64, + kernel_output_feature_dimension = 0 : i64, + kernel_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 1 : i64, + output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>}, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<3> : tensor<2xi64>, + precision_config = [], + result_scale = 1.000000e+00 : f64, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64>} + : (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> () + return +} + +// CHECK-LABEL: func @conv_fused +func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) { + %scratch = alloc() : memref<32xui8> + "lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch) + {activation_mode = "Relu", + backend_config = {algorithm = 0 : i64, tensor_ops_enabled = false}, + batch_group_count = 1 : i64, + dimension_numbers = {input_batch_dimension = 0 : i64, + input_feature_dimension = 1 : i64, + input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 1 : i64, + output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>}, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<1> : tensor<2xi64>, + precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"], + result_scale = 1.000000e+00 : f64, + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64>} + : (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<32xui8>) -> () + return +} + +// CHECK-LABEL: func @conv_fused_side_input +func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) { + %scratch = alloc() : memref<0xui8> + "lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch) + {activation_mode = "Relu", + backend_config = {algorithm = 0 : i64, tensor_ops_enabled = false}, + batch_group_count = 1 : i64, + dimension_numbers = {input_batch_dimension = 0 : i64, + input_feature_dimension = 1 : i64, + input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 1 : i64, + output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>}, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<1> : tensor<2xi64>, + precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"], + result_scale = 1.000000e+00 : f64, + rhs_dilation = dense<1> : tensor<2xi64>, + side_input_scale = 1.000000e+00 : f64, + window_strides = dense<1> : tensor<2xi64>} + : (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<0xui8>) -> () + return +} // CHECK-LABEL: func @gemm func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>) {