Add inference for Identity operation. (#400)
This commit is contained in:
parent
7e3f96e642
commit
0a8af69e94
|
@ -266,7 +266,8 @@ def gen_schema(schema) :
|
||||||
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal']
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
|
'Identity']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,6 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "llvm/ADT/SetVector.h"
|
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
|
@ -17,6 +15,8 @@
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
|
||||||
#include "onnx_ops.hpp"
|
#include "onnx_ops.hpp"
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ using namespace mlir;
|
||||||
|
|
||||||
/// Dialect creation, the instance will be owned by the context. This is the
|
/// Dialect creation, the instance will be owned by the context. This is the
|
||||||
/// point of registration of custom types and operations for the dialect.
|
/// point of registration of custom types and operations for the dialect.
|
||||||
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
|
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
|
||||||
: mlir::Dialect(getDialectNamespace(), ctx) {
|
: mlir::Dialect(getDialectNamespace(), ctx) {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
|
@ -202,6 +202,14 @@ void ONNXMinOp::inferShapes() {
|
||||||
getResult()->setType(getOperand(0)->getType());
|
getResult()->setType(getOperand(0)->getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Identity
|
||||||
|
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
void ONNXIdentityOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand()->getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// MatMul
|
// MatMul
|
||||||
|
|
|
@ -1026,7 +1026,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXIdentityOp:ONNX_Op<"Identity",
|
def ONNXIdentityOp:ONNX_Op<"Identity",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX Identity operation";
|
let summary = "ONNX Identity operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -9,9 +9,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
|
|
||||||
#include "shape_inference_interface.hpp"
|
#include "shape_inference_interface.hpp"
|
||||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
@ -30,14 +30,14 @@ namespace {
|
||||||
* of operations is empty [credit MLIR authors].
|
* of operations is empty [credit MLIR authors].
|
||||||
*/
|
*/
|
||||||
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto f = getFunction();
|
auto f = getFunction();
|
||||||
|
|
||||||
// Populate the worklist with the operations that need shape inference:
|
// Populate the worklist with the operations that need shape inference:
|
||||||
// these are operations that return a dynamic shape.
|
// these are operations that return a dynamic shape.
|
||||||
llvm::SmallPtrSet<mlir::Operation*, 16> op_worklist;
|
llvm::SmallPtrSet<mlir::Operation *, 16> op_worklist;
|
||||||
f.walk([&](mlir::Operation* op) {
|
f.walk([&](mlir::Operation *op) {
|
||||||
if (returnsDynamicShape(op))
|
if (returnsDynamicShape(op))
|
||||||
op_worklist.insert(op);
|
op_worklist.insert(op);
|
||||||
});
|
});
|
||||||
|
@ -51,16 +51,15 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
if (nextop == op_worklist.end())
|
if (nextop == op_worklist.end())
|
||||||
break;
|
break;
|
||||||
|
|
||||||
Operation* op = *nextop;
|
Operation *op = *nextop;
|
||||||
op_worklist.erase(op);
|
op_worklist.erase(op);
|
||||||
|
|
||||||
// Ask the operation to infer its output shapes.
|
// Ask the operation to infer its output shapes.
|
||||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||||
shape_op.inferShapes();
|
shape_op.inferShapes();
|
||||||
} else {
|
} else {
|
||||||
op->emitError(
|
op->emitError("unable to infer shape of operation without shape "
|
||||||
"unable to infer shape of operation without shape "
|
"inference interface");
|
||||||
"inference interface");
|
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -74,15 +73,16 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
|
|
||||||
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
||||||
auto results = terminator_op->getOperandTypes();
|
auto results = terminator_op->getOperandTypes();
|
||||||
f.setType(FunctionType::get(f.getType().getInputs(),
|
f.setType(FunctionType::get(
|
||||||
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
f.getType().getInputs(),
|
||||||
|
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Check if the given operation has a dynamically shaped result.
|
* Check if the given operation has a dynamically shaped result.
|
||||||
*/
|
*/
|
||||||
static bool returnsDynamicShape(Operation* op) {
|
static bool returnsDynamicShape(Operation *op) {
|
||||||
// TODO: remove this check.
|
// TODO: remove this check.
|
||||||
// Temporary fix until more ops are supported.
|
// Temporary fix until more ops are supported.
|
||||||
// All operations which do not return a ranked tensor type have dynamic
|
// All operations which do not return a ranked tensor type have dynamic
|
||||||
|
@ -109,16 +109,18 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
op->getName().getStringRef() != "onnx.Sum" &&
|
op->getName().getStringRef() != "onnx.Sum" &&
|
||||||
op->getName().getStringRef() != "onnx.Max" &&
|
op->getName().getStringRef() != "onnx.Max" &&
|
||||||
op->getName().getStringRef() != "onnx.Min" &&
|
op->getName().getStringRef() != "onnx.Min" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Identity" &&
|
||||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||||
op->getName().getStringRef() != "onnx.FullGemm" &&
|
op->getName().getStringRef() != "onnx.FullGemm" &&
|
||||||
op->getName().getStringRef() != "onnx.Reshape")
|
op->getName().getStringRef() != "onnx.Reshape")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(),
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
return !result_type.isa<RankedTensorType>();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Create a Shape Inference pass.
|
* Create a Shape Inference pass.
|
||||||
|
@ -127,5 +129,5 @@ std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
|
||||||
return std::make_unique<ShapeInferencePass>();
|
return std::make_unique<ShapeInferencePass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<ShapeInferencePass> pass(
|
static PassRegistration<ShapeInferencePass>
|
||||||
"shape-inference", "Shape inference for frontend dialects.");
|
pass("shape-inference", "Shape inference for frontend dialects.");
|
||||||
|
|
Loading…
Reference in New Issue