[HLO] Add custom print/parse for window attributes of convolutions (in LMHLO)

PiperOrigin-RevId: 373807616
This commit is contained in:
Rahul Joshi 2021-05-14 09:46:42 -07:00 committed by TensorFlow MLIR Team
parent e4caaaf921
commit a361253e4f
7 changed files with 238 additions and 12 deletions

1
BUILD
View File

@ -328,6 +328,7 @@ cc_library(
hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"], hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"],
includes = ["include"], includes = ["include"],
deps = [ deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
], ],

View File

@ -19,6 +19,8 @@ limitations under the License.
// This file defines functionality shared between chlo/mhlo/lhlo dialects. // This file defines functionality shared between chlo/mhlo/lhlo dialects.
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
namespace mlir { namespace mlir {
@ -26,7 +28,22 @@ namespace hlo {
// Verifies the source target pairs attached to collective permute. // Verifies the source target pairs attached to collective permute.
LogicalResult VerifyCollectivePermuteSourceTargetPairs( LogicalResult VerifyCollectivePermuteSourceTargetPairs(
Operation *op, DenseIntElementsAttr attr); Operation* op, DenseIntElementsAttr attr);
// Custom formatting for convolution window attributes.
void printWindowAttributes(OpAsmPrinter& p, Operation* op,
llvm::Optional<DenseIntElementsAttr> window_strides,
llvm::Optional<DenseIntElementsAttr> padding,
llvm::Optional<DenseIntElementsAttr> lhs_dilation,
llvm::Optional<DenseIntElementsAttr> rhs_dilation,
llvm::Optional<DenseElementsAttr> window_reversal);
ParseResult parseWindowAttributes(OpAsmParser& parser,
DenseIntElementsAttr& window_strides,
DenseIntElementsAttr& padding,
DenseIntElementsAttr& lhs_dilation,
DenseIntElementsAttr& rhs_dilation,
DenseElementsAttr& window_reversal);
} // namespace hlo } // namespace hlo
} // namespace mlir } // namespace mlir

View File

@ -868,7 +868,10 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []> {
let assemblyFormat = [{ let assemblyFormat = [{
`(`operands`)` `(`operands`)`
`dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,`
`window` `=` `{` custom<WindowAttributes>($window_strides, $padding,
$lhs_dilation, $rhs_dilation,
$window_reversal) `}`
attr-dict `:` functional-type(operands, results) attr-dict `:` functional-type(operands, results)
}]; }];
} }

View File

