Add shape inference and names (#266)

* Add shape inference and names

 - Add shape inference for PRelu
 - Fix shape inference for group conv
   for ConvTranspose
 - Add input and output names for
   graphs (functions)
 - Add support for (u)int8 tensor
   attributes

* Fix format issues

* Revert formatting for gen_onnx_mlir.py

* Pads can have ArrayAttr and DenseElementsAttr so support both

* NumInputs is the number of graph inputs that don't have initializers

* Add test for 2D batchnorm

* Fix typo in define_loops in new 2d BN test

* Change 'name' to 'onnx_node_name'

* Fix Batchnorm for 2D I/O and add lowering test

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Aman LaChapelle 2020-08-27 12:46:27 -07:00 committed by GitHub
parent 11a5029c10
commit 24d0a2ac71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 278 additions and 42 deletions

View File

@ -166,6 +166,15 @@ mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::DOUBLE): {
const auto &arrayAttrInitializer =
CreateArrayAttribute<double>(initializer);
auto elmType = builder.getF64Type();
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
denseElmAttr = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::INT8): {
const auto &arrayAttrInitializer =
CreateArrayAttribute<int32_t>(initializer);
@ -175,6 +184,15 @@ mlir::DenseElementsAttr onnxTensorProtoToDenseElmAttr(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::UINT8): {
const auto &arrayAttrInitializer =
CreateArrayAttribute<int32_t>(initializer);
auto elmType = builder.getIntegerType(8, false);
auto tensorType = mlir::RankedTensorType::get(tensorDims, elmType);
denseElmAttr = mlir::DenseElementsAttr::get(
tensorType, llvm::makeArrayRef(arrayAttrInitializer));
break;
}
case (onnx::TensorProto::INT32): {
const auto &arrayAttrInitializer =
CreateArrayAttribute<int32_t>(initializer);

View File

@ -155,6 +155,12 @@ private:
auto attr = node.attribute(i);
attributes.push_back(convertOnnxAttributeProtoToMlirNamedAttribute(attr));
}
// If the node has a name, then import it.
if (node.has_name()) {
attributes.push_back(builder_.getNamedAttr(
"onnx_node_name", builder_.getStringAttr(node.name())));
}
return attributes;
}
@ -397,6 +403,88 @@ private:
}
}
void ImportNodeSlice(const onnx::NodeProto &node) {
std::array<mlir::Value, 5> inVals = {
nullptr,
};
for (const auto &item : llvm::enumerate(node.input())) {
if (initializedTensors.ContainKey(legalize_name(item.value()))) {
inVals[item.index()] = initializedTensors.EmitInitializerForInputTensor(
UnknownLoc(), builder_, legalize_name(item.value()));
} else if (frontend_symbols_.ContainKey(legalize_name(item.value()))) {
inVals[item.index()] =
frontend_symbols_.GetTensorByOnnxName(item.value());
} else {
assert(false && "Unknown input");
}
}
// Data input is imported but starts, ends, axes, and steps may come from
// attributes, and need to be created as constant ops.
const auto elementType = builder_.getIntegerType(64);
const auto tensorType = mlir::RankedTensorType::get({1}, elementType);
const auto attributes = ImportNodeAttributes(node);
for (auto attr : attributes) {
if (auto arrayAttr = attr.second.dyn_cast<mlir::ArrayAttr>()) {
auto constantDenseAttribute =
mlir::DenseElementsAttr::get(tensorType, arrayAttr.getValue());
auto constantOp = builder_.create<mlir::ONNXConstantOp>(
UnknownLoc(), mlir::Attribute(), constantDenseAttribute);
mlir::Value constantValue = constantOp.output();
// Map from ONNX attributes to indices, which are
// matched with ONNXSliceOp::build ordering.
auto inputIdx = llvm::StringSwitch<int>(attr.first)
.Case("starts", 1)
.Case("ends", 2)
.Case("axes", 3)
.Case("steps", 4)
.Default(-1);
if (inputIdx < 0)
continue;
assert(inVals[inputIdx] == nullptr &&
"This input has already been filled in");
inVals[inputIdx] = constantValue;
}
}
assert(inVals[1] != nullptr && "Slice requires a starts attribute");
assert(inVals[2] != nullptr && "Slice requires an ends attribute");
const auto startsType = inVals[1].getType().dyn_cast<RankedTensorType>();
assert(startsType != nullptr && "starts type is not a RankedTensorType");
auto startsDim = startsType.getShape()[0];
// If axes is not specified, default to [0, ..., ndim-1]
if (inVals[3] == nullptr) {
SmallVector<int64_t, 1> vals = {};
for (size_t s = 0; s < startsDim; ++s)
vals.emplace_back(s);
auto constantDenseAttribute =
mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(vals));
auto constantOp = builder_.create<mlir::ONNXConstantOp>(
UnknownLoc(), mlir::Attribute(), constantDenseAttribute);
mlir::Value constantResult = constantOp.output();
inVals[3] = constantResult;
}
// If steps is not specified, default to [1, ..., 1]
if (inVals[4] == nullptr) {
SmallVector<int64_t, 1> vals(startsDim, 1);
auto constantDenseAttribute =
mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(vals));
auto constantOp = builder_.create<mlir::ONNXConstantOp>(
UnknownLoc(), mlir::Attribute(), constantDenseAttribute);
mlir::Value constantResult = constantOp.output();
inVals[4] = constantResult;
}
int nIn = mlir::ONNXSliceOp::getNumberOfOperands();
int nOut = mlir::ONNXSliceOp::getNumberOfResults();
const auto in = std::vector<mlir::Value>(inVals.begin(), inVals.end());
buildOutputAndOperation<mlir::ONNXSliceOp>(node, in, nIn, nOut);
}
void ImportNode(const onnx::NodeProto &node) {
llvm::StringRef opName = node.op_type();
@ -406,7 +494,9 @@ private:
// the generic operator is used
// one known reeason is the optional input
(this->*(import_handler_map_[opName.str()]))(node);
auto found = import_handler_map_.find(opName.str());
assert(found != import_handler_map_.end() && "Could not find op importer");
(this->*(found->second))(node);
}
void InitHandlerMap() {
@ -452,20 +542,41 @@ private:
// * maintain a list of the defined graph
llvm::SmallVector<mlir::Type, 4> arg_types;
// Get a list of function attributes - including names of inputs and outputs
llvm::SmallVector<mlir::NamedAttribute, 4> funcAttrs;
llvm::SmallVector<llvm::StringRef, 4> inputNames;
llvm::SmallVector<llvm::StringRef, 4> outputNames;
// Import the input tensor types that are not constant and not initialized.
for (const auto &input : graph.input())
if (!initializedTensors.ContainKey(legalize_name(input.name())))
int numInputs = 0;
for (const auto &input : graph.input()) {
if (!initializedTensors.ContainKey(legalize_name(input.name()))) {
inputNames.push_back(input.name());
arg_types.emplace_back(ImportInputTensorType(input));
// numInputs is the number of graph inputs not contained within the
// initializer
++numInputs;
}
}
for (const auto &output : graph.output()) {
outputNames.push_back(output.name());
}
funcAttrs.emplace_back(builder_.getNamedAttr(
"input_names", builder_.getStrArrayAttr(inputNames)));
funcAttrs.emplace_back(builder_.getNamedAttr(
"output_names", builder_.getStrArrayAttr(outputNames)));
// Create the main function.
auto funcType = builder_.getFunctionType(arg_types, {});
auto mainFunc =
mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {});
auto mainFunc = mlir::FuncOp::create(UnknownLoc(), name, funcType,
/* attrs = */ llvm::makeArrayRef(funcAttrs));
// Emit the entry point operation which specifies the number of user
// inputs and outputs.
auto entryPoint = mlir::ONNXEntryPointOp::create(UnknownLoc(), mainFunc,
/*numInputs=*/graph.input().size() - graph.initializer().size(),
/*numInputs=*/numInputs,
/*numOutputs=*/graph.output().size());
// Get the entru block inside the main function and set the insertion point

