diff --git a/BUILD b/BUILD index 5a64899..999c0de 100644 --- a/BUILD +++ b/BUILD @@ -592,6 +592,7 @@ cc_library( ":hlo", ":hlo_ops_base_enums", ":hlo_ops_base_structs", + ":hlo_ops_common", ":infer_fusibility_op_interface", ":lhlo_gpu_ops_enums", ":lhlo_gpu_ops_inc_gen", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td index 791d11d..f087a99 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -115,7 +115,19 @@ class GpuConvolutionAttributes { (ins ConvolutionBackendConfigAttr:$backend_config)); } -def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> { +// Provide a custom assembly format for all LHLO_GPU convolution operations. +class LHLOGPU_ConvBaseOp : LHLOGPU_Op { + let assemblyFormat = [{ + `(`operands`)` + `dim_numbers` `=` custom($dimension_numbers) `,` + `window` `=` `{` custom($window_strides, $padding, + $lhs_dilation, $rhs_dilation, + $window_reversal) `}` + attr-dict `:` functional-type(operands, results) + }]; +} + +def LHLOGPU_ConvForwardOp : LHLOGPU_ConvBaseOp<"conv_forward"> { let arguments = !con( (ins Arg:$input, @@ -125,7 +137,7 @@ def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> { GpuConvolutionAttributes<(ins)>.attributes); } -def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { +def LHLOGPU_ConvBackwardInputOp : LHLOGPU_ConvBaseOp<"conv_backwardinput"> { let arguments = !con( (ins Arg:$d_output, @@ -135,7 +147,7 @@ def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { GpuConvolutionAttributes<(ins)>.attributes); } -def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { +def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_ConvBaseOp<"conv_backwardfilter"> { let arguments = !con( (ins Arg:$input, @@ -146,7 +158,7 @@ def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { } // output = activation(result_scale * conv(input, filter) + bias) -def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { +def LHLOGPU_ConvForwardFusedOp : LHLOGPU_ConvBaseOp<"conv_forward_fused"> { let arguments = !con( (ins Arg:$input, @@ -161,7 +173,8 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { // output = activation(result_scale * conv(input, filter) + // side_input * side_input_scale + // bias) -def LHLOGPU_ConvForwardFusedSideInputOp : LHLOGPU_Op<"conv_forward_fused_with_side_input"> { +def LHLOGPU_ConvForwardFusedSideInputOp : + LHLOGPU_ConvBaseOp<"conv_forward_fused_with_side_input"> { let arguments = !con( (ins Arg:$input, diff --git a/lib/Dialect/mhlo/IR/CMakeLists.txt b/lib/Dialect/mhlo/IR/CMakeLists.txt index 35019fd..6786de6 100644 --- a/lib/Dialect/mhlo/IR/CMakeLists.txt +++ b/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -85,8 +85,11 @@ add_mlir_dialect_library(LmhloGPUDialect DEPENDS MLIRlhlo_gpu_opsIncGen ) -target_link_libraries(LmhloGPUDialect PUBLIC MLIRIR) - +target_link_libraries(LmhloGPUDialect + PUBLIC + MLIRIR + HloOpsCommon +) add_mlir_dialect_library(MhloRegisterDialects init.cc diff --git a/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc b/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc index 572cc43..42c97ac 100644 --- a/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc @@ -28,6 +28,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -57,6 +58,9 @@ LmhloGpuDialect::LmhloGpuDialect(MLIRContext *context) // TODO(jurahul): Add verification for operand shapes and ranks. +using mlir::hlo::parseWindowAttributes; +using mlir::hlo::printWindowAttributes; + } // namespace lmhlo_gpu } // namespace mlir diff --git a/tests/lhlo_gpu_ops.mlir b/tests/lhlo_gpu_ops.mlir index 4ffd0f4..2db1817 100644 --- a/tests/lhlo_gpu_ops.mlir +++ b/tests/lhlo_gpu_ops.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt +/// | FileCheck %s // CHECK-LABEL: func @batch_norm_grad_memrefs func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, @@ -28,8 +29,11 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3 return } -// CHECK-LABEL: func @conv_forward -func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { +// CHECK-LABEL: func @conv_forward_generic +// CHECK: lmhlo_gpu.conv_forward +// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1] +// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} +func @conv_forward_generic(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { %scratch = memref.alloc() : memref<32xi8> // This defined a 2D convolution over a 8x8 single channel input using a 2x2 // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) @@ -44,7 +48,7 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, % output_feature_dimension = 1 : i64, output_spatial_dimensions = dense<[2,3]> : tensor<2xi64>}, window_strides = dense<[1, 1]> : tensor<2xi64>, - padding = dense<[0,0]> : tensor<2xi64>, + padding = dense<[[0, 0], [1, 0]]> : tensor<2x2xi64>, lhs_dilation = dense<[1,1]> : tensor<2xi64>, rhs_dilation = dense<[1,1]> : tensor<2xi64>, feature_group_count = 1, @@ -59,71 +63,80 @@ func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, % return } +// CHECK-LABEL: func @conv_forward +// CHECK: lmhlo_gpu.conv_forward +// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1] +// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} +func @conv_forward(%input : memref<1x1x8x8xf16>, %filter: memref<1x1x2x2xf16>, %output: memref<1x1x7x7xf16>) { + %scratch = memref.alloc() : memref<32xi8> + // This defined a 2D convolution over a 8x8 single channel input using a 2x2 + // filter and with an output of 7x7xf16. The 1x1x8x8 is (N, C, H, W) + lmhlo_gpu.conv_forward(%input, %filter, %output, %scratch) + dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], + window = {stride = [1, 1], pad = [[0, 0], [1, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + { feature_group_count = 1, batch_group_count = 1, result_scale = 1.0, + backend_config = {algorithm=0, + operand_0_layout = [3,2,1,0], + operand_1_layout = [3,2,1,0], + result_layout = [3,2,1,0], + tensor_ops_enabled = true}} + : (memref<1x1x8x8xf16>, memref<1x1x2x2xf16>, memref<1x1x7x7xf16>, memref<32xi8>) -> () + return +} + // CHECK-LABEL: func @conv_backfilter +// CHECK: lmhlo_gpu.conv_backwardfilter +// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] +// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} func @conv_backfilter(%input : memref<3x56x56x16xf64>, %filter: memref<3x3x3x64xf64>, %output: memref<54x54x16x64xf64>) { %scratch = memref.alloc() : memref<23328xui8> - "lmhlo_gpu.conv_backwardfilter"(%input, %filter, %output, %scratch) + lmhlo_gpu.conv_backwardfilter(%input, %filter, %output, %scratch) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} { backend_config = {algorithm = 1 : i64, operand_0_layout = [3,2,1,0], operand_1_layout = [3,2,1,0], result_layout = [3,2,1,0], tensor_ops_enabled = false}, batch_group_count = 1 : i64, - dimension_numbers = {input_batch_dimension = 0 : i64, - input_feature_dimension = 3 : i64, - input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, - kernel_input_feature_dimension = 2 : i64, - kernel_output_feature_dimension = 3 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, - output_batch_dimension = 0 : i64, - output_feature_dimension = 3 : i64, - output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, - lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<0> : tensor<2xi64>, precision_config = [], - result_scale = 1.000000e+00 : f64, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} + result_scale = 1.000000e+00 : f64} : (memref<3x56x56x16xf64>, memref<3x3x3x64xf64>, memref<54x54x16x64xf64>, memref<23328xui8>) -> () return } // CHECK-LABEL: func @conv_backinput +// CHECK: lmhlo_gpu.conv_backwardinput +// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1] +// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[3, 0], [1, 5]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]} func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf64>, %output : memref<4x3x16x16xf64>) { %scratch = memref.alloc() : memref<32xui8> - "lmhlo_gpu.conv_backwardinput"(%input, %filter, %output, %scratch) + lmhlo_gpu.conv_backwardinput(%input, %filter, %output, %scratch) + dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], + window = {stride = [1, 1], pad = [[3, 0], [1, 5]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]} { backend_config = {algorithm = 1 : i64, operand_0_layout = [3,2,1,0], operand_1_layout = [3,2,1,0], result_layout = [3,2,1,0], tensor_ops_enabled = false}, batch_group_count = 1 : i64, - dimension_numbers = {input_batch_dimension = 0 : i64, - input_feature_dimension = 1 : i64, - input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>, - kernel_input_feature_dimension = 1 : i64, - kernel_output_feature_dimension = 0 : i64, - kernel_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>, - output_batch_dimension = 0 : i64, - output_feature_dimension = 1 : i64, - output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>}, feature_group_count = 1 : i64, - lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<3> : tensor<2xi64>, precision_config = [], - result_scale = 1.000000e+00 : f64, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>, - window_reversal = dense: tensor<2xi1>} + result_scale = 1.000000e+00 : f64} : (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> () return } // CHECK-LABEL: func @conv_fused +// CHECK: lmhlo_gpu.conv_forward_fused +// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1] +// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %output : memref<1x32x9x9xf16>) { %scratch = memref.alloc() : memref<32xui8> - "lmhlo_gpu.conv_forward_fused"(%input, %filter, %bias, %output, %scratch) + lmhlo_gpu.conv_forward_fused(%input, %filter, %bias, %output, %scratch) + dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {activation_mode = "Relu", backend_config = {algorithm = 1 : i64, operand_0_layout = [3,2,1,0], @@ -131,30 +144,22 @@ func @conv_fused(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, result_layout = [3,2,1,0], tensor_ops_enabled = false}, batch_group_count = 1 : i64, - dimension_numbers = {input_batch_dimension = 0 : i64, - input_feature_dimension = 1 : i64, - input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>, - kernel_input_feature_dimension = 2 : i64, - kernel_output_feature_dimension = 3 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, - output_batch_dimension = 0 : i64, - output_feature_dimension = 1 : i64, - output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>}, feature_group_count = 1 : i64, - lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<1> : tensor<2xi64>, precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"], - result_scale = 1.000000e+00 : f64, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} + result_scale = 1.000000e+00 : f64} : (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<32xui8>) -> () return } // CHECK-LABEL: func @conv_fused_side_input +// CHECK: lmhlo_gpu.conv_forward_fused_with_side_input +// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1] +// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>, %output : memref<1x32x9x9xf16>) { %scratch = memref.alloc() : memref<0xui8> - "lmhlo_gpu.conv_forward_fused_with_side_input"(%input, %filter, %bias, %side_input, %output, %scratch) + lmhlo_gpu.conv_forward_fused_with_side_input(%input, %filter, %bias, %side_input, %output, %scratch) + dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {activation_mode = "Relu", backend_config = {algorithm = 1 : i64, operand_0_layout = [3,2,1,0], @@ -162,23 +167,10 @@ func @conv_fused_side_input(%input : memref<1x17x9x9xf16>, %filter : memref<3x3x result_layout = [3,2,1,0], tensor_ops_enabled = false}, batch_group_count = 1 : i64, - dimension_numbers = {input_batch_dimension = 0 : i64, - input_feature_dimension = 1 : i64, - input_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>, - kernel_input_feature_dimension = 2 : i64, - kernel_output_feature_dimension = 3 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, - output_batch_dimension = 0 : i64, - output_feature_dimension = 1 : i64, - output_spatial_dimensions = dense<[2, 3]> : tensor<2xi64>}, feature_group_count = 1 : i64, - lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<1> : tensor<2xi64>, precision_config = ["DEFAULT", "DEFAULT", "DEFAULT"], result_scale = 1.000000e+00 : f64, - rhs_dilation = dense<1> : tensor<2xi64>, - side_input_scale = 1.000000e+00 : f64, - window_strides = dense<1> : tensor<2xi64>} + side_input_scale = 1.000000e+00 : f64} : (memref<1x17x9x9xf16>, memref<3x3x17x32xf16>, memref<32xf16>, memref<32xf16>, memref<1x32x9x9xf16>, memref<0xui8>) -> () return }