Implement PadOp based on attribute promotion (#71)

* enable promote attr for pad

* use optional arguments for pad

* shape infereance for pad

* Lowering Pad

* format file

* use DenseTensor for the attribute

* use Pad in ONNXRewrite

* fix the merge conflict

* fix the attr given to constantOp

* handle ONNXConstantOp in attribute promotion

* Fix bug when AttributePromotion is called more than once

* update ONNXOps.td.inc with correct version of onnx

* update onnx.md

* responses to review

* fix the build error

* change the implementation of Pad

* delete commented out code

* clang format

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
chentong319 2020-05-15 01:19:28 -04:00 committed by GitHub
parent 4a68597417
commit 23bea50404
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 530 additions and 182 deletions

View File

@ -3033,7 +3033,7 @@ ONNX Pad operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`data` | memref of any type values or tensor of any type values `data` | memref of any type values or tensor of any type values
`pads` | memref of any type values or tensor of any type values `pads` | memref of any type values or tensor of any type values or none type
`constant_value` | memref of any type values or tensor of any type values or none type `constant_value` | memref of any type values or tensor of any type values or none type
#### Results: #### Results:

View File

@ -303,9 +303,34 @@ private:
* Special handle for Pad operations. * Special handle for Pad operations.
*/ */
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) { void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
int nOps = node.input().size(); int nOps = node.input().size();
if (nOps == 2) { if (nOps == 2) {
buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut); llvm::SmallVector<int64_t, 2> dims;
dims.push_back(1);
llvm::SmallVector<float, 2> values;
values.push_back(0.);
auto elementType = builder_.getF32Type();
llvm::ArrayRef<int64_t> tensorDims(dims.data(), dims.size());
auto tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
auto constantDenseAttribute =
mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
// Use the special builder defined in ONNXOp.td.inc.
auto constantOp = builder_.create<mlir::ONNXConstantOp>(
UnknownLoc(), mlir::Attribute(), constantDenseAttribute);
mlir::Value constantResult = *(constantOp.getODSResults(0).begin());
std::vector<mlir::Value> inputs;
for (const auto &item : node.input())
if (initializedTensors.ContainKey(legalize_name(item))) {
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
UnknownLoc(), builder_, legalize_name(item)));
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
inputs.push_back(constantResult);
buildOutputAndOperation<mlir::ONNXPadOp>(node, inputs, nIn, nOut);
} else { } else {
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut); buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
} }

View File

