Implement PadOp based on attribute promotion (#71)
* enable promote attr for pad * use optional arguments for pad * shape infereance for pad * Lowering Pad * format file * use DenseTensor for the attribute * use Pad in ONNXRewrite * fix the merge conflict * fix the attr given to constantOp * handle ONNXConstantOp in attribute promotion * Fix bug when AttributePromotion is called more than once * update ONNXOps.td.inc with correct version of onnx * update onnx.md * responses to review * fix the build error * change the implementation of Pad * delete commented out code * clang format Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
4a68597417
commit
23bea50404
|
@ -3033,7 +3033,7 @@ ONNX Pad operation
|
||||||
| Operand | Description |
|
| Operand | Description |
|
||||||
| :-----: | ----------- |
|
| :-----: | ----------- |
|
||||||
`data` | memref of any type values or tensor of any type values
|
`data` | memref of any type values or tensor of any type values
|
||||||
`pads` | memref of any type values or tensor of any type values
|
`pads` | memref of any type values or tensor of any type values or none type
|
||||||
`constant_value` | memref of any type values or tensor of any type values or none type
|
`constant_value` | memref of any type values or tensor of any type values or none type
|
||||||
|
|
||||||
#### Results:
|
#### Results:
|
||||||
|
|
|
@ -303,9 +303,34 @@ private:
|
||||||
* Special handle for Pad operations.
|
* Special handle for Pad operations.
|
||||||
*/
|
*/
|
||||||
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
|
void ImportNodePad(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
|
|
||||||
int nOps = node.input().size();
|
int nOps = node.input().size();
|
||||||
if (nOps == 2) {
|
if (nOps == 2) {
|
||||||
buildOperation<mlir::ONNXPadConstantValueOp>(node, 2, nOut);
|
llvm::SmallVector<int64_t, 2> dims;
|
||||||
|
dims.push_back(1);
|
||||||
|
llvm::SmallVector<float, 2> values;
|
||||||
|
values.push_back(0.);
|
||||||
|
auto elementType = builder_.getF32Type();
|
||||||
|
llvm::ArrayRef<int64_t> tensorDims(dims.data(), dims.size());
|
||||||
|
auto tensorType = mlir::RankedTensorType::get(tensorDims, elementType);
|
||||||
|
auto constantDenseAttribute =
|
||||||
|
mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
|
||||||
|
|
||||||
|
// Use the special builder defined in ONNXOp.td.inc.
|
||||||
|
auto constantOp = builder_.create<mlir::ONNXConstantOp>(
|
||||||
|
UnknownLoc(), mlir::Attribute(), constantDenseAttribute);
|
||||||
|
mlir::Value constantResult = *(constantOp.getODSResults(0).begin());
|
||||||
|
std::vector<mlir::Value> inputs;
|
||||||
|
for (const auto &item : node.input())
|
||||||
|
if (initializedTensors.ContainKey(legalize_name(item))) {
|
||||||
|
inputs.push_back(initializedTensors.EmitInitializerForInputTensor(
|
||||||
|
UnknownLoc(), builder_, legalize_name(item)));
|
||||||
|
} else if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
}
|
||||||
|
inputs.push_back(constantResult);
|
||||||
|
|
||||||
|
buildOutputAndOperation<mlir::ONNXPadOp>(node, inputs, nIn, nOut);
|
||||||
} else {
|
} else {
|
||||||
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
|
buildOperation<mlir::ONNXPadOp>(node, nIn, nOut);
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,314 +5,314 @@
|
||||||
//********************************************************
|
//********************************************************
|
||||||
|
|
||||||
if (opName == "Abs")
|
if (opName == "Abs")
|
||||||
return buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAbsOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Acos")
|
if (opName == "Acos")
|
||||||
return buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAcosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Acosh")
|
if (opName == "Acosh")
|
||||||
return buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAcoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Add")
|
if (opName == "Add")
|
||||||
return buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAddOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "And")
|
if (opName == "And")
|
||||||
return buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAndOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "ArgMax")
|
if (opName == "ArgMax")
|
||||||
return buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXArgMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ArgMin")
|
if (opName == "ArgMin")
|
||||||
return buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXArgMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Asin")
|
if (opName == "Asin")
|
||||||
return buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAsinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Asinh")
|
if (opName == "Asinh")
|
||||||
return buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAsinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Atan")
|
if (opName == "Atan")
|
||||||
return buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAtanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Atanh")
|
if (opName == "Atanh")
|
||||||
return buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAtanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "AveragePool")
|
if (opName == "AveragePool")
|
||||||
return buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "BatchNormalization")
|
if (opName == "BatchNormalization")
|
||||||
return ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
|
ImportNodeBatchNormalization(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 5);
|
||||||
if (opName == "BitShift")
|
if (opName == "BitShift")
|
||||||
return buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXBitShiftOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Cast")
|
if (opName == "Cast")
|
||||||
return buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCastOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Ceil")
|
if (opName == "Ceil")
|
||||||
return buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCeilOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Clip")
|
if (opName == "Clip")
|
||||||
return buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXClipOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Compress")
|
if (opName == "Compress")
|
||||||
return buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCompressOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Concat")
|
if (opName == "Concat")
|
||||||
return buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConcatOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ConcatFromSequence")
|
if (opName == "ConcatFromSequence")
|
||||||
return buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConcatFromSequenceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Constant")
|
if (opName == "Constant")
|
||||||
return buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConstantOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
if (opName == "ConstantOfShape")
|
if (opName == "ConstantOfShape")
|
||||||
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Conv")
|
if (opName == "Conv")
|
||||||
return buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "ConvInteger")
|
if (opName == "ConvInteger")
|
||||||
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
if (opName == "ConvTranspose")
|
if (opName == "ConvTranspose")
|
||||||
return buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXConvTransposeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Cos")
|
if (opName == "Cos")
|
||||||
return buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCosOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Cosh")
|
if (opName == "Cosh")
|
||||||
return buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCoshOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "CumSum")
|
if (opName == "CumSum")
|
||||||
return buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXCumSumOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "DepthToSpace")
|
if (opName == "DepthToSpace")
|
||||||
return buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXDepthToSpaceOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "DequantizeLinear")
|
if (opName == "DequantizeLinear")
|
||||||
return buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXDequantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Det")
|
if (opName == "Det")
|
||||||
return buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXDetOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Div")
|
if (opName == "Div")
|
||||||
return buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXDivOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Dropout")
|
if (opName == "Dropout")
|
||||||
return buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
buildOperation<mlir::ONNXDropoutOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
||||||
if (opName == "DynamicQuantizeLinear")
|
if (opName == "DynamicQuantizeLinear")
|
||||||
return buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
|
buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 3);
|
||||||
if (opName == "Elu")
|
if (opName == "Elu")
|
||||||
return buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXEluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Equal")
|
if (opName == "Equal")
|
||||||
return buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXEqualOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Erf")
|
if (opName == "Erf")
|
||||||
return buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXErfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Exp")
|
if (opName == "Exp")
|
||||||
return buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Expand")
|
if (opName == "Expand")
|
||||||
return buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXExpandOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "EyeLike")
|
if (opName == "EyeLike")
|
||||||
return buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXEyeLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Flatten")
|
if (opName == "Flatten")
|
||||||
return buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXFlattenOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Floor")
|
if (opName == "Floor")
|
||||||
return buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXFloorOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "GRU")
|
if (opName == "GRU")
|
||||||
return buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
buildOperation<mlir::ONNXGRUOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
||||||
if (opName == "Gather")
|
if (opName == "Gather")
|
||||||
return buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGatherOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "GatherElements")
|
if (opName == "GatherElements")
|
||||||
return buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGatherElementsOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "GatherND")
|
if (opName == "GatherND")
|
||||||
return buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGatherNDOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Gemm")
|
if (opName == "Gemm")
|
||||||
return buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGemmOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "GlobalAveragePool")
|
if (opName == "GlobalAveragePool")
|
||||||
return buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGlobalAveragePoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "GlobalLpPool")
|
if (opName == "GlobalLpPool")
|
||||||
return buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGlobalLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "GlobalMaxPool")
|
if (opName == "GlobalMaxPool")
|
||||||
return buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGlobalMaxPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Greater")
|
if (opName == "Greater")
|
||||||
return buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXGreaterOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "HardSigmoid")
|
if (opName == "HardSigmoid")
|
||||||
return buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXHardSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Hardmax")
|
if (opName == "Hardmax")
|
||||||
return buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXHardmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Identity")
|
if (opName == "Identity")
|
||||||
return buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXIdentityOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "If")
|
if (opName == "If")
|
||||||
return buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
buildOperation<mlir::ONNXIfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
||||||
if (opName == "InstanceNormalization")
|
if (opName == "InstanceNormalization")
|
||||||
return buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXInstanceNormalizationOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "IsInf")
|
if (opName == "IsInf")
|
||||||
return buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXIsInfOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "IsNaN")
|
if (opName == "IsNaN")
|
||||||
return buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXIsNaNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "LRN")
|
if (opName == "LRN")
|
||||||
return buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLRNOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "LSTM")
|
if (opName == "LSTM")
|
||||||
return buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
|
buildOperation<mlir::ONNXLSTMOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 3);
|
||||||
if (opName == "LeakyRelu")
|
if (opName == "LeakyRelu")
|
||||||
return buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLeakyReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Less")
|
if (opName == "Less")
|
||||||
return buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLessOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Log")
|
if (opName == "Log")
|
||||||
return buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLogOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "LogSoftmax")
|
if (opName == "LogSoftmax")
|
||||||
return buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLogSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Loop")
|
if (opName == "Loop")
|
||||||
return buildOperation<mlir::ONNXLoopOp>(node);
|
buildOperation<mlir::ONNXLoopOp>(node);
|
||||||
if (opName == "LpNormalization")
|
if (opName == "LpNormalization")
|
||||||
return buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLpNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "LpPool")
|
if (opName == "LpPool")
|
||||||
return buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXLpPoolOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "MatMul")
|
if (opName == "MatMul")
|
||||||
return buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMatMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "MatMulInteger")
|
if (opName == "MatMulInteger")
|
||||||
return buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMatMulIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
if (opName == "Max")
|
if (opName == "Max")
|
||||||
return buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMaxOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
if (opName == "MaxPool")
|
if (opName == "MaxPool")
|
||||||
return ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
ImportNodeMaxPool(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 2);
|
||||||
if (opName == "MaxRoiPool")
|
if (opName == "MaxRoiPool")
|
||||||
return buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMaxRoiPoolOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "MaxUnpool")
|
if (opName == "MaxUnpool")
|
||||||
return buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMaxUnpoolOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Mean")
|
if (opName == "Mean")
|
||||||
return buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMeanOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
if (opName == "MeanVarianceNormalization")
|
if (opName == "MeanVarianceNormalization")
|
||||||
return buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Min")
|
if (opName == "Min")
|
||||||
return buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMinOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Mod")
|
if (opName == "Mod")
|
||||||
return buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXModOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Mul")
|
if (opName == "Mul")
|
||||||
return buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMulOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Multinomial")
|
if (opName == "Multinomial")
|
||||||
return buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXMultinomialOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Neg")
|
if (opName == "Neg")
|
||||||
return buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXNegOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "NonMaxSuppression")
|
if (opName == "NonMaxSuppression")
|
||||||
return buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXNonMaxSuppressionOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
||||||
if (opName == "NonZero")
|
if (opName == "NonZero")
|
||||||
return buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXNonZeroOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Not")
|
if (opName == "Not")
|
||||||
return buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXNotOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "OneHot")
|
if (opName == "OneHot")
|
||||||
return buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXOneHotOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Or")
|
if (opName == "Or")
|
||||||
return buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXOrOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "PRelu")
|
if (opName == "PRelu")
|
||||||
return buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXPReluOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Pad")
|
if (opName == "Pad")
|
||||||
return ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
ImportNodePad(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Pow")
|
if (opName == "Pow")
|
||||||
return buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXPowOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "QLinearConv")
|
if (opName == "QLinearConv")
|
||||||
return buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXQLinearConvOp>(node, /* expected_num_operands = */ 9, /* expected_num_results = */ 1);
|
||||||
if (opName == "QLinearMatMul")
|
if (opName == "QLinearMatMul")
|
||||||
return buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXQLinearMatMulOp>(node, /* expected_num_operands = */ 8, /* expected_num_results = */ 1);
|
||||||
if (opName == "QuantizeLinear")
|
if (opName == "QuantizeLinear")
|
||||||
return buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXQuantizeLinearOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "RNN")
|
if (opName == "RNN")
|
||||||
return buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
buildOperation<mlir::ONNXRNNOp>(node, /* expected_num_operands = */ 6, /* expected_num_results = */ 2);
|
||||||
if (opName == "RandomNormal")
|
if (opName == "RandomNormal")
|
||||||
return buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRandomNormalOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
if (opName == "RandomNormalLike")
|
if (opName == "RandomNormalLike")
|
||||||
return buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRandomNormalLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "RandomUniform")
|
if (opName == "RandomUniform")
|
||||||
return buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRandomUniformOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
if (opName == "RandomUniformLike")
|
if (opName == "RandomUniformLike")
|
||||||
return buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRandomUniformLikeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Range")
|
if (opName == "Range")
|
||||||
return buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRangeOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Reciprocal")
|
if (opName == "Reciprocal")
|
||||||
return buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReciprocalOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceL1")
|
if (opName == "ReduceL1")
|
||||||
return buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceL1Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceL2")
|
if (opName == "ReduceL2")
|
||||||
return buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceL2Op>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceLogSum")
|
if (opName == "ReduceLogSum")
|
||||||
return buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceLogSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceLogSumExp")
|
if (opName == "ReduceLogSumExp")
|
||||||
return buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceLogSumExpOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceMax")
|
if (opName == "ReduceMax")
|
||||||
return buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceMaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceMean")
|
if (opName == "ReduceMean")
|
||||||
return buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceMeanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceMin")
|
if (opName == "ReduceMin")
|
||||||
return buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceMinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceProd")
|
if (opName == "ReduceProd")
|
||||||
return buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceProdOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceSum")
|
if (opName == "ReduceSum")
|
||||||
return buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceSumOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReduceSumSquare")
|
if (opName == "ReduceSumSquare")
|
||||||
return buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReduceSumSquareOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Relu")
|
if (opName == "Relu")
|
||||||
return buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Reshape")
|
if (opName == "Reshape")
|
||||||
return ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Resize")
|
if (opName == "Resize")
|
||||||
return buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXResizeOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
|
||||||
if (opName == "ReverseSequence")
|
if (opName == "ReverseSequence")
|
||||||
return buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXReverseSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "RoiAlign")
|
if (opName == "RoiAlign")
|
||||||
return buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRoiAlignOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Round")
|
if (opName == "Round")
|
||||||
return buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXRoundOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Scan")
|
if (opName == "Scan")
|
||||||
return buildOperation<mlir::ONNXScanOp>(node);
|
buildOperation<mlir::ONNXScanOp>(node);
|
||||||
if (opName == "Scatter")
|
if (opName == "Scatter")
|
||||||
return buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXScatterOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "ScatterElements")
|
if (opName == "ScatterElements")
|
||||||
return buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXScatterElementsOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "ScatterND")
|
if (opName == "ScatterND")
|
||||||
return buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXScatterNDOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Selu")
|
if (opName == "Selu")
|
||||||
return buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSeluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "SequenceAt")
|
if (opName == "SequenceAt")
|
||||||
return buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceAtOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "SequenceConstruct")
|
if (opName == "SequenceConstruct")
|
||||||
return buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceConstructOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
if (opName == "SequenceEmpty")
|
if (opName == "SequenceEmpty")
|
||||||
return buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceEmptyOp>(node, /* expected_num_operands = */ 0, /* expected_num_results = */ 1);
|
||||||
if (opName == "SequenceErase")
|
if (opName == "SequenceErase")
|
||||||
return buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceEraseOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "SequenceInsert")
|
if (opName == "SequenceInsert")
|
||||||
return buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceInsertOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "SequenceLength")
|
if (opName == "SequenceLength")
|
||||||
return buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSequenceLengthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Shape")
|
if (opName == "Shape")
|
||||||
return buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Shrink")
|
if (opName == "Shrink")
|
||||||
return buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXShrinkOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Sigmoid")
|
if (opName == "Sigmoid")
|
||||||
return buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSigmoidOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Sign")
|
if (opName == "Sign")
|
||||||
return buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Sin")
|
if (opName == "Sin")
|
||||||
return buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSinOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Sinh")
|
if (opName == "Sinh")
|
||||||
return buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSinhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Size")
|
if (opName == "Size")
|
||||||
return buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSizeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Slice")
|
if (opName == "Slice")
|
||||||
return buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSliceOp>(node, /* expected_num_operands = */ 5, /* expected_num_results = */ 1);
|
||||||
if (opName == "Softmax")
|
if (opName == "Softmax")
|
||||||
return buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSoftmaxOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Softplus")
|
if (opName == "Softplus")
|
||||||
return buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSoftplusOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Softsign")
|
if (opName == "Softsign")
|
||||||
return buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSoftsignOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "SpaceToDepth")
|
if (opName == "SpaceToDepth")
|
||||||
return buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSpaceToDepthOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Split")
|
if (opName == "Split")
|
||||||
return buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
buildOperation<mlir::ONNXSplitOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ -1);
|
||||||
if (opName == "SplitToSequence")
|
if (opName == "SplitToSequence")
|
||||||
return buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSplitToSequenceOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Sqrt")
|
if (opName == "Sqrt")
|
||||||
return buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSqrtOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Squeeze")
|
if (opName == "Squeeze")
|
||||||
return buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "StringNormalizer")
|
if (opName == "StringNormalizer")
|
||||||
return buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXStringNormalizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Sub")
|
if (opName == "Sub")
|
||||||
return buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSubOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Sum")
|
if (opName == "Sum")
|
||||||
return buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXSumOp>(node, /* expected_num_operands = */ -1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Tan")
|
if (opName == "Tan")
|
||||||
return buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTanOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Tanh")
|
if (opName == "Tanh")
|
||||||
return buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTanhOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "TfIdfVectorizer")
|
if (opName == "TfIdfVectorizer")
|
||||||
return buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTfIdfVectorizerOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "ThresholdedRelu")
|
if (opName == "ThresholdedRelu")
|
||||||
return buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXThresholdedReluOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Tile")
|
if (opName == "Tile")
|
||||||
return buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTileOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "TopK")
|
if (opName == "TopK")
|
||||||
return buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
|
buildOperation<mlir::ONNXTopKOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 2);
|
||||||
if (opName == "Transpose")
|
if (opName == "Transpose")
|
||||||
return buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXTransposeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Unique")
|
if (opName == "Unique")
|
||||||
return buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
|
buildOperation<mlir::ONNXUniqueOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 4);
|
||||||
if (opName == "Unsqueeze")
|
if (opName == "Unsqueeze")
|
||||||
return buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXUnsqueezeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
|
||||||
if (opName == "Upsample")
|
if (opName == "Upsample")
|
||||||
return buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXUpsampleOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
if (opName == "Where")
|
if (opName == "Where")
|
||||||
return buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXWhereOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
|
||||||
if (opName == "Xor")
|
if (opName == "Xor")
|
||||||
return buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
buildOperation<mlir::ONNXXorOp>(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1);
|
||||||
|
|
|
@ -15,6 +15,7 @@ add_library(OMONNXToKrnl
|
||||||
Tensor/Identity.cpp
|
Tensor/Identity.cpp
|
||||||
Tensor/Reshape.cpp
|
Tensor/Reshape.cpp
|
||||||
Tensor/PadConstantValuePad.cpp
|
Tensor/PadConstantValuePad.cpp
|
||||||
|
Tensor/Pad.cpp
|
||||||
Tensor/Transpose.cpp
|
Tensor/Transpose.cpp
|
||||||
Tensor/Unsqueeze.cpp
|
Tensor/Unsqueeze.cpp
|
||||||
Tensor/Constant.cpp
|
Tensor/Constant.cpp
|
||||||
|
|
|
@ -93,6 +93,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
||||||
// Tensor
|
// Tensor
|
||||||
populateLoweringONNXReshapeOpPattern(patterns, &getContext());
|
populateLoweringONNXReshapeOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXPadConstantValuePadOpPattern(patterns, &getContext());
|
populateLoweringONNXPadConstantValuePadOpPattern(patterns, &getContext());
|
||||||
|
populateLoweringONNXPadOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
|
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||||
|
|
|
@ -237,6 +237,9 @@ void populateLoweringONNXTransposeOpPattern(
|
||||||
void populateLoweringONNXPadConstantValuePadOpPattern(
|
void populateLoweringONNXPadConstantValuePadOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
void populateLoweringONNXPadOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
void populateLoweringONNXReshapeOpPattern(
|
void populateLoweringONNXReshapeOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
//===-----------------------Pad.cpp - Lowering Pad Op -------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX Pad Operator to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
struct ONNXPadOpLowering : public ConversionPattern {
|
||||||
|
ONNXPadOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXPadOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
ONNXPadOp myOp = llvm::dyn_cast<ONNXPadOp>(op);
|
||||||
|
ONNXPadOpOperandAdaptor operandAdaptor(operands);
|
||||||
|
auto tensorType = myOp.output().getType();
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
// Only constant padding is supported now.
|
||||||
|
auto padMode = myOp.mode();
|
||||||
|
if (padMode != "constant")
|
||||||
|
emitError(loc, "unsupported mode for Pad");
|
||||||
|
DenseElementsAttr constantValAttr =
|
||||||
|
myOp.getAttr("constant_value")
|
||||||
|
.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
if (!constantValAttr)
|
||||||
|
emitError(loc, "unsupported value");
|
||||||
|
|
||||||
|
DenseElementsAttr padsAttributes =
|
||||||
|
myOp.getAttr("pads").dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
if (!padsAttributes)
|
||||||
|
emitError(loc, "Pad: unknown pads");
|
||||||
|
|
||||||
|
auto memRefType = convertToMemRefType(tensorType);
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
emitError(loc, "unexpected output has non-Constant shape");
|
||||||
|
|
||||||
|
// Number of loops
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t rank = memRefShape.size();
|
||||||
|
|
||||||
|
// get the padding vector into a temporary smallvector
|
||||||
|
SmallVector<int64_t, 2> pads(rank * 2, -1);
|
||||||
|
auto padsIt = padsAttributes.getValues<IntegerAttr>().begin();
|
||||||
|
for (int i = 0; i < rank * 2; ++i)
|
||||||
|
pads[i] = (*padsIt++).cast<IntegerAttr>().getInt();
|
||||||
|
|
||||||
|
// get the padding value
|
||||||
|
auto valueAttr = (*constantValAttr.getValues<FloatAttr>().begin());
|
||||||
|
|
||||||
|
// Iterate over the loop nest using the output shape.
|
||||||
|
BuildKrnlLoop padLoops(rewriter, loc, rank);
|
||||||
|
padLoops.createDefineAndOptimizeOp();
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
padLoops.pushBounds(0, alloc, i);
|
||||||
|
padLoops.createIterateOp();
|
||||||
|
|
||||||
|
// Iterate over the loop nest using the input shape.
|
||||||
|
BuildKrnlLoop valueLoops(rewriter, loc, rank);
|
||||||
|
valueLoops.createDefineAndOptimizeOp();
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
valueLoops.pushBounds(0, operandAdaptor.data(), i);
|
||||||
|
valueLoops.createIterateOp();
|
||||||
|
|
||||||
|
// Copy the input data into the output.
|
||||||
|
rewriter.setInsertionPointToStart(valueLoops.getIterateBlock());
|
||||||
|
|
||||||
|
SmallVector<Value, 4> inLoopIVs;
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
inLoopIVs.emplace_back(valueLoops.getInductionVar(i));
|
||||||
|
|
||||||
|
SmallVector<Value, 4> outLoopIVs;
|
||||||
|
for (int i = 0; i < rank; ++i) {
|
||||||
|
// Calculate the index for the load and store.
|
||||||
|
if (pads[i] == 0) {
|
||||||
|
outLoopIVs.emplace_back(valueLoops.getInductionVar(i));
|
||||||
|
} else {
|
||||||
|
auto outIV = rewriter.create<AddIOp>(loc,
|
||||||
|
rewriter.create<ConstantIndexOp>(loc, pads[i]),
|
||||||
|
valueLoops.getInductionVar(i));
|
||||||
|
outLoopIVs.emplace_back(outIV);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto originValue =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.data(), inLoopIVs);
|
||||||
|
rewriter.create<StoreOp>(loc, originValue, alloc, outLoopIVs);
|
||||||
|
rewriter.setInsertionPointToStart(padLoops.getIterateBlock());
|
||||||
|
|
||||||
|
SmallVector<Value, 4> outLoopIVs1;
|
||||||
|
for (int i = 0; i < rank; ++i)
|
||||||
|
outLoopIVs1.emplace_back(padLoops.getInductionVar(i));
|
||||||
|
|
||||||
|
auto paddingValue = rewriter.create<ConstantOp>(loc, valueAttr);
|
||||||
|
rewriter.create<StoreOp>(loc, paddingValue, alloc, outLoopIVs1);
|
||||||
|
|
||||||
|
// Replace the original op with the generated code.
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLoweringONNXPadOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXPadOpLowering>(ctx);
|
||||||
|
}
|
|
@ -1473,6 +1473,53 @@ bool ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool ONNXPadOp::inferShapes() {
|
||||||
|
// Cannot infer shape if no shape exists.
|
||||||
|
if (!data().getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Pad: unknown input shape");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot infer if the pads is not constant
|
||||||
|
DenseElementsAttr padsAttributes =
|
||||||
|
getAttr("pads").dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
|
||||||
|
if (!padsAttributes) {
|
||||||
|
emitError("Pad: unknown pads");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto dataTy = data().getType().cast<RankedTensorType>();
|
||||||
|
auto dataShape = dataTy.getShape();
|
||||||
|
auto dataRank = dataTy.getRank();
|
||||||
|
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
||||||
|
|
||||||
|
// Get pads from valueAttribute.
|
||||||
|
SmallVector<int64_t, 2> pads(dataRank * 2, -1);
|
||||||
|
auto valueIt = padsAttributes.getValues<IntegerAttr>().begin();
|
||||||
|
for (int64_t i = 0; i < dataRank * 2; ++i)
|
||||||
|
pads[i] = (*valueIt++).cast<IntegerAttr>().getInt();
|
||||||
|
|
||||||
|
// Pads consists of two values for each axis of data.
|
||||||
|
// The two values specify the number of elements padded before and after
|
||||||
|
// respectively.
|
||||||
|
for (int64_t i = 0; i < dataRank; ++i) {
|
||||||
|
int64_t p1 = pads[i];
|
||||||
|
int64_t p2 = pads[i + dataRank];
|
||||||
|
// Have to non-negative constant
|
||||||
|
if (p1 < 0 || p2 < 0) {
|
||||||
|
emitError("padding value can not be negative");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (outputShape[i] != -1)
|
||||||
|
outputShape[i] += p1 + p2;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outputType = RankedTensorType::get(outputShape, dataTy.getElementType());
|
||||||
|
getResult().setType(outputType);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!data.getType().isa<RankedTensorType>())
|
if (!data.getType().isa<RankedTensorType>())
|
||||||
|
|
|
@ -348,7 +348,17 @@ def ONNXConstantOp:ONNX_Op<"Constant",
|
||||||
let arguments = (ins OptionalAttr<AnyAttr>:$sparse_value,
|
let arguments = (ins OptionalAttr<AnyAttr>:$sparse_value,
|
||||||
OptionalAttr<AnyAttr>:$value);
|
OptionalAttr<AnyAttr>:$value);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||||
}
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, Attribute sparse_value, Attribute value", [{
|
||||||
|
if (value) {
|
||||||
|
auto tensorType = value.getType();
|
||||||
|
build(builder, state, tensorType, sparse_value, value);
|
||||||
|
} else {
|
||||||
|
auto tensorType = sparse_value.getType();
|
||||||
|
build(builder, state, tensorType, sparse_value, value);
|
||||||
|
}
|
||||||
|
}]>
|
||||||
|
];}
|
||||||
|
|
||||||
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
|
@ -1913,7 +1923,7 @@ def ONNXPReluOp:ONNX_Op<"PRelu",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXPadOp:ONNX_Op<"Pad",
|
def ONNXPadOp:ONNX_Op<"Pad",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, OpInterface<"PromotableConstOperandsOpInterface">]> {
|
||||||
let summary = "ONNX Pad operation";
|
let summary = "ONNX Pad operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, "
|
"Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, "
|
||||||
|
@ -1999,10 +2009,27 @@ def ONNXPadOp:ONNX_Op<"Pad",
|
||||||
""
|
""
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$pads,
|
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$pads,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$constant_value,
|
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$constant_value,
|
||||||
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
DefaultValuedAttr<StrAttr, "constant">:$mode);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$output);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, Value data, Value pads, Value constant_value, StringAttr mode", [{
|
||||||
|
auto elementType = data.getType().cast<TensorType>().getElementType();
|
||||||
|
build(builder, state, UnrankedTensorType::get(elementType), data, pads, constant_value, mode);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
std::map<std::string, size_t> promotableConstOperands() {
|
||||||
|
return {{"pads", 1}, {"constant_value", 2}};
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXPowOp:ONNX_Op<"Pow",
|
def ONNXPowOp:ONNX_Op<"Pow",
|
||||||
|
|
|
@ -66,6 +66,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) {
|
||||||
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
pm.addPass(mlir::createAttributePromotionPass());
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createAttributePromotionPass());
|
pm.addPass(mlir::createAttributePromotionPass());
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
#include "src/Interface/PromotableConstOperandsOpInterface.hpp"
|
||||||
#include "src/Pass/Passes.hpp"
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -60,14 +62,27 @@ public:
|
||||||
// move it to an attribute, and use None to indicate the absence
|
// move it to an attribute, and use None to indicate the absence
|
||||||
// of the original operand value.
|
// of the original operand value.
|
||||||
auto operandToPromote = op->getOperand(i);
|
auto operandToPromote = op->getOperand(i);
|
||||||
|
if (auto constantOp = dyn_cast_or_null<mlir::ONNXConstantOp>(
|
||||||
|
operandToPromote.getDefiningOp())) {
|
||||||
|
if (constantOp.valueAttr() &&
|
||||||
|
!constantOp.valueAttr().dyn_cast_or_null<UnitAttr>())
|
||||||
|
op->setAttr(name, constantOp.valueAttr());
|
||||||
|
if (constantOp.sparse_valueAttr() &&
|
||||||
|
!constantOp.sparse_valueAttr().dyn_cast_or_null<UnitAttr>())
|
||||||
|
op->setAttr(name, constantOp.sparse_valueAttr());
|
||||||
|
getOrCreateNoneValue(none, f);
|
||||||
|
op->setOperand(i, *none);
|
||||||
|
}
|
||||||
if (auto constantOp = dyn_cast_or_null<ConstantOp>(
|
if (auto constantOp = dyn_cast_or_null<ConstantOp>(
|
||||||
operandToPromote.getDefiningOp())) {
|
operandToPromote.getDefiningOp())) {
|
||||||
|
if (!constantOp.valueAttr().dyn_cast_or_null<UnitAttr>()) {
|
||||||
op->setAttr(name, constantOp.value());
|
op->setAttr(name, constantOp.value());
|
||||||
getOrCreateNoneValue(none, f);
|
getOrCreateNoneValue(none, f);
|
||||||
op->setOperand(i, *none);
|
op->setOperand(i, *none);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Dispatch canonicalization pattern rewriters to eliminate redundant
|
// Dispatch canonicalization pattern rewriters to eliminate redundant
|
||||||
|
|
|
@ -41,6 +41,16 @@ ArrayAttr createArrayAttrOfZeros(
|
||||||
return rewriter.getI64ArrayAttr(vals);
|
return rewriter.getI64ArrayAttr(vals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr createDenseFloatAttrOfValue(
|
||||||
|
PatternRewriter &rewriter, Value origValue, float constantValue) {
|
||||||
|
Type elementType = origValue.getType().cast<TensorType>().getElementType();
|
||||||
|
SmallVector<float, 1> wrapper(1, 0);
|
||||||
|
wrapper[0] = constantValue;
|
||||||
|
return DenseElementsAttr::get(
|
||||||
|
RankedTensorType::get(wrapper.size(), elementType),
|
||||||
|
llvm::makeArrayRef(wrapper));
|
||||||
|
}
|
||||||
|
|
||||||
// Pad a ArrayAttr with zeros.
|
// Pad a ArrayAttr with zeros.
|
||||||
//
|
//
|
||||||
// pads = [B1, B2, ... Bk, E1, E2, ..., Ek]
|
// pads = [B1, B2, ... Bk, E1, E2, ..., Ek]
|
||||||
|
@ -52,7 +62,7 @@ ArrayAttr createArrayAttrOfZeros(
|
||||||
// nZeros nZeros
|
// nZeros nZeros
|
||||||
//
|
//
|
||||||
// This function is used for padding attribute in Conv.
|
// This function is used for padding attribute in Conv.
|
||||||
ArrayAttr insertZerosForNonPaddedDims(
|
DenseElementsAttr insertZerosForNonPaddedDims(
|
||||||
PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) {
|
PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) {
|
||||||
int nDims = (int)origAttrs.getValue().size() / 2;
|
int nDims = (int)origAttrs.getValue().size() / 2;
|
||||||
int nElements = (nDims + extensionLength) * 2;
|
int nElements = (nDims + extensionLength) * 2;
|
||||||
|
@ -64,7 +74,12 @@ ArrayAttr insertZerosForNonPaddedDims(
|
||||||
pads[i + extensionLength] = beginPad;
|
pads[i + extensionLength] = beginPad;
|
||||||
pads[nDims + extensionLength + i + extensionLength] = endPad;
|
pads[nDims + extensionLength + i + extensionLength] = endPad;
|
||||||
}
|
}
|
||||||
return rewriter.getI64ArrayAttr(pads);
|
|
||||||
|
mlir::Type elementType = rewriter.getIntegerType(64);
|
||||||
|
llvm::ArrayRef<int64_t> tensorDims(pads.data(), pads.size());
|
||||||
|
mlir::ShapedType tensorType =
|
||||||
|
mlir::RankedTensorType::get(tensorDims, elementType);
|
||||||
|
return rewriter.getI64TensorAttr(llvm::makeArrayRef(pads));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
|
|
|
@ -24,14 +24,17 @@ include "src/Dialect/ONNX/ONNXOps.td"
|
||||||
/// dag benefitsAdded = (addBenefit 0)
|
/// dag benefitsAdded = (addBenefit 0)
|
||||||
/// >;
|
/// >;
|
||||||
|
|
||||||
|
def GetNullAttr :
|
||||||
|
NativeCodeCall<"Attribute()">;
|
||||||
|
|
||||||
// Create a StringAttr from a string.
|
// Create a StringAttr from a string.
|
||||||
class StringAttrOfValue<string val>:
|
class StringAttrOfValue<string val>:
|
||||||
NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">;
|
NativeCodeCall<"$_builder.getStringAttr(\"" # val # "\")">;
|
||||||
|
|
||||||
// Create a FloatAttr from an interger value.
|
// Create a DenseElementsAttr from an interger value.
|
||||||
// It seems Table-gen does not support `float` type, so we can not pass a float value.
|
// It seems Table-gen does not support `float` type, so we can not pass a float value.
|
||||||
class FloatAttrOfValue<int val>:
|
class FloatAttrOfValue<int val>:
|
||||||
NativeCodeCall<"FloatAttr::get($0.getType().cast<TensorType>().getElementType(), " # val # ")">;
|
NativeCodeCall<"createDenseFloatAttrOfValue($_builder, $0, " # val # ")">;
|
||||||
|
|
||||||
// Create an ArrayAttr of IntergerAttr(s) of zero values.
|
// Create an ArrayAttr of IntergerAttr(s) of zero values.
|
||||||
// This function is used for padding attribute in Conv.
|
// This function is used for padding attribute in Conv.
|
||||||
|
@ -82,10 +85,15 @@ def ConvOpPaddingPattern: Pat<
|
||||||
$pads,
|
$pads,
|
||||||
$strides),
|
$strides),
|
||||||
(ONNXConvOp
|
(ONNXConvOp
|
||||||
(ONNXPadConstantValuePadOp $x,
|
|
||||||
(insertZerosForNonPaddedDims<2> $pads),
|
(ONNXPadOp $x,
|
||||||
(FloatAttrOfValue<0> $res),
|
(ONNXConstantOp (GetNullAttr),
|
||||||
|
(insertZerosForNonPaddedDims<2> $pads)),
|
||||||
|
(ONNXConstantOp (GetNullAttr),
|
||||||
|
(FloatAttrOfValue<0> $res)),
|
||||||
(StringAttrOfValue<"constant">)),
|
(StringAttrOfValue<"constant">)),
|
||||||
|
|
||||||
|
|
||||||
$w, $b, $auto_pad, $dilation, $group, $kernel_shape,
|
$w, $b, $auto_pad, $dilation, $group, $kernel_shape,
|
||||||
(createArrayAttrOfZerosFrom $pads),
|
(createArrayAttrOfZerosFrom $pads),
|
||||||
$strides),
|
$strides),
|
||||||
|
|
|
@ -118,6 +118,7 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||||
op->getName().getStringRef() != "onnx.Sqrt" &&
|
op->getName().getStringRef() != "onnx.Sqrt" &&
|
||||||
op->getName().getStringRef() != "onnx.Conv" &&
|
op->getName().getStringRef() != "onnx.Conv" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Pad" &&
|
||||||
op->getName().getStringRef() != "onnx.PadConstantPad" &&
|
op->getName().getStringRef() != "onnx.PadConstantPad" &&
|
||||||
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
|
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
|
||||||
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
||||||
|
|
|
@ -62,10 +62,13 @@ func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>
|
||||||
%cst = constant unit
|
%cst = constant unit
|
||||||
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
|
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-NEXT: %cst = constant unit
|
// CHECK-NEXT: %cst = constant unit
|
||||||
// CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32>
|
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[0, 0, 2, 3, 0, 0, 4, 5]> : tensor<8xi64>} : () -> tensor<8xi64>
|
||||||
// CHECK-NEXT: %1 = "onnx.Conv"(%0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
|
// CHECK-NEXT: %1 = "onnx.Constant"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
|
||||||
// CHECK-NEXT: return %1 : tensor<*xf32>
|
// CHECK-NEXT: %2 = "onnx.Pad"(%arg0, %0, %1) {mode = "constant"} : (tensor<1x9x32x64xf32>, tensor<8xi64>, tensor<1xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: %3 = "onnx.Conv"(%2, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<*xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return %3 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -93,4 +96,3 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1
|
||||||
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
|
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
|
||||||
// return [[GEMM]] : tensor<*xf32>
|
// return [[GEMM]] : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1637,6 +1637,32 @@ func @test_constant_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> {
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_pad1(%arg0: tensor<16x16xf32>) -> tensor<18x20xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Pad"(%arg0, %cst, %cst) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 3, 2, 1]> : tensor<4xi32>} : (tensor<16x16xf32>, none, none) -> tensor<18x20xf32>
|
||||||
|
return %0 : tensor<18x20xf32>
|
||||||
|
// CHECK-LABEL: test_pad1
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<18x20xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS1:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS1:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS1]]#0, [[OPT_LOOPS1]]#1) with ([[DEF_LOOPS1]]#0 -> %arg1 = 0 to 18, [[DEF_LOOPS1]]#1 -> %arg2 = 0 to 20) {
|
||||||
|
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: store [[CST]], [[RES]][%arg1, %arg2] : memref<18x20xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[DEF_LOOPS2:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS2:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS2]]#0, [[OPT_LOOPS2]]#1) with ([[DEF_LOOPS2]]#0 -> %arg1 = 0 to 16, [[DEF_LOOPS2]]#1 -> %arg2 = 0 to 16) {
|
||||||
|
// CHECK: [[CST1:%.+]] = constant 3 : index
|
||||||
|
// CHECK: [[ADD:%.+]] = addi [[CST1]], %arg2 : index
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<16x16xf32>
|
||||||
|
// CHECK: store [[LOAD]], [[RES]][%arg1, [[ADD]]] : memref<18x20xf32>
|
||||||
|
// CHECK: }
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
||||||
|
|
|
@ -354,6 +354,19 @@ func @test_conv_12(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>, %arg2
|
||||||
/// Test shape inference for PadConstantValuePad.
|
/// Test shape inference for PadConstantValuePad.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Test Pad_1
|
||||||
|
func @test_Pad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%0 = "onnx.Pad"(%arg0, %cst, %cst) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<16x13xf32>, none, none) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_Pad_1
|
||||||
|
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Pad"(%arg0, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<16x13xf32>, none, none) -> tensor<18x19xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<18x19xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Test PadConstantValuePad_1
|
/// Test PadConstantValuePad_1
|
||||||
func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
|
func @test_PadConstantValuePad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<*xf32>
|
%0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 0]} : (tensor<16x13xf32>) -> tensor<*xf32>
|
||||||
|
|
|
@ -10,6 +10,16 @@ func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32
|
||||||
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_should_promote_to_attribute_1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%shape = "onnx.Constant"() { value = dense<[6, 7, 42]> : tensor<3xi32>}: () -> tensor<3xi32>
|
||||||
|
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
// CHECK-LABEL: test_should_promote_to_attribute_1
|
||||||
|
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||||
|
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : tensor<*xi64>) -> tensor<*xf32> {
|
func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : tensor<*xi64>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32>
|
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<*xi64>) -> tensor<*xf32>
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
|
@ -30,3 +40,14 @@ func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf3
|
||||||
// CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32>
|
// CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
|
// CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%shape = constant dense<[0, 2, 2, 4]> : tensor<4xi32>
|
||||||
|
%constant_value = constant dense<[0.]> : tensor<1xf32>
|
||||||
|
%0 = "onnx.Pad"(%arg0, %shape, %constant_value) {mode = "constant"} : (tensor<?x?xf32>, tensor<4xi32>, tensor<1xf32>)-> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
// CHECK-LABEL: test_should_promote_to_attribute1
|
||||||
|
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||||
|
// CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ OpsWithShapeInference = [
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
||||||
'LSTM', 'GRU', 'Split'
|
'LSTM', 'GRU', 'Split', 'Pad'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
|
@ -77,7 +77,8 @@ OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']
|
||||||
# should proceed. The key is the operation's name and the value is a list of
|
# should proceed. The key is the operation's name and the value is a list of
|
||||||
# tuples, whose first item is the attribute/operand name, and the second item is
|
# tuples, whose first item is the attribute/operand name, and the second item is
|
||||||
# the index at which such operand occurs in the list of the operation's inputs.
|
# the index at which such operand occurs in the list of the operation's inputs.
|
||||||
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]}
|
OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)],
|
||||||
|
"Pad": [("pads", 1), ("constant_value", 2)]}
|
||||||
|
|
||||||
# Add an Op in this list if the Op needs result type deduction which is required
|
# Add an Op in this list if the Op needs result type deduction which is required
|
||||||
# when writing declarative rewriting rules. Deduced type is always
|
# when writing declarative rewriting rules. Deduced type is always
|
||||||
|
@ -87,7 +88,24 @@ OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]}
|
||||||
# Currenlty, there are only two build methods generated:
|
# Currenlty, there are only two build methods generated:
|
||||||
# - one with operands and attributes having a separate parameter, and
|
# - one with operands and attributes having a separate parameter, and
|
||||||
# - one with operands and attributes having aggregated parameters.
|
# - one with operands and attributes having aggregated parameters.
|
||||||
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare']
|
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
|
||||||
|
|
||||||
|
|
||||||
|
#a dictionary to add any special definition for an operation
|
||||||
|
custom_definition_misc = dict([ ('Constant',
|
||||||
|
''' let builders = [
|
||||||
|
OpBuilder<"Builder *builder, OperationState &state, Attribute sparse_value, Attribute value", [{
|
||||||
|
if (value) {
|
||||||
|
auto tensorType = value.getType();
|
||||||
|
build(builder, state, tensorType, sparse_value, value);
|
||||||
|
} else {
|
||||||
|
auto tensorType = sparse_value.getType();
|
||||||
|
build(builder, state, tensorType, sparse_value, value);
|
||||||
|
}
|
||||||
|
}]>
|
||||||
|
];'''
|
||||||
|
)])
|
||||||
|
|
||||||
|
|
||||||
SNIPPETS = collect_snippets()
|
SNIPPETS = collect_snippets()
|
||||||
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
|
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
|
||||||
|
@ -254,7 +272,7 @@ def get_operands_or_results(schema, is_input):
|
||||||
# nullable in case it migrates to be an attribute.
|
# nullable in case it migrates to be an attribute.
|
||||||
if schema.name in OpsWithPromotableConstOperands:
|
if schema.name in OpsWithPromotableConstOperands:
|
||||||
idxs = dict(OpsWithPromotableConstOperands[schema.name]).values()
|
idxs = dict(OpsWithPromotableConstOperands[schema.name]).values()
|
||||||
if i in idxs:
|
if i in idxs and not OpSchema.FormalParameterOption.Optional == value.option:
|
||||||
types.append("NoneType")
|
types.append("NoneType")
|
||||||
|
|
||||||
if OpSchema.FormalParameterOption.Optional == value.option:
|
if OpSchema.FormalParameterOption.Optional == value.option:
|
||||||
|
@ -451,6 +469,10 @@ def gen_op_def(schema):
|
||||||
if schema.name in OpsWithPromotableConstOperands:
|
if schema.name in OpsWithPromotableConstOperands:
|
||||||
s = get_promotable_const_operands_func(
|
s = get_promotable_const_operands_func(
|
||||||
s, indent, OpsWithPromotableConstOperands[schema.name])
|
s, indent, OpsWithPromotableConstOperands[schema.name])
|
||||||
|
|
||||||
|
if ( schema.name in custom_definition_misc) :
|
||||||
|
s += custom_definition_misc[schema.name]
|
||||||
|
|
||||||
s += '}\n\n'
|
s += '}\n\n'
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
@ -492,7 +514,7 @@ def gen_op_importer(schema, file):
|
||||||
"/* expected_num_operands = */ {}".format(expected_num_operands))
|
"/* expected_num_operands = */ {}".format(expected_num_operands))
|
||||||
args.append(
|
args.append(
|
||||||
'/* expected_num_results = */ {}'.format(expected_num_results))
|
'/* expected_num_results = */ {}'.format(expected_num_results))
|
||||||
s += inc_indent(indent) + "return {}({});\n".format(
|
s += inc_indent(indent) + " {}({});\n".format(
|
||||||
handler_func, ", ".join(args))
|
handler_func, ", ".join(args))
|
||||||
|
|
||||||
file.write(s)
|
file.write(s)
|
||||||
|
|
Loading…
Reference in New Issue