From e260aa771ca1888606d5b0e61b3805661fad8f4b Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 12 May 2021 08:51:40 -0700 Subject: [PATCH] [HLO] Add custom print/parse for convolution dimension numbers (in LMHLO) PiperOrigin-RevId: 373379227 --- .../Dialect/mhlo/IR/hlo_ops_base_structs.h | 14 ++ include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 6 + lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc | 205 ++++++++++++++++++ tests/hlo-legalize-to-lhlo.mlir | 2 +- tests/lhlo_ops.mlir | 122 +++++++++++ 5 files changed, 348 insertions(+), 1 deletion(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h index 70247d7..052770d 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h @@ -21,10 +21,24 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Identifier.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" // Order matters, this .inc header is not self-contained, and relies on the // #includes above. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc" +namespace mlir { +namespace mhlo { + +// Custom printer and parser for struct attributes. +void printConvolutionDimensions(OpAsmPrinter &p, Operation *op, + ConvDimensionNumbers dnums); +ParseResult parseConvolutionDimensions(OpAsmParser &parser, + ConvDimensionNumbers &dnums); + +} // namespace mhlo +} // namespace mlir + #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index c99977c..6e847c2 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -865,6 +865,12 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []> { [](bool v) { return v; }); } }]; + + let assemblyFormat = [{ + `(`operands`)` + `dim_numbers` `=` custom($dimension_numbers) + attr-dict `:` functional-type(operands, results) + }]; } def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]> { diff --git a/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc b/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc index 90da125..248180e 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc @@ -15,4 +15,209 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include +#include + #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace mlir { +namespace mhlo { + +namespace { +enum NonSpatialDim : int64_t { + IOBatch = -1, // Input or output batch dimension + IOFeature = -2, // Input or output feature dimension + KIFeature = -3, // Kernel input feature dimension + KOFeature = -4, // Kernel output feature dimensions. +}; + +char NonSpatialDimToString(NonSpatialDim dim) { + switch (dim) { + case IOBatch: + return 'b'; + case IOFeature: + return 'f'; + case KIFeature: + return 'i'; + case KOFeature: + return 'o'; + } +} +} // namespace + +// Custom printer and parser for struct attributes. +void printConvolutionDimensions(OpAsmPrinter &p, Operation * /*op*/, + ConvDimensionNumbers dnums) { + auto print_dim = + [&p](DenseIntElementsAttr spatial_dims, + ArrayRef> non_spatial_dims) { + llvm::SmallVector dims(non_spatial_dims.size() + + spatial_dims.size()); + // Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0) + // spatial dimension index. + for (const std::pair &non_spatial_dim : + non_spatial_dims) { + dims[non_spatial_dim.first.getInt()] = non_spatial_dim.second; + } + for (auto spatial_dim : + llvm::enumerate(spatial_dims.getValues())) { + dims[spatial_dim.value()] = static_cast(spatial_dim.index()); + } + + // Each dimension numbers will be printed as a comma separated list + // surrounded by square brackets, e.g., [b, 0, 1, 2, f] + p << '['; + llvm::interleaveComma(dims, p, [&](int64_t dim) { + if (dim >= 0) { + p << dim; + } else { + p << NonSpatialDimToString(static_cast(dim)); + } + }); + p << ']'; + }; + + print_dim(dnums.input_spatial_dimensions(), + {{dnums.input_batch_dimension(), IOBatch}, + {dnums.input_feature_dimension(), IOFeature}}); + p << "x"; + print_dim(dnums.kernel_spatial_dimensions(), + {{dnums.kernel_input_feature_dimension(), KIFeature}, + {dnums.kernel_output_feature_dimension(), KOFeature}}); + p << "->"; + print_dim(dnums.output_spatial_dimensions(), + {{dnums.output_batch_dimension(), IOBatch}, + {dnums.output_feature_dimension(), IOFeature}}); +} + +ParseResult parseConvolutionDimensions(OpAsmParser &parser, + ConvDimensionNumbers &dnums) { + // Parsing a single set of dim numbers gives the spatial dimensions as a + // single DenseIntElementsAttr and a list of non-spatial dimensions as + // IntegerAttrs (indexed by the NonSpatialDim enum). + using parse_dim_result_t = std::pair< + DenseIntElementsAttr, + std::unordered_map>>; + + // Note that the allowed_non_spatial_dims is a set (as opposed to unordered + // set) because its used to print a list of allowed non spatial dims in the + // error messages, so making it a set keeps the error messages deterministic. + auto parse_dims = + [&](std::set> allowed_non_spatial_dims, + parse_dim_result_t &parsed_dims) -> ParseResult { + // Parse the starting [ + if (parser.parseLSquare()) { + return failure(); + } + llvm::SmallVector spatial_dims; + std::unordered_map> + non_spatial_dims; + + int64_t index = 0; + do { + int64_t spatial_dim; + OptionalParseResult parseResult = + parser.parseOptionalInteger(spatial_dim); + if (parseResult.hasValue()) { + if (parseResult.getValue().failed()) { + return failure(); + } + // We were successful in parsing an integer. Add its index to the + // spatial dims. + spatial_dims.push_back(index); + } else { + // We did not parse an integer. We expect a keyword token. + StringRef keyword; + if (parser.parseKeyword(&keyword)) { + return failure(); + } + if (keyword.size() != 1 || allowed_non_spatial_dims.empty()) { + return parser.emitError(parser.getCurrentLocation(), + "Unexpected keyword ") + << keyword; + } + // Check if the keyword matches one of the allowed non-spatial dims. + // If so, add it to the non_spatial dims and remove it from the + // allowed set so that it won't be allowed again. + bool is_allowed = false; + for (NonSpatialDim allowed : allowed_non_spatial_dims) { + if (keyword[0] == NonSpatialDimToString(allowed)) { + non_spatial_dims.insert( + {allowed, parser.getBuilder().getI64IntegerAttr(index)}); + allowed_non_spatial_dims.erase(allowed); + is_allowed = true; + break; + } + } + + if (!is_allowed) { + mlir::InFlightDiagnostic diag = parser.emitError( + parser.getCurrentLocation(), "Unexpected dimension "); + diag << keyword << ", expecting "; + llvm::interleaveComma( + allowed_non_spatial_dims, diag, + [&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); }); + return diag; + } + } + index++; + } while (parser.parseOptionalComma().succeeded()); + + // Make sure all expected non-spatial dimensions are parsed. + if (!allowed_non_spatial_dims.empty()) { + mlir::InFlightDiagnostic diag = + parser.emitError(parser.getCurrentLocation(), "Expected dimensions "); + llvm::interleaveComma( + allowed_non_spatial_dims, diag, + [&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); }); + diag << " not specified"; + return diag; + } + + // parse ending ] + if (parser.parseRSquare()) { + return failure(); + } + + parsed_dims = std::make_pair( + parser.getBuilder().getI64TensorAttr(spatial_dims), non_spatial_dims); + return success(); + }; + + parse_dim_result_t parsed_dims; + if (parse_dims({IOBatch, IOFeature}, parsed_dims)) { + return failure(); + } + DenseIntElementsAttr input_spatial_dimensions = parsed_dims.first; + IntegerAttr input_batch_dimension = parsed_dims.second[IOBatch]; + IntegerAttr input_feature_dimension = parsed_dims.second[IOFeature]; + if (parser.parseKeyword("x")) return failure(); + if (parse_dims({KIFeature, KOFeature}, parsed_dims)) { + return failure(); + } + DenseIntElementsAttr kernel_spatial_dimensions = parsed_dims.first; + IntegerAttr kernel_input_feature_dimension = parsed_dims.second[KIFeature]; + IntegerAttr kernel_output_feature_dimension = parsed_dims.second[KOFeature]; + if (parser.parseArrow()) { + return failure(); + } + if (parse_dims({IOBatch, IOFeature}, parsed_dims)) { + return failure(); + } + DenseIntElementsAttr output_spatial_dimensions = parsed_dims.first; + IntegerAttr output_batch_dimension = parsed_dims.second[IOBatch]; + IntegerAttr output_feature_dimension = parsed_dims.second[IOFeature]; + dnums = ConvDimensionNumbers::get( + input_batch_dimension, input_feature_dimension, input_spatial_dimensions, + kernel_input_feature_dimension, kernel_output_feature_dimension, + kernel_spatial_dimensions, output_batch_dimension, + output_feature_dimension, output_spatial_dimensions, + parser.getBuilder().getContext()); + + return success(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index 554ce1e..704c9b8 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -518,7 +518,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { %c0 = constant 0 : index // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32> - // CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) + // CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]]) // CHECK-SAME: padding = dense<[ // CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> // CHECK-SAME: rhs_dilation = dense<[1, 2]> diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 2a504dd..b9fecfd 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -171,6 +171,128 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () { // ----- +// CHECK-LABEL: func @convolution +// CHECK: lmhlo.convolution +// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + "lmhlo.convolution"(%arg0, %arg1, %arg2) {batch_group_count = 1 : i64, + dimension_numbers = {input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, + feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64>} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @convolution +// CHECK: lmhlo.convolution +// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + { batch_group_count = 1 : i64, feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64>} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return +} + +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+2{{Unexpected dimension c, expecting b, f}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [c, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + { batch_group_count = 1 : i64, feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64>} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return + return +} + +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+2{{Unexpected dimension b, expecting i, o}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [b, 0, 1, f]x[0, 1, b, o]->[b, 0, 1, f] + { batch_group_count = 1 : i64, feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64>} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return + return +} + +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+2{{Unexpected dimension i, expecting o}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [b, 0, 1, f]x[0, 1, i, i]->[b, 0, 1, f] + { batch_group_count = 1 : i64, feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64>} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return + return +} + +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+2{{Expected dimensions f not specified}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [b, 0, 1]x[0, 1, i, o]->[b, 0, 1, f] + { batch_group_count = 1 : i64, feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64>} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return + return +} + +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+2{{Unexpected keyword b}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o, b]->[b, 0, 1, f] + { batch_group_count = 1 : i64, feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64>} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return + return +} + +// ----- + +func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { + // expected-error@+2{{expected '['}} + lmhlo.convolution(%arg0, %arg1, %arg2) + dim_numbers = {b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + { batch_group_count = 1 : i64, feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64>} + : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return + return +} + +// ----- // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()