From 41f663ce47a39c53683c4080b403a476070d2ae2 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 20 May 2021 12:12:57 -0700 Subject: [PATCH] [HLO] Adopt custom syntax for convolution dimensions and window attributes (HLO) PiperOrigin-RevId: 374923250 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 9 +++++++ lib/Dialect/mhlo/IR/hlo_ops.cc | 3 +++ tests/ops.mlir | 26 +++++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 6c2949e..eceac56 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1409,6 +1409,15 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]> { [](bool v) { return v; }); } }]; + + let assemblyFormat = [{ + `(`operands`)` + `dim_numbers` `=` custom($dimension_numbers) `,` + `window` `=` `{` custom($window_strides, $padding, + $lhs_dilation, $rhs_dilation, + $window_reversal) `}` + attr-dict `:` functional-type(operands, results) + }]; } def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> { diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 0373b87..d917ae1 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -3460,6 +3460,9 @@ OpFoldResult ScatterOp::fold(ArrayRef operands) { return DenseElementsAttr::get(base_type, results); } +using mlir::hlo::parseWindowAttributes; +using mlir::hlo::printWindowAttributes; + } // namespace mhlo } // namespace mlir diff --git a/tests/ops.mlir b/tests/ops.mlir index f67c0ac..a437aa0 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -1440,3 +1440,29 @@ func @rng_uniform_invalid(%arg0: tensor, %arg1: tensor, %arg2: tensor< %0 = "mhlo.rng_uniform"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor<7xi64>) -> tensor return } + +// ----- +// CHECK: func @conv2d_generic +// CHECK: mhlo.convolution +// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] +// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} +func @conv2d_generic(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = + {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, + feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + return %0 : tensor<1x8x8x16xf32> +} + +// CHECK: func @conv2d +// CHECK: mhlo.convolution +// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] +// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} +func @conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = ["DEFAULT", "DEFAULT"]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + return %0 : tensor<1x8x8x16xf32> +}