[HLO] Add custom print/parse for convolution dimension numbers (in LMHLO)
PiperOrigin-RevId: 373379227
This commit is contained in:
parent
30779f0c2f
commit
e260aa771c
|
@ -21,10 +21,24 @@ limitations under the License.
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Identifier.h"
|
#include "mlir/IR/Identifier.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
// Order matters, this .inc header is not self-contained, and relies on the
|
// Order matters, this .inc header is not self-contained, and relies on the
|
||||||
// #includes above.
|
// #includes above.
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace mhlo {
|
||||||
|
|
||||||
|
// Custom printer and parser for struct attributes.
|
||||||
|
void printConvolutionDimensions(OpAsmPrinter &p, Operation *op,
|
||||||
|
ConvDimensionNumbers dnums);
|
||||||
|
ParseResult parseConvolutionDimensions(OpAsmParser &parser,
|
||||||
|
ConvDimensionNumbers &dnums);
|
||||||
|
|
||||||
|
} // namespace mhlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_
|
||||||
|
|
|
@ -865,6 +865,12 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []> {
|
||||||
[](bool v) { return v; });
|
[](bool v) { return v; });
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(`operands`)`
|
||||||
|
`dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers)
|
||||||
|
attr-dict `:` functional-type(operands, results)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]> {
|
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]> {
|
||||||
|
|
|
@ -15,4 +15,209 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace mhlo {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
enum NonSpatialDim : int64_t {
|
||||||
|
IOBatch = -1, // Input or output batch dimension
|
||||||
|
IOFeature = -2, // Input or output feature dimension
|
||||||
|
KIFeature = -3, // Kernel input feature dimension
|
||||||
|
KOFeature = -4, // Kernel output feature dimensions.
|
||||||
|
};
|
||||||
|
|
||||||
|
char NonSpatialDimToString(NonSpatialDim dim) {
|
||||||
|
switch (dim) {
|
||||||
|
case IOBatch:
|
||||||
|
return 'b';
|
||||||
|
case IOFeature:
|
||||||
|
return 'f';
|
||||||
|
case KIFeature:
|
||||||
|
return 'i';
|
||||||
|
case KOFeature:
|
||||||
|
return 'o';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Custom printer and parser for struct attributes.
|
||||||
|
void printConvolutionDimensions(OpAsmPrinter &p, Operation * /*op*/,
|
||||||
|
ConvDimensionNumbers dnums) {
|
||||||
|
auto print_dim =
|
||||||
|
[&p](DenseIntElementsAttr spatial_dims,
|
||||||
|
ArrayRef<std::pair<IntegerAttr, NonSpatialDim>> non_spatial_dims) {
|
||||||
|
llvm::SmallVector<int64_t> dims(non_spatial_dims.size() +
|
||||||
|
spatial_dims.size());
|
||||||
|
// Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0)
|
||||||
|
// spatial dimension index.
|
||||||
|
for (const std::pair<IntegerAttr, NonSpatialDim> &non_spatial_dim :
|
||||||
|
non_spatial_dims) {
|
||||||
|
dims[non_spatial_dim.first.getInt()] = non_spatial_dim.second;
|
||||||
|
}
|
||||||
|
for (auto spatial_dim :
|
||||||
|
llvm::enumerate(spatial_dims.getValues<int64_t>())) {
|
||||||
|
dims[spatial_dim.value()] = static_cast<int64_t>(spatial_dim.index());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each dimension numbers will be printed as a comma separated list
|
||||||
|
// surrounded by square brackets, e.g., [b, 0, 1, 2, f]
|
||||||
|
p << '[';
|
||||||
|
llvm::interleaveComma(dims, p, [&](int64_t dim) {
|
||||||
|
if (dim >= 0) {
|
||||||
|
p << dim;
|
||||||
|
} else {
|
||||||
|
p << NonSpatialDimToString(static_cast<NonSpatialDim>(dim));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
p << ']';
|
||||||
|
};
|
||||||
|
|
||||||
|
print_dim(dnums.input_spatial_dimensions(),
|
||||||
|
{{dnums.input_batch_dimension(), IOBatch},
|
||||||
|
{dnums.input_feature_dimension(), IOFeature}});
|
||||||
|
p << "x";
|
||||||
|
print_dim(dnums.kernel_spatial_dimensions(),
|
||||||
|
{{dnums.kernel_input_feature_dimension(), KIFeature},
|
||||||
|
{dnums.kernel_output_feature_dimension(), KOFeature}});
|
||||||
|
p << "->";
|
||||||
|
print_dim(dnums.output_spatial_dimensions(),
|
||||||
|
{{dnums.output_batch_dimension(), IOBatch},
|
||||||
|
{dnums.output_feature_dimension(), IOFeature}});
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult parseConvolutionDimensions(OpAsmParser &parser,
|
||||||
|
ConvDimensionNumbers &dnums) {
|
||||||
|
// Parsing a single set of dim numbers gives the spatial dimensions as a
|
||||||
|
// single DenseIntElementsAttr and a list of non-spatial dimensions as
|
||||||
|
// IntegerAttrs (indexed by the NonSpatialDim enum).
|
||||||
|
using parse_dim_result_t = std::pair<
|
||||||
|
DenseIntElementsAttr,
|
||||||
|
std::unordered_map<NonSpatialDim, IntegerAttr, std::hash<int64_t>>>;
|
||||||
|
|
||||||
|
// Note that the allowed_non_spatial_dims is a set (as opposed to unordered
|
||||||
|
// set) because its used to print a list of allowed non spatial dims in the
|
||||||
|
// error messages, so making it a set keeps the error messages deterministic.
|
||||||
|
auto parse_dims =
|
||||||
|
[&](std::set<NonSpatialDim, std::greater<>> allowed_non_spatial_dims,
|
||||||
|
parse_dim_result_t &parsed_dims) -> ParseResult {
|
||||||
|
// Parse the starting [
|
||||||
|
if (parser.parseLSquare()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
llvm::SmallVector<int64_t> spatial_dims;
|
||||||
|
std::unordered_map<NonSpatialDim, IntegerAttr, std::hash<int64_t>>
|
||||||
|
non_spatial_dims;
|
||||||
|
|
||||||
|
int64_t index = 0;
|
||||||
|
do {
|
||||||
|
int64_t spatial_dim;
|
||||||
|
OptionalParseResult parseResult =
|
||||||
|
parser.parseOptionalInteger(spatial_dim);
|
||||||
|
if (parseResult.hasValue()) {
|
||||||
|
if (parseResult.getValue().failed()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
// We were successful in parsing an integer. Add its index to the
|
||||||
|
// spatial dims.
|
||||||
|
spatial_dims.push_back(index);
|
||||||
|
} else {
|
||||||
|
// We did not parse an integer. We expect a keyword token.
|
||||||
|
StringRef keyword;
|
||||||
|
if (parser.parseKeyword(&keyword)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (keyword.size() != 1 || allowed_non_spatial_dims.empty()) {
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"Unexpected keyword ")
|
||||||
|
<< keyword;
|
||||||
|
}
|
||||||
|
// Check if the keyword matches one of the allowed non-spatial dims.
|
||||||
|
// If so, add it to the non_spatial dims and remove it from the
|
||||||
|
// allowed set so that it won't be allowed again.
|
||||||
|
bool is_allowed = false;
|
||||||
|
for (NonSpatialDim allowed : allowed_non_spatial_dims) {
|
||||||
|
if (keyword[0] == NonSpatialDimToString(allowed)) {
|
||||||
|
non_spatial_dims.insert(
|
||||||
|
{allowed, parser.getBuilder().getI64IntegerAttr(index)});
|
||||||
|
allowed_non_spatial_dims.erase(allowed);
|
||||||
|
is_allowed = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_allowed) {
|
||||||
|
mlir::InFlightDiagnostic diag = parser.emitError(
|
||||||
|
parser.getCurrentLocation(), "Unexpected dimension ");
|
||||||
|
diag << keyword << ", expecting ";
|
||||||
|
llvm::interleaveComma(
|
||||||
|
allowed_non_spatial_dims, diag,
|
||||||
|
[&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
|
||||||
|
return diag;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
} while (parser.parseOptionalComma().succeeded());
|
||||||
|
|
||||||
|
// Make sure all expected non-spatial dimensions are parsed.
|
||||||
|
if (!allowed_non_spatial_dims.empty()) {
|
||||||
|
mlir::InFlightDiagnostic diag =
|
||||||
|
parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
|
||||||
|
llvm::interleaveComma(
|
||||||
|
allowed_non_spatial_dims, diag,
|
||||||
|
[&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
|
||||||
|
diag << " not specified";
|
||||||
|
return diag;
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse ending ]
|
||||||
|
if (parser.parseRSquare()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed_dims = std::make_pair(
|
||||||
|
parser.getBuilder().getI64TensorAttr(spatial_dims), non_spatial_dims);
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
parse_dim_result_t parsed_dims;
|
||||||
|
if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
DenseIntElementsAttr input_spatial_dimensions = parsed_dims.first;
|
||||||
|
IntegerAttr input_batch_dimension = parsed_dims.second[IOBatch];
|
||||||
|
IntegerAttr input_feature_dimension = parsed_dims.second[IOFeature];
|
||||||
|
if (parser.parseKeyword("x")) return failure();
|
||||||
|
if (parse_dims({KIFeature, KOFeature}, parsed_dims)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
DenseIntElementsAttr kernel_spatial_dimensions = parsed_dims.first;
|
||||||
|
IntegerAttr kernel_input_feature_dimension = parsed_dims.second[KIFeature];
|
||||||
|
IntegerAttr kernel_output_feature_dimension = parsed_dims.second[KOFeature];
|
||||||
|
if (parser.parseArrow()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
DenseIntElementsAttr output_spatial_dimensions = parsed_dims.first;
|
||||||
|
IntegerAttr output_batch_dimension = parsed_dims.second[IOBatch];
|
||||||
|
IntegerAttr output_feature_dimension = parsed_dims.second[IOFeature];
|
||||||
|
dnums = ConvDimensionNumbers::get(
|
||||||
|
input_batch_dimension, input_feature_dimension, input_spatial_dimensions,
|
||||||
|
kernel_input_feature_dimension, kernel_output_feature_dimension,
|
||||||
|
kernel_spatial_dimensions, output_batch_dimension,
|
||||||
|
output_feature_dimension, output_spatial_dimensions,
|
||||||
|
parser.getBuilder().getContext());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mhlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
|
@ -518,7 +518,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
|
||||||
-> tensor<3x5x5x4xf32> {
|
-> tensor<3x5x5x4xf32> {
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
|
// CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x5x5x4xf32>
|
||||||
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
// CHECK: lmhlo.convolution(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||||
// CHECK-SAME: padding = dense<[
|
// CHECK-SAME: padding = dense<[
|
||||||
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||||
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
|
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
|
||||||
|
|
|
@ -171,6 +171,128 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convolution
|
||||||
|
// CHECK: lmhlo.convolution
|
||||||
|
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
"lmhlo.convolution"(%arg0, %arg1, %arg2) {batch_group_count = 1 : i64,
|
||||||
|
dimension_numbers = {input_batch_dimension = 0 : i64,
|
||||||
|
input_feature_dimension = 3 : i64,
|
||||||
|
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
kernel_input_feature_dimension = 2 : i64,
|
||||||
|
kernel_output_feature_dimension = 3 : i64,
|
||||||
|
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
output_batch_dimension = 0 : i64,
|
||||||
|
output_feature_dimension = 3 : i64,
|
||||||
|
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||||
|
feature_group_count = 1 : i64,
|
||||||
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convolution
|
||||||
|
// CHECK: lmhlo.convolution
|
||||||
|
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||||
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+2{{Unexpected dimension c, expecting b, f}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [c, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||||
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+2{{Unexpected dimension b, expecting i, o}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [b, 0, 1, f]x[0, 1, b, o]->[b, 0, 1, f]
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||||
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+2{{Unexpected dimension i, expecting o}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [b, 0, 1, f]x[0, 1, i, i]->[b, 0, 1, f]
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||||
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+2{{Expected dimensions f not specified}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [b, 0, 1]x[0, 1, i, o]->[b, 0, 1, f]
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||||
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+2{{Unexpected keyword b}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = [b, 0, 1, f]x[0, 1, i, o, b]->[b, 0, 1, f]
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||||
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @convolution(%arg0: memref<2x2x3x4xf32>, %arg1: memref<3x5x5x3xf32>, %arg2: memref<3x5x5x4xf32>) {
|
||||||
|
// expected-error@+2{{expected '['}}
|
||||||
|
lmhlo.convolution(%arg0, %arg1, %arg2)
|
||||||
|
dim_numbers = {b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
|
||||||
|
{ batch_group_count = 1 : i64, feature_group_count = 1 : i64,
|
||||||
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||||
|
window_strides = dense<[2, 1]> : tensor<2xi64>}
|
||||||
|
: (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @exp
|
// CHECK-LABEL: func @exp
|
||||||
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
"lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
"lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
|
Loading…
Reference in New Issue