@ -5,314 +5,314 @@
//******************************************************** //********************************************************
if (opName == "Abs") if (opName == "Abs")
return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Acos") if (opName == "Acos")
return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Acosh") if (opName == "Acosh")
return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Add") if (opName == "Add")
return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "And") if (opName == "And")
return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "ArgMax") if (opName == "ArgMax")
return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ArgMin") if (opName == "ArgMin")
return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Asin") if (opName == "Asin")
return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Asinh") if (opName == "Asinh")
return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Atan") if (opName == "Atan")
return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Atanh") if (opName == "Atanh")
return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "AveragePool") if (opName == "AveragePool")
return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "BatchNormalization") if (opName == "BatchNormalization")
return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5); ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
if (opName == "BitShift") if (opName == "BitShift")
return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Cast") if (opName == "Cast")
return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Ceil") if (opName == "Ceil")
return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Clip") if (opName == "Clip")
return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Compress") if (opName == "Compress")
return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Concat") if (opName == "Concat")
return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "ConcatFromSequence") if (opName == "ConcatFromSequence")
return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Constant") if (opName == "Constant")
return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "ConstantOfShape") if (opName == "ConstantOfShape")
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Conv") if (opName == "Conv")
return buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ConvInteger") if (opName == "ConvInteger")
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ConvTranspose") if (opName == "ConvTranspose")
return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Cos") if (opName == "Cos")
return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Cosh") if (opName == "Cosh")
return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "CumSum") if (opName == "CumSum")
return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "DepthToSpace") if (opName == "DepthToSpace")
return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "DequantizeLinear") if (opName == "DequantizeLinear")
return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Det") if (opName == "Det")
return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Div") if (opName == "Div")
return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Dropout") if (opName == "Dropout")
return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
if (opName == "DynamicQuantizeLinear") if (opName == "DynamicQuantizeLinear")
return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3); buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
if (opName == "Elu") if (opName == "Elu")
return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Equal") if (opName == "Equal")
return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Erf") if (opName == "Erf")
return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Exp") if (opName == "Exp")
return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Expand") if (opName == "Expand")
return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "EyeLike") if (opName == "EyeLike")
return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Flatten") if (opName == "Flatten")
return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Floor") if (opName == "Floor")
return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GRU") if (opName == "GRU")
return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
if (opName == "Gather") if (opName == "Gather")
return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "GatherElements") if (opName == "GatherElements")
return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "GatherND") if (opName == "GatherND")
return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Gemm") if (opName == "Gemm")
return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "GlobalAveragePool") if (opName == "GlobalAveragePool")
return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GlobalLpPool") if (opName == "GlobalLpPool")
return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "GlobalMaxPool") if (opName == "GlobalMaxPool")
return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Greater") if (opName == "Greater")
return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "HardSigmoid") if (opName == "HardSigmoid")
return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Hardmax") if (opName == "Hardmax")
return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Identity") if (opName == "Identity")
return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "If") if (opName == "If")
return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
if (opName == "InstanceNormalization") if (opName == "InstanceNormalization")
return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "IsInf") if (opName == "IsInf")
return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "IsNaN") if (opName == "IsNaN")
return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LRN") if (opName == "LRN")
return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LSTM") if (opName == "LSTM")
return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3); buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
if (opName == "LeakyRelu") if (opName == "LeakyRelu")
return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Less") if (opName == "Less")
return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Log") if (opName == "Log")
return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LogSoftmax") if (opName == "LogSoftmax")
return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Loop") if (opName == "Loop")
return buildOperation<mlir::ONNXLoopOp>(node); buildOperation<mlir::ONNXLoopOp>(node);
if (opName == "LpNormalization") if (opName == "LpNormalization")
return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "LpPool") if (opName == "LpPool")
return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "MatMul") if (opName == "MatMul")
return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "MatMulInteger") if (opName == "MatMulInteger")
return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "Max") if (opName == "Max")
return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "MaxPool") if (opName == "MaxPool")
return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
if (opName == "MaxRoiPool") if (opName == "MaxRoiPool")
return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "MaxUnpool") if (opName == "MaxUnpool")
return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Mean") if (opName == "Mean")
return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "MeanVarianceNormalization") if (opName == "MeanVarianceNormalization")
return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Min") if (opName == "Min")
return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "Mod") if (opName == "Mod")
return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Mul") if (opName == "Mul")
return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Multinomial") if (opName == "Multinomial")
return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Neg") if (opName == "Neg")
return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "NonMaxSuppression") if (opName == "NonMaxSuppression")
return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
if (opName == "NonZero") if (opName == "NonZero")
return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Not") if (opName == "Not")
return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "OneHot") if (opName == "OneHot")
return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Or") if (opName == "Or")
return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "PRelu") if (opName == "PRelu")
return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Pad") if (opName == "Pad")
return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Pow") if (opName == "Pow")
return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "QLinearConv") if (opName == "QLinearConv")
return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1); buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
if (opName == "QLinearMatMul") if (opName == "QLinearMatMul")
return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1); buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
if (opName == "QuantizeLinear") if (opName == "QuantizeLinear")
return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "RNN") if (opName == "RNN")
return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
if (opName == "RandomNormal") if (opName == "RandomNormal")
return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "RandomNormalLike") if (opName == "RandomNormalLike")
return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "RandomUniform") if (opName == "RandomUniform")
return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "RandomUniformLike") if (opName == "RandomUniformLike")
return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Range") if (opName == "Range")
return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Reciprocal") if (opName == "Reciprocal")
return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceL1") if (opName == "ReduceL1")
return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceL2") if (opName == "ReduceL2")
return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceLogSum") if (opName == "ReduceLogSum")
return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceLogSumExp") if (opName == "ReduceLogSumExp")
return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMax") if (opName == "ReduceMax")
return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMean") if (opName == "ReduceMean")
return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceMin") if (opName == "ReduceMin")
return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceProd") if (opName == "ReduceProd")
return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceSum") if (opName == "ReduceSum")
return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ReduceSumSquare") if (opName == "ReduceSumSquare")
return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Relu") if (opName == "Relu")
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Reshape") if (opName == "Reshape")
return ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Resize") if (opName == "Resize")
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ReverseSequence") if (opName == "ReverseSequence")
return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "RoiAlign") if (opName == "RoiAlign")
return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Round") if (opName == "Round")
return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Scan") if (opName == "Scan")
return buildOperation<mlir::ONNXScanOp>(node); buildOperation<mlir::ONNXScanOp>(node);
if (opName == "Scatter") if (opName == "Scatter")
return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ScatterElements") if (opName == "ScatterElements")
return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ScatterND") if (opName == "ScatterND")
return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Selu") if (opName == "Selu")
return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "SequenceAt") if (opName == "SequenceAt")
return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "SequenceConstruct") if (opName == "SequenceConstruct")
return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "SequenceEmpty") if (opName == "SequenceEmpty")
return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
if (opName == "SequenceErase") if (opName == "SequenceErase")
return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "SequenceInsert") if (opName == "SequenceInsert")
return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "SequenceLength") if (opName == "SequenceLength")
return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Shape") if (opName == "Shape")
return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Shrink") if (opName == "Shrink")
return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sigmoid") if (opName == "Sigmoid")
return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sign") if (opName == "Sign")
return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sin") if (opName == "Sin")
return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sinh") if (opName == "Sinh")
return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Size") if (opName == "Size")
return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Slice") if (opName == "Slice")
return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
if (opName == "Softmax") if (opName == "Softmax")
return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Softplus") if (opName == "Softplus")
return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Softsign") if (opName == "Softsign")
return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "SpaceToDepth") if (opName == "SpaceToDepth")
return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Split") if (opName == "Split")
return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
if (opName == "SplitToSequence") if (opName == "SplitToSequence")
return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Sqrt") if (opName == "Sqrt")
return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Squeeze") if (opName == "Squeeze")
return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "StringNormalizer") if (opName == "StringNormalizer")
return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Sub") if (opName == "Sub")
return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Sum") if (opName == "Sum")
return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
if (opName == "Tan") if (opName == "Tan")
return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Tanh") if (opName == "Tanh")
return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "TfIdfVectorizer") if (opName == "TfIdfVectorizer")
return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "ThresholdedRelu") if (opName == "ThresholdedRelu")
return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Tile") if (opName == "Tile")
return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "TopK") if (opName == "TopK")
return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2); buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
if (opName == "Transpose") if (opName == "Transpose")
return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Unique") if (opName == "Unique")
return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4); buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
if (opName == "Unsqueeze") if (opName == "Unsqueeze")
return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Upsample") if (opName == "Upsample")
return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
if (opName == "Where") if (opName == "Where")
return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "Xor") if (opName == "Xor")
return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);

View File

@ -15,6 +15,7 @@ add_library(OMONNXToKrnl
Tensor/Identity.cpp Tensor/Identity.cpp
Tensor/Reshape.cpp Tensor/Reshape.cpp
Tensor/PadConstantValuePad.cpp Tensor/PadConstantValuePad.cpp
Tensor/Pad.cpp
Tensor/Transpose.cpp Tensor/Transpose.cpp
Tensor/Unsqueeze.cpp Tensor/Unsqueeze.cpp
Tensor/Constant.cpp Tensor/Constant.cpp

View File

@ -93,6 +93,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
// Tensor // Tensor
populateLoweringONNXReshapeOpPattern(patterns, &getContext()); populateLoweringONNXReshapeOpPattern(patterns, &getContext());
populateLoweringONNXPadConstantValuePadOpPattern(patterns, &getContext()); populateLoweringONNXPadConstantValuePadOpPattern(patterns, &getContext());
populateLoweringONNXPadOpPattern(patterns, &getContext());
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext()); populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
populateLoweringONNXTransposeOpPattern(patterns, &getContext()); populateLoweringONNXTransposeOpPattern(patterns, &getContext());
populateLoweringONNXIdentityOpPattern(patterns, &getContext()); populateLoweringONNXIdentityOpPattern(patterns, &getContext());

View File

@ -237,6 +237,9 @@ void populateLoweringONNXTransposeOpPattern(
void populateLoweringONNXPadConstantValuePadOpPattern( void populateLoweringONNXPadConstantValuePadOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXPadOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
void populateLoweringONNXReshapeOpPattern( void populateLoweringONNXReshapeOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx); OwningRewritePatternList &patterns, MLIRContext *ctx);

View File

@ -0,0 +1,120 @@
//===-----------------------Pad.cpp - Lowering Pad Op -------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Pad Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir;
struct ONNXPadOpLowering : public ConversionPattern {
ONNXPadOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXPadOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
ONNXPadOp myOp = llvm::dyn_cast<ONNXPadOp>(op);
ONNXPadOpOperandAdaptor operandAdaptor(operands);
auto tensorType = myOp.output().getType();
auto loc = op->getLoc();
// Only constant padding is supported now.
auto padMode = myOp.mode();
if (padMode != "constant")
emitError(loc, "unsupported mode for Pad");
DenseElementsAttr constantValAttr =
myOp.getAttr("constant_value")
.dyn_cast_or_null<mlir::DenseElementsAttr>();
if (!constantValAttr)
emitError(loc, "unsupported value");
DenseElementsAttr padsAttributes =
myOp.getAttr("pads").dyn_cast_or_null<mlir::DenseElementsAttr>();
if (!padsAttributes)
emitError(loc, "Pad: unknown pads");
auto memRefType = convertToMemRefType(tensorType);
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
emitError(loc, "unexpected output has non-Constant shape");
// Number of loops
auto memRefShape = memRefType.getShape();
int64_t rank = memRefShape.size();
// get the padding vector into a temporary smallvector
SmallVector<int64_t, 2> pads(rank * 2, -1);
auto padsIt = padsAttributes.getValues<IntegerAttr>().begin();
for (int i = 0; i < rank * 2; ++i)
pads[i] = (*padsIt++).cast<IntegerAttr>().getInt();
// get the padding value
auto valueAttr = (*constantValAttr.getValues<FloatAttr>().begin());
// Iterate over the loop nest using the output shape.
BuildKrnlLoop padLoops(rewriter, loc, rank);
padLoops.createDefineAndOptimizeOp();
for (int i = 0; i < rank; ++i)
padLoops.pushBounds(0, alloc, i);
padLoops.createIterateOp();
// Iterate over the loop nest using the input shape.
BuildKrnlLoop valueLoops(rewriter, loc, rank);
valueLoops.createDefineAndOptimizeOp();
for (int i = 0; i < rank; ++i)
valueLoops.pushBounds(0, operandAdaptor.data(), i);
valueLoops.createIterateOp();
// Copy the input data into the output.
rewriter.setInsertionPointToStart(valueLoops.getIterateBlock());
SmallVector<Value, 4> inLoopIVs;
for (int i = 0; i < rank; ++i)
inLoopIVs.emplace_back(valueLoops.getInductionVar(i));
SmallVector<Value, 4> outLoopIVs;
for (int i = 0; i < rank; ++i) {
// Calculate the index for the load and store.
if (pads[i] == 0) {
outLoopIVs.emplace_back(valueLoops.getInductionVar(i));
} else {
auto outIV = rewriter.create<AddIOp>(loc,
rewriter.create<ConstantIndexOp>(loc, pads[i]),
valueLoops.getInductionVar(i));
outLoopIVs.emplace_back(outIV);
}
}
auto originValue =
rewriter.create<LoadOp>(loc, operandAdaptor.data(), inLoopIVs);
rewriter.create<StoreOp>(loc, originValue, alloc, outLoopIVs);
rewriter.setInsertionPointToStart(padLoops.getIterateBlock());
SmallVector<Value, 4> outLoopIVs1;
for (int i = 0; i < rank; ++i)
outLoopIVs1.emplace_back(padLoops.getInductionVar(i));
auto paddingValue = rewriter.create<ConstantOp>(loc, valueAttr);
rewriter.create<StoreOp>(loc, paddingValue, alloc, outLoopIVs1);
// Replace the original op with the generated code.
rewriter.replaceOp(op, alloc);
return success();
}
};
void populateLoweringONNXPadOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXPadOpLowering>(ctx);
}

View File

