//===----- shape_inference_pass.cpp - Shape Inference ---------------------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file implements a Function level pass performing propagation of array // shapes through function specialization. // //===----------------------------------------------------------------------===// #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_ostream.h" #include "shape_inference_interface.hpp" #include "src/dialect/onnx/onnx_ops.hpp" #include "passes.hpp" using namespace mlir; // Include the auto-generated definitions for the shape inference interfaces. #include "src/shape_inference.cpp.inc" namespace { /*! * FunctionPass that performs shape inference by iterating over a list of * candidate operations and propagating the shape information until the list * of operations is empty [credit MLIR authors]. */ class ShapeInferencePass : public mlir::FunctionPass { public: void runOnFunction() override { auto f = getFunction(); // Populate the worklist with the operations that need shape inference: // these are operations that return a dynamic shape. llvm::SmallPtrSet op_worklist; f.walk([&](mlir::Operation *op) { if (returnsDynamicShape(op)) op_worklist.insert(op); }); // Iterate on the operations in the worklist until all operations have been // inferred or no change happened (fix point). while (!op_worklist.empty()) { // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). auto nextop = llvm::find_if(op_worklist, returnsDynamicShape); if (nextop == op_worklist.end()) break; Operation *op = *nextop; op_worklist.erase(op); // Ask the operation to infer its output shapes. if (auto shape_op = dyn_cast(op)) { shape_op.inferShapes(); } else { op->emitError("unable to infer shape of operation without shape " "inference interface"); return signalPassFailure(); } } // If the operation worklist isn't empty, this indicates a failure. if (!op_worklist.empty()) { f.emitError("Shape inference failed, ") << op_worklist.size() << " operations couldn't be inferred\n"; signalPassFailure(); } if (auto terminator_op = f.getBody().back().getTerminator()) { auto results = terminator_op->getOperandTypes(); f.setType(FunctionType::get( f.getType().getInputs(), std::vector(results.begin(), results.end()), f.getContext())); } } /*! * Check if the given operation has a dynamically shaped result. */ static bool returnsDynamicShape(Operation *op) { // TODO: remove this check. // Temporary fix until more ops are supported. // All operations which do not return a ranked tensor type have dynamic // shaped outputs. All those operation need to implement the inferShape() // method. if (op->getName().getStringRef() != "onnx.Exp" && op->getName().getStringRef() != "onnx.Tanh" && op->getName().getStringRef() != "onnx.Sinh" && op->getName().getStringRef() != "onnx.Cosh" && op->getName().getStringRef() != "onnx.Cos" && op->getName().getStringRef() != "onnx.Log" && op->getName().getStringRef() != "onnx.Sigmoid" && op->getName().getStringRef() != "onnx.HardSigmoid" && op->getName().getStringRef() != "onnx.Elu" && op->getName().getStringRef() != "onnx.Relu" && op->getName().getStringRef() != "onnx.LeakyRelu" && op->getName().getStringRef() != "onnx.Selu" && op->getName().getStringRef() != "onnx.Reciprocal" && op->getName().getStringRef() != "onnx.Softplus" && op->getName().getStringRef() != "onnx.Softsign" && op->getName().getStringRef() != "onnx.Mul" && op->getName().getStringRef() != "onnx.Add" && op->getName().getStringRef() != "onnx.Div" && op->getName().getStringRef() != "onnx.Sub" && op->getName().getStringRef() != "onnx.And" && op->getName().getStringRef() != "onnx.Or" && op->getName().getStringRef() != "onnx.Xor" && op->getName().getStringRef() != "onnx.Sum" && op->getName().getStringRef() != "onnx.Max" && op->getName().getStringRef() != "onnx.Min" && op->getName().getStringRef() != "onnx.Identity" && op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Transpose" && op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.ConvNoBias") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); }); } }; } // end anonymous namespace /*! * Create a Shape Inference pass. */ std::unique_ptr mlir::createShapeInferencePass() { return std::make_unique(); } static PassRegistration pass("shape-inference", "Shape inference for frontend dialects.");