Avoid hardcoding ops with shape inference in ShapeInferencePass (#165)
* Avoid hardcoding ops with shape inference in ShapeInferencePass * Minimize the changes * Clang-format Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
e0ae583da0
commit
60c648ae39
|
@ -59,7 +59,7 @@ public:
|
||||||
if (dynamicOperations != 0) {
|
if (dynamicOperations != 0) {
|
||||||
f.emitError("Shape inference failed, ")
|
f.emitError("Shape inference failed, ")
|
||||||
<< dynamicOperations << " operations couldn't be inferred\n";
|
<< dynamicOperations << " operations couldn't be inferred\n";
|
||||||
signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
||||||
|
@ -73,75 +73,6 @@ public:
|
||||||
* Check if the given operation has a dynamically shaped result.
|
* Check if the given operation has a dynamically shaped result.
|
||||||
*/
|
*/
|
||||||
static bool returnsDynamicShape(Operation *op) {
|
static bool returnsDynamicShape(Operation *op) {
|
||||||
// TODO: remove this check.
|
|
||||||
// Temporary fix until more ops are supported.
|
|
||||||
// All operations which do not return a ranked tensor type have dynamic
|
|
||||||
// shaped outputs. All those operation need to implement the inferShape()
|
|
||||||
// method.
|
|
||||||
if (op->getName().getStringRef() != "onnx.Exp" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Atan" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Tan" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Tanh" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Sin" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Sinh" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Cosh" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Cos" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Log" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Sigmoid" &&
|
|
||||||
op->getName().getStringRef() != "onnx.HardSigmoid" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Elu" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Relu" &&
|
|
||||||
op->getName().getStringRef() != "onnx.LeakyRelu" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Selu" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Reciprocal" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Softplus" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Softsign" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Sign" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Mul" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Add" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Div" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Sub" &&
|
|
||||||
op->getName().getStringRef() != "onnx.And" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Or" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Xor" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Sum" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Max" &&
|
|
||||||
op->getName().getStringRef() != "onnx.AveragePool" &&
|
|
||||||
op->getName().getStringRef() != "onnx.MaxPoolSingleOut" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Min" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Identity" &&
|
|
||||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
|
||||||
op->getName().getStringRef() != "onnx.ReduceMax" &&
|
|
||||||
op->getName().getStringRef() != "onnx.ReduceMin" &&
|
|
||||||
op->getName().getStringRef() != "onnx.ReduceProd" &&
|
|
||||||
op->getName().getStringRef() != "onnx.ReduceSum" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Sqrt" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Conv" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Pad" &&
|
|
||||||
op->getName().getStringRef() != "onnx.PadConstantPad" &&
|
|
||||||
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
|
|
||||||
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Abs" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Constant" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Concat" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Split" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Neg" &&
|
|
||||||
op->getName().getStringRef() != "onnx.RNN" &&
|
|
||||||
op->getName().getStringRef() != "onnx.LSTM" &&
|
|
||||||
op->getName().getStringRef() != "onnx.GRU" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Unsqueeze" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Cast" &&
|
|
||||||
op->getName().getStringRef() != "onnx.ConvTranspose" &&
|
|
||||||
op->getName().getStringRef() != "onnx.Flatten" &&
|
|
||||||
op->getName().getStringRef() != "onnx.DynamicQuantizeLinear" &&
|
|
||||||
op->getName().getStringRef() != "onnx.QuantizeLinear" &&
|
|
||||||
op->getName().getStringRef() != "onnx.DequantizeLinear" &&
|
|
||||||
op->getName().getStringRef() != "onnx.ConvInteger")
|
|
||||||
return false;
|
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
return !result_type.isa<NoneType>() &&
|
return !result_type.isa<NoneType>() &&
|
||||||
!result_type.isa<RankedTensorType>();
|
!result_type.isa<RankedTensorType>();
|
||||||
|
|
Loading…
Reference in New Issue