View File

@ -269,7 +269,7 @@ import_handler_map_["Sinh"] =
import_handler_map_["Size"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSizeOp>;
import_handler_map_["Slice"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSliceOp>;
&onnx_mlir::detail::FrontendGenImpl::ImportNodeSlice;
import_handler_map_["Softmax"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSoftmaxOp>;
import_handler_map_["Softplus"] =

View File

@ -106,6 +106,9 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
loopIVs.emplace_back(loopCIVs[0]); // Insert C back.
for (int i = 1; i < args.size(); ++i)
loopIVs.emplace_back(args[i]);
} else if (rank == 2) {
loopIVs.emplace_back(args[0]);
loopIVs.emplace_back(loopCIVs[0]); // Insert C back.
} else {
loopIVs.emplace_back(args[0]);
}

View File

@ -699,6 +699,16 @@ LogicalResult ONNXSeluOp::inferShapes() {
return success();
}
//===----------------------------------------------------------------------===//
// PRelu
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXPReluOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXPReluOp::inferShapes() {
getResult().setType(getOperand(0).getType());
return success();
}
//===----------------------------------------------------------------------===//
// Reciprocal
//===----------------------------------------------------------------------===//
@ -1164,14 +1174,13 @@ LogicalResult ONNXBatchNormalizationTestModeOp::inferShapes() {
// Check whether the shapes of scale, bias, mean and variance are valid.
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1.
// 2-D tensors are assumed to be of shape NxC
// Shapes of scale, bias, mean and variance must be C.
int64_t c = -1;
if (inputTensorTy.getShape().size() == 1) {
c = 1;
} else if (inputTensorTy.getShape().size() > 2) {
} else if (inputTensorTy.getShape().size() >= 2) {
c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
} else {
return emitError("Wrong rank for the input");
}
if (c != -1) {
@ -1419,8 +1428,10 @@ LogicalResult ONNXConvOp::inferShapes() {
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (xShape[1] != -1 && weightShape[1] != -1 &&
xShape[1] != (weightShape[1] * group))
return emitError("Channel dimension mismatch");
xShape[1] != (weightShape[1] * group)) {
return emitOpError("Channel dimension mismatch")
<< xTy << " " << weightTy << " " << group;
}
// Check the size of bias.
if (hasBias) {
@ -1500,7 +1511,7 @@ LogicalResult ONNXConvOp::inferShapes() {
LogicalResult ONNXConvTransposeOp::inferShapes() {
// Generic shape for data input X, weight tensor W, and optional bias B
// X: (N x C x D1 x D2 ... x Dn)
// W: (M x C/group x k1 x k2 x ... x kn)
// W: (C x M/group x k1 x k2 x ... x kn)
// B: (M) Optional
bool hasBias = !B().getType().isa<NoneType>();
@ -1535,10 +1546,15 @@ LogicalResult ONNXConvTransposeOp::inferShapes() {
if (!groupAttr())
groupAttr(builder.getI64IntegerAttr(group));
// Check that the X.shape[1] == (W.shape[0] * group) == C condition holds.
if (xShape[1] != -1 && weightShape[0] != -1 &&
xShape[1] != (weightShape[0] * group)) {
return emitError("Channel dimension mismatch");
int64_t inChannels = weightShape[0];
int64_t outChannels = weightShape[1] * group;
// Check that the X.shape[1] == W.shape[0] == C && X.shape[1] % group == 0
// condition holds.
if (xShape[1] != -1 && inChannels != -1 && xShape[1] != inChannels &&
xShape[1] % group != 0) {
return emitOpError("Channel dimension mismatch")
<< xTy << " " << weightTy << " " << group;
}
// Check the size of bias.
@ -1548,9 +1564,9 @@ LogicalResult ONNXConvTransposeOp::inferShapes() {
if (bShape.size() != 1) {
return emitError("bias should be one dimensional");
}
if (bShape[0] != weightShape[1]) {
if (bShape[0] != outChannels) {
return emitError(
"bias should have same dimensions as weight's second dimension");
"bias should have same dimensions as number of output channels");
}
}
@ -1601,8 +1617,9 @@ LogicalResult ONNXConvTransposeOp::inferShapes() {
SmallVector<int64_t, 4> outputDims;
// Insert batch size.
outputDims.emplace_back(xShape[0]);
// Insert number of filters being applied (number of output channels).
outputDims.emplace_back(weightShape[1]);
// Insert number of filters being applied (number of output channels *
// groups).
outputDims.emplace_back(outChannels);
// Compute and insert spatial dims.
insertConvTransposeSpatialDim(outputDims, xShape, kernelShape, padsOpt,
stridesOpt, outputPads, outputShape, dilationsOpt);
@ -1730,22 +1747,29 @@ LogicalResult ONNXPadOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>())
return emitError("Pad: unknown input shape");
// Cannot infer if the pads is not constant
DenseElementsAttr padsAttributes =
getAttr("pads").dyn_cast_or_null<mlir::DenseElementsAttr>();
if (!padsAttributes)
return emitError("Pad: unknown pads");
auto dataTy = data().getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto dataRank = dataTy.getRank();
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
// Get pads from valueAttribute.
Attribute padattr = getAttr("pads");
SmallVector<int64_t, 2> pads(dataRank * 2, -1);
auto valueIt = padsAttributes.getValues<IntegerAttr>().begin();
// Sometimes it's an ArrayAttr and sometimes it's a DenseElementsAttr, so
// handle both cases.
if (ArrayAttr padsAttributes = padattr.dyn_cast_or_null<mlir::ArrayAttr>()) {
auto valueIt = padsAttributes.getValue().begin();
for (int64_t i = 0; i < dataRank * 2; ++i)
pads[i] = (*valueIt++).cast<IntegerAttr>().getInt();
} else if (DenseElementsAttr padsAttributes =
padattr.dyn_cast_or_null<mlir::DenseElementsAttr>()) {
auto valueIt = padsAttributes.getValues<IntegerAttr>().begin();
for (int64_t i = 0; i < dataRank * 2; ++i)
pads[i] = (*valueIt++).getInt();
} else {
// Cannot infer if the pads is not constant
return emitError("Pad: unknown pads ") << getAttr("pads");
}
// Pads consists of two values for each axis of data.
// The two values specify the number of elements padded before and after
@ -2012,9 +2036,11 @@ LogicalResult ONNXConcatOp::inferShapes() {
return emitError("Concat axis being concatenated is "
"expected to be known at compile time for now");
} else if (currShape[j] != commonShape[j]) {
return emitError(
"Concat input dimensions must be all identical, "
"except for dimension on the axis of the concatenation");
return emitError("Concat input dimensions must be all identical, "
"except for dimension on the axis of the "
"concatenation. Expected something compatible with: ")
<< commonType << " but got " << getOperand(i).getType()
<< " instead.";
}
}
cummulativeAxisSize += currShape[axisIndex];
@ -2687,6 +2713,14 @@ LogicalResult ONNXSliceOp::inferShapes() {
outputDims[axis] = q;
}
// Fill in the rest of the dimensions - assume they're untouched.
for (int i = 0, e = outputDims.size(); i < e; ++i) {
if (llvm::any_of(axesValue, [i](int64_t a) { return a == i; })) {
continue;
}
outputDims[i] = dataShape[i];
}
getResult().setType(RankedTensorType::get(outputDims, elementType));
return success();
}

View File

@ -3090,7 +3090,7 @@ def ONNXOrOp:ONNX_Op<"Or",
}
def ONNXPReluOp:ONNX_Op<"PRelu",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX PRelu operation";
let description = [{
"PRelu takes input data (Tensor<T>) and slope tensor as input, and produces one"

View File

@ -1398,6 +1398,35 @@ func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %a
// -----
func @test_batchnorm_testmode_2d(%arg0: tensor<10x3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>) -> tensor<10x3xf32> {
%0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<10x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<10x3xf32>
return %0 : tensor<10x3xf32>
// CHECK-LABEL: test_batchnorm_testmode_2d
// CHECK: [[RES:%.+]] = alloc() : memref<10x3xf32>
// CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: krnl.iterate([[DEF_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg5 = 0 to 3) {
// CHECK: [[SCALE:%.+]] = affine.load %arg1[%arg5] : memref<3xf32>
// CHECK: [[BIAS:%.+]] = affine.load %arg2[%arg5] : memref<3xf32>
// CHECK: [[MEAN:%.+]] = affine.load %arg3[%arg5] : memref<3xf32>
// CHECK: [[VARIANCE:%.+]] = affine.load %arg4[%arg5] : memref<3xf32>
// CHECK: krnl.iterate([[DEF_LOOPS]]#0) with ([[DEF_LOOPS]]#0 -> %arg6 = 0 to 10) {
// CHECK: [[LOADED_VAL:%.+]] = affine.load %arg0[%arg6, %arg5] : memref<10x3xf32>
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
// CHECK: [[DIVISOR:%.+]] = sqrt [[ADJUSTED_VARIANCE]] : f32
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
// CHECK: affine.store [[SHIFT_SCALE_NORM]], [[RES]][%arg6, %arg5] : memref<10x3xf32>
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<10x3xf32>
}
// -----
func @test_abs_float(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Abs"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()

View File

@ -350,6 +350,32 @@ func @test_conv_12(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>, %arg2
// -----
//===----------------------------------------------------------------------===//
/// Test shape inference for ConvTranspose.
//===----------------------------------------------------------------------===//
func @test_conv_transpose_1(%arg0 : tensor<1x64x36x48xf32>, %arg1 : tensor<64x1x2x2xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.ConvTranspose"(%arg0, %arg1, %cst) {dilations = [1, 1], kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x64x36x48xf32>, tensor<64x1x2x2xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_transpose_1
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvTranspose"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [2, 2], output_shape = [1, 1, 72, 96], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x64x36x48xf32>, tensor<64x1x2x2xf32>, none) -> tensor<1x1x72x96xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x1x72x96xf32>
}
func @test_conv_transpose_2(%arg0 : tensor<1x64x36x48xf32>, %arg1 : tensor<64x1x2x2xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.ConvTranspose"(%arg0, %arg1, %cst) {dilations = [1, 1], group = 64 : i64, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x64x36x48xf32>, tensor<64x1x2x2xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_transpose_2
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvTranspose"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 64 : i64, kernel_shape = [2, 2], output_shape = [1, 64, 72, 96], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x64x36x48xf32>, tensor<64x1x2x2xf32>, none) -> tensor<1x64x72x96xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x64x72x96xf32>
}
// -----
//===----------------------------------------------------------------------===//
/// Test shape inference for PadConstantValuePad.
//===----------------------------------------------------------------------===//
@ -357,15 +383,26 @@ func @test_conv_12(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>, %arg2
/// Test Pad_1
func @test_Pad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Pad"(%arg0, %cst, %cst) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<16x13xf32>, none, none) -> tensor<*xf32>
%0 = "onnx.Pad"(%arg0, %cst, %cst) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = [0, 2, 2, 4]} : (tensor<16x13xf32>, none, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_Pad_1
// CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK: [[RES:%.+]] = "onnx.Pad"(%arg0, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<16x13xf32>, none, none) -> tensor<18x19xf32>
// CHECK: [[RES:%.+]] = "onnx.Pad"(%arg0, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = [0, 2, 2, 4]} : (tensor<16x13xf32>, none, none) -> tensor<18x19xf32>
// CHECK: return [[RES]] : tensor<18x19xf32>
}
/// Test Pad_2
func @test_Pad_2(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Pad"(%arg0, %cst, %cst) {mode = "edge", pads = [0, 2, 2, 4]} : (tensor<16x13xf32>, none, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_Pad_2
// CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK: [[RES:%.+]] = "onnx.Pad"(%arg0, [[NONE]], [[NONE]]) {mode = "edge", pads = [0, 2, 2, 4]} : (tensor<16x13xf32>, none, none) -> tensor<18x19xf32>
// CHECK: return [[RES]] : tensor<18x19xf32>
}
/// Test PadConstantValuePad_1
func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {

View File

@ -1,7 +1,9 @@
find_package(Python 3 REQUIRED COMPONENTS Interpreter)
# Invoke gen_onnx_mlir.py to obtain ONNXOps.td.inc, OpBuildTable.inc.
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/ONNXOps.td.inc
${CMAKE_CURRENT_SOURCE_DIR}/OpBuildTable.inc
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py
COMMAND Python::Interpreter ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py)
# Move the generated files to respective destinations:

View File

@ -239,6 +239,7 @@ special_op_handler = dict([
("MaxPool", "ImportNodeMaxPool"),
("BatchNormalization", "ImportNodeBatchNormalization"),
("Pad", "ImportNodePad"),
("Slice", "ImportNodeSlice"),
#("Transpose", "ImportNodeTranspose")
])
@ -284,6 +285,7 @@ OpsWithShapeInference=[
'Or',
'Pad',
'Pow',
'PRelu',
'QuantizeLinear',
'RNN',
'Reciprocal',