[MLIR] generate op from onnx document (#366)
* generate op from onnx document * Restore FullGemm * update the op attribute for shape inference and canonicalizer * Update onnx_canonicalization.mlir
This commit is contained in:
parent
d01ac7732f
commit
3f68c5420d
|
@ -1,5 +1,6 @@
|
|||
add_library(builder
|
||||
frontend_dialect_transformer.cpp
|
||||
op_build_table.inc
|
||||
)
|
||||
|
||||
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
|
|
|
@ -149,6 +149,33 @@ class FrontendGenImpl {
|
|||
}
|
||||
}
|
||||
|
||||
//if c++17 is used, these two def can be combined with 'if constexpr'
|
||||
//leave n there for possible future use
|
||||
//alternative is to use template and pass the outputTypes, inputs and attributes
|
||||
//as parameter
|
||||
|
||||
#define MultipleOuts(name, nIn, nOut)\
|
||||
{ \
|
||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
|
||||
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
|
||||
for (int i = 0; i < node.output().size(); i++) { \
|
||||
frontend_symbols_.AddMapping(\
|
||||
legalize_name(node.output()[i]), op.getResult(i));\
|
||||
}\
|
||||
return;\
|
||||
}\
|
||||
}
|
||||
|
||||
#define OneOut(name, nIn, nOut)\
|
||||
{ \
|
||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
|
||||
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
|
||||
frontend_symbols_.AddMapping(\
|
||||
legalize_name(node.output()[0]), op.getResult());\
|
||||
return;\
|
||||
}\
|
||||
}
|
||||
|
||||
/*!
|
||||
* Import an onnx input tensor type by determining and recording its type
|
||||
* in a list of input tensor mlir types.
|
||||
|
@ -206,39 +233,23 @@ class FrontendGenImpl {
|
|||
}
|
||||
}
|
||||
|
||||
// Handle ONNX Add Operation by using its representation in the
|
||||
// ONNX Dialect.
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
if (OpName == "Add") {
|
||||
auto op = builder_.create<mlir::ONNXAddOp>(UnknownLoc(),
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
|
||||
inputs[1]);
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[0]), op.getResult());
|
||||
return;
|
||||
} else if (OpName == "MatMul") {
|
||||
auto op = builder_.create<mlir::ONNXMatMulOp>(UnknownLoc(),
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
|
||||
inputs[1]);
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[0]), op.getResult());
|
||||
return;
|
||||
} else if (OpName == "Gemm") {
|
||||
if (inputs.size() == 3) {
|
||||
auto op = builder_.create<mlir::ONNXFullGemmOp>(UnknownLoc(),
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs[0],
|
||||
inputs[1], inputs[2]);
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[0]), op.getResult());
|
||||
} else {
|
||||
auto op = builder_.create<mlir::ONNXGemmOp>(UnknownLoc(),
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()), inputs);
|
||||
frontend_symbols_.AddMapping(
|
||||
legalize_name(node.output()[0]), op.getResult());
|
||||
}
|
||||
return;
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
for (auto item : node.output()) {
|
||||
outputTypes.push_back(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
}
|
||||
|
||||
std::vector<mlir::NamedAttribute> attributes;
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
|
||||
//the following code is generated by gen_doc.py
|
||||
//refer to dialect/onnx/onnx.td for details
|
||||
//when the input or output of then op does not match the specification,
|
||||
//the generic operator is used
|
||||
//one known reeason is the optional input
|
||||
|
||||
#include "src/builder/op_build_table.inc"
|
||||
|
||||
|
||||
// Old way of doing things.
|
||||
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
|
||||
for (auto item : node.output()) {
|
||||
|
|
|
@ -0,0 +1,313 @@
|
|||
if (OpName == "Abs") {
|
||||
OneOut(Abs, 1, 1);
|
||||
}else if (OpName == "Acos") {
|
||||
OneOut(Acos, 1, 1);
|
||||
}else if (OpName == "Acosh") {
|
||||
OneOut(Acosh, 1, 1);
|
||||
}else if (OpName == "Add") {
|
||||
OneOut(Add, 2, 1);
|
||||
}else if (OpName == "And") {
|
||||
OneOut(And, 2, 1);
|
||||
}else if (OpName == "ArgMax") {
|
||||
OneOut(ArgMax, 1, 1);
|
||||
}else if (OpName == "ArgMin") {
|
||||
OneOut(ArgMin, 1, 1);
|
||||
}else if (OpName == "Asin") {
|
||||
OneOut(Asin, 1, 1);
|
||||
}else if (OpName == "Asinh") {
|
||||
OneOut(Asinh, 1, 1);
|
||||
}else if (OpName == "Atan") {
|
||||
OneOut(Atan, 1, 1);
|
||||
}else if (OpName == "Atanh") {
|
||||
OneOut(Atanh, 1, 1);
|
||||
}else if (OpName == "AveragePool") {
|
||||
OneOut(AveragePool, 1, 1);
|
||||
}else if (OpName == "BatchNormalization") {
|
||||
MultipleOuts(BatchNormalization, 5, 5);
|
||||
}else if (OpName == "BitShift") {
|
||||
OneOut(BitShift, 2, 1);
|
||||
}else if (OpName == "Cast") {
|
||||
OneOut(Cast, 1, 1);
|
||||
}else if (OpName == "Ceil") {
|
||||
OneOut(Ceil, 1, 1);
|
||||
}else if (OpName == "Clip") {
|
||||
OneOut(Clip, 3, 1);
|
||||
}else if (OpName == "Compress") {
|
||||
OneOut(Compress, 2, 1);
|
||||
}else if (OpName == "Concat") {
|
||||
OneOut(Concat, 1, 1);
|
||||
}else if (OpName == "ConcatFromSequence") {
|
||||
OneOut(ConcatFromSequence, 1, 1);
|
||||
}else if (OpName == "Constant") {
|
||||
OneOut(Constant, 0, 1);
|
||||
}else if (OpName == "ConstantOfShape") {
|
||||
OneOut(ConstantOfShape, 1, 1);
|
||||
}else if (OpName == "Conv") {
|
||||
OneOut(Conv, 3, 1);
|
||||
}else if (OpName == "ConvInteger") {
|
||||
OneOut(ConvInteger, 4, 1);
|
||||
}else if (OpName == "ConvTranspose") {
|
||||
OneOut(ConvTranspose, 3, 1);
|
||||
}else if (OpName == "Cos") {
|
||||
OneOut(Cos, 1, 1);
|
||||
}else if (OpName == "Cosh") {
|
||||
OneOut(Cosh, 1, 1);
|
||||
}else if (OpName == "CumSum") {
|
||||
OneOut(CumSum, 2, 1);
|
||||
}else if (OpName == "DepthToSpace") {
|
||||
OneOut(DepthToSpace, 1, 1);
|
||||
}else if (OpName == "DequantizeLinear") {
|
||||
OneOut(DequantizeLinear, 3, 1);
|
||||
}else if (OpName == "Det") {
|
||||
OneOut(Det, 1, 1);
|
||||
}else if (OpName == "Div") {
|
||||
OneOut(Div, 2, 1);
|
||||
}else if (OpName == "Dropout") {
|
||||
MultipleOuts(Dropout, 1, 2);
|
||||
}else if (OpName == "DynamicQuantizeLinear") {
|
||||
MultipleOuts(DynamicQuantizeLinear, 1, 3);
|
||||
}else if (OpName == "Elu") {
|
||||
OneOut(Elu, 1, 1);
|
||||
}else if (OpName == "Equal") {
|
||||
OneOut(Equal, 2, 1);
|
||||
}else if (OpName == "Erf") {
|
||||
OneOut(Erf, 1, 1);
|
||||
}else if (OpName == "Exp") {
|
||||
OneOut(Exp, 1, 1);
|
||||
}else if (OpName == "Expand") {
|
||||
OneOut(Expand, 2, 1);
|
||||
}else if (OpName == "EyeLike") {
|
||||
OneOut(EyeLike, 1, 1);
|
||||
}else if (OpName == "Flatten") {
|
||||
OneOut(Flatten, 1, 1);
|
||||
}else if (OpName == "Floor") {
|
||||
OneOut(Floor, 1, 1);
|
||||
}else if (OpName == "GRU") {
|
||||
MultipleOuts(GRU, 6, 2);
|
||||
}else if (OpName == "Gather") {
|
||||
OneOut(Gather, 2, 1);
|
||||
}else if (OpName == "GatherElements") {
|
||||
OneOut(GatherElements, 2, 1);
|
||||
}else if (OpName == "GatherND") {
|
||||
OneOut(GatherND, 2, 1);
|
||||
}else if (OpName == "Gemm") {
|
||||
OneOut(Gemm, 3, 1);
|
||||
}else if (OpName == "GlobalAveragePool") {
|
||||
OneOut(GlobalAveragePool, 1, 1);
|
||||
}else if (OpName == "GlobalLpPool") {
|
||||
OneOut(GlobalLpPool, 1, 1);
|
||||
}else if (OpName == "GlobalMaxPool") {
|
||||
OneOut(GlobalMaxPool, 1, 1);
|
||||
}else if (OpName == "Greater") {
|
||||
OneOut(Greater, 2, 1);
|
||||
}else if (OpName == "HardSigmoid") {
|
||||
OneOut(HardSigmoid, 1, 1);
|
||||
}else if (OpName == "Hardmax") {
|
||||
OneOut(Hardmax, 1, 1);
|
||||
}else if (OpName == "Identity") {
|
||||
OneOut(Identity, 1, 1);
|
||||
}else if (OpName == "If") {
|
||||
OneOut(If, 1, 1);
|
||||
}else if (OpName == "InstanceNormalization") {
|
||||
OneOut(InstanceNormalization, 3, 1);
|
||||
}else if (OpName == "IsInf") {
|
||||
OneOut(IsInf, 1, 1);
|
||||
}else if (OpName == "IsNaN") {
|
||||
OneOut(IsNaN, 1, 1);
|
||||
}else if (OpName == "LRN") {
|
||||
OneOut(LRN, 1, 1);
|
||||
}else if (OpName == "LSTM") {
|
||||
MultipleOuts(LSTM, 8, 3);
|
||||
}else if (OpName == "LeakyRelu") {
|
||||
OneOut(LeakyRelu, 1, 1);
|
||||
}else if (OpName == "Less") {
|
||||
OneOut(Less, 2, 1);
|
||||
}else if (OpName == "Log") {
|
||||
OneOut(Log, 1, 1);
|
||||
}else if (OpName == "LogSoftmax") {
|
||||
OneOut(LogSoftmax, 1, 1);
|
||||
}else if (OpName == "Loop") {
|
||||
OneOut(Loop, 3, 1);
|
||||
}else if (OpName == "LpNormalization") {
|
||||
OneOut(LpNormalization, 1, 1);
|
||||
}else if (OpName == "LpPool") {
|
||||
OneOut(LpPool, 1, 1);
|
||||
}else if (OpName == "MatMul") {
|
||||
OneOut(MatMul, 2, 1);
|
||||
}else if (OpName == "MatMulInteger") {
|
||||
OneOut(MatMulInteger, 4, 1);
|
||||
}else if (OpName == "Max") {
|
||||
OneOut(Max, 1, 1);
|
||||
}else if (OpName == "MaxPool") {
|
||||
MultipleOuts(MaxPool, 1, 2);
|
||||
}else if (OpName == "MaxRoiPool") {
|
||||
OneOut(MaxRoiPool, 2, 1);
|
||||
}else if (OpName == "MaxUnpool") {
|
||||
OneOut(MaxUnpool, 3, 1);
|
||||
}else if (OpName == "Mean") {
|
||||
OneOut(Mean, 1, 1);
|
||||
}else if (OpName == "MeanVarianceNormalization") {
|
||||
OneOut(MeanVarianceNormalization, 1, 1);
|
||||
}else if (OpName == "Min") {
|
||||
OneOut(Min, 1, 1);
|
||||
}else if (OpName == "Mod") {
|
||||
OneOut(Mod, 2, 1);
|
||||
}else if (OpName == "Mul") {
|
||||
OneOut(Mul, 2, 1);
|
||||
}else if (OpName == "Multinomial") {
|
||||
OneOut(Multinomial, 1, 1);
|
||||
}else if (OpName == "Neg") {
|
||||
OneOut(Neg, 1, 1);
|
||||
}else if (OpName == "NonMaxSuppression") {
|
||||
OneOut(NonMaxSuppression, 5, 1);
|
||||
}else if (OpName == "NonZero") {
|
||||
OneOut(NonZero, 1, 1);
|
||||
}else if (OpName == "Not") {
|
||||
OneOut(Not, 1, 1);
|
||||
}else if (OpName == "OneHot") {
|
||||
OneOut(OneHot, 3, 1);
|
||||
}else if (OpName == "Or") {
|
||||
OneOut(Or, 2, 1);
|
||||
}else if (OpName == "PRelu") {
|
||||
OneOut(PRelu, 2, 1);
|
||||
}else if (OpName == "Pad") {
|
||||
OneOut(Pad, 3, 1);
|
||||
}else if (OpName == "Pow") {
|
||||
OneOut(Pow, 2, 1);
|
||||
}else if (OpName == "QLinearConv") {
|
||||
OneOut(QLinearConv, 9, 1);
|
||||
}else if (OpName == "QLinearMatMul") {
|
||||
OneOut(QLinearMatMul, 8, 1);
|
||||
}else if (OpName == "QuantizeLinear") {
|
||||
OneOut(QuantizeLinear, 3, 1);
|
||||
}else if (OpName == "RNN") {
|
||||
MultipleOuts(RNN, 6, 2);
|
||||
}else if (OpName == "RandomNormal") {
|
||||
OneOut(RandomNormal, 0, 1);
|
||||
}else if (OpName == "RandomNormalLike") {
|
||||
OneOut(RandomNormalLike, 1, 1);
|
||||
}else if (OpName == "RandomUniform") {
|
||||
OneOut(RandomUniform, 0, 1);
|
||||
}else if (OpName == "RandomUniformLike") {
|
||||
OneOut(RandomUniformLike, 1, 1);
|
||||
}else if (OpName == "Range") {
|
||||
OneOut(Range, 3, 1);
|
||||
}else if (OpName == "Reciprocal") {
|
||||
OneOut(Reciprocal, 1, 1);
|
||||
}else if (OpName == "ReduceL1") {
|
||||
OneOut(ReduceL1, 1, 1);
|
||||
}else if (OpName == "ReduceL2") {
|
||||
OneOut(ReduceL2, 1, 1);
|
||||
}else if (OpName == "ReduceLogSum") {
|
||||
OneOut(ReduceLogSum, 1, 1);
|
||||
}else if (OpName == "ReduceLogSumExp") {
|
||||
OneOut(ReduceLogSumExp, 1, 1);
|
||||
}else if (OpName == "ReduceMax") {
|
||||
OneOut(ReduceMax, 1, 1);
|
||||
}else if (OpName == "ReduceMean") {
|
||||
OneOut(ReduceMean, 1, 1);
|
||||
}else if (OpName == "ReduceMin") {
|
||||
OneOut(ReduceMin, 1, 1);
|
||||
}else if (OpName == "ReduceProd") {
|
||||
OneOut(ReduceProd, 1, 1);
|
||||
}else if (OpName == "ReduceSum") {
|
||||
OneOut(ReduceSum, 1, 1);
|
||||
}else if (OpName == "ReduceSumSquare") {
|
||||
OneOut(ReduceSumSquare, 1, 1);
|
||||
}else if (OpName == "Relu") {
|
||||
OneOut(Relu, 1, 1);
|
||||
}else if (OpName == "Reshape") {
|
||||
OneOut(Reshape, 2, 1);
|
||||
}else if (OpName == "Resize") {
|
||||
OneOut(Resize, 4, 1);
|
||||
}else if (OpName == "ReverseSequence") {
|
||||
OneOut(ReverseSequence, 2, 1);
|
||||
}else if (OpName == "RoiAlign") {
|
||||
OneOut(RoiAlign, 3, 1);
|
||||
}else if (OpName == "Round") {
|
||||
OneOut(Round, 1, 1);
|
||||
}else if (OpName == "Scan") {
|
||||
OneOut(Scan, 1, 1);
|
||||
}else if (OpName == "Scatter") {
|
||||
OneOut(Scatter, 3, 1);
|
||||
}else if (OpName == "ScatterElements") {
|
||||
OneOut(ScatterElements, 3, 1);
|
||||
}else if (OpName == "ScatterND") {
|
||||
OneOut(ScatterND, 3, 1);
|
||||
}else if (OpName == "Selu") {
|
||||
OneOut(Selu, 1, 1);
|
||||
}else if (OpName == "SequenceAt") {
|
||||
OneOut(SequenceAt, 2, 1);
|
||||
}else if (OpName == "SequenceConstruct") {
|
||||
OneOut(SequenceConstruct, 1, 1);
|
||||
}else if (OpName == "SequenceEmpty") {
|
||||
OneOut(SequenceEmpty, 0, 1);
|
||||
}else if (OpName == "SequenceErase") {
|
||||
OneOut(SequenceErase, 2, 1);
|
||||
}else if (OpName == "SequenceInsert") {
|
||||
OneOut(SequenceInsert, 3, 1);
|
||||
}else if (OpName == "SequenceLength") {
|
||||
OneOut(SequenceLength, 1, 1);
|
||||
}else if (OpName == "Shape") {
|
||||
OneOut(Shape, 1, 1);
|
||||
}else if (OpName == "Shrink") {
|
||||
OneOut(Shrink, 1, 1);
|
||||
}else if (OpName == "Sigmoid") {
|
||||
OneOut(Sigmoid, 1, 1);
|
||||
}else if (OpName == "Sign") {
|
||||
OneOut(Sign, 1, 1);
|
||||
}else if (OpName == "Sin") {
|
||||
OneOut(Sin, 1, 1);
|
||||
}else if (OpName == "Sinh") {
|
||||
OneOut(Sinh, 1, 1);
|
||||
}else if (OpName == "Size") {
|
||||
OneOut(Size, 1, 1);
|
||||
}else if (OpName == "Slice") {
|
||||
OneOut(Slice, 5, 1);
|
||||
}else if (OpName == "Softmax") {
|
||||
OneOut(Softmax, 1, 1);
|
||||
}else if (OpName == "Softplus") {
|
||||
OneOut(Softplus, 1, 1);
|
||||
}else if (OpName == "Softsign") {
|
||||
OneOut(Softsign, 1, 1);
|
||||
}else if (OpName == "SpaceToDepth") {
|
||||
OneOut(SpaceToDepth, 1, 1);
|
||||
}else if (OpName == "Split") {
|
||||
OneOut(Split, 1, 1);
|
||||
}else if (OpName == "SplitToSequence") {
|
||||
OneOut(SplitToSequence, 2, 1);
|
||||
}else if (OpName == "Sqrt") {
|
||||
OneOut(Sqrt, 1, 1);
|
||||
}else if (OpName == "Squeeze") {
|
||||
OneOut(Squeeze, 1, 1);
|
||||
}else if (OpName == "StringNormalizer") {
|
||||
OneOut(StringNormalizer, 1, 1);
|
||||
}else if (OpName == "Sub") {
|
||||
OneOut(Sub, 2, 1);
|
||||
}else if (OpName == "Sum") {
|
||||
OneOut(Sum, 1, 1);
|
||||
}else if (OpName == "Tan") {
|
||||
OneOut(Tan, 1, 1);
|
||||
}else if (OpName == "Tanh") {
|
||||
OneOut(Tanh, 1, 1);
|
||||
}else if (OpName == "TfIdfVectorizer") {
|
||||
OneOut(TfIdfVectorizer, 1, 1);
|
||||
}else if (OpName == "ThresholdedRelu") {
|
||||
OneOut(ThresholdedRelu, 1, 1);
|
||||
}else if (OpName == "Tile") {
|
||||
OneOut(Tile, 2, 1);
|
||||
}else if (OpName == "TopK") {
|
||||
MultipleOuts(TopK, 2, 2);
|
||||
}else if (OpName == "Transpose") {
|
||||
OneOut(Transpose, 1, 1);
|
||||
}else if (OpName == "Unique") {
|
||||
MultipleOuts(Unique, 1, 4);
|
||||
}else if (OpName == "Unsqueeze") {
|
||||
OneOut(Unsqueeze, 1, 1);
|
||||
}else if (OpName == "Upsample") {
|
||||
OneOut(Upsample, 2, 1);
|
||||
}else if (OpName == "Where") {
|
||||
OneOut(Where, 3, 1);
|
||||
}else if (OpName == "Xor") {
|
||||
OneOut(Xor, 2, 1);
|
||||
}
|
|
@ -10,8 +10,9 @@ add_library(
|
|||
dialect/krnl/parser_helper.hpp
|
||||
pass/shape_inference_pass.cpp
|
||||
pass/shape_inference_interface.hpp
|
||||
pass/onnx_combine.cpp
|
||||
pass/passes.hpp)
|
||||
pass/passes.hpp
|
||||
dialect/onnx/onnxop.inc
|
||||
pass/onnx_combine.cpp)
|
||||
|
||||
# Include root src directory.
|
||||
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
|
||||
|
|
|
@ -0,0 +1,520 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from collections import defaultdict
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from onnx import defs, FunctionProto, helper, OperatorStatus
|
||||
from onnx.defs import OpSchema, ONNX_DOMAIN, ONNX_ML_DOMAIN
|
||||
from onnx.backend.test.case import collect_snippets
|
||||
from onnx.backend.sample.ops import collect_sample_implementations
|
||||
from typing import Any, Text, Sequence, Dict, List, Type, Set, Tuple
|
||||
|
||||
|
||||
SNIPPETS = collect_snippets()
|
||||
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
|
||||
ONNX_ML = not bool(os.getenv('ONNX_ML') == '0')
|
||||
|
||||
ONNX_ML = False
|
||||
print("ONNX_ML", ONNX_ML)
|
||||
|
||||
|
||||
if ONNX_ML:
|
||||
ext = '-ml.md'
|
||||
else:
|
||||
ext = '.md'
|
||||
|
||||
|
||||
def display_number(v): # type: (int) -> Text
|
||||
if defs.OpSchema.is_infinite(v):
|
||||
return '∞'
|
||||
return Text(v)
|
||||
|
||||
|
||||
def should_render_domain(domain): # type: (Text) -> bool
|
||||
if domain == ONNX_ML_DOMAIN and not ONNX_ML:
|
||||
return False
|
||||
elif ONNX_ML and domain != ONNX_ML_DOMAIN:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def format_name_with_domain(domain, schema_name): # type: (Text, Text) -> Text
|
||||
if domain:
|
||||
return '{}.{}'.format(domain, schema_name)
|
||||
else:
|
||||
return schema_name
|
||||
|
||||
|
||||
def display_attr_type(v): # type: (OpSchema.AttrType) -> Text
|
||||
assert isinstance(v, OpSchema.AttrType)
|
||||
s = Text(v)
|
||||
s = s[s.rfind('.') + 1:].lower()
|
||||
if s[-1] == 's':
|
||||
s = 'list of ' + s
|
||||
return s
|
||||
|
||||
|
||||
def display_domain(domain): # type: (Text) -> Text
|
||||
if domain:
|
||||
return "the '{}' operator set".format(domain)
|
||||
else:
|
||||
return "the default ONNX operator set"
|
||||
|
||||
|
||||
def display_domain_short(domain): # type: (Text) -> Text
|
||||
if domain:
|
||||
return domain
|
||||
else:
|
||||
return 'ai.onnx (default)'
|
||||
|
||||
|
||||
def display_version_link(name, version): # type: (Text, int) -> Text
|
||||
changelog_md = 'Changelog' + ext
|
||||
name_with_ver = '{}-{}'.format(name, version)
|
||||
return '<a href="{}#{}">{}</a>'.format(changelog_md, name_with_ver, name_with_ver)
|
||||
|
||||
|
||||
def display_schema(schema, versions): # type: (OpSchema, Sequence[OpSchema]) -> Text
|
||||
s = ''
|
||||
|
||||
# doc
|
||||
if schema.doc:
|
||||
s += '\n'
|
||||
s += '\n'.join(' ' + line
|
||||
for line in schema.doc.lstrip().splitlines())
|
||||
s += '\n'
|
||||
|
||||
# since version
|
||||
s += '\n#### Version\n'
|
||||
if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
|
||||
s += '\nNo versioning maintained for experimental ops.'
|
||||
else:
|
||||
s += '\nThis version of the operator has been ' + ('deprecated' if schema.deprecated else 'available') + ' since version {}'.format(schema.since_version)
|
||||
s += ' of {}.\n'.format(display_domain(schema.domain))
|
||||
if len(versions) > 1:
|
||||
# TODO: link to the Changelog.md
|
||||
s += '\nOther versions of this operator: {}\n'.format(
|
||||
', '.join(display_version_link(format_name_with_domain(v.domain, v.name),
|
||||
v.since_version) for v in versions[:-1]))
|
||||
|
||||
# If this schema is deprecated, don't display any of the following sections
|
||||
if schema.deprecated:
|
||||
return s
|
||||
|
||||
# attributes
|
||||
if schema.attributes:
|
||||
s += '\n#### Attributes\n\n'
|
||||
s += '<dl>\n'
|
||||
for _, attr in sorted(schema.attributes.items()):
|
||||
# option holds either required or default value
|
||||
opt = ''
|
||||
if attr.required:
|
||||
opt = 'required'
|
||||
elif attr.default_value.name:
|
||||
default_value = helper.get_attribute_value(attr.default_value)
|
||||
|
||||
def format_value(value): # type: (Any) -> Text
|
||||
if isinstance(value, float):
|
||||
formatted = str(np.round(value, 5))
|
||||
# use default formatting, unless too long.
|
||||
if (len(formatted) > 10):
|
||||
formatted = str("({:e})".format(value))
|
||||
return formatted
|
||||
elif isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3:
|
||||
return str(value.decode('utf-8'))
|
||||
return str(value)
|
||||
|
||||
if isinstance(default_value, list):
|
||||
default_value = [format_value(val) for val in default_value]
|
||||
else:
|
||||
default_value = format_value(default_value)
|
||||
opt = 'default is {}'.format(default_value)
|
||||
|
||||
s += '<dt><tt>{}</tt> : {}{}</dt>\n'.format(
|
||||
attr.name,
|
||||
display_attr_type(attr.type),
|
||||
' ({})'.format(opt) if opt else '')
|
||||
s += '<dd>{}</dd>\n'.format(attr.description)
|
||||
s += '</dl>\n'
|
||||
|
||||
# inputs
|
||||
s += '\n#### Inputs'
|
||||
if schema.min_input != schema.max_input:
|
||||
s += ' ({} - {})'.format(display_number(schema.min_input),
|
||||
display_number(schema.max_input))
|
||||
s += '\n\n'
|
||||
if schema.inputs:
|
||||
s += '<dl>\n'
|
||||
for input in schema.inputs:
|
||||
option_str = ""
|
||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||
option_str = " (optional)"
|
||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||
if input.isHomogeneous:
|
||||
option_str = " (variadic)"
|
||||
else:
|
||||
option_str = " (variadic, heterogeneous)"
|
||||
s += '<dt><tt>{}</tt>{} : {}</dt>\n'.format(input.name, option_str, input.typeStr)
|
||||
s += '<dd>{}</dd>\n'.format(input.description)
|
||||
s += '</dl>\n'
|
||||
|
||||
# outputs
|
||||
s += '\n#### Outputs'
|
||||
if schema.min_output != schema.max_output:
|
||||
s += ' ({} - {})'.format(display_number(schema.min_output),
|
||||
display_number(schema.max_output))
|
||||
s += '\n\n'
|
||||
|
||||
if schema.outputs:
|
||||
s += '<dl>\n'
|
||||
for output in schema.outputs:
|
||||
option_str = ""
|
||||
if OpSchema.FormalParameterOption.Optional == output.option:
|
||||
option_str = " (optional)"
|
||||
elif OpSchema.FormalParameterOption.Variadic == output.option:
|
||||
if output.isHomogeneous:
|
||||
option_str = " (variadic)"
|
||||
else:
|
||||
option_str = " (variadic, heterogeneous)"
|
||||
s += '<dt><tt>{}</tt>{} : {}</dt>\n'.format(output.name, option_str, output.typeStr)
|
||||
s += '<dd>{}</dd>\n'.format(output.description)
|
||||
s += '</dl>\n'
|
||||
|
||||
# type constraints
|
||||
s += '\n#### Type Constraints'
|
||||
s += '\n\n'
|
||||
if schema.type_constraints:
|
||||
s += '<dl>\n'
|
||||
for type_constraint in schema.type_constraints:
|
||||
allowedTypes = type_constraint.allowed_type_strs
|
||||
if (len(allowedTypes) > 0):
|
||||
allowedTypeStr = allowedTypes[0]
|
||||
for allowedType in allowedTypes[1:]:
|
||||
allowedTypeStr += ', ' + allowedType
|
||||
s += '<dt><tt>{}</tt> : {}</dt>\n'.format(
|
||||
type_constraint.type_param_str, allowedTypeStr)
|
||||
s += '<dd>{}</dd>\n'.format(type_constraint.description)
|
||||
s += '</dl>\n'
|
||||
|
||||
# Function Body
|
||||
if schema.has_function: # type: ignore
|
||||
s += '\n#### Function\n'
|
||||
s += '\nThe Function can be represented as a function.\n'
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def support_level_str(level): # type: (OpSchema.SupportType) -> Text
|
||||
return \
|
||||
"<sub>experimental</sub> " if level == OpSchema.SupportType.EXPERIMENTAL else ""
|
||||
|
||||
def convert_type(tstr) :
|
||||
tfrom = np.array(['bool', 'int8', 'int16', 'int32', 'int64',
|
||||
'unkown', 'float16', 'float', 'double'])
|
||||
tto =np.array(['I1', 'I8', 'I16', 'I32', 'I64',
|
||||
'BF16', 'F16', 'F32', 'F64'])
|
||||
index = -1
|
||||
for i in range(len(tfrom)) :
|
||||
if tfrom[i] in tstr :
|
||||
index = i
|
||||
break
|
||||
if index == -1 :
|
||||
print("error", tstr)
|
||||
return ''
|
||||
else :
|
||||
return tto[i]
|
||||
|
||||
def collect_types(schema, input) :
|
||||
allowedTypeStr=''
|
||||
#first step just ignore the type constraints
|
||||
return allowedTypeStr
|
||||
if input.typeStr :
|
||||
tstr = input.typeStr
|
||||
else :
|
||||
return allwedTypeStr
|
||||
if schema.type_constraints:
|
||||
for type_constraint in schema.type_constraints:
|
||||
if type_constraint.type_param_str != tstr :
|
||||
continue
|
||||
allowedTypes = type_constraint.allowed_type_strs
|
||||
allowedTypeStr=''
|
||||
if (len(allowedTypes) > 0):
|
||||
t = convert_type(allowedTypes[0])
|
||||
if t == '' :
|
||||
return ''
|
||||
allowedTypeStr += t
|
||||
for allowedType in allowedTypes[1:]:
|
||||
t = convert_type(allowedType)
|
||||
if t == '' :
|
||||
return ''
|
||||
if not t in allowedTypeStr :
|
||||
allowedTypeStr += ', '+t
|
||||
|
||||
return allowedTypeStr
|
||||
|
||||
return allowedTypeStr
|
||||
|
||||
def gen_schema(schema) :
|
||||
ShapeInferenceList=['Add', 'MatMul', 'Gemm']
|
||||
CanonicalList=['Add']
|
||||
line_indent = ' '
|
||||
|
||||
#s = 'def ONNX'+schema.name+str(schema.since_version)+'Op:ONNX_Op<"'+schema.name+'", \n'
|
||||
s = 'def ONNX'+schema.name+'Op:ONNX_Op<"'+schema.name+'", \n'
|
||||
s += line_indent+' [NoSideEffect'
|
||||
if schema.name in ShapeInferenceList :
|
||||
s+= ', DeclareOpInterfaceMethods<ShapeInferenceOpInterface>'
|
||||
s += ']> {'
|
||||
|
||||
if schema.name in CanonicalList:
|
||||
s += '\n'+line_indent+'let hasCanonicalizer = 1;'
|
||||
|
||||
#summary
|
||||
s += '\n'+line_indent
|
||||
s += 'let summary = "ONNX '+schema.name+' operation";'
|
||||
|
||||
#description
|
||||
s += '\n'+line_indent
|
||||
s += 'let description = [{'
|
||||
if schema.doc:
|
||||
"""
|
||||
s += '\n'.join(line_indent + line
|
||||
for line in schema.doc.lstrip().splitlines())
|
||||
"""
|
||||
for line in schema.doc.lstrip().splitlines():
|
||||
line = line.replace('}]', '\}\]')
|
||||
s += '\n'+line_indent+' '+'"'+line+'"'
|
||||
else :
|
||||
s += '\n'+line_indent*2 +'no doc for this op from onnx'
|
||||
s += '\n'+line_indent+'}];'
|
||||
|
||||
#input
|
||||
s+= '\n'+line_indent+'let arguments = (ins '
|
||||
if schema.inputs:
|
||||
for input in schema.inputs:
|
||||
if input != schema.inputs[0] :
|
||||
s+= ', '
|
||||
etypes=collect_types(schema, input)
|
||||
|
||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||
#TODO: handle optional
|
||||
print("optional ", input.name)
|
||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||
if input.isHomogeneous:
|
||||
s+= 'Variadic<'
|
||||
else:
|
||||
#TODO handle (variadic, heterogeneous)"
|
||||
print('variadic, heterogeneous', input.name)
|
||||
if etypes == '':
|
||||
s+= 'AnyTensor'
|
||||
else:
|
||||
s+= 'TensorOf<['+etypes+']>'
|
||||
|
||||
if OpSchema.FormalParameterOption.Optional == input.option:
|
||||
#TODO: handle optional
|
||||
t=''
|
||||
elif OpSchema.FormalParameterOption.Variadic == input.option:
|
||||
if input.isHomogeneous:
|
||||
s+= '>'
|
||||
else:
|
||||
#TODO handle (variadic, heterogeneous)"
|
||||
t=''
|
||||
s+=':$'+input.name
|
||||
s+= ');'
|
||||
|
||||
#output
|
||||
s+= '\n'+line_indent+'let results = (outs '
|
||||
if schema.outputs:
|
||||
for output in schema.outputs:
|
||||
if output != schema.outputs[0] :
|
||||
s+= ', '
|
||||
#need to interpret output.typeStr
|
||||
etypes=collect_types(schema, output)
|
||||
if etypes == '':
|
||||
s+= 'AnyTensor'
|
||||
else:
|
||||
s+= 'TensorOf<['+etypes+']>'
|
||||
s+= ');'
|
||||
|
||||
#s+= 'let hasCanonicalizer = 1;'
|
||||
|
||||
s += '\n}\n\n'
|
||||
|
||||
return s
|
||||
|
||||
def main(args): # type: (Type[Args]) -> None
|
||||
with io.open(args.changelog, 'w', newline='') as fout:
|
||||
fout.write('## Operator Changelog\n')
|
||||
fout.write(
|
||||
"*This file is automatically generated from the\n"
|
||||
" [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
|
||||
" Do not modify directly and instead edit operator definitions.*\n")
|
||||
|
||||
# domain -> version -> [schema]
|
||||
dv_index = defaultdict(lambda: defaultdict(list)) # type: Dict[Text, Dict[int, List[OpSchema]]]
|
||||
for schema in defs.get_all_schemas_with_history():
|
||||
dv_index[schema.domain][schema.since_version].append(schema)
|
||||
|
||||
fout.write('\n')
|
||||
|
||||
for domain, versionmap in sorted(dv_index.items()):
|
||||
print("domain", domain)
|
||||
if not should_render_domain(domain):
|
||||
continue
|
||||
|
||||
s = '# {}\n'.format(display_domain_short(domain))
|
||||
|
||||
for version, unsorted_schemas in sorted(versionmap.items()):
|
||||
s += '## Version {} of {}\n'.format(version, display_domain(domain))
|
||||
for schema in sorted(unsorted_schemas, key=lambda s: s.name):
|
||||
name_with_ver = '{}-{}'.format(format_name_with_domain(domain, schema.name),
|
||||
schema.since_version)
|
||||
s += ('### <a name="{}"></a>**{}**' + (' (deprecated)' if schema.deprecated else '') + '</a>\n').format(name_with_ver, name_with_ver)
|
||||
s += display_schema(schema, [schema])
|
||||
s += '\n'
|
||||
|
||||
fout.write(s)
|
||||
|
||||
with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
|
||||
fout.write('## Operator Schemas\n')
|
||||
fout.write(
|
||||
"*This file is automatically generated from the\n"
|
||||
" [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
|
||||
" Do not modify directly and instead edit operator definitions.*\n")
|
||||
|
||||
# domain -> support level -> name -> [schema]
|
||||
index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]]
|
||||
for schema in defs.get_all_schemas_with_history():
|
||||
#print("check point 0", schema.name, schema.domain, schema.support_level)
|
||||
#gen_schema(schema)
|
||||
index[schema.domain][int(schema.support_level)][schema.name].append(schema)
|
||||
|
||||
fout.write('\n')
|
||||
|
||||
# Preprocess the Operator Schemas
|
||||
# [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
|
||||
operator_schemas = list() # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]]
|
||||
exsting_ops = set() # type: Set[Text]
|
||||
for domain, _supportmap in sorted(index.items()):
|
||||
if not should_render_domain(domain):
|
||||
continue
|
||||
|
||||
processed_supportmap = list()
|
||||
for _support, _namemap in sorted(_supportmap.items()):
|
||||
processed_namemap = list()
|
||||
for n, unsorted_versions in sorted(_namemap.items()):
|
||||
versions = sorted(unsorted_versions, key=lambda s: s.since_version)
|
||||
schema = versions[-1]
|
||||
#print("check point 2", schema)
|
||||
if schema.name in exsting_ops:
|
||||
continue
|
||||
exsting_ops.add(schema.name)
|
||||
processed_namemap.append((n, schema, versions))
|
||||
processed_supportmap.append((_support, processed_namemap))
|
||||
operator_schemas.append((domain, processed_supportmap))
|
||||
|
||||
# Table of contents
|
||||
for domain, supportmap in operator_schemas:
|
||||
s = '* {}\n'.format(display_domain_short(domain))
|
||||
fout.write(s)
|
||||
function_ops = list()
|
||||
for _, namemap in supportmap:
|
||||
for n, schema, versions in namemap:
|
||||
if schema.has_function: # type: ignore
|
||||
function_ops.append((n, schema, versions))
|
||||
continue
|
||||
s = ' * {}<a href="#{}">{}</a>\n'.format(
|
||||
support_level_str(schema.support_level),
|
||||
format_name_with_domain(domain, n),
|
||||
format_name_with_domain(domain, n))
|
||||
fout.write(s)
|
||||
if len(function_ops):
|
||||
fout.write('\n')
|
||||
fout.write(' **Operators with function registered:**\n')
|
||||
for n, schema, versions in function_ops:
|
||||
s = ' * {}<a href="#{}">{}</a>\n'.format(
|
||||
support_level_str(schema.support_level),
|
||||
format_name_with_domain(domain, n),
|
||||
format_name_with_domain(domain, n))
|
||||
fout.write(s)
|
||||
|
||||
fout.write('\n')
|
||||
tdfile= io.open(args.tdfile, 'w', newline='')
|
||||
fefile=io.open('op_build_table.inc', 'w', newline='')
|
||||
firstfunc = True
|
||||
|
||||
for domain, supportmap in operator_schemas:
|
||||
s = '## {}\n'.format(display_domain_short(domain))
|
||||
fout.write(s)
|
||||
|
||||
for _, namemap in supportmap:
|
||||
for op_type, schema, versions in namemap:
|
||||
# op_type
|
||||
#print("check point 1", schema.name, len(schema.inputs), len(schema.outputs))
|
||||
if firstfunc :
|
||||
fefile.write(' '+'if (OpName == "'+schema.name+'") {\n')
|
||||
firstfunc = False
|
||||
else :
|
||||
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
||||
if len(schema.outputs) > 1 :
|
||||
fefile.write(' '+'MultipleOuts('+schema.name+', '
|
||||
+str(schema.since_version)+', '
|
||||
+str(len(schema.inputs))+', '
|
||||
+str(len(schema.outputs))+');\n')
|
||||
else :
|
||||
fefile.write(' '+'OneOut('+schema.name+', '
|
||||
+str(schema.since_version)+', '
|
||||
+str(len(schema.inputs))+', '
|
||||
+str(len(schema.outputs))+');\n')
|
||||
r = gen_schema(schema)
|
||||
tdfile.write(r)
|
||||
s = ('### {}<a name="{}"></a><a name="{}">**{}**' + (' (deprecated)' if schema.deprecated else '') + '</a>\n').format(
|
||||
support_level_str(schema.support_level),
|
||||
format_name_with_domain(domain, op_type),
|
||||
format_name_with_domain(domain, op_type.lower()),
|
||||
format_name_with_domain(domain, op_type))
|
||||
|
||||
s += display_schema(schema, versions)
|
||||
|
||||
s += '\n\n'
|
||||
|
||||
if op_type in SNIPPETS:
|
||||
s += '#### Examples\n\n'
|
||||
for summary, code in sorted(SNIPPETS[op_type]):
|
||||
s += '<details>\n'
|
||||
s += '<summary>{}</summary>\n\n'.format(summary)
|
||||
s += '```python\n{}\n```\n\n'.format(code)
|
||||
s += '</details>\n'
|
||||
s += '\n\n'
|
||||
if op_type.lower() in SAMPLE_IMPLEMENTATIONS:
|
||||
s += '#### Sample Implementation\n\n'
|
||||
s += '<details>\n'
|
||||
s += '<summary>{}</summary>\n\n'.format(op_type)
|
||||
s += '```python\n{}\n```\n\n'.format(SAMPLE_IMPLEMENTATIONS[op_type.lower()])
|
||||
s += '</details>\n'
|
||||
s += '\n\n'
|
||||
|
||||
fout.write(s)
|
||||
fefile.write(' }')
|
||||
fefile.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
docs_dir = os.path.join(base_dir, 'docs')
|
||||
print(docs_dir)
|
||||
|
||||
class Args(object):
|
||||
output = os.path.join(docs_dir, 'Operators' + ext)
|
||||
changelog = os.path.join(docs_dir, 'Changelog' + ext)
|
||||
tdfile = os.path.join(docs_dir, 'onnxop.inc')
|
||||
print(Args)
|
||||
main(Args)
|
|
@ -39,66 +39,36 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
// ONNX Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// We define an ONNX operation for adding two tensors elementwise.
|
||||
def ONNXAddOp: ONNX_Op<"add",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX add operation";
|
||||
let description = [{
|
||||
//the tablegen code onnxop.in is generated with gen_doc.py
|
||||
//clone and install onnx
|
||||
// git clone --recursive https://github.com/onnx/onnx.git
|
||||
// set up env for anaconda3 and for DLC (BOOSTROOT, cmake, gcc ...)
|
||||
// cd onnx
|
||||
//install onnx
|
||||
// CC=gcc CXX=g++ pip install -e .
|
||||
//run the script
|
||||
// python onnx/defs/gen_doc.py
|
||||
//result is in docs/onnxop.inc
|
||||
//current limitations:
|
||||
// 1. Attributes are not processed
|
||||
// 2. output type inference not implemented except Add
|
||||
// 3. Type Attribute: 'optional' and 'Variadic hetergeneous' are ignored
|
||||
// 4. type of string, complex64 and complex128 for input/output are ignored
|
||||
// 5. unsigned int are treated as signed one
|
||||
|
||||
The "onnx.add" adds two tensors element-wise.
|
||||
include "dialect/onnx/onnxop.inc"
|
||||
|
||||
}];
|
||||
def ONNXFullGemmOp: ONNX_Op<"full_gemm",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX general matrix multiply operation";
|
||||
let description = [{
|
||||
|
||||
// TODO: AnyTensor might be too wide for ONNX and may need to be constrained
|
||||
// to fewer valid types.
|
||||
// In the ONNX spec:
|
||||
// T : tensor(uint32), tensor(uint64),
|
||||
// tensor(int32), tensor(int64),
|
||||
// tensor(float16), tensor(float), tensor(double)
|
||||
//
|
||||
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in);
|
||||
let results = (outs AnyTensor);
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
The "onnx.gemm" generic matrix multiplication with bias.
|
||||
|
||||
def ONNXMatMulOp: ONNX_Op<"matmul",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX matrix multiply operation";
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
The "onnx.mul" multiplies two matrices.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[F32Tensor, F64Tensor]>:$lhs_in,
|
||||
AnyTypeOf<[F32Tensor, F64Tensor]>:$rhs_in);
|
||||
let results = (outs AnyTypeOf<[F32Tensor, F64Tensor]>);
|
||||
}
|
||||
|
||||
def ONNXGemmOp: ONNX_Op<"gemm",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX general matrix multiply operation";
|
||||
let description = [{
|
||||
|
||||
The "onnx.gemm" generic matrix multiplication with bias.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyTensor>:$inputs);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def ONNXFullGemmOp: ONNX_Op<"full_gemm",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX general matrix multiply operation";
|
||||
let description = [{
|
||||
|
||||
The "onnx.gemm" generic matrix multiplication with bias.
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
|
||||
let results = (outs AnyTensor);
|
||||
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
#endif // ONNX_OPS
|
||||
|
|
|
@ -39,9 +39,9 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
|
|||
//===----------------------------------------------------------------------===//
|
||||
// ONNX Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Add
|
||||
|
||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXAddOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -7,8 +7,8 @@ module {
|
|||
%1 = "frontend.input t2"() : () -> tensor<10x10xf32>
|
||||
%2 = "frontend.input t3"() : () -> tensor<10x10xf32>
|
||||
// CHECK: %{{[0-9]+}} = "onnx.full_gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%3 = "onnx.matmul"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%4 = "onnx.add"(%3, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%3 = "onnx.MatMul"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%4 = "onnx.Add"(%3, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%5 = "frontend.output t4"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue