Add inference for Identity operation. (#400)

This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-12-16 18:45:39 -05:00 committed by Tian Jin
parent 7e3f96e642
commit 0a8af69e94
4 changed files with 32 additions and 21 deletions

View File

@ -266,7 +266,8 @@ def gen_schema(schema) :
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal'] 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']
line_indent = ' ' line_indent = ' '

View File

@ -8,8 +8,6 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
@ -17,6 +15,8 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "onnx_ops.hpp" #include "onnx_ops.hpp"
@ -28,7 +28,7 @@ using namespace mlir;
/// Dialect creation, the instance will be owned by the context. This is the /// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect. /// point of registration of custom types and operations for the dialect.
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx) ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx) { : mlir::Dialect(getDialectNamespace(), ctx) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
@ -202,6 +202,14 @@ void ONNXMinOp::inferShapes() {
getResult()->setType(getOperand(0)->getType()); getResult()->setType(getOperand(0)->getType());
} }
//===----------------------------------------------------------------------===//
// Identity
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
/// shape inference interface.
void ONNXIdentityOp::inferShapes() {
getResult()->setType(getOperand()->getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MatMul // MatMul

View File

@ -1026,7 +1026,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax",
} }
def ONNXIdentityOp:ONNX_Op<"Identity", def ONNXIdentityOp:ONNX_Op<"Identity",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let summary = "ONNX Identity operation"; let summary = "ONNX Identity operation";
let description = [{ let description = [{

View File

@ -9,9 +9,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/Pass/Pass.h"
#include "shape_inference_interface.hpp" #include "shape_inference_interface.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp"
@ -30,14 +30,14 @@ namespace {
* of operations is empty [credit MLIR authors]. * of operations is empty [credit MLIR authors].
*/ */
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
public: public:
void runOnFunction() override { void runOnFunction() override {
auto f = getFunction(); auto f = getFunction();
// Populate the worklist with the operations that need shape inference: // Populate the worklist with the operations that need shape inference:
// these are operations that return a dynamic shape. // these are operations that return a dynamic shape.
llvm::SmallPtrSet<mlir::Operation*, 16> op_worklist; llvm::SmallPtrSet<mlir::Operation *, 16> op_worklist;
f.walk([&](mlir::Operation* op) { f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op)) if (returnsDynamicShape(op))
op_worklist.insert(op); op_worklist.insert(op);
}); });
@ -51,16 +51,15 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
if (nextop == op_worklist.end()) if (nextop == op_worklist.end())
break; break;
Operation* op = *nextop; Operation *op = *nextop;
op_worklist.erase(op); op_worklist.erase(op);
// Ask the operation to infer its output shapes. // Ask the operation to infer its output shapes.
if (auto shape_op = dyn_cast<ShapeInference>(op)) { if (auto shape_op = dyn_cast<ShapeInference>(op)) {
shape_op.inferShapes(); shape_op.inferShapes();
} else { } else {
op->emitError( op->emitError("unable to infer shape of operation without shape "
"unable to infer shape of operation without shape " "inference interface");
"inference interface");
return signalPassFailure(); return signalPassFailure();
} }
} }
@ -74,15 +73,16 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
if (auto terminator_op = f.getBody().back().getTerminator()) { if (auto terminator_op = f.getBody().back().getTerminator()) {
auto results = terminator_op->getOperandTypes(); auto results = terminator_op->getOperandTypes();
f.setType(FunctionType::get(f.getType().getInputs(), f.setType(FunctionType::get(
std::vector<Type>(results.begin(), results.end()), f.getContext())); f.getType().getInputs(),
std::vector<Type>(results.begin(), results.end()), f.getContext()));
} }
} }
/*! /*!
* Check if the given operation has a dynamically shaped result. * Check if the given operation has a dynamically shaped result.
*/ */
static bool returnsDynamicShape(Operation* op) { static bool returnsDynamicShape(Operation *op) {
// TODO: remove this check. // TODO: remove this check.
// Temporary fix until more ops are supported. // Temporary fix until more ops are supported.
// All operations which do not return a ranked tensor type have dynamic // All operations which do not return a ranked tensor type have dynamic
@ -109,16 +109,18 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
op->getName().getStringRef() != "onnx.Sum" && op->getName().getStringRef() != "onnx.Sum" &&
op->getName().getStringRef() != "onnx.Max" && op->getName().getStringRef() != "onnx.Max" &&
op->getName().getStringRef() != "onnx.Min" && op->getName().getStringRef() != "onnx.Min" &&
op->getName().getStringRef() != "onnx.Identity" &&
op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm" && op->getName().getStringRef() != "onnx.FullGemm" &&
op->getName().getStringRef() != "onnx.Reshape") op->getName().getStringRef() != "onnx.Reshape")
return false; return false;
return llvm::any_of(op->getResultTypes(), return llvm::any_of(op->getResultTypes(), [](Type result_type) {
[](Type result_type) { return !result_type.isa<RankedTensorType>(); }); return !result_type.isa<RankedTensorType>();
});
} }
}; };
} // end anonymous namespace } // end anonymous namespace
/*! /*!
* Create a Shape Inference pass. * Create a Shape Inference pass.
@ -127,5 +129,5 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>(); return std::make_unique<ShapeInferencePass>();
} }
static PassRegistration<ShapeInferencePass> pass( static PassRegistration<ShapeInferencePass>
"shape-inference", "Shape inference for frontend dialects."); pass("shape-inference", "Shape inference for frontend dialects.");