Add attributes as operation parameters (#45)
* add attributes of Op into parameters * fix rewrite rule for GemmOp with attributes * use I64Attr instead of I32Attr and modify test cases for the changes in attributes * add output name (prefixed with o_) to Op definition * update shape inference for the new attributes
This commit is contained in:
parent
383a5c31ac
commit
c74f814f64
|
@ -226,7 +226,7 @@ private:
|
||||||
std::string _name;
|
std::string _name;
|
||||||
|
|
||||||
mlir::NamedAttribute operator()(int64_t const &r) {
|
mlir::NamedAttribute operator()(int64_t const &r) {
|
||||||
auto val = _builder.getI32IntegerAttr(r);
|
auto val = _builder.getI64IntegerAttr(r);
|
||||||
return _builder.getNamedAttr(_name, val);
|
return _builder.getNamedAttr(_name, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,21 +288,12 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
||||||
const onnx::NodeProto &node,
|
const onnx::NodeProto &node) {
|
||||||
std::initializer_list<std::pair<std::string, AttrValueType>>
|
|
||||||
defaultAttrList) {
|
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
std::set<std::string> definedAttributeSet;
|
|
||||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||||
auto attr = node.attribute(i);
|
auto attr = node.attribute(i);
|
||||||
auto nameValPair = convertAttributeProtoToNameValuePair(attr);
|
auto nameValPair = convertAttributeProtoToNameValuePair(attr);
|
||||||
attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
|
attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
|
||||||
definedAttributeSet.insert(attr.name());
|
|
||||||
}
|
|
||||||
for (const auto &defaultAttr : defaultAttrList) {
|
|
||||||
if (definedAttributeSet.find(defaultAttr.first) ==
|
|
||||||
definedAttributeSet.end())
|
|
||||||
attributes.push_back(convertNameValuePairToNamedAttribute(defaultAttr));
|
|
||||||
}
|
}
|
||||||
return attributes;
|
return attributes;
|
||||||
}
|
}
|
||||||
|
@ -340,9 +331,7 @@ private:
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void
|
void
|
||||||
ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
|
ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut) {
|
||||||
std::initializer_list<std::pair<std::string, AttrValueType>>
|
|
||||||
defaultAttrList) {
|
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (const auto &item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
@ -356,7 +345,7 @@ private:
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto attributes = ImportNodeAttributes(node, defaultAttrList);
|
auto attributes = ImportNodeAttributes(node);
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
|
||||||
|
@ -372,9 +361,7 @@ private:
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ImportNodeMultipleOuts(
|
void ImportNodeMultipleOuts(
|
||||||
const onnx::NodeProto &node, int nIn, int nOut,
|
const onnx::NodeProto &node, int nIn, int nOut) {
|
||||||
std::initializer_list<std::pair<std::string, AttrValueType>>
|
|
||||||
defaultAttrList) {
|
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (const auto &item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
@ -388,7 +375,7 @@ private:
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto attributes = ImportNodeAttributes(node, defaultAttrList);
|
auto attributes = ImportNodeAttributes(node);
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
|
||||||
|
@ -410,9 +397,7 @@ private:
|
||||||
* a specialized function is used
|
* a specialized function is used
|
||||||
*/
|
*/
|
||||||
void
|
void
|
||||||
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut,
|
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
|
||||||
std::initializer_list<std::pair<std::string, AttrValueType>>
|
|
||||||
defaultAttrList) {
|
|
||||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||||
// which is determined by the shape of first argument. However, since the
|
// which is determined by the shape of first argument. However, since the
|
||||||
// shape is unknown now, these attributes can be not generated auto
|
// shape is unknown now, these attributes can be not generated auto
|
||||||
|
@ -427,25 +412,23 @@ private:
|
||||||
|
|
||||||
if (nOps == 2)
|
if (nOps == 2)
|
||||||
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
|
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
|
||||||
node, nOps, nOut, defaultAttrList);
|
node, nOps, nOut);
|
||||||
else
|
else
|
||||||
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, defaultAttrList);
|
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for MaxPool operations.
|
* Special handle for MaxPool operations.
|
||||||
*/
|
*/
|
||||||
void ImportNodeMaxPool(
|
void ImportNodeMaxPool(
|
||||||
onnx::NodeProto node, int nIn, int nOut,
|
onnx::NodeProto node, int nIn, int nOut) {
|
||||||
std::initializer_list<std::pair<std::string, AttrValueType>>
|
|
||||||
defaultAttrList) {
|
|
||||||
int nOuts = node.output().size();
|
int nOuts = node.output().size();
|
||||||
if (nOuts == 1) {
|
if (nOuts == 1) {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
|
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
|
||||||
node, nIn, nOuts, defaultAttrList);
|
node, nIn, nOuts);
|
||||||
} else {
|
} else {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
|
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
|
||||||
node, nIn, nOuts, defaultAttrList);
|
node, nIn, nOuts);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,592 +1,314 @@
|
||||||
if (OpName == "DUMMY") {
|
if (OpName == "DUMMY") {
|
||||||
}else if (OpName == "Abs") {
|
}else if (OpName == "Abs") {
|
||||||
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Acos") {
|
}else if (OpName == "Acos") {
|
||||||
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Acosh") {
|
}else if (OpName == "Acosh") {
|
||||||
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Add") {
|
}else if (OpName == "Add") {
|
||||||
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "And") {
|
}else if (OpName == "And") {
|
||||||
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "ArgMax") {
|
}else if (OpName == "ArgMax") {
|
||||||
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1);
|
||||||
{"axis", 0}
|
|
||||||
,{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ArgMin") {
|
}else if (OpName == "ArgMin") {
|
||||||
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1);
|
||||||
{"axis", 0}
|
|
||||||
,{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Asin") {
|
}else if (OpName == "Asin") {
|
||||||
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Asinh") {
|
}else if (OpName == "Asinh") {
|
||||||
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Atan") {
|
}else if (OpName == "Atan") {
|
||||||
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Atanh") {
|
}else if (OpName == "Atanh") {
|
||||||
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "AveragePool") {
|
}else if (OpName == "AveragePool") {
|
||||||
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
|
||||||
{"auto_pad", "NOTSET"}
|
|
||||||
,{"ceil_mode", 0}
|
|
||||||
,{"count_include_pad", 0}
|
|
||||||
,{"kernel_shape", std::vector<int64_t> {}}
|
|
||||||
});
|
|
||||||
}else if (OpName == "BatchNormalization") {
|
}else if (OpName == "BatchNormalization") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, {
|
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5);
|
||||||
{"epsilon", (float)1e-05}
|
|
||||||
,{"momentum", (float)0.9}
|
|
||||||
});
|
|
||||||
}else if (OpName == "BitShift") {
|
}else if (OpName == "BitShift") {
|
||||||
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Cast") {
|
}else if (OpName == "Cast") {
|
||||||
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1);
|
||||||
{"to", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Ceil") {
|
}else if (OpName == "Ceil") {
|
||||||
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Clip") {
|
}else if (OpName == "Clip") {
|
||||||
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Compress") {
|
}else if (OpName == "Compress") {
|
||||||
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Concat") {
|
}else if (OpName == "Concat") {
|
||||||
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1);
|
||||||
{"axis", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ConcatFromSequence") {
|
}else if (OpName == "ConcatFromSequence") {
|
||||||
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1);
|
||||||
{"new_axis", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Constant") {
|
}else if (OpName == "Constant") {
|
||||||
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, {
|
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "ConstantOfShape") {
|
}else if (OpName == "ConstantOfShape") {
|
||||||
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Conv") {
|
}else if (OpName == "Conv") {
|
||||||
ImportNodeConv(node, 3, 1, {
|
ImportNodeConv(node, 3, 1);
|
||||||
{"auto_pad", "NOTSET"}
|
|
||||||
,{"group", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ConvInteger") {
|
}else if (OpName == "ConvInteger") {
|
||||||
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, {
|
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1);
|
||||||
{"auto_pad", "NOTSET"}
|
|
||||||
,{"group", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ConvTranspose") {
|
}else if (OpName == "ConvTranspose") {
|
||||||
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1);
|
||||||
{"auto_pad", "NOTSET"}
|
|
||||||
,{"group", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Cos") {
|
}else if (OpName == "Cos") {
|
||||||
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Cosh") {
|
}else if (OpName == "Cosh") {
|
||||||
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "CumSum") {
|
}else if (OpName == "CumSum") {
|
||||||
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1);
|
||||||
{"exclusive", 0}
|
|
||||||
,{"reverse", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "DepthToSpace") {
|
}else if (OpName == "DepthToSpace") {
|
||||||
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1);
|
||||||
{"mode", "DCR"}
|
|
||||||
});
|
|
||||||
}else if (OpName == "DequantizeLinear") {
|
}else if (OpName == "DequantizeLinear") {
|
||||||
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Det") {
|
}else if (OpName == "Det") {
|
||||||
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Div") {
|
}else if (OpName == "Div") {
|
||||||
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Dropout") {
|
}else if (OpName == "Dropout") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, {
|
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2);
|
||||||
{"ratio", (float)0.5}
|
|
||||||
});
|
|
||||||
}else if (OpName == "DynamicQuantizeLinear") {
|
}else if (OpName == "DynamicQuantizeLinear") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, {
|
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3);
|
||||||
});
|
|
||||||
}else if (OpName == "Elu") {
|
}else if (OpName == "Elu") {
|
||||||
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1);
|
||||||
{"alpha", (float)1.0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Equal") {
|
}else if (OpName == "Equal") {
|
||||||
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Erf") {
|
}else if (OpName == "Erf") {
|
||||||
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Exp") {
|
}else if (OpName == "Exp") {
|
||||||
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Expand") {
|
}else if (OpName == "Expand") {
|
||||||
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "EyeLike") {
|
}else if (OpName == "EyeLike") {
|
||||||
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1);
|
||||||
{"k", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Flatten") {
|
}else if (OpName == "Flatten") {
|
||||||
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1);
|
||||||
{"axis", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Floor") {
|
}else if (OpName == "Floor") {
|
||||||
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "GRU") {
|
}else if (OpName == "GRU") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, {
|
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2);
|
||||||
{"direction", "forward"}
|
|
||||||
,{"linear_before_reset", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Gather") {
|
}else if (OpName == "Gather") {
|
||||||
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1);
|
||||||
{"axis", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "GatherElements") {
|
}else if (OpName == "GatherElements") {
|
||||||
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1);
|
||||||
{"axis", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "GatherND") {
|
}else if (OpName == "GatherND") {
|
||||||
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Gemm") {
|
}else if (OpName == "Gemm") {
|
||||||
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1);
|
||||||
{"alpha", (float)1.0}
|
|
||||||
,{"beta", (float)1.0}
|
|
||||||
,{"transA", 0}
|
|
||||||
,{"transB", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "GlobalAveragePool") {
|
}else if (OpName == "GlobalAveragePool") {
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "GlobalLpPool") {
|
}else if (OpName == "GlobalLpPool") {
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1);
|
||||||
{"p", 2}
|
|
||||||
});
|
|
||||||
}else if (OpName == "GlobalMaxPool") {
|
}else if (OpName == "GlobalMaxPool") {
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Greater") {
|
}else if (OpName == "Greater") {
|
||||||
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "HardSigmoid") {
|
}else if (OpName == "HardSigmoid") {
|
||||||
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1);
|
||||||
{"alpha", (float)0.2}
|
|
||||||
,{"beta", (float)0.5}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Hardmax") {
|
}else if (OpName == "Hardmax") {
|
||||||
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1);
|
||||||
{"axis", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Identity") {
|
}else if (OpName == "Identity") {
|
||||||
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "If") {
|
}else if (OpName == "If") {
|
||||||
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "InstanceNormalization") {
|
}else if (OpName == "InstanceNormalization") {
|
||||||
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1);
|
||||||
{"epsilon", (float)1e-05}
|
|
||||||
});
|
|
||||||
}else if (OpName == "IsInf") {
|
}else if (OpName == "IsInf") {
|
||||||
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1);
|
||||||
{"detect_negative", 1}
|
|
||||||
,{"detect_positive", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "IsNaN") {
|
}else if (OpName == "IsNaN") {
|
||||||
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "LRN") {
|
}else if (OpName == "LRN") {
|
||||||
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1);
|
||||||
{"alpha", (float)0.0001}
|
|
||||||
,{"beta", (float)0.75}
|
|
||||||
,{"bias", (float)1.0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "LSTM") {
|
}else if (OpName == "LSTM") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, {
|
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3);
|
||||||
{"direction", "forward"}
|
|
||||||
,{"input_forget", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "LeakyRelu") {
|
}else if (OpName == "LeakyRelu") {
|
||||||
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1);
|
||||||
{"alpha", (float)0.01}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Less") {
|
}else if (OpName == "Less") {
|
||||||
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Log") {
|
}else if (OpName == "Log") {
|
||||||
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "LogSoftmax") {
|
}else if (OpName == "LogSoftmax") {
|
||||||
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1);
|
||||||
{"axis", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Loop") {
|
}else if (OpName == "Loop") {
|
||||||
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "LpNormalization") {
|
}else if (OpName == "LpNormalization") {
|
||||||
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1);
|
||||||
{"axis", -1}
|
|
||||||
,{"p", 2}
|
|
||||||
});
|
|
||||||
}else if (OpName == "LpPool") {
|
}else if (OpName == "LpPool") {
|
||||||
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1);
|
||||||
{"auto_pad", "NOTSET"}
|
|
||||||
,{"p", 2}
|
|
||||||
});
|
|
||||||
}else if (OpName == "MatMul") {
|
}else if (OpName == "MatMul") {
|
||||||
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "MatMulInteger") {
|
}else if (OpName == "MatMulInteger") {
|
||||||
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1, {
|
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Max") {
|
}else if (OpName == "Max") {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "MaxPool") {
|
}else if (OpName == "MaxPool") {
|
||||||
ImportNodeMaxPool(node, 1, 2, {
|
ImportNodeMaxPool(node, 1, 2);
|
||||||
{"auto_pad", "NOTSET"}
|
|
||||||
,{"ceil_mode", 0}
|
|
||||||
,{"kernel_shape", std::vector<int64_t> {}}
|
|
||||||
,{"storage_order", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "MaxRoiPool") {
|
}else if (OpName == "MaxRoiPool") {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1);
|
||||||
{"spatial_scale", (float)1.0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "MaxUnpool") {
|
}else if (OpName == "MaxUnpool") {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Mean") {
|
}else if (OpName == "Mean") {
|
||||||
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "MeanVarianceNormalization") {
|
}else if (OpName == "MeanVarianceNormalization") {
|
||||||
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1);
|
||||||
{"axes", std::vector<int64_t>{0, 2, 3}}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Min") {
|
}else if (OpName == "Min") {
|
||||||
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Mod") {
|
}else if (OpName == "Mod") {
|
||||||
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1);
|
||||||
{"fmod", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Mul") {
|
}else if (OpName == "Mul") {
|
||||||
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Multinomial") {
|
}else if (OpName == "Multinomial") {
|
||||||
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1);
|
||||||
{"dtype", 6}
|
|
||||||
,{"sample_size", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Neg") {
|
}else if (OpName == "Neg") {
|
||||||
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "NonMaxSuppression") {
|
}else if (OpName == "NonMaxSuppression") {
|
||||||
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, {
|
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1);
|
||||||
{"center_point_box", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "NonZero") {
|
}else if (OpName == "NonZero") {
|
||||||
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Not") {
|
}else if (OpName == "Not") {
|
||||||
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "OneHot") {
|
}else if (OpName == "OneHot") {
|
||||||
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1);
|
||||||
{"axis", -1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Or") {
|
}else if (OpName == "Or") {
|
||||||
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "PRelu") {
|
}else if (OpName == "PRelu") {
|
||||||
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Pad") {
|
}else if (OpName == "Pad") {
|
||||||
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1);
|
||||||
{"mode", "constant"}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Pow") {
|
}else if (OpName == "Pow") {
|
||||||
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "QLinearConv") {
|
}else if (OpName == "QLinearConv") {
|
||||||
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, {
|
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1);
|
||||||
{"auto_pad", "NOTSET"}
|
|
||||||
,{"group", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "QLinearMatMul") {
|
}else if (OpName == "QLinearMatMul") {
|
||||||
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, {
|
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "QuantizeLinear") {
|
}else if (OpName == "QuantizeLinear") {
|
||||||
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "RNN") {
|
}else if (OpName == "RNN") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, {
|
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2);
|
||||||
{"activation_alpha", std::vector<float> {}}
|
|
||||||
,{"activation_beta", std::vector<float> {}}
|
|
||||||
,{"activations", std::vector<std::string>{"Tanh", "Tanh"}}
|
|
||||||
,{"direction", "forward"}
|
|
||||||
});
|
|
||||||
}else if (OpName == "RandomNormal") {
|
}else if (OpName == "RandomNormal") {
|
||||||
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, {
|
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1);
|
||||||
{"dtype", 1}
|
|
||||||
,{"mean", (float)0.0}
|
|
||||||
,{"scale", (float)1.0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "RandomNormalLike") {
|
}else if (OpName == "RandomNormalLike") {
|
||||||
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1);
|
||||||
{"mean", (float)0.0}
|
|
||||||
,{"scale", (float)1.0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "RandomUniform") {
|
}else if (OpName == "RandomUniform") {
|
||||||
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, {
|
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1);
|
||||||
{"dtype", 1}
|
|
||||||
,{"high", (float)1.0}
|
|
||||||
,{"low", (float)0.0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "RandomUniformLike") {
|
}else if (OpName == "RandomUniformLike") {
|
||||||
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1);
|
||||||
{"high", (float)1.0}
|
|
||||||
,{"low", (float)0.0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Range") {
|
}else if (OpName == "Range") {
|
||||||
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Reciprocal") {
|
}else if (OpName == "Reciprocal") {
|
||||||
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceL1") {
|
}else if (OpName == "ReduceL1") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceL2") {
|
}else if (OpName == "ReduceL2") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceLogSum") {
|
}else if (OpName == "ReduceLogSum") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceLogSumExp") {
|
}else if (OpName == "ReduceLogSumExp") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceMax") {
|
}else if (OpName == "ReduceMax") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceMean") {
|
}else if (OpName == "ReduceMean") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceMin") {
|
}else if (OpName == "ReduceMin") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceProd") {
|
}else if (OpName == "ReduceProd") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceSum") {
|
}else if (OpName == "ReduceSum") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReduceSumSquare") {
|
}else if (OpName == "ReduceSumSquare") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1);
|
||||||
{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Relu") {
|
}else if (OpName == "Relu") {
|
||||||
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Reshape") {
|
}else if (OpName == "Reshape") {
|
||||||
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Resize") {
|
}else if (OpName == "Resize") {
|
||||||
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, {
|
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1);
|
||||||
{"coordinate_transformation_mode", "half_pixel"}
|
|
||||||
,{"cubic_coeff_a", (float)-0.75}
|
|
||||||
,{"exclude_outside", 0}
|
|
||||||
,{"extrapolation_value", (float)0.0}
|
|
||||||
,{"mode", "nearest"}
|
|
||||||
,{"nearest_mode", "round_prefer_floor"}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ReverseSequence") {
|
}else if (OpName == "ReverseSequence") {
|
||||||
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1);
|
||||||
{"batch_axis", 1}
|
|
||||||
,{"time_axis", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "RoiAlign") {
|
}else if (OpName == "RoiAlign") {
|
||||||
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1);
|
||||||
{"mode", "avg"}
|
|
||||||
,{"output_height", 1}
|
|
||||||
,{"output_width", 1}
|
|
||||||
,{"sampling_ratio", 0}
|
|
||||||
,{"spatial_scale", (float)1.0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Round") {
|
}else if (OpName == "Round") {
|
||||||
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Scan") {
|
}else if (OpName == "Scan") {
|
||||||
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Scatter") {
|
}else if (OpName == "Scatter") {
|
||||||
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1);
|
||||||
{"axis", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ScatterElements") {
|
}else if (OpName == "ScatterElements") {
|
||||||
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1);
|
||||||
{"axis", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "ScatterND") {
|
}else if (OpName == "ScatterND") {
|
||||||
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Selu") {
|
}else if (OpName == "Selu") {
|
||||||
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1);
|
||||||
{"alpha", (float)1.67326}
|
|
||||||
,{"gamma", (float)1.0507}
|
|
||||||
});
|
|
||||||
}else if (OpName == "SequenceAt") {
|
}else if (OpName == "SequenceAt") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "SequenceConstruct") {
|
}else if (OpName == "SequenceConstruct") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "SequenceEmpty") {
|
}else if (OpName == "SequenceEmpty") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, {
|
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "SequenceErase") {
|
}else if (OpName == "SequenceErase") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "SequenceInsert") {
|
}else if (OpName == "SequenceInsert") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "SequenceLength") {
|
}else if (OpName == "SequenceLength") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Shape") {
|
}else if (OpName == "Shape") {
|
||||||
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Shrink") {
|
}else if (OpName == "Shrink") {
|
||||||
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1);
|
||||||
{"bias", (float)0.0}
|
|
||||||
,{"lambd", (float)0.5}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Sigmoid") {
|
}else if (OpName == "Sigmoid") {
|
||||||
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Sign") {
|
}else if (OpName == "Sign") {
|
||||||
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Sin") {
|
}else if (OpName == "Sin") {
|
||||||
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Sinh") {
|
}else if (OpName == "Sinh") {
|
||||||
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Size") {
|
}else if (OpName == "Size") {
|
||||||
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Slice") {
|
}else if (OpName == "Slice") {
|
||||||
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1, {
|
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Softmax") {
|
}else if (OpName == "Softmax") {
|
||||||
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1);
|
||||||
{"axis", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Softplus") {
|
}else if (OpName == "Softplus") {
|
||||||
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Softsign") {
|
}else if (OpName == "Softsign") {
|
||||||
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "SpaceToDepth") {
|
}else if (OpName == "SpaceToDepth") {
|
||||||
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Split") {
|
}else if (OpName == "Split") {
|
||||||
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1);
|
||||||
{"axis", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "SplitToSequence") {
|
}else if (OpName == "SplitToSequence") {
|
||||||
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1);
|
||||||
{"axis", 0}
|
|
||||||
,{"keepdims", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Sqrt") {
|
}else if (OpName == "Sqrt") {
|
||||||
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Squeeze") {
|
}else if (OpName == "Squeeze") {
|
||||||
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "StringNormalizer") {
|
}else if (OpName == "StringNormalizer") {
|
||||||
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1);
|
||||||
{"case_change_action", "NONE"}
|
|
||||||
,{"is_case_sensitive", 0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Sub") {
|
}else if (OpName == "Sub") {
|
||||||
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Sum") {
|
}else if (OpName == "Sum") {
|
||||||
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Tan") {
|
}else if (OpName == "Tan") {
|
||||||
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Tanh") {
|
}else if (OpName == "Tanh") {
|
||||||
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "TfIdfVectorizer") {
|
}else if (OpName == "TfIdfVectorizer") {
|
||||||
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "ThresholdedRelu") {
|
}else if (OpName == "ThresholdedRelu") {
|
||||||
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1);
|
||||||
{"alpha", (float)1.0}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Tile") {
|
}else if (OpName == "Tile") {
|
||||||
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "TopK") {
|
}else if (OpName == "TopK") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, {
|
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2);
|
||||||
{"axis", -1}
|
|
||||||
,{"largest", 1}
|
|
||||||
,{"sorted", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Transpose") {
|
}else if (OpName == "Transpose") {
|
||||||
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Unique") {
|
}else if (OpName == "Unique") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, {
|
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4);
|
||||||
{"sorted", 1}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Unsqueeze") {
|
}else if (OpName == "Unsqueeze") {
|
||||||
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1);
|
||||||
{"axes", std::vector<int64_t> {}}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Upsample") {
|
}else if (OpName == "Upsample") {
|
||||||
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1);
|
||||||
{"mode", "nearest"}
|
|
||||||
});
|
|
||||||
}else if (OpName == "Where") {
|
}else if (OpName == "Where") {
|
||||||
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1);
|
||||||
});
|
|
||||||
}else if (OpName == "Xor") {
|
}else if (OpName == "Xor") {
|
||||||
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1);
|
||||||
});
|
|
||||||
}
|
}
|
|
@ -270,6 +270,12 @@ def gen_schema(schema) :
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
||||||
'Softplus', 'Softsign']
|
'Softplus', 'Softsign']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
|
manual_code = dict([
|
||||||
|
('DummyExample', ' let extraClassDeclaration = [{ \n'+
|
||||||
|
' static StringRef getPermAttrName() { return "perm"; }\n'+
|
||||||
|
' }];\n')
|
||||||
|
])
|
||||||
|
skip_attr_gen = []
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
#s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n'
|
#s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n'
|
||||||
|
@ -303,21 +309,23 @@ def gen_schema(schema) :
|
||||||
|
|
||||||
#input
|
#input
|
||||||
s+= '\n'+line_indent+'let arguments = (ins '
|
s+= '\n'+line_indent+'let arguments = (ins '
|
||||||
|
isfirst = True
|
||||||
if schema.inputs:
|
if schema.inputs:
|
||||||
|
isfirst = False
|
||||||
for input in schema.inputs:
|
for input in schema.inputs:
|
||||||
if input != schema.inputs[0] :
|
if input != schema.inputs[0] :
|
||||||
s+= ', '
|
s+= ',\n '
|
||||||
etypes=collect_types(schema, input)
|
etypes=collect_types(schema, input)
|
||||||
|
|
||||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||||
#TODO: handle optional
|
#TODO: handle optional
|
||||||
print("optional ", input.name)
|
print("warning: optional input for"+schema.name+' '+input.name)
|
||||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||||
if input.isHomogeneous:
|
if input.isHomogeneous:
|
||||||
s+= 'Variadic<'
|
s+= 'Variadic<'
|
||||||
else:
|
else:
|
||||||
#TODO handle (variadic, heterogeneous)"
|
#TODO handle (variadic, heterogeneous)"
|
||||||
print('variadic, heterogeneous', input.name)
|
print("warning: (variadic, heterogeneous) for"+schema.name+' '+input.name)
|
||||||
if etypes == '':
|
if etypes == '':
|
||||||
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
||||||
else:
|
else:
|
||||||
|
@ -333,6 +341,8 @@ def gen_schema(schema) :
|
||||||
#TODO handle (variadic, heterogeneous)"
|
#TODO handle (variadic, heterogeneous)"
|
||||||
t=''
|
t=''
|
||||||
s+=':$'+input.name
|
s+=':$'+input.name
|
||||||
|
if not schema.name in skip_attr_gen :
|
||||||
|
s += gen_attr_ins(schema, isfirst)
|
||||||
s+= ');'
|
s+= ');'
|
||||||
|
|
||||||
#output
|
#output
|
||||||
|
@ -347,11 +357,15 @@ def gen_schema(schema) :
|
||||||
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
s+= 'AnyTypeOf<[AnyMemRef, AnyTensor]>'
|
||||||
else:
|
else:
|
||||||
s+= 'TensorOf<['+etypes+']>'
|
s+= 'TensorOf<['+etypes+']>'
|
||||||
s+= ');'
|
s += ':$o_'+output.name
|
||||||
|
s+= ');\n'
|
||||||
|
|
||||||
#s+= 'let hasCanonicalizer = 1;'
|
#s+= 'let hasCanonicalizer = 1;'
|
||||||
|
#add special code
|
||||||
|
if schema.name in manual_code :
|
||||||
|
s += manual_code[schema.name]
|
||||||
|
|
||||||
s += '\n}\n\n'
|
s += '}\n\n'
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
@ -369,44 +383,91 @@ def gen_code(schema,fefile) :
|
||||||
("MaxPool", "ImportNodeMaxPool"),
|
("MaxPool", "ImportNodeMaxPool"),
|
||||||
#("Transpose", "ImportNodeTranspose")
|
#("Transpose", "ImportNodeTranspose")
|
||||||
])
|
])
|
||||||
list_str = 'std::vector'
|
|
||||||
empty_ints = list_str+'<int> {}'
|
|
||||||
empty_floats = list_str+'<float> {}'
|
|
||||||
special_default = dict([
|
|
||||||
("AveragePool "+"kernel_shape", empty_ints),
|
|
||||||
("MaxPool "+"kernel_shape", empty_ints),
|
|
||||||
("Cast "+"to", '0'),
|
|
||||||
("Concat "+"axis", '0'),
|
|
||||||
("Unsqueeze "+"axes", empty_ints),
|
|
||||||
("RNN "+"activation_alpha", empty_floats),
|
|
||||||
("RNN "+"activation_beta", empty_floats)
|
|
||||||
])
|
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
||||||
op_type_str='mlir::ONNX'+schema.name+'Op'
|
op_type_str='mlir::ONNX'+schema.name+'Op'
|
||||||
if schema.name in special_handler :
|
if schema.name in special_handler :
|
||||||
fefile.write(' '+special_handler[schema.name]+'(node, '
|
fefile.write(' '+special_handler[schema.name]+'(node, '
|
||||||
+str(len(schema.inputs))
|
+str(len(schema.inputs))
|
||||||
+', ' +str(len(schema.outputs))+', {\n')
|
+', ' +str(len(schema.outputs)))
|
||||||
elif len(schema.outputs) > 1 :
|
elif len(schema.outputs) > 1 :
|
||||||
fefile.write(' '+'ImportNodeMultipleOuts<'+op_type_str+'>(node, '
|
fefile.write(' '+'ImportNodeMultipleOuts<'+op_type_str+'>(node, '
|
||||||
+str(len(schema.inputs))
|
+str(len(schema.inputs))
|
||||||
+', ' +str(len(schema.outputs))+', {\n')
|
+', ' +str(len(schema.outputs)))
|
||||||
else :
|
else :
|
||||||
fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
|
fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
|
||||||
+str(len(schema.inputs))
|
+str(len(schema.inputs))
|
||||||
+', ' +str(len(schema.outputs))+', {\n')
|
+', ' +str(len(schema.outputs)))
|
||||||
|
fefile.write(');\n')
|
||||||
|
|
||||||
|
def gen_attr_ins(schema, isfirst) :
|
||||||
|
special_defaults = dict([
|
||||||
|
("AveragePool "+"kernel_shape", ('ints', '{}')),
|
||||||
|
("MaxPool "+"kernel_shape", ('ints', '{}')),
|
||||||
|
("Cast "+"to", ('int', '0')),
|
||||||
|
("Concat "+"axis", ('int', '0')),
|
||||||
|
("Conv "+"group", ('int', '1')),
|
||||||
|
("Unsqueeze "+"axes", ('ints', '{}')),
|
||||||
|
("RNN "+"activation_alpha", ('floats', '{}')),
|
||||||
|
("RNN "+"activation_beta", ('floats', '{}')),
|
||||||
|
])
|
||||||
|
|
||||||
|
def get_attr_type_basic(attr_type) :
|
||||||
|
if attr_type == 'int' :
|
||||||
|
mytype = 'I64Attr'
|
||||||
|
elif attr_type == 'float' :
|
||||||
|
mytype = 'F32Attr'
|
||||||
|
elif attr_type == 'ints' :
|
||||||
|
mytype = 'I64ArrayAttr'
|
||||||
|
elif attr_type == 'floats' :
|
||||||
|
mytype = 'F32ArrayAttr'
|
||||||
|
elif attr_type == "string" :
|
||||||
|
mytype = 'StrAttr'
|
||||||
|
elif attr_type == "strings" :
|
||||||
|
mytype = 'StrArrayAttr'
|
||||||
|
else :
|
||||||
|
mytype ='AnyAttr'
|
||||||
|
#TODO: tensor and sparse tensor
|
||||||
|
return mytype
|
||||||
|
|
||||||
|
def get_attr_type_optional(attr_type) :
|
||||||
|
mytype = 'OptionalAttr<'
|
||||||
|
mytype += get_attr_type_basic(attr_type)
|
||||||
|
mytype += '>'
|
||||||
|
return mytype
|
||||||
|
|
||||||
|
def get_attr_type_with_default(attr_type, attr_default) :
|
||||||
|
mytype = 'DefaultValuedAttr<'
|
||||||
|
mytype += get_attr_type_basic(attr_type)
|
||||||
|
mytype += ', "'+attr_default+'">'
|
||||||
|
return mytype
|
||||||
|
|
||||||
|
attr_line = ''
|
||||||
if schema.attributes:
|
if schema.attributes:
|
||||||
first_attr = True
|
|
||||||
for _, attr in sorted(schema.attributes.items()):
|
for _, attr in sorted(schema.attributes.items()):
|
||||||
#only generate default attr list
|
#attr_line = line_indent+line_indent+line_indent+line_indent
|
||||||
if schema.name+' '+attr.name in special_default:
|
if not isfirst:
|
||||||
attr_value = special_default[schema.name+' '+attr.name]
|
attr_line += ',\n '
|
||||||
elif attr.default_value.name:
|
else :
|
||||||
default_value = helper.get_attribute_value(attr.default_value)
|
isfirst = False
|
||||||
|
|
||||||
|
if schema.name+' '+attr.name in special_defaults:
|
||||||
|
(attr_type_str, attr_default_str) = special_defaults[schema.name+' '+attr.name]
|
||||||
|
attr_line += get_attr_type_with_default(attr_type_str, attr_default_str)
|
||||||
|
attr_line += ':$'+attr.name
|
||||||
|
elif attr.required:
|
||||||
|
s = Text(attr.type)
|
||||||
|
attr_type_str = s[s.rfind('.') + 1:].lower()
|
||||||
|
attr_line += get_attr_type_basic(attr_type_str)
|
||||||
|
attr_line += ':$'+attr.name
|
||||||
|
|
||||||
|
# option holds either required or default value
|
||||||
|
elif attr.default_value.name:
|
||||||
|
s = Text(attr.type)
|
||||||
|
attr_type_str = s[s.rfind('.') + 1:].lower()
|
||||||
|
|
||||||
|
default_value = helper.get_attribute_value(attr.default_value)
|
||||||
def format_value(value): # type: (Any) -> Text
|
def format_value(value): # type: (Any) -> Text
|
||||||
if isinstance(value, float):
|
if isinstance(value, float):
|
||||||
formatted = str(np.round(value, 5))
|
formatted = str(np.round(value, 5))
|
||||||
|
@ -419,66 +480,25 @@ def gen_code(schema,fefile) :
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
if isinstance(default_value, list):
|
if isinstance(default_value, list):
|
||||||
|
|
||||||
value = default_value[0]
|
|
||||||
default_value = [format_value(val) for val in default_value]
|
default_value = [format_value(val) for val in default_value]
|
||||||
attr_option_str = '{}'.format(default_value)
|
attr_option_str = '{}'.format(default_value)
|
||||||
attr_option_str = attr_option_str.replace('[', '{', 1)
|
attr_option_str = attr_option_str.replace('[', '{', 1)
|
||||||
attr_option_str = attr_option_str.replace(']', '}', 1)
|
attr_option_str = attr_option_str.replace(']', '}', 1)
|
||||||
# TODO the list type is homogenous or htergeneous?
|
if attr_type_str == 'strings' :
|
||||||
|
attr_option_str = attr_option_str.replace("'", '\\"')
|
||||||
if isinstance(value, float) :
|
|
||||||
attr_type_str = list_str+'<float>'
|
|
||||||
attr_option_str = attr_option_str.replace("'", '')
|
|
||||||
elif isinstance(value, int) :
|
|
||||||
attr_type_str = list_str+'<int>'
|
|
||||||
attr_option_str = attr_option_str.replace("'", '')
|
|
||||||
elif isinstance(value, str) :
|
|
||||||
attr_type_str = list_str+'<std::string>'
|
|
||||||
attr_option_str = attr_option_str.replace("'", '"')
|
|
||||||
elif isinstance(value, (bytes, bytearray)) :
|
|
||||||
attr_type_str = list_str+'<std::string>'
|
|
||||||
attr_option_str = attr_option_str.replace("'", '"')
|
|
||||||
else :
|
else :
|
||||||
attr_type_str = '"unknowns"'
|
attr_option_str = attr_option_str.replace("'", '')
|
||||||
else:
|
else:
|
||||||
if isinstance(default_value, float) :
|
|
||||||
attr_type_str = '(float)'
|
|
||||||
attr_option_str = default_value
|
|
||||||
elif isinstance(default_value, int) :
|
|
||||||
attr_option_str = default_value
|
|
||||||
attr_type_str=''
|
|
||||||
elif isinstance(default_value, str) :
|
|
||||||
attr_type_str = '"str"'
|
|
||||||
elif isinstance(default_value, (bytes, bytearray)) :
|
|
||||||
attr_type_str = '"str"'
|
|
||||||
else :
|
|
||||||
attr_type_str = '"unknown"'
|
|
||||||
default_value = format_value(default_value)
|
default_value = format_value(default_value)
|
||||||
if attr_type_str == '"str"' :
|
attr_option_str = default_value
|
||||||
attr_option_str = '"'+default_value+'"'
|
attr_line += get_attr_type_with_default(attr_type_str, attr_option_str)
|
||||||
attr_type_str=''
|
attr_line += ':$'+attr.name
|
||||||
else :
|
|
||||||
attr_option_str = default_value
|
|
||||||
attr_value = attr_type_str+attr_option_str
|
|
||||||
else:
|
else:
|
||||||
#no default value
|
s = Text(attr.type)
|
||||||
continue
|
attr_type_str = s[s.rfind('.') + 1:].lower()
|
||||||
|
attr_line += get_attr_type_optional(attr_type_str)
|
||||||
attr_line = line_indent+line_indent+line_indent+line_indent
|
attr_line += ':$'+attr.name
|
||||||
if not first_attr:
|
return attr_line
|
||||||
attr_line += ',{'
|
|
||||||
else :
|
|
||||||
attr_line += ' {'
|
|
||||||
first_attr = False
|
|
||||||
|
|
||||||
attr_line += '"'+attr.name+'", '
|
|
||||||
attr_line += attr_value
|
|
||||||
attr_line += '}\n'
|
|
||||||
fefile.write(attr_line)
|
|
||||||
fefile.write(line_indent+line_indent+line_indent+'});\n')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main(args): # type: (Type[Args]) -> None
|
def main(args): # type: (Type[Args]) -> None
|
||||||
with io.open(args.changelog, 'w', newline='') as fout:
|
with io.open(args.changelog, 'w', newline='') as fout:
|
||||||
|
@ -496,7 +516,6 @@ def main(args): # type: (Type[Args]) -> None
|
||||||
fout.write('\n')
|
fout.write('\n')
|
||||||
|
|
||||||
for domain, versionmap in sorted(dv_index.items()):
|
for domain, versionmap in sorted(dv_index.items()):
|
||||||
print("domain", domain)
|
|
||||||
if not should_render_domain(domain):
|
if not should_render_domain(domain):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -633,6 +652,6 @@ if __name__ == '__main__':
|
||||||
class Args(object):
|
class Args(object):
|
||||||
output = os.path.join(docs_dir, 'Operators' + ext)
|
output = os.path.join(docs_dir, 'Operators' + ext)
|
||||||
changelog = os.path.join(docs_dir, 'Changelog' + ext)
|
changelog = os.path.join(docs_dir, 'Changelog' + ext)
|
||||||
tdfile = os.path.join(docs_dir, 'onnxop.inc')
|
tdfile = os.path.join(base_dir, 'onnxop.inc')
|
||||||
print(Args)
|
print(Args)
|
||||||
main(Args)
|
main(Args)
|
||||||
|
|
|
@ -99,8 +99,14 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
||||||
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
|
||||||
|
DefaultValuedAttr<F32Attr, "1.0">:$alpha,
|
||||||
|
DefaultValuedAttr<F32Attr, "1.0">:$beta,
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$transA,
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$transB);
|
||||||
|
|
||||||
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||||
|
@ -110,10 +116,15 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||||
"The convolution operator consumes an input tensor and a filter, and"
|
"The convolution operator consumes an input tensor and a filter, and"
|
||||||
"computes the output."
|
"computes the output."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$W,
|
||||||
|
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
|
||||||
let verifier = [{ return ::verify(*this); }];
|
OptionalAttr<I64ArrayAttr>:$dilations,
|
||||||
|
DefaultValuedAttr<I64Attr, "1">:$group,
|
||||||
|
OptionalAttr<I64ArrayAttr>:$kernel_shape,
|
||||||
|
OptionalAttr<I64ArrayAttr>:$pads,
|
||||||
|
OptionalAttr<I64ArrayAttr>:$strides);
|
||||||
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||||
|
@ -123,8 +134,15 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||||
"ONNX MaxPool operation with a single output."
|
"ONNX MaxPool operation with a single output."
|
||||||
"See ONNXMaxPoolOp for a full description of the MaxPool semantics."
|
"See ONNXMaxPoolOp for a full description of the MaxPool semantics."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad,
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$ceil_mode,
|
||||||
|
OptionalAttr<I64ArrayAttr>:$dilations,
|
||||||
|
DefaultValuedAttr<I64ArrayAttr, "{}">:$kernel_shape,
|
||||||
|
OptionalAttr<I64ArrayAttr>:$pads,
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$storage_order,
|
||||||
|
OptionalAttr<I64ArrayAttr>:$strides);
|
||||||
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // ONNX_OPS
|
#endif // ONNX_OPS
|
||||||
|
|
|
@ -435,8 +435,10 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
|
|
||||||
if (auto permutation = getAttrOfType<ArrayAttr>(
|
//if (auto permutation = getAttrOfType<ArrayAttr>(
|
||||||
ONNXTransposeOp::getPermAttrName())) {
|
// ONNXTransposeOp::getPermAttrName())) {
|
||||||
|
auto permutation = ONNXTransposeOp::permAttr();
|
||||||
|
if (permutation) {
|
||||||
// Perform transposition according to perm attribute.
|
// Perform transposition according to perm attribute.
|
||||||
for (auto perm : permutation.getValue())
|
for (auto perm : permutation.getValue())
|
||||||
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
||||||
|
@ -449,20 +451,6 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verify(ONNXTransposeOp op) {
|
|
||||||
auto module = op.getParentOfType<ModuleOp>();
|
|
||||||
if (!module)
|
|
||||||
op.emitError("Expected to belong to a module.");
|
|
||||||
|
|
||||||
if (auto permutation = op.getAttrOfType<ArrayAttr>(
|
|
||||||
ONNXTransposeOp::getPermAttrName())) {
|
|
||||||
for (auto perm : permutation.getValue())
|
|
||||||
if (perm.cast<IntegerAttr>().getInt() < 0)
|
|
||||||
op.emitError("Cannot tranpose, permuation contains negative index.");
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
@ -491,11 +479,9 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
emitError("Weight size not compatible with data size.");
|
emitError("Weight size not compatible with data size.");
|
||||||
|
|
||||||
// Required attribute auto_pad defaults to NOTSET.
|
// Required attribute auto_pad defaults to NOTSET.
|
||||||
auto autoPad = getAttrOfType<StringAttr>(
|
auto autoPad = auto_pad();
|
||||||
ONNXConvOp::getAutoPadAttrName()).getValue();
|
|
||||||
// Group is a required attribute and should have default value of 1.
|
// Group is a required attribute and should have default value of 1.
|
||||||
int64_t group = getAttrOfType<IntegerAttr>(
|
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
|
||||||
ONNXConvOp::getGroupAttrName()).getInt();
|
|
||||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||||
if (dataShape[1] != (weightShape[1] * group))
|
if (dataShape[1] != (weightShape[1] * group))
|
||||||
emitError("Channel dimension mismatch.");
|
emitError("Channel dimension mismatch.");
|
||||||
|
@ -527,8 +513,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// Use kernel_shape attribute if present otherwise use size from weight
|
// Use kernel_shape attribute if present otherwise use size from weight
|
||||||
// argument.
|
// argument.
|
||||||
SmallVector<int64_t, 2> kernelDims;
|
SmallVector<int64_t, 2> kernelDims;
|
||||||
if (auto kernelShape = getAttrOfType<ArrayAttr>(
|
if (auto kernelShape = kernel_shapeAttr()) {
|
||||||
ONNXConvOp::getKernelShapeAttrName())) {
|
|
||||||
if (kernelShape.getValue().size() != nDims)
|
if (kernelShape.getValue().size() != nDims)
|
||||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
emitError("kernel_shape length incompatible with spatial dimensions.");
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
|
@ -550,8 +535,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
//
|
//
|
||||||
// From a dimensionality perspective the kernel size becomes the dilated
|
// From a dimensionality perspective the kernel size becomes the dilated
|
||||||
// kernel size.
|
// kernel size.
|
||||||
if (auto dilations = getAttrOfType<ArrayAttr>(
|
if (auto dilations = dilationsAttr()) {
|
||||||
ONNXConvOp::getDilationsAttrName())) {
|
|
||||||
if (dilations.getValue().size() != nDims)
|
if (dilations.getValue().size() != nDims)
|
||||||
emitError("dilations length incompatible with spatial dimensions.");
|
emitError("dilations length incompatible with spatial dimensions.");
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
|
@ -567,8 +551,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
if (autoPad == "NOTSET") {
|
if (autoPad == "NOTSET") {
|
||||||
// Use pads to to determine the padding. If attribute is not
|
// Use pads to to determine the padding. If attribute is not
|
||||||
// present then pads is considered to be all zeros (no padding).
|
// present then pads is considered to be all zeros (no padding).
|
||||||
if (auto pads = getAttrOfType<ArrayAttr>(
|
if (auto pads = padsAttr()) {
|
||||||
ONNXConvOp::getPadsAttrName())) {
|
|
||||||
// pads consists of two entries for each spatial axis.
|
// pads consists of two entries for each spatial axis.
|
||||||
if (pads.getValue().size() != 2 * nDims)
|
if (pads.getValue().size() != 2 * nDims)
|
||||||
emitError("pads size is not twice the spatial size.");
|
emitError("pads size is not twice the spatial size.");
|
||||||
|
@ -599,13 +582,12 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strides
|
// Strides
|
||||||
if (auto strides = getAttrOfType<ArrayAttr>(
|
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
|
||||||
ONNXConvOp::getStridesAttrName())) {
|
|
||||||
if (strides.getValue().size() != nDims)
|
if (strides.getValue().size() != nDims)
|
||||||
emitError("strides length incompatible with spatial dimensions.");
|
emitError("strides length incompatible with spatial dimensions.");
|
||||||
for (int i = 0; i < nDims; ++i) {
|
for (int i = 0; i < nDims; ++i) {
|
||||||
int64_t stride =
|
int64_t stride =
|
||||||
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
|
strides.getValue()[i].cast<IntegerAttr>().getInt();
|
||||||
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -617,28 +599,6 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verify(ONNXConvNoBiasOp op) {
|
|
||||||
auto module = op.getParentOfType<ModuleOp>();
|
|
||||||
if (!module)
|
|
||||||
op.emitError("expected to belong to a module");
|
|
||||||
|
|
||||||
auto autoPadAttr = op.getAttrOfType<StringAttr>(
|
|
||||||
ONNXConvOp::getAutoPadAttrName());
|
|
||||||
if (!autoPadAttr)
|
|
||||||
op.emitError("auto_pad attribute not specified.");
|
|
||||||
if (autoPadAttr.getValue() != "NOTSET")
|
|
||||||
if (auto pads = op.getAttrOfType<ArrayAttr>(
|
|
||||||
ONNXConvOp::getPadsAttrName()))
|
|
||||||
op.emitError("auto_pad and pads are both set.");
|
|
||||||
|
|
||||||
auto groupAttr =
|
|
||||||
op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName());
|
|
||||||
if (!groupAttr)
|
|
||||||
op.emitError("group attribute not specified.");
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -30,9 +30,14 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||||
// Pattern-Match and Rewrite
|
// Pattern-Match and Rewrite
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def GemmAlpha : NativeCodeCall<"$_builder.getF32FloatAttr(1.0)">;
|
||||||
|
def GemmBeta : NativeCodeCall<"$_builder.getF32FloatAttr(1.0)">;
|
||||||
|
def GemmTransA : NativeCodeCall<"$_builder.getI64IntegerAttr(0)">;
|
||||||
|
def GemmTransB : NativeCodeCall<"$_builder.getI64IntegerAttr(0)">;
|
||||||
|
|
||||||
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
|
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
|
||||||
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
||||||
(ONNXGemmOp $m1, $m2, $m3),
|
(ONNXGemmOp $m1, $m2, $m3, (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
|
||||||
[(HasOneUse $res)]>;
|
[(HasOneUse $res)]>;
|
||||||
|
|
||||||
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
||||||
// CHECK-LABEL: test_matmul_add_simplification
|
// CHECK-LABEL: test_matmul_add_simplification
|
||||||
// CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
// CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
||||||
|
|
|
@ -579,7 +579,7 @@ func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Softmax"(%arg0) {axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Softmax"(%arg0) {axis=1:i64} : (tensor<10x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_softmax
|
// CHECK-LABEL: test_softmax
|
||||||
|
|
|
@ -32,120 +32,120 @@ func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
/// Default and required attributes.
|
/// Default and required attributes.
|
||||||
|
|
||||||
func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_1
|
// CHECK-LABEL: test_conv_no_bias_1
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32>
|
||||||
|
|
||||||
/// kernel_shape attribute.
|
/// kernel_shape attribute.
|
||||||
|
|
||||||
func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_2
|
// CHECK-LABEL: test_conv_no_bias_2
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32>
|
||||||
|
|
||||||
/// pads attribute.
|
/// pads attribute.
|
||||||
/// Use pads to make output size equal to input size by adding K - 1 to the result.
|
/// Use pads to make output size equal to input size by adding K - 1 to the result.
|
||||||
|
|
||||||
func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_3
|
// CHECK-LABEL: test_conv_no_bias_3
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||||
|
|
||||||
/// auto_pad set to SAME_UPPER and SAME_LOWER.
|
/// auto_pad set to SAME_UPPER and SAME_LOWER.
|
||||||
|
|
||||||
func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_4
|
// CHECK-LABEL: test_conv_no_bias_4
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||||
|
|
||||||
func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_5
|
// CHECK-LABEL: test_conv_no_bias_5
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||||
|
|
||||||
/// auto_pad set to VALID.
|
/// auto_pad set to VALID.
|
||||||
|
|
||||||
func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_6
|
// CHECK-LABEL: test_conv_no_bias_6
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32>
|
||||||
|
|
||||||
/// With strides attribute.
|
/// With strides attribute.
|
||||||
|
|
||||||
func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_7
|
// CHECK-LABEL: test_conv_no_bias_7
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32>
|
||||||
|
|
||||||
/// auto_pad set to SAME_UPPER with strides attribute.
|
/// auto_pad set to SAME_UPPER with strides attribute.
|
||||||
/// The auto_pad will pas as if stride is equal to 1.
|
/// The auto_pad will pas as if stride is equal to 1.
|
||||||
|
|
||||||
func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_8
|
// CHECK-LABEL: test_conv_no_bias_8
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32>
|
||||||
|
|
||||||
/// dilations attribute.
|
/// dilations attribute.
|
||||||
|
|
||||||
func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_9
|
// CHECK-LABEL: test_conv_no_bias_9
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x20x42xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x20x42xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x20x42xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x20x42xf32>
|
||||||
|
|
||||||
/// dilations attribute with stride.
|
/// dilations attribute with stride.
|
||||||
|
|
||||||
func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i32, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_10
|
// CHECK-LABEL: test_conv_no_bias_10
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i32, strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x10x21xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x10x21xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x10x21xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x10x21xf32>
|
||||||
|
|
||||||
/// dilations attribute with auto_pad set to SAME_UPPER.
|
/// dilations attribute with auto_pad set to SAME_UPPER.
|
||||||
|
|
||||||
func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i32, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_conv_no_bias_11
|
// CHECK-LABEL: test_conv_no_bias_11
|
||||||
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", dilations = [2, 3], group = 1 : i32} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32>
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", dilations = [2, 3], group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32>
|
||||||
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32>
|
||||||
|
|
Loading…
Reference in New Issue