@ -1473,6 +1473,53 @@ bool ONNXMaxPoolSingleOutOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool ONNXPadOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!data().getType().isa<RankedTensorType>()) {
emitError("Pad: unknown input shape");
return false;
}
// Cannot infer if the pads is not constant
DenseElementsAttr padsAttributes =
getAttr("pads").dyn_cast_or_null<mlir::DenseElementsAttr>();
if (!padsAttributes) {
emitError("Pad: unknown pads");
return false;
}
auto dataTy = data().getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto dataRank = dataTy.getRank();
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
// Get pads from valueAttribute.
SmallVector<int64_t, 2> pads(dataRank * 2, -1);
auto valueIt = padsAttributes.getValues<IntegerAttr>().begin();
for (int64_t i = 0; i < dataRank * 2; ++i)
pads[i] = (*valueIt++).cast<IntegerAttr>().getInt();
// Pads consists of two values for each axis of data.
// The two values specify the number of elements padded before and after
// respectively.
for (int64_t i = 0; i < dataRank; ++i) {
int64_t p1 = pads[i];
int64_t p2 = pads[i + dataRank];
// Have to non-negative constant
if (p1 < 0 || p2 < 0) {
emitError("padding value can not be negative");
return false;
}
if (outputShape[i] != -1)
outputShape[i] += p1 + p2;
}
auto outputType = RankedTensorType::get(outputShape, dataTy.getElementType());
getResult().setType(outputType);
return true;
}
static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!data.getType().isa<RankedTensorType>()) if (!data.getType().isa<RankedTensorType>())

View File

@ -348,7 +348,17 @@ def ONNXConstantOp:ONNX_Op<"Constant",
let arguments = (ins OptionalAttr<AnyAttr>:$sparse_value, let arguments = (ins OptionalAttr<AnyAttr>:$sparse_value,
OptionalAttr<AnyAttr>:$value); OptionalAttr<AnyAttr>:$value);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
} let builders = [
OpBuilder<"Builder *builder, OperationState &state, Attribute sparse_value, Attribute value", [{
if (value) {
auto tensorType = value.getType();
build(builder, state, tensorType, sparse_value, value);
} else {
auto tensorType = sparse_value.getType();
build(builder, state, tensorType, sparse_value, value);
}
}]>
];}
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
[NoSideEffect]> { [NoSideEffect]> {
@ -1913,7 +1923,7 @@ def ONNXPReluOp:ONNX_Op<"PRelu",
} }
def ONNXPadOp:ONNX_Op<"Pad", def ONNXPadOp:ONNX_Op<"Pad",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> {
let summary = "ONNX Pad operation"; let summary = "ONNX Pad operation";
let description = [{ let description = [{
"Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, " "Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, "
@ -1999,10 +2009,27 @@ def ONNXPadOp:ONNX_Op<"Pad",
"" ""
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$pads, AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$pads,
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$constant_value, AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$constant_value,
DefaultValuedAttr<StrAttr, "constant">:$mode); DefaultValuedAttr<StrAttr, "constant">:$mode);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
let builders = [
OpBuilder<"Builder *builder, OperationState &state, Value data, Value pads, Value constant_value, StringAttr mode", [{
auto elementType = data.getType().cast<TensorType>().getElementType();
build(builder, state, UnrankedTensorType::get(elementType), data, pads, constant_value, mode);
}]>,
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{
std::map<std::string, size_t> promotableConstOperands() {
return {{"pads", 1}, {"constant_value", 2}};
}
}];
} }
def ONNXPowOp:ONNX_Op<"Pow", def ONNXPowOp:ONNX_Op<"Pow",

View File

@ -66,6 +66,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createDecomposeONNXToONNXPass()); pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createAttributePromotionPass());
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createAttributePromotionPass()); pm.addPass(mlir::createAttributePromotionPass());
} }

View File

@ -17,6 +17,8 @@
#include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Interface/PromotableConstOperandsOpInterface.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace { namespace {
@ -60,12 +62,25 @@ public:
// move it to an attribute, and use None to indicate the absence // move it to an attribute, and use None to indicate the absence
// of the original operand value. // of the original operand value.
auto operandToPromote = op->getOperand(i); auto operandToPromote = op->getOperand(i);
if (auto constantOp = dyn_cast_or_null<ConstantOp>( if (auto constantOp = dyn_cast_or_null<mlir::ONNXConstantOp>(
operandToPromote.getDefiningOp())) { operandToPromote.getDefiningOp())) {
op->setAttr(name, constantOp.value()); if (constantOp.valueAttr() &&
!constantOp.valueAttr().dyn_cast_or_null<UnitAttr>())
op->setAttr(name, constantOp.valueAttr());
if (constantOp.sparse_valueAttr() &&
!constantOp.sparse_valueAttr().dyn_cast_or_null<UnitAttr>())
op->setAttr(name, constantOp.sparse_valueAttr());
getOrCreateNoneValue(none, f); getOrCreateNoneValue(none, f);
op->setOperand(i, *none); op->setOperand(i, *none);
} }
if (auto constantOp = dyn_cast_or_null<ConstantOp>(
operandToPromote.getDefiningOp())) {
if (!constantOp.valueAttr().dyn_cast_or_null<UnitAttr>()) {
op->setAttr(name, constantOp.value());
getOrCreateNoneValue(none, f);
op->setOperand(i, *none);
}
}
} }
} }
}); });

View File

@ -41,6 +41,16 @@ ArrayAttr createArrayAttrOfZeros(
return rewriter.getI64ArrayAttr(vals); return rewriter.getI64ArrayAttr(vals);
} }
DenseElementsAttr createDenseFloatAttrOfValue(
PatternRewriter &rewriter, Value origValue, float constantValue) {
Type elementType = origValue.getType().cast<TensorType>().getElementType();
SmallVector<float, 1> wrapper(1, 0);
wrapper[0] = constantValue;
return DenseElementsAttr::get(
RankedTensorType::get(wrapper.size(), elementType),
llvm::makeArrayRef(wrapper));
}
// Pad a ArrayAttr with zeros. // Pad a ArrayAttr with zeros.
// //
// pads = [B1, B2, ... Bk, E1, E2, ..., Ek] // pads = [B1, B2, ... Bk, E1, E2, ..., Ek]
@ -52,7 +62,7 @@ ArrayAttr createArrayAttrOfZeros(
// nZeros nZeros // nZeros nZeros
// //
// This function is used for padding attribute in Conv. // This function is used for padding attribute in Conv.
ArrayAttr insertZerosForNonPaddedDims( DenseElementsAttr insertZerosForNonPaddedDims(
PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) { PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) {
int nDims = (int)origAttrs.getValue().size() / 2; int nDims = (int)origAttrs.getValue().size() / 2;
int nElements = (nDims + extensionLength) * 2; int nElements = (nDims + extensionLength) * 2;
@ -64,7 +74,12 @@ ArrayAttr insertZerosForNonPaddedDims(
pads[i + extensionLength] = beginPad; pads[i + extensionLength] = beginPad;
pads[nDims + extensionLength + i + extensionLength] = endPad; pads[nDims + extensionLength + i + extensionLength] = endPad;
} }
return rewriter.getI64ArrayAttr(pads);
mlir::Type elementType = rewriter.getIntegerType(64);
llvm::ArrayRef<int64_t> tensorDims(pads.data(), pads.size());
mlir::ShapedType tensorType =
mlir::RankedTensorType::get(tensorDims, elementType);
return rewriter.getI64TensorAttr(llvm::makeArrayRef(pads));
} }
/// Include the patterns defined in the Declarative Rewrite framework. /// Include the patterns defined in the Declarative Rewrite framework.

View File

@ -24,14 +24,17 @@ include "src/Dialect/ONNX/ONNXOps.td"
/// dag benefitsAdded = (addBenefit 0) /// dag benefitsAdded = (addBenefit 0)
/// >; /// >;
def GetNullAttr :
NativeCodeCall<"Attribute()">;
// Create a StringAttr from a string. // Create a StringAttr from a string.
class StringAttrOfValue<string val>: class StringAttrOfValue<string val>:
NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">; NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">;
// Create a FloatAttr from an interger value. // Create a DenseElementsAttr from an interger value.
// It seems Table-gen does not support `float` type, so we can not pass a float value. // It seems Table-gen does not support `float` type, so we can not pass a float value.
class FloatAttrOfValue<int val>: class FloatAttrOfValue<int val>:
NativeCodeCall<"FloatAttr::get($0.getType().cast<TensorType>().getElementType(), " # val # ")">; NativeCodeCall<"createDenseFloatAttrOfValue($_builder, $0, " # val # ")">;
// Create an ArrayAttr of IntergerAttr(s) of zero values. // Create an ArrayAttr of IntergerAttr(s) of zero values.
// This function is used for padding attribute in Conv. // This function is used for padding attribute in Conv.
@ -82,10 +85,15 @@ def ConvOpPaddingPattern: Pat<
$pads, $pads,
$strides), $strides),
(ONNXConvOp (ONNXConvOp
(ONNXPadConstantValuePadOp $x,
(insertZerosForNonPaddedDims<2> $pads), (ONNXPadOp $x,
(FloatAttrOfValue<0> $res), (ONNXConstantOp (GetNullAttr),
(StringAttrOfValue<"constant">)), (insertZerosForNonPaddedDims<2> $pads)),
(ONNXConstantOp (GetNullAttr),
(FloatAttrOfValue<0> $res)),
(StringAttrOfValue<"constant">)),
$w, $b, $auto_pad, $dilation, $group, $kernel_shape, $w, $b, $auto_pad, $dilation, $group, $kernel_shape,
(createArrayAttrOfZerosFrom $pads), (createArrayAttrOfZerosFrom $pads),
$strides), $strides),

