[MLIR] Add shape inference pass (#361)
* Add shape inference pass. * Format code. * Enable shape inference pass. * Rename new files to use ONNF convention. * Use include path from src folder. * Replace guards with pragma once. * Format variable names. Shuffle headers. * Fix comment. * Fix comments.
This commit is contained in:
parent
a6cca3cbb7
commit
03be41f7df
|
@ -279,12 +279,15 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
|
|||
return module;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname) {
|
||||
void ImportFrontendModelFile(std::string model_fname,
|
||||
mlir::MLIRContext& context, mlir::OwningModuleRef& module) {
|
||||
onnx::ModelProto model;
|
||||
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
||||
|
||||
auto parse_success = model.ParseFromIstream(&input);
|
||||
|
||||
return ImportFrontendModel(model);
|
||||
FrontendGenImpl myONNXGen(context);
|
||||
module = myONNXGen.ImportONNXModel(model);
|
||||
module->dump();
|
||||
}
|
||||
} // namespace onnf
|
||||
|
|
|
@ -40,7 +40,8 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
|
|||
* @param model_fname file name pointing to the onnx model protobuf.
|
||||
* @return MLIR::module generated for the ONNX model.
|
||||
*/
|
||||
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname);
|
||||
void ImportFrontendModelFile(std::string model_fname,
|
||||
mlir::MLIRContext& context, mlir::OwningModuleRef& module);
|
||||
|
||||
/*!
|
||||
* TODO: Import models into other extension dialects that cover the
|
||||
|
|
|
@ -3,7 +3,10 @@ add_library(
|
|||
ir/knl/knl_ops.cpp
|
||||
ir/knl/knl_ops.hpp
|
||||
dialect/onnx/onnx_ops.cpp
|
||||
dialect/onnx/onnx_ops.hpp)
|
||||
dialect/onnx/onnx_ops.hpp
|
||||
pass/shape_inference_pass.cpp
|
||||
pass/shape_inference_interface.hpp
|
||||
pass/passes.hpp)
|
||||
|
||||
# Include root src directory.
|
||||
target_include_directories(compiler PRIVATE ../..)
|
||||
|
@ -37,8 +40,14 @@ onnf_tablegen(knl.cpp.inc -gen-op-defs)
|
|||
add_public_tablegen_target(gen_kir)
|
||||
add_dependencies(compiler gen_kir)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
||||
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
||||
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(gen_shape_inference)
|
||||
add_dependencies(compiler gen_shape_inference)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls)
|
||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs)
|
||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||
add_public_tablegen_target(gen_onnx)
|
||||
add_dependencies(compiler gen_onnx)
|
||||
|
|
|
@ -17,6 +17,11 @@
|
|||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||
#else
|
||||
include "pass/shape_inference_interface.td"
|
||||
#endif // SHAPE_INFERENCE_INTERFACE
|
||||
|
||||
def ONNX_Dialect : Dialect {
|
||||
let name = "onnx";
|
||||
let cppNamespace = "";
|
||||
|
@ -35,7 +40,8 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// We define an ONNX operation for adding two tensors elementwise.
|
||||
def ONNXAddOp: ONNX_Op<"add", [NoSideEffect]> {
|
||||
def ONNXAddOp: ONNX_Op<"add",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX add operation";
|
||||
let description = [{
|
||||
|
||||
|
|
|
@ -46,6 +46,12 @@ static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state,
|
|||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXAddOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -8,14 +8,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_ONNX_ONNXOPS_H
|
||||
#define MLIR_DIALECT_ONNX_ONNXOPS_H
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "src/compiler/pass/shape_inference_interface.hpp"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ONNXOpsDialect : public Dialect {
|
||||
|
@ -35,5 +36,3 @@ class ONNXOpsDialect : public Dialect {
|
|||
} // end namespace mlir
|
||||
|
||||
namespace onnf {}
|
||||
|
||||
#endif // MLIR_DIALECT_ONNX_ONNXOPS_H
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
//===- passes.hpp - ONNF Passes Definition --------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file exposes the entry points to create compiler passes for ONNF.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class Pass;
|
||||
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
|
||||
// TODO: Add pass for lowering to kernel IR.
|
||||
|
||||
// TODO: Add pass for lowering to LLVM IR.
|
||||
|
||||
} // end namespace mlir
|
|
@ -0,0 +1,21 @@
|
|||
//===- shape_inference_interface.hpp - Definition for ShapeInference --------=//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the declarations of the shape inference interfaces defined
|
||||
// in ShapeInferenceInterface.td.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Include the auto-generated declarations.
|
||||
#include "src/compiler/shape_inference.hpp.inc"
|
||||
|
||||
} // end namespace mlir
|
|
@ -0,0 +1,32 @@
|
|||
//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines the operations of the Shape Inference Op Interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||
#else
|
||||
#define SHAPE_INFERENCE_INTERFACE
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
||||
let description = [{
|
||||
Interface to access a registered method to infer the return types for an
|
||||
operation that can be used during type inference.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||
"void", "inferShapes">
|
||||
];
|
||||
}
|
||||
|
||||
#endif // SHAPE_INFERENCE_INTERFACE
|
|
@ -0,0 +1,98 @@
|
|||
//===----- shape_inference_pass.cpp - Shape Inference ---------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a Function level pass performing propagation of array
|
||||
// shapes through function specialization.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
#include "shape_inference_interface.hpp"
|
||||
|
||||
#include "passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Include the auto-generated definitions for the shape inference interfaces.
|
||||
#include "src/compiler/shape_inference.cpp.inc"
|
||||
|
||||
namespace {
|
||||
/*!
|
||||
* FunctionPass that performs shape inference by iterating over a list of
|
||||
* candidate operations and propagating the shape information until the list
|
||||
* of operations is empty [credit MLIR authors].
|
||||
*/
|
||||
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
auto f = getFunction();
|
||||
|
||||
// Populate the worklist with the operations that need shape inference:
|
||||
// these are operations that return a dynamic shape.
|
||||
llvm::SmallPtrSet<mlir::Operation*, 16> op_worklist;
|
||||
f.walk([&](mlir::Operation* op) {
|
||||
if (returnsDynamicShape(op))
|
||||
op_worklist.insert(op);
|
||||
});
|
||||
|
||||
// Iterate on the operations in the worklist until all operations have been
|
||||
// inferred or no change happened (fix point).
|
||||
while (!op_worklist.empty()) {
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(op_worklist, returnsDynamicShape);
|
||||
if (nextop == op_worklist.end())
|
||||
break;
|
||||
|
||||
Operation* op = *nextop;
|
||||
op_worklist.erase(op);
|
||||
|
||||
// Ask the operation to infer its output shapes.
|
||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||
shape_op.inferShapes();
|
||||
} else {
|
||||
op->emitError(
|
||||
"unable to infer shape of operation without shape "
|
||||
"inference interface");
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
// If the operation worklist isn't empty, this indicates a failure.
|
||||
if (!op_worklist.empty()) {
|
||||
f.emitError("Shape inference failed, ")
|
||||
<< op_worklist.size() << " operations couldn't be inferred\n";
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Check if the given operation has a dynamically shaped result.
|
||||
*/
|
||||
static bool returnsDynamicShape(Operation* op) {
|
||||
// TODO: remove this check.
|
||||
// Temporary fix until more ops are supported.
|
||||
// All operations which do not return a ranked tensor type have dynamic
|
||||
// shaped outputs. All those operation need to implement the inferShape()
|
||||
// method.
|
||||
if (op->getName().getStringRef() != "onnx.add")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(),
|
||||
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/*!
|
||||
* Create a Shape Inference pass.
|
||||
*/
|
||||
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
|
||||
return std::make_unique<ShapeInferencePass>();
|
||||
}
|
19
src/main.cpp
19
src/main.cpp
|
@ -22,8 +22,18 @@
|
|||
|
||||
#include "src/builder/frontend_dialect_transformer.hpp"
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
#include "src/compiler/pass/passes.hpp"
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Target/LLVMIR.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace onnf;
|
||||
|
@ -48,8 +58,15 @@ int main(int ac, char* av[]) {
|
|||
|
||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
|
||||
string model_filename = vm["onnx-model"].as<string>();
|
||||
auto module = ImportFrontendModelFile(model_filename);
|
||||
ImportFrontendModelFile(model_filename, context, module);
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.run(*module);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue