[HLO] Add custom print/parse for window attributes of convolutions (in LMHLO)
PiperOrigin-RevId: 373807616
This commit is contained in:
parent
e4caaaf921
commit
a361253e4f
1
BUILD
1
BUILD
|
@ -328,6 +328,7 @@ cc_library(
|
||||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"],
|
hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"],
|
||||||
includes = ["include"],
|
includes = ["include"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
],
|
],
|
||||||
|
|
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||||
// This file defines functionality shared between chlo/mhlo/lhlo dialects.
|
// This file defines functionality shared between chlo/mhlo/lhlo dialects.
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -26,7 +28,22 @@ namespace hlo {
|
||||||
|
|
||||||
// Verifies the source target pairs attached to collective permute.
|
// Verifies the source target pairs attached to collective permute.
|
||||||
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
||||||
Operation *op, DenseIntElementsAttr attr);
|
Operation* op, DenseIntElementsAttr attr);
|
||||||
|
|
||||||
|
// Custom formatting for convolution window attributes.
|
||||||
|
void printWindowAttributes(OpAsmPrinter& p, Operation* op,
|
||||||
|
llvm::Optional<DenseIntElementsAttr> window_strides,
|
||||||
|
llvm::Optional<DenseIntElementsAttr> padding,
|
||||||
|
llvm::Optional<DenseIntElementsAttr> lhs_dilation,
|
||||||
|
llvm::Optional<DenseIntElementsAttr> rhs_dilation,
|
||||||
|
llvm::Optional<DenseElementsAttr> window_reversal);
|
||||||
|
|
||||||
|
ParseResult parseWindowAttributes(OpAsmParser& parser,
|
||||||
|
DenseIntElementsAttr& window_strides,
|
||||||
|
DenseIntElementsAttr& padding,
|
||||||
|
DenseIntElementsAttr& lhs_dilation,
|
||||||
|
DenseIntElementsAttr& rhs_dilation,
|
||||||
|
DenseElementsAttr& window_reversal);
|
||||||
|
|
||||||
} // namespace hlo
|
} // namespace hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -868,7 +868,10 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []> {
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(`operands`)`
|
`(`operands`)`
|
||||||
`dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers)
|
`dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,`
|
||||||
|
`window` `=` `{` custom<WindowAttributes>($window_strides, $padding,
|
||||||
|
$lhs_dilation, $rhs_dilation,
|
||||||
|
$window_reversal) `}`
|
||||||
attr-dict `:` functional-type(operands, results)
|
attr-dict `:` functional-type(operands, results)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,11 +15,14 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace hlo {
|
namespace hlo {
|
||||||
|
|
||||||
// Verifies the source target pairs attached to collective permute.
|
// Verifies the source target pairs attached to collective permute.
|
||||||
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
||||||
Operation *op, DenseIntElementsAttr attr) {
|
Operation *op, DenseIntElementsAttr attr) {
|
||||||
|
@ -50,5 +53,164 @@ LogicalResult VerifyCollectivePermuteSourceTargetPairs(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Custom formatting for convolution window attributes.
|
||||||
|
void printWindowAttribute(OpAsmPrinter &p, DenseElementsAttr attribute) {
|
||||||
|
if (attribute.getType().getElementType().isInteger(/*width=*/1)) {
|
||||||
|
// boolean attribute.
|
||||||
|
llvm::interleaveComma(attribute.getBoolValues(), p,
|
||||||
|
[&](bool b) { p << (b ? 1 : 0); });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (attribute.getType().getRank() == 2) {
|
||||||
|
// Padding is Nx2 attribute.
|
||||||
|
auto it = attribute.getValues<int64_t>().begin();
|
||||||
|
std::vector<std::pair<int64_t, int64_t>> values(attribute.getNumElements() /
|
||||||
|
2);
|
||||||
|
for (auto &item : values) {
|
||||||
|
int64_t first = *it;
|
||||||
|
++it;
|
||||||
|
int64_t second = *it;
|
||||||
|
++it;
|
||||||
|
item = {first, second};
|
||||||
|
}
|
||||||
|
llvm::interleaveComma(
|
||||||
|
values, p, [&](const std::pair<int64_t, int64_t> pair) {
|
||||||
|
p << '[' << pair.first << ", " << pair.second << ']';
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
llvm::interleaveComma(attribute.getValues<int64_t>(), p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void printWindowAttributes(OpAsmPrinter &p, Operation *op,
|
||||||
|
llvm::Optional<DenseIntElementsAttr> window_strides,
|
||||||
|
llvm::Optional<DenseIntElementsAttr> padding,
|
||||||
|
llvm::Optional<DenseIntElementsAttr> lhs_dilation,
|
||||||
|
llvm::Optional<DenseIntElementsAttr> rhs_dilation,
|
||||||
|
llvm::Optional<DenseElementsAttr> window_reversal) {
|
||||||
|
using pair_t = std::pair<DenseElementsAttr, StringRef>;
|
||||||
|
std::array<pair_t, 5> printed_attributes = {{
|
||||||
|
{window_strides ? *window_strides : nullptr, "stride"},
|
||||||
|
{padding ? *padding : nullptr, "pad"},
|
||||||
|
{lhs_dilation ? *lhs_dilation : nullptr, "lhs_dilate"},
|
||||||
|
{rhs_dilation ? *rhs_dilation : nullptr, "rhs_dilate"},
|
||||||
|
{window_reversal ? *window_reversal : nullptr, "reverse"},
|
||||||
|
}};
|
||||||
|
|
||||||
|
// Do not print attributes that do no exist.
|
||||||
|
auto non_null_attributes = llvm::make_filter_range(
|
||||||
|
printed_attributes,
|
||||||
|
[](const pair_t &a) { return static_cast<bool>(a.first); });
|
||||||
|
|
||||||
|
llvm::interleaveComma(non_null_attributes, p, [&](const pair_t &a) {
|
||||||
|
p << a.second << " = [";
|
||||||
|
printWindowAttribute(p, a.first);
|
||||||
|
p << "]";
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult parseWindowAttributes(OpAsmParser &parser,
|
||||||
|
DenseIntElementsAttr &window_strides,
|
||||||
|
DenseIntElementsAttr &padding,
|
||||||
|
DenseIntElementsAttr &lhs_dilation,
|
||||||
|
DenseIntElementsAttr &rhs_dilation,
|
||||||
|
DenseElementsAttr &window_reversal) {
|
||||||
|
StringRef attribute_name;
|
||||||
|
|
||||||
|
// Helper to parse an array of the form [ e0, e1, .. ]
|
||||||
|
auto parse_array = [&](std::function<ParseResult(void)> parse_element,
|
||||||
|
llvm::Optional<size_t> expected_size =
|
||||||
|
llvm::None) -> ParseResult {
|
||||||
|
if (parser.parseLSquare()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
size_t size = 0;
|
||||||
|
do {
|
||||||
|
if (parse_element()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
size++;
|
||||||
|
} while (parser.parseOptionalComma().succeeded());
|
||||||
|
if (parser.parseRSquare()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (expected_size && size != *expected_size) {
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"Expected array with")
|
||||||
|
<< *expected_size << " elements, got " << size
|
||||||
|
<< " elements instead";
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
llvm::StringSet<> allowed_attribute_names{
|
||||||
|
{"stride", "pad", "lhs_dilate", "rhs_dilate", "reverse"}};
|
||||||
|
|
||||||
|
while (parser.parseOptionalKeyword(&attribute_name).succeeded()) {
|
||||||
|
// Verify that the attribute name is valid and erase it.
|
||||||
|
if (!allowed_attribute_names.erase(attribute_name)) {
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"Unexpected keyword ")
|
||||||
|
<< attribute_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parser.parseEqual()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the attribute value. We need to support either 1D and Nx2 array of
|
||||||
|
// integers to parse.
|
||||||
|
llvm::SmallVector<int64_t> values;
|
||||||
|
auto int64_parser = [&]() {
|
||||||
|
return parser.parseInteger(values.emplace_back(0));
|
||||||
|
};
|
||||||
|
|
||||||
|
if (attribute_name == "pad") {
|
||||||
|
// Parse a 2D array of integers.
|
||||||
|
auto inner_parser = [&]() {
|
||||||
|
return parse_array(int64_parser, /*expected_size=*/2);
|
||||||
|
};
|
||||||
|
if (parse_array(inner_parser)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
const int64_t size = static_cast<int64_t>(values.size());
|
||||||
|
// values should be filled with the Nx2 padding values.
|
||||||
|
auto ty = RankedTensorType::get({size / 2, 2},
|
||||||
|
parser.getBuilder().getIntegerType(64));
|
||||||
|
padding = DenseIntElementsAttr::get(ty, values);
|
||||||
|
} else {
|
||||||
|
// Parse 1D array of integers.
|
||||||
|
if (parse_array(int64_parser)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
const int64_t size = static_cast<int64_t>(values.size());
|
||||||
|
if (attribute_name == "reverse") {
|
||||||
|
auto ty = RankedTensorType::get({size},
|
||||||
|
parser.getBuilder().getIntegerType(1));
|
||||||
|
auto bool_vector = llvm::to_vector<4>(
|
||||||
|
llvm::map_range(values, [](int64_t v) { return v != 0; }));
|
||||||
|
window_reversal = DenseElementsAttr::get(ty, bool_vector);
|
||||||
|
} else {
|
||||||
|
auto attr = parser.getBuilder().getI64TensorAttr(values);
|
||||||
|
|
||||||
|
if (attribute_name == "stride") {
|
||||||
|
window_strides = attr;
|
||||||
|
} else if (attribute_name == "lhs_dilate") {
|
||||||
|
lhs_dilation = attr;
|
||||||
|
} else if (attribute_name == "rhs_dilate") {
|
||||||
|
rhs_dilation = attr;
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("Unexpected attribute name");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// continue parsing if there is a comma at the end.
|
||||||
|
if (parser.parseOptionalComma().failed()) break;
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace hlo
|
} // namespace hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -363,6 +363,11 @@ LogicalResult WhileOp::moveOutOfLoop(ArrayRef<Operation*> ops) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// suppress warning.
|
||||||
|
|
||||||
|
using mlir::hlo::parseWindowAttributes;
|
||||||
|
using mlir::hlo::printWindowAttributes;
|
||||||
|
|
||||||
} // namespace lmhlo
|
} // namespace lmhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -519,10 +519,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
|
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
|
||||||
// CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]])
|
// CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||||
// CHECK-SAME: padding = dense<[
|
// CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
|
||||||
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
|
||||||
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
|
|
||||||
// CHECK-SAME: window_strides = dense<[2, 1]>
|
|
||||||
%out = "mhlo.convolution"(%filter, %input) {
|
%out = "mhlo.convolution"(%filter, %input) {
|
||||||
batch_group_count = 1 : i64,
|
batch_group_count = 1 : i64,
|
||||||
dimension_numbers = {
|
dimension_numbers = {
|
||||||
|
|
|
@ -198,13 +198,12 @@ func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2:
|
||||||
// CHECK-LABEL: func @convolution
|
// CHECK-LABEL: func @convolution
|
||||||
// CHECK: lmhlo.convolution
|
// CHECK: lmhlo.convolution
|
||||||
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||||
|
// CHECK-SAME{LITERAL}: window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
|
||||||
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
lmhlo.convolution(%arg0, %arg1, %arg2)
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
|
||||||
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
|
||||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
|
||||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
|
||||||
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
|
||||||
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,6 +291,48 @@ func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+3{{Expected array with2 elements, got 3 elements instead}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
|
||||||
|
window = {stride = [2, 1], pad = [[0, 1, 2], [0, 1]], rhs_dilate = [1, 2]}
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+3{{Unexpected keyword stide}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
|
||||||
|
window = {stide = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
}
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+3{{expected integer value}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
|
||||||
|
window = {stride = [2, b], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]}
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
}
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+3{{Unexpected keyword stride}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
|
||||||
|
window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2], stride=[2,1]}
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func @exp
|
// CHECK-LABEL: func @exp
|
||||||
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
Loading…
Reference in New Issue