diff --git a/BUILD b/BUILD index bcf0956..5a64899 100644 --- a/BUILD +++ b/BUILD @@ -328,6 +328,7 @@ cc_library( hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"], includes = ["include"], deps = [ + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h index e5b4477..6aae3d1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h @@ -19,6 +19,8 @@ limitations under the License. // This file defines functionality shared between chlo/mhlo/lhlo dialects. #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" namespace mlir { @@ -26,7 +28,22 @@ namespace hlo { // Verifies the source target pairs attached to collective permute. LogicalResult VerifyCollectivePermuteSourceTargetPairs( - Operation *op, DenseIntElementsAttr attr); + Operation* op, DenseIntElementsAttr attr); + +// Custom formatting for convolution window attributes. +void printWindowAttributes(OpAsmPrinter& p, Operation* op, + llvm::Optional window_strides, + llvm::Optional padding, + llvm::Optional lhs_dilation, + llvm::Optional rhs_dilation, + llvm::Optional window_reversal); + +ParseResult parseWindowAttributes(OpAsmParser& parser, + DenseIntElementsAttr& window_strides, + DenseIntElementsAttr& padding, + DenseIntElementsAttr& lhs_dilation, + DenseIntElementsAttr& rhs_dilation, + DenseElementsAttr& window_reversal); } // namespace hlo } // namespace mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 6e847c2..fb47c8b 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -868,7 +868,10 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []> { let assemblyFormat = [{ `(`operands`)` - `dim_numbers` `=` custom($dimension_numbers) + `dim_numbers` `=` custom($dimension_numbers) `,` + `window` `=` `{` custom($window_strides, $padding, + $lhs_dilation, $rhs_dilation, + $window_reversal) `}` attr-dict `:` functional-type(operands, results) }]; } diff --git a/lib/Dialect/mhlo/IR/hlo_ops_common.cc b/lib/Dialect/mhlo/IR/hlo_ops_common.cc index 06bb29e..15b5e0a 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops_common.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops_common.cc @@ -15,11 +15,14 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" namespace mlir { namespace hlo { - // Verifies the source target pairs attached to collective permute. LogicalResult VerifyCollectivePermuteSourceTargetPairs( Operation *op, DenseIntElementsAttr attr) { @@ -50,5 +53,164 @@ LogicalResult VerifyCollectivePermuteSourceTargetPairs( return success(); } +namespace { +// Custom formatting for convolution window attributes. +void printWindowAttribute(OpAsmPrinter &p, DenseElementsAttr attribute) { + if (attribute.getType().getElementType().isInteger(/*width=*/1)) { + // boolean attribute. + llvm::interleaveComma(attribute.getBoolValues(), p, + [&](bool b) { p << (b ? 1 : 0); }); + return; + } + if (attribute.getType().getRank() == 2) { + // Padding is Nx2 attribute. + auto it = attribute.getValues().begin(); + std::vector> values(attribute.getNumElements() / + 2); + for (auto &item : values) { + int64_t first = *it; + ++it; + int64_t second = *it; + ++it; + item = {first, second}; + } + llvm::interleaveComma( + values, p, [&](const std::pair pair) { + p << '[' << pair.first << ", " << pair.second << ']'; + }); + } else { + llvm::interleaveComma(attribute.getValues(), p); + } +} +} // namespace + +void printWindowAttributes(OpAsmPrinter &p, Operation *op, + llvm::Optional window_strides, + llvm::Optional padding, + llvm::Optional lhs_dilation, + llvm::Optional rhs_dilation, + llvm::Optional window_reversal) { + using pair_t = std::pair; + std::array printed_attributes = {{ + {window_strides ? *window_strides : nullptr, "stride"}, + {padding ? *padding : nullptr, "pad"}, + {lhs_dilation ? *lhs_dilation : nullptr, "lhs_dilate"}, + {rhs_dilation ? *rhs_dilation : nullptr, "rhs_dilate"}, + {window_reversal ? *window_reversal : nullptr, "reverse"}, + }}; + + // Do not print attributes that do no exist. + auto non_null_attributes = llvm::make_filter_range( + printed_attributes, + [](const pair_t &a) { return static_cast(a.first); }); + + llvm::interleaveComma(non_null_attributes, p, [&](const pair_t &a) { + p << a.second << " = ["; + printWindowAttribute(p, a.first); + p << "]"; + }); +} + +ParseResult parseWindowAttributes(OpAsmParser &parser, + DenseIntElementsAttr &window_strides, + DenseIntElementsAttr &padding, + DenseIntElementsAttr &lhs_dilation, + DenseIntElementsAttr &rhs_dilation, + DenseElementsAttr &window_reversal) { + StringRef attribute_name; + + // Helper to parse an array of the form [ e0, e1, .. ] + auto parse_array = [&](std::function parse_element, + llvm::Optional expected_size = + llvm::None) -> ParseResult { + if (parser.parseLSquare()) { + return failure(); + } + size_t size = 0; + do { + if (parse_element()) { + return failure(); + } + size++; + } while (parser.parseOptionalComma().succeeded()); + if (parser.parseRSquare()) { + return failure(); + } + if (expected_size && size != *expected_size) { + return parser.emitError(parser.getCurrentLocation(), + "Expected array with") + << *expected_size << " elements, got " << size + << " elements instead"; + } + return success(); + }; + + llvm::StringSet<> allowed_attribute_names{ + {"stride", "pad", "lhs_dilate", "rhs_dilate", "reverse"}}; + + while (parser.parseOptionalKeyword(&attribute_name).succeeded()) { + // Verify that the attribute name is valid and erase it. + if (!allowed_attribute_names.erase(attribute_name)) { + return parser.emitError(parser.getCurrentLocation(), + "Unexpected keyword ") + << attribute_name; + } + + if (parser.parseEqual()) { + return failure(); + } + + // parse the attribute value. We need to support either 1D and Nx2 array of + // integers to parse. + llvm::SmallVector values; + auto int64_parser = [&]() { + return parser.parseInteger(values.emplace_back(0)); + }; + + if (attribute_name == "pad") { + // Parse a 2D array of integers. + auto inner_parser = [&]() { + return parse_array(int64_parser, /*expected_size=*/2); + }; + if (parse_array(inner_parser)) { + return failure(); + } + const int64_t size = static_cast(values.size()); + // values should be filled with the Nx2 padding values. + auto ty = RankedTensorType::get({size / 2, 2}, + parser.getBuilder().getIntegerType(64)); + padding = DenseIntElementsAttr::get(ty, values); + } else { + // Parse 1D array of integers. + if (parse_array(int64_parser)) { + return failure(); + } + const int64_t size = static_cast(values.size()); + if (attribute_name == "reverse") { + auto ty = RankedTensorType::get({size}, + parser.getBuilder().getIntegerType(1)); + auto bool_vector = llvm::to_vector<4>( + llvm::map_range(values, [](int64_t v) { return v != 0; })); + window_reversal = DenseElementsAttr::get(ty, bool_vector); + } else { + auto attr = parser.getBuilder().getI64TensorAttr(values); + + if (attribute_name == "stride") { + window_strides = attr; + } else if (attribute_name == "lhs_dilate") { + lhs_dilation = attr; + } else if (attribute_name == "rhs_dilate") { + rhs_dilation = attr; + } else { + llvm_unreachable("Unexpected attribute name"); + } + } + } + // continue parsing if there is a comma at the end. + if (parser.parseOptionalComma().failed()) break; + } + return success(); +} + } // namespace hlo } // namespace mlir diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 5456f7e..72be3a0 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -363,6 +363,11 @@ LogicalResult WhileOp::moveOutOfLoop(ArrayRef ops) { return success(); } +// suppress warning. + +using mlir::hlo::parseWindowAttributes; +using mlir::hlo::printWindowAttributes; + } // namespace lmhlo } // namespace mlir diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 704c9b8..467c4b6 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -519,10 +519,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) %c0 = constant 0 : index // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32> // CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]]) - // CHECK-SAME: padding = dense<[ - // CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> - // CHECK-SAME: rhs_dilation = dense<[1, 2]> - // CHECK-SAME: window_strides = dense<[2, 1]> + // CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} %out = "mhlo.convolution"(%filter, %input) { batch_group_count = 1 : i64, dimension_numbers = { diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index b9fecfd..9bd1308 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -198,13 +198,12 @@ func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: // CHECK-LABEL: func @convolution // CHECK: lmhlo.convolution // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] +// CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { lmhlo.convolution(%arg0, %arg1, %arg2) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] - { batch_group_count = 1 : i64, feature_group_count = 1 : i64, - padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, - rhs_dilation = dense<[1, 2]> : tensor<2xi64>, - window_strides = dense<[2, 1]> : tensor<2xi64>} + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return } @@ -292,6 +291,48 @@ func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: return } +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+3{{Expected array with2 elements, got 3 elements instead}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [2, 1], pad = [[0, 1, 2], [0, 1]], rhs_dilate = [1, 2]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return +} + +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+3{{Unexpected keyword stide}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stide = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return +} +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+3{{expected integer value}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [2, b], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return +} +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+3{{Unexpected keyword stride}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2], stride=[2,1]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return +} + // ----- // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {