Avoid hardcoding ops with shape inference in ShapeInferencePass (#165)
* Avoid hardcoding ops with shape inference in ShapeInferencePass * Minimize the changes * Clang-format Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
e0ae583da0
commit
60c648ae39
|
@ -59,7 +59,7 @@ public:
|
|||
if (dynamicOperations != 0) {
|
||||
f.emitError("Shape inference failed, ")
|
||||
<< dynamicOperations << " operations couldn't be inferred\n";
|
||||
signalPassFailure();
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
||||
|
@ -73,75 +73,6 @@ public:
|
|||
* Check if the given operation has a dynamically shaped result.
|
||||
*/
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
// TODO: remove this check.
|
||||
// Temporary fix until more ops are supported.
|
||||
// All operations which do not return a ranked tensor type have dynamic
|
||||
// shaped outputs. All those operation need to implement the inferShape()
|
||||
// method.
|
||||
if (op->getName().getStringRef() != "onnx.Exp" &&
|
||||
op->getName().getStringRef() != "onnx.Atan" &&
|
||||
op->getName().getStringRef() != "onnx.Tan" &&
|
||||
op->getName().getStringRef() != "onnx.Tanh" &&
|
||||
op->getName().getStringRef() != "onnx.Sin" &&
|
||||
op->getName().getStringRef() != "onnx.Sinh" &&
|
||||
op->getName().getStringRef() != "onnx.Cosh" &&
|
||||
op->getName().getStringRef() != "onnx.Cos" &&
|
||||
op->getName().getStringRef() != "onnx.Log" &&
|
||||
op->getName().getStringRef() != "onnx.Sigmoid" &&
|
||||
op->getName().getStringRef() != "onnx.HardSigmoid" &&
|
||||
op->getName().getStringRef() != "onnx.Elu" &&
|
||||
op->getName().getStringRef() != "onnx.Relu" &&
|
||||
op->getName().getStringRef() != "onnx.LeakyRelu" &&
|
||||
op->getName().getStringRef() != "onnx.Selu" &&
|
||||
op->getName().getStringRef() != "onnx.Reciprocal" &&
|
||||
op->getName().getStringRef() != "onnx.Softplus" &&
|
||||
op->getName().getStringRef() != "onnx.Softsign" &&
|
||||
op->getName().getStringRef() != "onnx.Sign" &&
|
||||
op->getName().getStringRef() != "onnx.Mul" &&
|
||||
op->getName().getStringRef() != "onnx.Add" &&
|
||||
op->getName().getStringRef() != "onnx.Div" &&
|
||||
op->getName().getStringRef() != "onnx.Sub" &&
|
||||
op->getName().getStringRef() != "onnx.And" &&
|
||||
op->getName().getStringRef() != "onnx.Or" &&
|
||||
op->getName().getStringRef() != "onnx.Xor" &&
|
||||
op->getName().getStringRef() != "onnx.Sum" &&
|
||||
op->getName().getStringRef() != "onnx.Max" &&
|
||||
op->getName().getStringRef() != "onnx.AveragePool" &&
|
||||
op->getName().getStringRef() != "onnx.MaxPoolSingleOut" &&
|
||||
op->getName().getStringRef() != "onnx.Min" &&
|
||||
op->getName().getStringRef() != "onnx.Identity" &&
|
||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||
op->getName().getStringRef() != "onnx.ReduceMax" &&
|
||||
op->getName().getStringRef() != "onnx.ReduceMin" &&
|
||||
op->getName().getStringRef() != "onnx.ReduceProd" &&
|
||||
op->getName().getStringRef() != "onnx.ReduceSum" &&
|
||||
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||
op->getName().getStringRef() != "onnx.Sqrt" &&
|
||||
op->getName().getStringRef() != "onnx.Conv" &&
|
||||
op->getName().getStringRef() != "onnx.Pad" &&
|
||||
op->getName().getStringRef() != "onnx.PadConstantPad" &&
|
||||
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
|
||||
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
||||
op->getName().getStringRef() != "onnx.Abs" &&
|
||||
op->getName().getStringRef() != "onnx.Constant" &&
|
||||
op->getName().getStringRef() != "onnx.Concat" &&
|
||||
op->getName().getStringRef() != "onnx.Split" &&
|
||||
op->getName().getStringRef() != "onnx.Neg" &&
|
||||
op->getName().getStringRef() != "onnx.RNN" &&
|
||||
op->getName().getStringRef() != "onnx.LSTM" &&
|
||||
op->getName().getStringRef() != "onnx.GRU" &&
|
||||
op->getName().getStringRef() != "onnx.Unsqueeze" &&
|
||||
op->getName().getStringRef() != "onnx.Cast" &&
|
||||
op->getName().getStringRef() != "onnx.ConvTranspose" &&
|
||||
op->getName().getStringRef() != "onnx.Flatten" &&
|
||||
op->getName().getStringRef() != "onnx.DynamicQuantizeLinear" &&
|
||||
op->getName().getStringRef() != "onnx.QuantizeLinear" &&
|
||||
op->getName().getStringRef() != "onnx.DequantizeLinear" &&
|
||||
op->getName().getStringRef() != "onnx.ConvInteger")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
return !result_type.isa<NoneType>() &&
|
||||
!result_type.isa<RankedTensorType>();
|
||||
|
|
Loading…
Reference in New Issue