From 0a8af69e944274c8c0b413fd5eed6def13b206bb Mon Sep 17 00:00:00 2001 From: GHEORGHE-TEOD BERCEA Date: Mon, 16 Dec 2019 18:45:39 -0500 Subject: [PATCH] Add inference for Identity operation. (#400) --- src/compiler/dialect/onnx/gen_doc.py | 3 +- src/compiler/dialect/onnx/onnx_ops.cpp | 14 +++++++-- src/compiler/dialect/onnx/onnxop.inc | 2 +- src/compiler/pass/shape_inference_pass.cpp | 34 ++++++++++++---------- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/src/compiler/dialect/onnx/gen_doc.py b/src/compiler/dialect/onnx/gen_doc.py index 8d3e728..709055e 100644 --- a/src/compiler/dialect/onnx/gen_doc.py +++ b/src/compiler/dialect/onnx/gen_doc.py @@ -266,7 +266,8 @@ def gen_schema(schema) : ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', - 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal'] + 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', + 'Identity'] CanonicalList=['Add', 'Identity'] line_indent = ' ' diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index 1fb5fea..eb7c60e 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -8,8 +8,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallBitVector.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" @@ -17,6 +15,8 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" #include "onnx_ops.hpp" @@ -28,7 +28,7 @@ using namespace mlir; /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. -ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx) +ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx) : mlir::Dialect(getDialectNamespace(), ctx) { addOperations< #define GET_OP_LIST @@ -202,6 +202,14 @@ void ONNXMinOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } +//===----------------------------------------------------------------------===// +// Identity +/// Infer the output shape of the ONNXIdentityOp. This method is required by the +/// shape inference interface. +void ONNXIdentityOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + //===----------------------------------------------------------------------===// // MatMul diff --git a/src/compiler/dialect/onnx/onnxop.inc b/src/compiler/dialect/onnx/onnxop.inc index 1ac969e..b0434ba 100644 --- a/src/compiler/dialect/onnx/onnxop.inc +++ b/src/compiler/dialect/onnx/onnxop.inc @@ -1026,7 +1026,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax", } def ONNXIdentityOp:ONNX_Op<"Identity", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Identity operation"; let description = [{ diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index acddc98..0182d92 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -9,9 +9,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Pass/Pass.h" #include "shape_inference_interface.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp" @@ -30,14 +30,14 @@ namespace { * of operations is empty [credit MLIR authors]. */ class ShapeInferencePass : public mlir::FunctionPass { - public: +public: void runOnFunction() override { auto f = getFunction(); // Populate the worklist with the operations that need shape inference: // these are operations that return a dynamic shape. - llvm::SmallPtrSet op_worklist; - f.walk([&](mlir::Operation* op) { + llvm::SmallPtrSet op_worklist; + f.walk([&](mlir::Operation *op) { if (returnsDynamicShape(op)) op_worklist.insert(op); }); @@ -51,16 +51,15 @@ class ShapeInferencePass : public mlir::FunctionPass { if (nextop == op_worklist.end()) break; - Operation* op = *nextop; + Operation *op = *nextop; op_worklist.erase(op); // Ask the operation to infer its output shapes. if (auto shape_op = dyn_cast(op)) { shape_op.inferShapes(); } else { - op->emitError( - "unable to infer shape of operation without shape " - "inference interface"); + op->emitError("unable to infer shape of operation without shape " + "inference interface"); return signalPassFailure(); } } @@ -74,15 +73,16 @@ class ShapeInferencePass : public mlir::FunctionPass { if (auto terminator_op = f.getBody().back().getTerminator()) { auto results = terminator_op->getOperandTypes(); - f.setType(FunctionType::get(f.getType().getInputs(), - std::vector(results.begin(), results.end()), f.getContext())); + f.setType(FunctionType::get( + f.getType().getInputs(), + std::vector(results.begin(), results.end()), f.getContext())); } } /*! * Check if the given operation has a dynamically shaped result. */ - static bool returnsDynamicShape(Operation* op) { + static bool returnsDynamicShape(Operation *op) { // TODO: remove this check. // Temporary fix until more ops are supported. // All operations which do not return a ranked tensor type have dynamic @@ -109,16 +109,18 @@ class ShapeInferencePass : public mlir::FunctionPass { op->getName().getStringRef() != "onnx.Sum" && op->getName().getStringRef() != "onnx.Max" && op->getName().getStringRef() != "onnx.Min" && + op->getName().getStringRef() != "onnx.Identity" && op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.FullGemm" && op->getName().getStringRef() != "onnx.Reshape") return false; - return llvm::any_of(op->getResultTypes(), - [](Type result_type) { return !result_type.isa(); }); + return llvm::any_of(op->getResultTypes(), [](Type result_type) { + return !result_type.isa(); + }); } }; -} // end anonymous namespace +} // end anonymous namespace /*! * Create a Shape Inference pass. @@ -127,5 +129,5 @@ std::unique_ptr mlir::createShapeInferencePass() { return std::make_unique(); } -static PassRegistration pass( - "shape-inference", "Shape inference for frontend dialects."); +static PassRegistration + pass("shape-inference", "Shape inference for frontend dialects.");