diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 9d11120..c742c69 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -3033,7 +3033,7 @@ ONNX Pad operation | Operand | Description | | :-----: | ----------- | `data` | memref of any type values or tensor of any type values -`pads` | memref of any type values or tensor of any type values +`pads` | memref of any type values or tensor of any type values or none type `constant_value` | memref of any type values or tensor of any type values or none type #### Results: diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 068b222..aecbbdb 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -303,9 +303,34 @@ private: * Special handle for Pad operations. */ void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) { + int nOps = node.input().size(); if (nOps == 2) { - buildOperation(node, 2, nOut); + llvm::SmallVector dims; + dims.push_back(1); + llvm::SmallVector values; + values.push_back(0.); + auto elementType = builder_.getF32Type(); + llvm::ArrayRef tensorDims(dims.data(), dims.size()); + auto tensorType = mlir::RankedTensorType::get(tensorDims, elementType); + auto constantDenseAttribute = + mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)); + + // Use the special builder defined in ONNXOp.td.inc. + auto constantOp = builder_.create( + UnknownLoc(), mlir::Attribute(), constantDenseAttribute); + mlir::Value constantResult = *(constantOp.getODSResults(0).begin()); + std::vector inputs; + for (const auto &item : node.input()) + if (initializedTensors.ContainKey(legalize_name(item))) { + inputs.push_back(initializedTensors.EmitInitializerForInputTensor( + UnknownLoc(), builder_, legalize_name(item))); + } else if (frontend_symbols_.ContainKey(legalize_name(item))) { + inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); + } + inputs.push_back(constantResult); + + buildOutputAndOperation(node, inputs, nIn, nOut); } else { buildOperation(node, nIn, nOut); } diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index c37b59b..94e8b54 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -5,314 +5,314 @@ //******************************************************** if (opName == "Abs") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Acos") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Acosh") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Add") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "And") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "ArgMax") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ArgMin") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Asin") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Asinh") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Atan") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Atanh") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "AveragePool") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "BatchNormalization") - return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5); + ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5); if (opName == "BitShift") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Cast") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Ceil") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Clip") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Compress") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Concat") - return buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); if (opName == "ConcatFromSequence") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Constant") - return buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); if (opName == "ConstantOfShape") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Conv") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "ConvInteger") - return buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); if (opName == "ConvTranspose") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Cos") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Cosh") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "CumSum") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "DepthToSpace") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "DequantizeLinear") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Det") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Div") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Dropout") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); if (opName == "DynamicQuantizeLinear") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3); if (opName == "Elu") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Equal") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Erf") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Exp") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Expand") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "EyeLike") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Flatten") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Floor") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "GRU") - return buildOperation(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); + buildOperation(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); if (opName == "Gather") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "GatherElements") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "GatherND") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Gemm") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "GlobalAveragePool") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "GlobalLpPool") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "GlobalMaxPool") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Greater") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "HardSigmoid") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Hardmax") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Identity") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "If") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); if (opName == "InstanceNormalization") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "IsInf") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "IsNaN") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "LRN") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "LSTM") - return buildOperation(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3); + buildOperation(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3); if (opName == "LeakyRelu") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Less") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Log") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "LogSoftmax") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Loop") - return buildOperation(node); + buildOperation(node); if (opName == "LpNormalization") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "LpPool") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "MatMul") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "MatMulInteger") - return buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); if (opName == "Max") - return buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); if (opName == "MaxPool") - return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); + ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2); if (opName == "MaxRoiPool") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "MaxUnpool") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Mean") - return buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); if (opName == "MeanVarianceNormalization") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Min") - return buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); if (opName == "Mod") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Mul") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Multinomial") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Neg") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "NonMaxSuppression") - return buildOperation(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); if (opName == "NonZero") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Not") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "OneHot") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Or") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "PRelu") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Pad") - return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Pow") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "QLinearConv") - return buildOperation(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1); if (opName == "QLinearMatMul") - return buildOperation(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1); if (opName == "QuantizeLinear") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "RNN") - return buildOperation(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); + buildOperation(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2); if (opName == "RandomNormal") - return buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); if (opName == "RandomNormalLike") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "RandomUniform") - return buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); if (opName == "RandomUniformLike") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Range") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Reciprocal") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceL1") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceL2") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceLogSum") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceLogSumExp") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceMax") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceMean") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceMin") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceProd") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceSum") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ReduceSumSquare") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Relu") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Reshape") - return ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Resize") - return buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); if (opName == "ReverseSequence") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "RoiAlign") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Round") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Scan") - return buildOperation(node); + buildOperation(node); if (opName == "Scatter") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "ScatterElements") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "ScatterND") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Selu") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "SequenceAt") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "SequenceConstruct") - return buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); if (opName == "SequenceEmpty") - return buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1); if (opName == "SequenceErase") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "SequenceInsert") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "SequenceLength") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Shape") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Shrink") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Sigmoid") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Sign") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Sin") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Sinh") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Size") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Slice") - return buildOperation(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1); if (opName == "Softmax") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Softplus") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Softsign") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "SpaceToDepth") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Split") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1); if (opName == "SplitToSequence") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Sqrt") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Squeeze") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "StringNormalizer") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Sub") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Sum") - return buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1); if (opName == "Tan") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Tanh") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "TfIdfVectorizer") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "ThresholdedRelu") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Tile") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "TopK") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2); if (opName == "Transpose") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Unique") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4); if (opName == "Unsqueeze") - return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Upsample") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Where") - return buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); if (opName == "Xor") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index f147f8f..a972031 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -15,6 +15,7 @@ add_library(OMONNXToKrnl Tensor/Identity.cpp Tensor/Reshape.cpp Tensor/PadConstantValuePad.cpp + Tensor/Pad.cpp Tensor/Transpose.cpp Tensor/Unsqueeze.cpp Tensor/Constant.cpp diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index f2e04b7..c3de275 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -93,6 +93,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { // Tensor populateLoweringONNXReshapeOpPattern(patterns, &getContext()); populateLoweringONNXPadConstantValuePadOpPattern(patterns, &getContext()); + populateLoweringONNXPadOpPattern(patterns, &getContext()); populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext()); populateLoweringONNXTransposeOpPattern(patterns, &getContext()); populateLoweringONNXIdentityOpPattern(patterns, &getContext()); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 97c019e..b5079d5 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -237,6 +237,9 @@ void populateLoweringONNXTransposeOpPattern( void populateLoweringONNXPadConstantValuePadOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); +void populateLoweringONNXPadOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + void populateLoweringONNXReshapeOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp b/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp new file mode 100644 index 0000000..7421225 --- /dev/null +++ b/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp @@ -0,0 +1,120 @@ +//===-----------------------Pad.cpp - Lowering Pad Op -------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Pad Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +using namespace mlir; + +struct ONNXPadOpLowering : public ConversionPattern { + ONNXPadOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXPadOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + ONNXPadOp myOp = llvm::dyn_cast(op); + ONNXPadOpOperandAdaptor operandAdaptor(operands); + auto tensorType = myOp.output().getType(); + + auto loc = op->getLoc(); + + // Only constant padding is supported now. + auto padMode = myOp.mode(); + if (padMode != "constant") + emitError(loc, "unsupported mode for Pad"); + DenseElementsAttr constantValAttr = + myOp.getAttr("constant_value") + .dyn_cast_or_null(); + if (!constantValAttr) + emitError(loc, "unsupported value"); + + DenseElementsAttr padsAttributes = + myOp.getAttr("pads").dyn_cast_or_null(); + if (!padsAttributes) + emitError(loc, "Pad: unknown pads"); + + auto memRefType = convertToMemRefType(tensorType); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + emitError(loc, "unexpected output has non-Constant shape"); + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + + // get the padding vector into a temporary smallvector + SmallVector pads(rank * 2, -1); + auto padsIt = padsAttributes.getValues().begin(); + for (int i = 0; i < rank * 2; ++i) + pads[i] = (*padsIt++).cast().getInt(); + + // get the padding value + auto valueAttr = (*constantValAttr.getValues().begin()); + + // Iterate over the loop nest using the output shape. + BuildKrnlLoop padLoops(rewriter, loc, rank); + padLoops.createDefineAndOptimizeOp(); + for (int i = 0; i < rank; ++i) + padLoops.pushBounds(0, alloc, i); + padLoops.createIterateOp(); + + // Iterate over the loop nest using the input shape. + BuildKrnlLoop valueLoops(rewriter, loc, rank); + valueLoops.createDefineAndOptimizeOp(); + for (int i = 0; i < rank; ++i) + valueLoops.pushBounds(0, operandAdaptor.data(), i); + valueLoops.createIterateOp(); + + // Copy the input data into the output. + rewriter.setInsertionPointToStart(valueLoops.getIterateBlock()); + + SmallVector inLoopIVs; + for (int i = 0; i < rank; ++i) + inLoopIVs.emplace_back(valueLoops.getInductionVar(i)); + + SmallVector outLoopIVs; + for (int i = 0; i < rank; ++i) { + // Calculate the index for the load and store. + if (pads[i] == 0) { + outLoopIVs.emplace_back(valueLoops.getInductionVar(i)); + } else { + auto outIV = rewriter.create(loc, + rewriter.create(loc, pads[i]), + valueLoops.getInductionVar(i)); + outLoopIVs.emplace_back(outIV); + } + } + + auto originValue = + rewriter.create(loc, operandAdaptor.data(), inLoopIVs); + rewriter.create(loc, originValue, alloc, outLoopIVs); + rewriter.setInsertionPointToStart(padLoops.getIterateBlock()); + + SmallVector outLoopIVs1; + for (int i = 0; i < rank; ++i) + outLoopIVs1.emplace_back(padLoops.getInductionVar(i)); + + auto paddingValue = rewriter.create(loc, valueAttr); + rewriter.create(loc, paddingValue, alloc, outLoopIVs1); + + // Replace the original op with the generated code. + rewriter.replaceOp(op, alloc); + + return success(); + } +}; + +void populateLoweringONNXPadOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 16cd67b..ff733ca 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -1473,6 +1473,53 @@ bool ONNXMaxPoolSingleOutOp::inferShapes() { //===----------------------------------------------------------------------===// +bool ONNXPadOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!data().getType().isa()) { + emitError("Pad: unknown input shape"); + return false; + } + + // Cannot infer if the pads is not constant + DenseElementsAttr padsAttributes = + getAttr("pads").dyn_cast_or_null(); + + if (!padsAttributes) { + emitError("Pad: unknown pads"); + return false; + } + + auto dataTy = data().getType().cast(); + auto dataShape = dataTy.getShape(); + auto dataRank = dataTy.getRank(); + SmallVector outputShape(dataShape.begin(), dataShape.end()); + + // Get pads from valueAttribute. + SmallVector pads(dataRank * 2, -1); + auto valueIt = padsAttributes.getValues().begin(); + for (int64_t i = 0; i < dataRank * 2; ++i) + pads[i] = (*valueIt++).cast().getInt(); + + // Pads consists of two values for each axis of data. + // The two values specify the number of elements padded before and after + // respectively. + for (int64_t i = 0; i < dataRank; ++i) { + int64_t p1 = pads[i]; + int64_t p2 = pads[i + dataRank]; + // Have to non-negative constant + if (p1 < 0 || p2 < 0) { + emitError("padding value can not be negative"); + return false; + } + if (outputShape[i] != -1) + outputShape[i] += p1 + p2; + } + + auto outputType = RankedTensorType::get(outputShape, dataTy.getElementType()); + getResult().setType(outputType); + return true; +} + static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { // Cannot infer shape if no shape exists. if (!data.getType().isa()) diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index f047ae3..39a785f 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -348,7 +348,17 @@ def ONNXConstantOp:ONNX_Op<"Constant", let arguments = (ins OptionalAttr:$sparse_value, OptionalAttr:$value); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); -} + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, Attribute sparse_value, Attribute value", [{ + if (value) { + auto tensorType = value.getType(); + build(builder, state, tensorType, sparse_value, value); + } else { + auto tensorType = sparse_value.getType(); + build(builder, state, tensorType, sparse_value, value); + } + }]> + ];} def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", [NoSideEffect]> { @@ -1913,7 +1923,7 @@ def ONNXPReluOp:ONNX_Op<"PRelu", } def ONNXPadOp:ONNX_Op<"Pad", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods, OpInterface<"PromotableConstOperandsOpInterface">]> { let summary = "ONNX Pad operation"; let description = [{ "Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, " @@ -1999,10 +2009,27 @@ def ONNXPadOp:ONNX_Op<"Pad", "" }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$pads, + AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$pads, AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$constant_value, DefaultValuedAttr:$mode); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output); + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, Value data, Value pads, Value constant_value, StringAttr mode", [{ + auto elementType = data.getType().cast().getElementType(); + build(builder, state, UnrankedTensorType::get(elementType), data, pads, constant_value, mode); + }]>, + OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto elementType = operands[0].getType().cast().getElementType(); + std::vector outputTypes; + outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + std::map promotableConstOperands() { + return {{"pads", 1}, {"constant_value", 2}}; + } + }]; } def ONNXPowOp:ONNX_Op<"Pow", diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index c806e4e..1c7d164 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -66,6 +66,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) { pm.addPass(mlir::createDecomposeONNXToONNXPass()); pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createAttributePromotionPass()); pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createAttributePromotionPass()); } @@ -178,4 +179,4 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget, (outputBaseName + ".onnx.mlir").c_str()); } } -} \ No newline at end of file +} diff --git a/src/Transform/ONNX/AttributePromotion.cpp b/src/Transform/ONNX/AttributePromotion.cpp index 35721cd..461c911 100644 --- a/src/Transform/ONNX/AttributePromotion.cpp +++ b/src/Transform/ONNX/AttributePromotion.cpp @@ -17,6 +17,8 @@ #include "src/Interface/PromotableConstOperandsOpInterface.hpp" #include "src/Pass/Passes.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + using namespace mlir; namespace { @@ -60,12 +62,25 @@ public: // move it to an attribute, and use None to indicate the absence // of the original operand value. auto operandToPromote = op->getOperand(i); - if (auto constantOp = dyn_cast_or_null( + if (auto constantOp = dyn_cast_or_null( operandToPromote.getDefiningOp())) { - op->setAttr(name, constantOp.value()); + if (constantOp.valueAttr() && + !constantOp.valueAttr().dyn_cast_or_null()) + op->setAttr(name, constantOp.valueAttr()); + if (constantOp.sparse_valueAttr() && + !constantOp.sparse_valueAttr().dyn_cast_or_null()) + op->setAttr(name, constantOp.sparse_valueAttr()); getOrCreateNoneValue(none, f); op->setOperand(i, *none); } + if (auto constantOp = dyn_cast_or_null( + operandToPromote.getDefiningOp())) { + if (!constantOp.valueAttr().dyn_cast_or_null()) { + op->setAttr(name, constantOp.value()); + getOrCreateNoneValue(none, f); + op->setOperand(i, *none); + } + } } } }); diff --git a/src/Transform/ONNX/ONNXRewrite.cpp b/src/Transform/ONNX/ONNXRewrite.cpp index a6f1b3f..b390419 100644 --- a/src/Transform/ONNX/ONNXRewrite.cpp +++ b/src/Transform/ONNX/ONNXRewrite.cpp @@ -41,6 +41,16 @@ ArrayAttr createArrayAttrOfZeros( return rewriter.getI64ArrayAttr(vals); } +DenseElementsAttr createDenseFloatAttrOfValue( + PatternRewriter &rewriter, Value origValue, float constantValue) { + Type elementType = origValue.getType().cast().getElementType(); + SmallVector wrapper(1, 0); + wrapper[0] = constantValue; + return DenseElementsAttr::get( + RankedTensorType::get(wrapper.size(), elementType), + llvm::makeArrayRef(wrapper)); +} + // Pad a ArrayAttr with zeros. // // pads = [B1, B2, ... Bk, E1, E2, ..., Ek] @@ -52,7 +62,7 @@ ArrayAttr createArrayAttrOfZeros( // nZeros nZeros // // This function is used for padding attribute in Conv. -ArrayAttr insertZerosForNonPaddedDims( +DenseElementsAttr insertZerosForNonPaddedDims( PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) { int nDims = (int)origAttrs.getValue().size() / 2; int nElements = (nDims + extensionLength) * 2; @@ -64,7 +74,12 @@ ArrayAttr insertZerosForNonPaddedDims( pads[i + extensionLength] = beginPad; pads[nDims + extensionLength + i + extensionLength] = endPad; } - return rewriter.getI64ArrayAttr(pads); + + mlir::Type elementType = rewriter.getIntegerType(64); + llvm::ArrayRef tensorDims(pads.data(), pads.size()); + mlir::ShapedType tensorType = + mlir::RankedTensorType::get(tensorDims, elementType); + return rewriter.getI64TensorAttr(llvm::makeArrayRef(pads)); } /// Include the patterns defined in the Declarative Rewrite framework. diff --git a/src/Transform/ONNX/ONNXRewrite.td b/src/Transform/ONNX/ONNXRewrite.td index 496cdc2..4857348 100644 --- a/src/Transform/ONNX/ONNXRewrite.td +++ b/src/Transform/ONNX/ONNXRewrite.td @@ -24,14 +24,17 @@ include "src/Dialect/ONNX/ONNXOps.td" /// dag benefitsAdded = (addBenefit 0) /// >; +def GetNullAttr : + NativeCodeCall<"Attribute()">; + // Create a StringAttr from a string. class StringAttrOfValue: NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">; -// Create a FloatAttr from an interger value. +// Create a DenseElementsAttr from an interger value. // It seems Table-gen does not support `float` type, so we can not pass a float value. class FloatAttrOfValue: - NativeCodeCall<"FloatAttr::get($0.getType().cast().getElementType(), " # val # ")">; + NativeCodeCall<"createDenseFloatAttrOfValue($_builder, $0, " # val # ")">; // Create an ArrayAttr of IntergerAttr(s) of zero values. // This function is used for padding attribute in Conv. @@ -82,10 +85,15 @@ def ConvOpPaddingPattern: Pat< $pads, $strides), (ONNXConvOp - (ONNXPadConstantValuePadOp $x, - (insertZerosForNonPaddedDims<2> $pads), - (FloatAttrOfValue<0> $res), - (StringAttrOfValue<"constant">)), + + (ONNXPadOp $x, + (ONNXConstantOp (GetNullAttr), + (insertZerosForNonPaddedDims<2> $pads)), + (ONNXConstantOp (GetNullAttr), + (FloatAttrOfValue<0> $res)), + (StringAttrOfValue<"constant">)), + + $w, $b, $auto_pad, $dilation, $group, $kernel_shape, (createArrayAttrOfZerosFrom $pads), $strides), diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index e18f8fe..ad1def7 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -118,6 +118,7 @@ public: op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Sqrt" && op->getName().getStringRef() != "onnx.Conv" && + op->getName().getStringRef() != "onnx.Pad" && op->getName().getStringRef() != "onnx.PadConstantPad" && op->getName().getStringRef() != "onnx.PadConstantValuePad" && op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 123d77f..3d37f5d 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -62,10 +62,13 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32> %cst = constant unit %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + // CHECK-NEXT: %cst = constant unit - // CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32> - // CHECK-NEXT: %1 = "onnx.Conv"(%0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> - // CHECK-NEXT: return %1 : tensor<*xf32> + // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[0, 0, 2, 3, 0, 0, 4, 5]> : tensor<8xi64>} : () -> tensor<8xi64> + // CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> + // CHECK-NEXT: %2 = "onnx.Pad"(%arg0, %0, %1) {mode = "constant"} : (tensor<1x9x32x64xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32> + // CHECK-NEXT: %3 = "onnx.Conv"(%2, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<*xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> + // CHECK-NEXT: return %3 : tensor<*xf32> } // ----- @@ -93,4 +96,3 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1 // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32> // return [[GEMM]] : tensor<*xf32> } - diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 57d71f3..10ff7cd 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1637,6 +1637,32 @@ func @test_constant_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> { // CHECK: } } +func @test_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> { + %cst = constant unit + %0 = "onnx.Pad"(%arg0, %cst, %cst) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 3, 2, 1]> : tensor<4xi32>} : (tensor<16x16xf32>, none, none) -> tensor<18x20xf32> + return %0 : tensor<18x20xf32> + // CHECK-LABEL: test_pad1 + // CHECK: [[RES:%.+]] = alloc() : memref<18x20xf32> + // CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 18, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 20) { + // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: store [[CST]], [[RES]][%arg1, %arg2] : memref<18x20xf32> + // CHECK: } + // CHECK: [[DEF_LOOPS2:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS2:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 16, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 16) { + // CHECK: [[CST1:%.+]] = constant 3 : index + // CHECK: [[ADD:%.+]] = addi [[CST1]], %arg2 : index + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<16x16xf32> + // CHECK: store [[LOAD]], [[RES]][%arg1, [[ADD]]] : memref<18x20xf32> + // CHECK: } +} + // ----- func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> { diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index faed4ee..ef2ba33 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -354,6 +354,19 @@ func @test_conv_12(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>, %arg2 /// Test shape inference for PadConstantValuePad. //===----------------------------------------------------------------------===// +/// Test Pad_1 +func @test_Pad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> { + %cst = constant unit + %0 = "onnx.Pad"(%arg0, %cst, %cst) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<16x13xf32>, none, none) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_Pad_1 + // CHECK-NEXT: [[NONE:%.+]] = constant unit + // CHECK: [[RES:%.+]] = "onnx.Pad"(%arg0, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<16x13xf32>, none, none) -> tensor<18x19xf32> + // CHECK: return [[RES]] : tensor<18x19xf32> +} + + /// Test PadConstantValuePad_1 func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> { %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<*xf32> diff --git a/test/mlir/transform/attribute_promotion.mlir b/test/mlir/transform/attribute_promotion.mlir index a7555fa..563bbb7 100644 --- a/test/mlir/transform/attribute_promotion.mlir +++ b/test/mlir/transform/attribute_promotion.mlir @@ -10,6 +10,16 @@ func @test_should_promote_to_attribute(%arg0 : tensor) -> tensor<*xf32 // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> } +func @test_should_promote_to_attribute_1(%arg0 : tensor) -> tensor<*xf32> { + %shape = "onnx.Constant"() { value = dense<[6, 7, 42]> : tensor<3xi32>}: () -> tensor<3xi32> + %0 = "onnx.Reshape"(%arg0, %shape) : (tensor, tensor<3xi32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + // CHECK-LABEL: test_should_promote_to_attribute_1 + // CHECK-NEXT: [[NONE:%.+]] = constant unit + // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor, none) -> tensor<*xf32> + // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> +} + func @test_should_not_promote_to_attribute(%arg0 : tensor, %arg1 : tensor<*xi64>) -> tensor<*xf32> { %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor, tensor<*xi64>) -> tensor<*xf32> return %0 : tensor<*xf32> @@ -29,4 +39,15 @@ func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor : tensor<3xi32>} : (tensor, none) -> tensor<*xf32> // CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32> // CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32> -} \ No newline at end of file +} + +func @test_should_promote_to_attribute1(%arg0 : tensor) -> tensor<*xf32> { + %shape = constant dense<[0, 2, 2, 4]> : tensor<4xi32> + %constant_value = constant dense<[0.]> : tensor<1xf32> + %0 = "onnx.Pad"(%arg0, %shape, %constant_value) {mode = "constant"} : (tensor, tensor<4xi32>, tensor<1xf32>)-> tensor<*xf32> + return %0 : tensor<*xf32> + // CHECK-LABEL: test_should_promote_to_attribute1 + // CHECK-NEXT: [[NONE:%.+]] = constant unit + // CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor, none, none) -> tensor<*xf32> + // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> +} diff --git a/utils/gen_doc.py b/utils/gen_doc.py index 46afb5d..0c55118 100644 --- a/utils/gen_doc.py +++ b/utils/gen_doc.py @@ -64,7 +64,7 @@ OpsWithShapeInference = [ 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', - 'LSTM', 'GRU', 'Split' + 'LSTM', 'GRU', 'Split', 'Pad' ] # Operations supporting canonicalization. @@ -77,7 +77,8 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv'] # should proceed. The key is the operation's name and the value is a list of # tuples, whose first item is the attribute/operand name, and the second item is # the index at which such operand occurs in the list of the operation's inputs. -OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]} +OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)], + "Pad": [("pads", 1), ("constant_value", 2)]} # Add an Op in this list if the Op needs result type deduction which is required # when writing declarative rewriting rules. Deduced type is always @@ -87,7 +88,24 @@ OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]} # Currenlty, there are only two build methods generated: # - one with operands and attributes having a separate parameter, and # - one with operands and attributes having aggregated parameters. -custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare'] +custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad'] + + +#a dictionary to add any special definition for an operation +custom_definition_misc = dict([ ('Constant', + ''' let builders = [ + OpBuilder<"Builder *builder, OperationState &state, Attribute sparse_value, Attribute value", [{ + if (value) { + auto tensorType = value.getType(); + build(builder, state, tensorType, sparse_value, value); + } else { + auto tensorType = sparse_value.getType(); + build(builder, state, tensorType, sparse_value, value); + } + }]> + ];''' + )]) + SNIPPETS = collect_snippets() SAMPLE_IMPLEMENTATIONS = collect_sample_implementations() @@ -254,7 +272,7 @@ def get_operands_or_results(schema, is_input): # nullable in case it migrates to be an attribute. if schema.name in OpsWithPromotableConstOperands: idxs = dict(OpsWithPromotableConstOperands[schema.name]).values() - if i in idxs: + if i in idxs and not OpSchema.FormalParameterOption.Optional == value.option: types.append("NoneType") if OpSchema.FormalParameterOption.Optional == value.option: @@ -451,6 +469,10 @@ def gen_op_def(schema): if schema.name in OpsWithPromotableConstOperands: s = get_promotable_const_operands_func( s, indent, OpsWithPromotableConstOperands[schema.name]) + + if ( schema.name in custom_definition_misc) : + s += custom_definition_misc[schema.name] + s += '}\n\n' return s @@ -492,7 +514,7 @@ def gen_op_importer(schema, file): "/* expected_num_operands = */ {}".format(expected_num_operands)) args.append( '/* expected_num_results = */ {}'.format(expected_num_results)) - s += inc_indent(indent) + "return {}({});\n".format( + s += inc_indent(indent) + " {}({});\n".format( handler_func, ", ".join(args)) file.write(s)