Avoid hardcoding ops with shape inference in ShapeInferencePass (#165)

* Avoid hardcoding ops with shape inference in ShapeInferencePass

* Minimize the changes

* Clang-format

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-06-12 16:42:05 +09:00 committed by GitHub
parent e0ae583da0
commit 60c648ae39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 70 deletions

View File

@ -59,7 +59,7 @@ public:
if (dynamicOperations != 0) { if (dynamicOperations != 0) {
f.emitError("Shape inference failed, ") f.emitError("Shape inference failed, ")
<< dynamicOperations << " operations couldn't be inferred\n"; << dynamicOperations << " operations couldn't be inferred\n";
signalPassFailure(); return signalPassFailure();
} }
if (auto terminator_op = f.getBody().back().getTerminator()) { if (auto terminator_op = f.getBody().back().getTerminator()) {
@ -73,75 +73,6 @@ public:
* Check if the given operation has a dynamically shaped result. * Check if the given operation has a dynamically shaped result.
*/ */
static bool returnsDynamicShape(Operation *op) { static bool returnsDynamicShape(Operation *op) {
// TODO: remove this check.
// Temporary fix until more ops are supported.
// All operations which do not return a ranked tensor type have dynamic
// shaped outputs. All those operation need to implement the inferShape()
// method.
if (op->getName().getStringRef() != "onnx.Exp" &&
op->getName().getStringRef() != "onnx.Atan" &&
op->getName().getStringRef() != "onnx.Tan" &&
op->getName().getStringRef() != "onnx.Tanh" &&
op->getName().getStringRef() != "onnx.Sin" &&
op->getName().getStringRef() != "onnx.Sinh" &&
op->getName().getStringRef() != "onnx.Cosh" &&
op->getName().getStringRef() != "onnx.Cos" &&
op->getName().getStringRef() != "onnx.Log" &&
op->getName().getStringRef() != "onnx.Sigmoid" &&
op->getName().getStringRef() != "onnx.HardSigmoid" &&
op->getName().getStringRef() != "onnx.Elu" &&
op->getName().getStringRef() != "onnx.Relu" &&
op->getName().getStringRef() != "onnx.LeakyRelu" &&
op->getName().getStringRef() != "onnx.Selu" &&
op->getName().getStringRef() != "onnx.Reciprocal" &&
op->getName().getStringRef() != "onnx.Softplus" &&
op->getName().getStringRef() != "onnx.Softsign" &&
op->getName().getStringRef() != "onnx.Sign" &&
op->getName().getStringRef() != "onnx.Mul" &&
op->getName().getStringRef() != "onnx.Add" &&
op->getName().getStringRef() != "onnx.Div" &&
op->getName().getStringRef() != "onnx.Sub" &&
op->getName().getStringRef() != "onnx.And" &&
op->getName().getStringRef() != "onnx.Or" &&
op->getName().getStringRef() != "onnx.Xor" &&
op->getName().getStringRef() != "onnx.Sum" &&
op->getName().getStringRef() != "onnx.Max" &&
op->getName().getStringRef() != "onnx.AveragePool" &&
op->getName().getStringRef() != "onnx.MaxPoolSingleOut" &&
op->getName().getStringRef() != "onnx.Min" &&
op->getName().getStringRef() != "onnx.Identity" &&
op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.ReduceMax" &&
op->getName().getStringRef() != "onnx.ReduceMin" &&
op->getName().getStringRef() != "onnx.ReduceProd" &&
op->getName().getStringRef() != "onnx.ReduceSum" &&
op->getName().getStringRef() != "onnx.Softmax" &&
op->getName().getStringRef() != "onnx.Sqrt" &&
op->getName().getStringRef() != "onnx.Conv" &&
op->getName().getStringRef() != "onnx.Pad" &&
op->getName().getStringRef() != "onnx.PadConstantPad" &&
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
op->getName().getStringRef() != "onnx.Abs" &&
op->getName().getStringRef() != "onnx.Constant" &&
op->getName().getStringRef() != "onnx.Concat" &&
op->getName().getStringRef() != "onnx.Split" &&
op->getName().getStringRef() != "onnx.Neg" &&
op->getName().getStringRef() != "onnx.RNN" &&
op->getName().getStringRef() != "onnx.LSTM" &&
op->getName().getStringRef() != "onnx.GRU" &&
op->getName().getStringRef() != "onnx.Unsqueeze" &&
op->getName().getStringRef() != "onnx.Cast" &&
op->getName().getStringRef() != "onnx.ConvTranspose" &&
op->getName().getStringRef() != "onnx.Flatten" &&
op->getName().getStringRef() != "onnx.DynamicQuantizeLinear" &&
op->getName().getStringRef() != "onnx.QuantizeLinear" &&
op->getName().getStringRef() != "onnx.DequantizeLinear" &&
op->getName().getStringRef() != "onnx.ConvInteger")
return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) { return llvm::any_of(op->getResultTypes(), [](Type result_type) {
return !result_type.isa<NoneType>() && return !result_type.isa<NoneType>() &&
!result_type.isa<RankedTensorType>(); !result_type.isa<RankedTensorType>();