View File

@ -118,6 +118,7 @@ public:
op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Softmax" &&
op->getName().getStringRef() != "onnx.Sqrt" && op->getName().getStringRef() != "onnx.Sqrt" &&
op->getName().getStringRef() != "onnx.Conv" && op->getName().getStringRef() != "onnx.Conv" &&
op->getName().getStringRef() != "onnx.Pad" &&
op->getName().getStringRef() != "onnx.PadConstantPad" && op->getName().getStringRef() != "onnx.PadConstantPad" &&
op->getName().getStringRef() != "onnx.PadConstantValuePad" && op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&

View File

@ -62,10 +62,13 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>
%cst = constant unit %cst = constant unit
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-NEXT: %cst = constant unit // CHECK-NEXT: %cst = constant unit
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32> // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[0, 0, 2, 3, 0, 0, 4, 5]> : tensor<8xi64>} : () -> tensor<8xi64>
// CHECK-NEXT: %1 = "onnx.Conv"(%0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> // CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
// CHECK-NEXT: return %1 : tensor<*xf32> // CHECK-NEXT: %2 = "onnx.Pad"(%arg0, %0, %1) {mode = "constant"} : (tensor<1x9x32x64xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32>
// CHECK-NEXT: %3 = "onnx.Conv"(%2, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<*xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
// CHECK-NEXT: return %3 : tensor<*xf32>
} }
// ----- // -----
@ -93,4 +96,3 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32> // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
// return [[GEMM]] : tensor<*xf32> // return [[GEMM]] : tensor<*xf32>
} }

View File

@ -1637,6 +1637,32 @@ func @test_constant_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> {
// CHECK: } // CHECK: }
} }
func @test_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> {
%cst = constant unit
%0 = "onnx.Pad"(%arg0, %cst, %cst) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 3, 2, 1]> : tensor<4xi32>} : (tensor<16x16xf32>, none, none) -> tensor<18x20xf32>
return %0 : tensor<18x20xf32>
// CHECK-LABEL: test_pad1
// CHECK: [[RES:%.+]] = alloc() : memref<18x20xf32>
// CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 18, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 20) {
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
// CHECK: store [[CST]], [[RES]][%arg1, %arg2] : memref<18x20xf32>
// CHECK: }
// CHECK: [[DEF_LOOPS2:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS2:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 16, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 16) {
// CHECK: [[CST1:%.+]] = constant 3 : index
// CHECK: [[ADD:%.+]] = addi [[CST1]], %arg2 : index
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<16x16xf32>
// CHECK: store [[LOAD]], [[RES]][%arg1, [[ADD]]] : memref<18x20xf32>
// CHECK: }
}
// ----- // -----
func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> { func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {

View File

@ -354,6 +354,19 @@ func @test_conv_12(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>, %arg2
/// Test shape inference for PadConstantValuePad. /// Test shape inference for PadConstantValuePad.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Test Pad_1
func @test_Pad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Pad"(%arg0, %cst, %cst) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<16x13xf32>, none, none) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_Pad_1
// CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK: [[RES:%.+]] = "onnx.Pad"(%arg0, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<16x13xf32>, none, none) -> tensor<18x19xf32>
// CHECK: return [[RES]] : tensor<18x19xf32>
}
/// Test PadConstantValuePad_1 /// Test PadConstantValuePad_1
func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> { func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
%0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<*xf32> %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<*xf32>

View File

@ -10,6 +10,16 @@ func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
} }
func @test_should_promote_to_attribute_1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%shape = "onnx.Constant"() { value = dense<[6, 7, 42]> : tensor<3xi32>}: () -> tensor<3xi32>
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: test_should_promote_to_attribute_1
// CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
}
func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : tensor<*xi64>) -> tensor<*xf32> { func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : tensor<*xi64>) -> tensor<*xf32> {
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32> %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
@ -30,3 +40,14 @@ func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf3
// CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32> // CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32> // CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
} }
func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
%shape = constant dense<[0, 2, 2, 4]> : tensor<4xi32>
%constant_value = constant dense<[0.]> : tensor<1xf32>
%0 = "onnx.Pad"(%arg0, %shape, %constant_value) {mode = "constant"} : (tensor<?x?xf32>, tensor<4xi32>, tensor<1xf32>)-> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: test_should_promote_to_attribute1
// CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
}