@ -15,11 +15,14 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
namespace mlir { namespace mlir {
namespace hlo { namespace hlo {
// Verifies the source target pairs attached to collective permute. // Verifies the source target pairs attached to collective permute.
LogicalResult VerifyCollectivePermuteSourceTargetPairs( LogicalResult VerifyCollectivePermuteSourceTargetPairs(
Operation *op, DenseIntElementsAttr attr) { Operation *op, DenseIntElementsAttr attr) {
@ -50,5 +53,164 @@ LogicalResult VerifyCollectivePermuteSourceTargetPairs(
return success(); return success();
} }
namespace {
// Custom formatting for convolution window attributes.
void printWindowAttribute(OpAsmPrinter &p, DenseElementsAttr attribute) {
if (attribute.getType().getElementType().isInteger(/*width=*/1)) {
// boolean attribute.
llvm::interleaveComma(attribute.getBoolValues(), p,
[&](bool b) { p << (b ? 1 : 0); });
return;
}
if (attribute.getType().getRank() == 2) {
// Padding is Nx2 attribute.
auto it = attribute.getValues<int64_t>().begin();
std::vector<std::pair<int64_t, int64_t>> values(attribute.getNumElements() /
2);
for (auto &item : values) {
int64_t first = *it;
++it;
int64_t second = *it;
++it;
item = {first, second};
}
llvm::interleaveComma(
values, p, [&](const std::pair<int64_t, int64_t> pair) {
p << '[' << pair.first << ", " << pair.second << ']';
});
} else {
llvm::interleaveComma(attribute.getValues<int64_t>(), p);
}
}
} // namespace
void printWindowAttributes(OpAsmPrinter &p, Operation *op,
llvm::Optional<DenseIntElementsAttr> window_strides,
llvm::Optional<DenseIntElementsAttr> padding,
llvm::Optional<DenseIntElementsAttr> lhs_dilation,
llvm::Optional<DenseIntElementsAttr> rhs_dilation,
llvm::Optional<DenseElementsAttr> window_reversal) {
using pair_t = std::pair<DenseElementsAttr, StringRef>;
std::array<pair_t, 5> printed_attributes = {{
{window_strides ? *window_strides : nullptr, "stride"},
{padding ? *padding : nullptr, "pad"},
{lhs_dilation ? *lhs_dilation : nullptr, "lhs_dilate"},
{rhs_dilation ? *rhs_dilation : nullptr, "rhs_dilate"},
{window_reversal ? *window_reversal : nullptr, "reverse"},
}};
// Do not print attributes that do no exist.
auto non_null_attributes = llvm::make_filter_range(
printed_attributes,
[](const pair_t &a) { return static_cast<bool>(a.first); });
llvm::interleaveComma(non_null_attributes, p, [&](const pair_t &a) {
p << a.second << " = [";
printWindowAttribute(p, a.first);
p << "]";
});
}
ParseResult parseWindowAttributes(OpAsmParser &parser,
DenseIntElementsAttr &window_strides,
DenseIntElementsAttr &padding,
DenseIntElementsAttr &lhs_dilation,
DenseIntElementsAttr &rhs_dilation,
DenseElementsAttr &window_reversal) {
StringRef attribute_name;
// Helper to parse an array of the form [ e0, e1, .. ]
auto parse_array = [&](std::function<ParseResult(void)> parse_element,
llvm::Optional<size_t> expected_size =
llvm::None) -> ParseResult {
if (parser.parseLSquare()) {
return failure();
}
size_t size = 0;
do {
if (parse_element()) {
return failure();
}
size++;
} while (parser.parseOptionalComma().succeeded());
if (parser.parseRSquare()) {
return failure();
}
if (expected_size && size != *expected_size) {
return parser.emitError(parser.getCurrentLocation(),
"Expected array with")
<< *expected_size << " elements, got " << size
<< " elements instead";
}
return success();
};
llvm::StringSet<> allowed_attribute_names{
{"stride", "pad", "lhs_dilate", "rhs_dilate", "reverse"}};
while (parser.parseOptionalKeyword(&attribute_name).succeeded()) {
// Verify that the attribute name is valid and erase it.
if (!allowed_attribute_names.erase(attribute_name)) {
return parser.emitError(parser.getCurrentLocation(),
"Unexpected keyword ")
<< attribute_name;
}
if (parser.parseEqual()) {
return failure();
}
// parse the attribute value. We need to support either 1D and Nx2 array of
// integers to parse.
llvm::SmallVector<int64_t> values;
auto int64_parser = [&]() {
return parser.parseInteger(values.emplace_back(0));
};
if (attribute_name == "pad") {
// Parse a 2D array of integers.
auto inner_parser = [&]() {
return parse_array(int64_parser, /*expected_size=*/2);
};
if (parse_array(inner_parser)) {
return failure();
}
const int64_t size = static_cast<int64_t>(values.size());
// values should be filled with the Nx2 padding values.
auto ty = RankedTensorType::get({size / 2, 2},
parser.getBuilder().getIntegerType(64));
padding = DenseIntElementsAttr::get(ty, values);
} else {
// Parse 1D array of integers.
if (parse_array(int64_parser)) {
return failure();
}
const int64_t size = static_cast<int64_t>(values.size());
if (attribute_name == "reverse") {
auto ty = RankedTensorType::get({size},
parser.getBuilder().getIntegerType(1));
auto bool_vector = llvm::to_vector<4>(
llvm::map_range(values, [](int64_t v) { return v != 0; }));
window_reversal = DenseElementsAttr::get(ty, bool_vector);
} else {
auto attr = parser.getBuilder().getI64TensorAttr(values);
if (attribute_name == "stride") {
window_strides = attr;
} else if (attribute_name == "lhs_dilate") {
lhs_dilation = attr;
} else if (attribute_name == "rhs_dilate") {
rhs_dilation = attr;
} else {
llvm_unreachable("Unexpected attribute name");
}
}
}
// continue parsing if there is a comma at the end.
if (parser.parseOptionalComma().failed()) break;
}
return success();
}
} // namespace hlo } // namespace hlo
} // namespace mlir } // namespace mlir

View File

@ -363,6 +363,11 @@ LogicalResult WhileOp::moveOutOfLoop(ArrayRef<Operation*> ops) {
return success(); return success();
} }
// suppress warning.
using mlir::hlo::parseWindowAttributes;
using mlir::hlo::printWindowAttributes;
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -519,10 +519,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
%c0 = constant 0 : index %c0 = constant 0 : index
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32> // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
// CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]]) // CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME: padding = dense<[ // CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
// CHECK-SAME: window_strides = dense<[2, 1]>
%out = "mhlo.convolution"(%filter, %input) { %out = "mhlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64, batch_group_count = 1 : i64,
dimension_numbers = { dimension_numbers = {

View File

@ -198,13 +198,12 @@ func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2:
// CHECK-LABEL: func @convolution // CHECK-LABEL: func @convolution
// CHECK: lmhlo.convolution // CHECK: lmhlo.convolution
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
// CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) { func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
lmhlo.convolution(%arg0, %arg1, %arg2) lmhlo.convolution(%arg0, %arg1, %arg2)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64, window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, { batch_group_count = 1 : i64, feature_group_count = 1 : i64}
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>}
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
} }
@ -292,6 +291,48 @@ func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2:
return return
} }
// -----
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
// expected-error@+3{{Expected array with2 elements, got 3 elements instead}}
lmhlo.convolution(%arg0, %arg1, %arg2)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [2, 1], pad = [[0, 1, 2], [0, 1]], rhs_dilate = [1, 2]}
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
}
// -----
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
// expected-error@+3{{Unexpected keyword stide}}
lmhlo.convolution(%arg0, %arg1, %arg2)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stide = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
}
// -----
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
// expected-error@+3{{expected integer value}}
lmhlo.convolution(%arg0, %arg1, %arg2)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [2, b], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
}
// -----
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
// expected-error@+3{{Unexpected keyword stride}}
lmhlo.convolution(%arg0, %arg1, %arg2)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2], stride=[2,1]}
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
}
// ----- // -----
// CHECK-LABEL: func @exp // CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {