diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 1ea68bb..b0322ff 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -59,7 +59,7 @@ public: if (dynamicOperations != 0) { f.emitError("Shape inference failed, ") << dynamicOperations << " operations couldn't be inferred\n"; - signalPassFailure(); + return signalPassFailure(); } if (auto terminator_op = f.getBody().back().getTerminator()) { @@ -73,75 +73,6 @@ public: * Check if the given operation has a dynamically shaped result. */ static bool returnsDynamicShape(Operation *op) { - // TODO: remove this check. - // Temporary fix until more ops are supported. - // All operations which do not return a ranked tensor type have dynamic - // shaped outputs. All those operation need to implement the inferShape() - // method. - if (op->getName().getStringRef() != "onnx.Exp" && - op->getName().getStringRef() != "onnx.Atan" && - op->getName().getStringRef() != "onnx.Tan" && - op->getName().getStringRef() != "onnx.Tanh" && - op->getName().getStringRef() != "onnx.Sin" && - op->getName().getStringRef() != "onnx.Sinh" && - op->getName().getStringRef() != "onnx.Cosh" && - op->getName().getStringRef() != "onnx.Cos" && - op->getName().getStringRef() != "onnx.Log" && - op->getName().getStringRef() != "onnx.Sigmoid" && - op->getName().getStringRef() != "onnx.HardSigmoid" && - op->getName().getStringRef() != "onnx.Elu" && - op->getName().getStringRef() != "onnx.Relu" && - op->getName().getStringRef() != "onnx.LeakyRelu" && - op->getName().getStringRef() != "onnx.Selu" && - op->getName().getStringRef() != "onnx.Reciprocal" && - op->getName().getStringRef() != "onnx.Softplus" && - op->getName().getStringRef() != "onnx.Softsign" && - op->getName().getStringRef() != "onnx.Sign" && - op->getName().getStringRef() != "onnx.Mul" && - op->getName().getStringRef() != "onnx.Add" && - op->getName().getStringRef() != "onnx.Div" && - op->getName().getStringRef() != "onnx.Sub" && - op->getName().getStringRef() != "onnx.And" && - op->getName().getStringRef() != "onnx.Or" && - op->getName().getStringRef() != "onnx.Xor" && - op->getName().getStringRef() != "onnx.Sum" && - op->getName().getStringRef() != "onnx.Max" && - op->getName().getStringRef() != "onnx.AveragePool" && - op->getName().getStringRef() != "onnx.MaxPoolSingleOut" && - op->getName().getStringRef() != "onnx.Min" && - op->getName().getStringRef() != "onnx.Identity" && - op->getName().getStringRef() != "onnx.MatMul" && - op->getName().getStringRef() != "onnx.Gemm" && - op->getName().getStringRef() != "onnx.Reshape" && - op->getName().getStringRef() != "onnx.Transpose" && - op->getName().getStringRef() != "onnx.ReduceMax" && - op->getName().getStringRef() != "onnx.ReduceMin" && - op->getName().getStringRef() != "onnx.ReduceProd" && - op->getName().getStringRef() != "onnx.ReduceSum" && - op->getName().getStringRef() != "onnx.Softmax" && - op->getName().getStringRef() != "onnx.Sqrt" && - op->getName().getStringRef() != "onnx.Conv" && - op->getName().getStringRef() != "onnx.Pad" && - op->getName().getStringRef() != "onnx.PadConstantPad" && - op->getName().getStringRef() != "onnx.PadConstantValuePad" && - op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && - op->getName().getStringRef() != "onnx.Abs" && - op->getName().getStringRef() != "onnx.Constant" && - op->getName().getStringRef() != "onnx.Concat" && - op->getName().getStringRef() != "onnx.Split" && - op->getName().getStringRef() != "onnx.Neg" && - op->getName().getStringRef() != "onnx.RNN" && - op->getName().getStringRef() != "onnx.LSTM" && - op->getName().getStringRef() != "onnx.GRU" && - op->getName().getStringRef() != "onnx.Unsqueeze" && - op->getName().getStringRef() != "onnx.Cast" && - op->getName().getStringRef() != "onnx.ConvTranspose" && - op->getName().getStringRef() != "onnx.Flatten" && - op->getName().getStringRef() != "onnx.DynamicQuantizeLinear" && - op->getName().getStringRef() != "onnx.QuantizeLinear" && - op->getName().getStringRef() != "onnx.DequantizeLinear" && - op->getName().getStringRef() != "onnx.ConvInteger") - return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa() && !result_type.isa();