View File

@ -64,7 +64,7 @@ OpsWithShapeInference = [
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
'LSTM', 'GRU', 'Split' 'LSTM', 'GRU', 'Split', 'Pad'
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.
@ -77,7 +77,8 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
# should proceed. The key is the operation's name and the value is a list of # should proceed. The key is the operation's name and the value is a list of
# tuples, whose first item is the attribute/operand name, and the second item is # tuples, whose first item is the attribute/operand name, and the second item is
# the index at which such operand occurs in the list of the operation's inputs. # the index at which such operand occurs in the list of the operation's inputs.
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]} OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
"Pad": [("pads", 1), ("constant_value", 2)]}
# Add an Op in this list if the Op needs result type deduction which is required # Add an Op in this list if the Op needs result type deduction which is required
# when writing declarative rewriting rules. Deduced type is always # when writing declarative rewriting rules. Deduced type is always
@ -87,7 +88,24 @@ OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]}
# Currenlty, there are only two build methods generated: # Currenlty, there are only two build methods generated:
# - one with operands and attributes having a separate parameter, and # - one with operands and attributes having a separate parameter, and
# - one with operands and attributes having aggregated parameters. # - one with operands and attributes having aggregated parameters.
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare'] custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
#a dictionary to add any special definition for an operation
custom_definition_misc = dict([ ('Constant',
''' let builders = [
OpBuilder<"Builder *builder, OperationState &state, Attribute sparse_value, Attribute value", [{
if (value) {
auto tensorType = value.getType();
build(builder, state, tensorType, sparse_value, value);
} else {
auto tensorType = sparse_value.getType();
build(builder, state, tensorType, sparse_value, value);
}
}]>
];'''
)])
SNIPPETS = collect_snippets() SNIPPETS = collect_snippets()
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations() SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
@ -254,7 +272,7 @@ def get_operands_or_results(schema, is_input):
# nullable in case it migrates to be an attribute. # nullable in case it migrates to be an attribute.
if schema.name in OpsWithPromotableConstOperands: if schema.name in OpsWithPromotableConstOperands:
idxs = dict(OpsWithPromotableConstOperands[schema.name]).values() idxs = dict(OpsWithPromotableConstOperands[schema.name]).values()
if i in idxs: if i in idxs and not OpSchema.FormalParameterOption.Optional == value.option:
types.append("NoneType") types.append("NoneType")
if OpSchema.FormalParameterOption.Optional == value.option: if OpSchema.FormalParameterOption.Optional == value.option:
@ -451,6 +469,10 @@ def gen_op_def(schema):
if schema.name in OpsWithPromotableConstOperands: if schema.name in OpsWithPromotableConstOperands:
s = get_promotable_const_operands_func( s = get_promotable_const_operands_func(
s, indent, OpsWithPromotableConstOperands[schema.name]) s, indent, OpsWithPromotableConstOperands[schema.name])
if ( schema.name in custom_definition_misc) :
s += custom_definition_misc[schema.name]
s += '}\n\n' s += '}\n\n'
return s return s
@ -492,7 +514,7 @@ def gen_op_importer(schema, file):
"/* expected_num_operands = */ {}".format(expected_num_operands)) "/* expected_num_operands = */ {}".format(expected_num_operands))
args.append( args.append(
'/* expected_num_results = */ {}'.format(expected_num_results)) '/* expected_num_results = */ {}'.format(expected_num_results))
s += inc_indent(indent) + "return {}({});\n".format( s += inc_indent(indent) + " {}({});\n".format(
handler_func, ", ".join(args)) handler_func, ", ".join(args))
file.write(s) file.write(s)