[HLO] Add custom print/parse for convolution dimension numbers (in LMHLO)
PiperOrigin-RevId: 373379227
This commit is contained in:
parent
30779f0c2f
commit
e260aa771c
|
@ -21,10 +21,24 @@ limitations under the License.
|
|||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
// Order matters, this .inc header is not self-contained, and relies on the
|
||||
// #includes above.
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
// Custom printer and parser for struct attributes.
|
||||
void printConvolutionDimensions(OpAsmPrinter &p, Operation *op,
|
||||
ConvDimensionNumbers dnums);
|
||||
ParseResult parseConvolutionDimensions(OpAsmParser &parser,
|
||||
ConvDimensionNumbers &dnums);
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
|
||||
|
|
|
@ -865,6 +865,12 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []> {
|
|||
[](bool v) { return v; });
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(`operands`)`
|
||||
`dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers)
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]> {
|
||||
|
|
|
@ -15,4 +15,209 @@ limitations under the License.
|
|||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
enum NonSpatialDim : int64_t {
|
||||
IOBatch = -1, // Input or output batch dimension
|
||||
IOFeature = -2, // Input or output feature dimension
|
||||
KIFeature = -3, // Kernel input feature dimension
|
||||
KOFeature = -4, // Kernel output feature dimensions.
|
||||
};
|
||||
|
||||
char NonSpatialDimToString(NonSpatialDim dim) {
|
||||
switch (dim) {
|
||||
case IOBatch:
|
||||
return 'b';
|
||||
case IOFeature:
|
||||
return 'f';
|
||||
case KIFeature:
|
||||
return 'i';
|
||||
case KOFeature:
|
||||
return 'o';
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Custom printer and parser for struct attributes.
|
||||
void printConvolutionDimensions(OpAsmPrinter &p, Operation * /*op*/,
|
||||
ConvDimensionNumbers dnums) {
|
||||
auto print_dim =
|
||||
[&p](DenseIntElementsAttr spatial_dims,
|
||||
ArrayRef<std::pair<IntegerAttr, NonSpatialDim>> non_spatial_dims) {
|
||||
llvm::SmallVector<int64_t> dims(non_spatial_dims.size() +
|
||||
spatial_dims.size());
|
||||
// Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0)
|
||||
// spatial dimension index.
|
||||
for (const std::pair<IntegerAttr, NonSpatialDim> &non_spatial_dim :
|
||||
non_spatial_dims) {
|
||||
dims[non_spatial_dim.first.getInt()] = non_spatial_dim.second;
|
||||
}
|
||||
for (auto spatial_dim :
|
||||
llvm::enumerate(spatial_dims.getValues<int64_t>())) {
|
||||
dims[spatial_dim.value()] = static_cast<int64_t>(spatial_dim.index());
|
||||
}
|
||||
|
||||
// Each dimension numbers will be printed as a comma separated list
|
||||
// surrounded by square brackets, e.g., [b, 0, 1, 2, f]
|
||||
p << '[';
|
||||
llvm::interleaveComma(dims, p, [&](int64_t dim) {
|
||||
if (dim >= 0) {
|
||||
p << dim;
|
||||
} else {
|
||||
p << NonSpatialDimToString(static_cast<NonSpatialDim>(dim));
|
||||
}
|
||||
});
|
||||
p << ']';
|
||||
};
|
||||
|
||||
print_dim(dnums.input_spatial_dimensions(),
|
||||
{{dnums.input_batch_dimension(), IOBatch},
|
||||
{dnums.input_feature_dimension(), IOFeature}});
|
||||
p << "x";
|
||||
print_dim(dnums.kernel_spatial_dimensions(),
|
||||
{{dnums.kernel_input_feature_dimension(), KIFeature},
|
||||
{dnums.kernel_output_feature_dimension(), KOFeature}});
|
||||
p << "->";
|
||||
print_dim(dnums.output_spatial_dimensions(),
|
||||
{{dnums.output_batch_dimension(), IOBatch},
|
||||
{dnums.output_feature_dimension(), IOFeature}});
|
||||
}
|
||||
|
||||
ParseResult parseConvolutionDimensions(OpAsmParser &parser,
|
||||
ConvDimensionNumbers &dnums) {
|
||||
// Parsing a single set of dim numbers gives the spatial dimensions as a
|
||||
// single DenseIntElementsAttr and a list of non-spatial dimensions as
|
||||
// IntegerAttrs (indexed by the NonSpatialDim enum).
|
||||
using parse_dim_result_t = std::pair<
|
||||
DenseIntElementsAttr,
|
||||
std::unordered_map<NonSpatialDim, IntegerAttr, std::hash<int64_t>>>;
|
||||
|
||||
// Note that the allowed_non_spatial_dims is a set (as opposed to unordered
|
||||
// set) because its used to print a list of allowed non spatial dims in the
|
||||
// error messages, so making it a set keeps the error messages deterministic.
|
||||
auto parse_dims =
|
||||
[&](std::set<NonSpatialDim, std::greater<>> allowed_non_spatial_dims,
|
||||
parse_dim_result_t &parsed_dims) -> ParseResult {
|
||||
// Parse the starting [
|
||||
if (parser.parseLSquare()) {
|
||||
return failure();
|
||||
}
|
||||
llvm::SmallVector<int64_t> spatial_dims;
|
||||
std::unordered_map<NonSpatialDim, IntegerAttr, std::hash<int64_t>>
|
||||
non_spatial_dims;
|
||||
|
||||
int64_t index = 0;
|
||||
do {
|
||||
int64_t spatial_dim;
|
||||
OptionalParseResult parseResult =
|
||||
parser.parseOptionalInteger(spatial_dim);
|
||||
if (parseResult.hasValue()) {
|
||||
if (parseResult.getValue().failed()) {
|
||||
return failure();
|
||||
}
|
||||
// We were successful in parsing an integer. Add its index to the
|
||||
// spatial dims.
|
||||
spatial_dims.push_back(index);
|
||||
} else {
|
||||
// We did not parse an integer. We expect a keyword token.
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword)) {
|
||||
return failure();
|
||||
}
|
||||
if (keyword.size() != 1 || allowed_non_spatial_dims.empty()) {
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"Unexpected keyword ")
|
||||
<< keyword;
|
||||
}
|
||||
// Check if the keyword matches one of the allowed non-spatial dims.
|
||||
// If so, add it to the non_spatial dims and remove it from the
|
||||
// allowed set so that it won't be allowed again.
|
||||
bool is_allowed = false;
|
||||
for (NonSpatialDim allowed : allowed_non_spatial_dims) {
|
||||
if (keyword[0] == NonSpatialDimToString(allowed)) {
|
||||
non_spatial_dims.insert(
|
||||
{allowed, parser.getBuilder().getI64IntegerAttr(index)});
|
||||
allowed_non_spatial_dims.erase(allowed);
|
||||
is_allowed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_allowed) {
|
||||
mlir::InFlightDiagnostic diag = parser.emitError(
|
||||
parser.getCurrentLocation(), "Unexpected dimension ");
|
||||
diag << keyword << ", expecting ";
|
||||
llvm::interleaveComma(
|
||||
allowed_non_spatial_dims, diag,
|
||||
[&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
|
||||
return diag;
|
||||
}
|
||||
}
|
||||
index++;
|
||||
} while (parser.parseOptionalComma().succeeded());
|
||||
|
||||
// Make sure all expected non-spatial dimensions are parsed.
|
||||
if (!allowed_non_spatial_dims.empty()) {
|
||||
mlir::InFlightDiagnostic diag =
|
||||
parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
|
||||
llvm::interleaveComma(
|
||||
allowed_non_spatial_dims, diag,
|
||||
[&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
|
||||
diag << " not specified";
|
||||
return diag;
|
||||
}
|
||||
|
||||
// parse ending ]
|
||||
if (parser.parseRSquare()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
parsed_dims = std::make_pair(
|
||||
parser.getBuilder().getI64TensorAttr(spatial_dims), non_spatial_dims);
|
||||
return success();
|
||||
};
|
||||
|
||||
parse_dim_result_t parsed_dims;
|
||||
if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
|
||||
return failure();
|
||||
}
|
||||
DenseIntElementsAttr input_spatial_dimensions = parsed_dims.first;
|
||||
IntegerAttr input_batch_dimension = parsed_dims.second[IOBatch];
|
||||
IntegerAttr input_feature_dimension = parsed_dims.second[IOFeature];
|
||||
if (parser.parseKeyword("x")) return failure();
|
||||
if (parse_dims({KIFeature, KOFeature}, parsed_dims)) {
|
||||
return failure();
|
||||
}
|
||||
DenseIntElementsAttr kernel_spatial_dimensions = parsed_dims.first;
|
||||
IntegerAttr kernel_input_feature_dimension = parsed_dims.second[KIFeature];
|
||||
IntegerAttr kernel_output_feature_dimension = parsed_dims.second[KOFeature];
|
||||
if (parser.parseArrow()) {
|
||||
return failure();
|
||||
}
|
||||
if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
|
||||
return failure();
|
||||
}
|
||||
DenseIntElementsAttr output_spatial_dimensions = parsed_dims.first;
|
||||
IntegerAttr output_batch_dimension = parsed_dims.second[IOBatch];
|
||||
IntegerAttr output_feature_dimension = parsed_dims.second[IOFeature];
|
||||
dnums = ConvDimensionNumbers::get(
|
||||
input_batch_dimension, input_feature_dimension, input_spatial_dimensions,
|
||||
kernel_input_feature_dimension, kernel_output_feature_dimension,
|
||||
kernel_spatial_dimensions, output_batch_dimension,
|
||||
output_feature_dimension, output_spatial_dimensions,
|
||||
parser.getBuilder().getContext());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -518,7 +518,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
|
|||
-> tensor<3x5x5x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// CHECK-SAME: padding = dense<[
|
||||
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
|
||||
|
|
|
@ -171,6 +171,128 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @convolution
|
||||
// CHECK: lmhlo.convolution
|
||||
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||
"lmhlo.convolution"(%arg0, %arg1, %arg2) {batch_group_count = 1 : i64,
|
||||
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||
feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @convolution
|
||||
// CHECK: lmhlo.convolution
|
||||
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||
// expected-error@+2{{Unexpected dimension c, expecting b, f}}
|
||||
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||
dim_numbers = [c, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||
// expected-error@+2{{Unexpected dimension b, expecting i, o}}
|
||||
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||
dim_numbers = [b, 0, 1, f]x[0, 1, b, o]->[b, 0, 1, f]
|
||||
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||
// expected-error@+2{{Unexpected dimension i, expecting o}}
|
||||
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||
dim_numbers = [b, 0, 1, f]x[0, 1, i, i]->[b, 0, 1, f]
|
||||
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||
// expected-error@+2{{Expected dimensions f not specified}}
|
||||
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||
dim_numbers = [b, 0, 1]x[0, 1, i, o]->[b, 0, 1, f]
|
||||
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||
// expected-error@+2{{Unexpected keyword b}}
|
||||
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||
dim_numbers = [b, 0, 1, f]x[0, 1, i, o, b]->[b, 0, 1, f]
|
||||
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||
// expected-error@+2{{expected '['}}
|
||||
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||
dim_numbers = {b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue