[MLIR] Add shape inference pass (#361)
* Add shape inference pass. * Format code. * Enable shape inference pass. * Rename new files to use ONNF convention. * Use include path from src folder. * Replace guards with pragma once. * Format variable names. Shuffle headers. * Fix comment. * Fix comments.
This commit is contained in:
parent
a6cca3cbb7
commit
03be41f7df
|
@ -279,12 +279,15 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname) {
|
void ImportFrontendModelFile(std::string model_fname,
|
||||||
|
mlir::MLIRContext& context, mlir::OwningModuleRef& module) {
|
||||||
onnx::ModelProto model;
|
onnx::ModelProto model;
|
||||||
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
||||||
|
|
||||||
auto parse_success = model.ParseFromIstream(&input);
|
auto parse_success = model.ParseFromIstream(&input);
|
||||||
|
|
||||||
return ImportFrontendModel(model);
|
FrontendGenImpl myONNXGen(context);
|
||||||
|
module = myONNXGen.ImportONNXModel(model);
|
||||||
|
module->dump();
|
||||||
}
|
}
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
|
|
@ -40,7 +40,8 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
|
||||||
* @param model_fname file name pointing to the onnx model protobuf.
|
* @param model_fname file name pointing to the onnx model protobuf.
|
||||||
* @return MLIR::module generated for the ONNX model.
|
* @return MLIR::module generated for the ONNX model.
|
||||||
*/
|
*/
|
||||||
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname);
|
void ImportFrontendModelFile(std::string model_fname,
|
||||||
|
mlir::MLIRContext& context, mlir::OwningModuleRef& module);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* TODO: Import models into other extension dialects that cover the
|
* TODO: Import models into other extension dialects that cover the
|
||||||
|
|
|
@ -3,7 +3,10 @@ add_library(
|
||||||
ir/knl/knl_ops.cpp
|
ir/knl/knl_ops.cpp
|
||||||
ir/knl/knl_ops.hpp
|
ir/knl/knl_ops.hpp
|
||||||
dialect/onnx/onnx_ops.cpp
|
dialect/onnx/onnx_ops.cpp
|
||||||
dialect/onnx/onnx_ops.hpp)
|
dialect/onnx/onnx_ops.hpp
|
||||||
|
pass/shape_inference_pass.cpp
|
||||||
|
pass/shape_inference_interface.hpp
|
||||||
|
pass/passes.hpp)
|
||||||
|
|
||||||
# Include root src directory.
|
# Include root src directory.
|
||||||
target_include_directories(compiler PRIVATE ../..)
|
target_include_directories(compiler PRIVATE ../..)
|
||||||
|
@ -37,8 +40,14 @@ onnf_tablegen(knl.cpp.inc -gen-op-defs)
|
||||||
add_public_tablegen_target(gen_kir)
|
add_public_tablegen_target(gen_kir)
|
||||||
add_dependencies(compiler gen_kir)
|
add_dependencies(compiler gen_kir)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
|
||||||
|
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
|
||||||
|
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
|
||||||
|
add_public_tablegen_target(gen_shape_inference)
|
||||||
|
add_dependencies(compiler gen_shape_inference)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls)
|
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs)
|
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
add_public_tablegen_target(gen_onnx)
|
add_public_tablegen_target(gen_onnx)
|
||||||
add_dependencies(compiler gen_onnx)
|
add_dependencies(compiler gen_onnx)
|
||||||
|
|
|
@ -17,6 +17,11 @@
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||||
|
#else
|
||||||
|
include "pass/shape_inference_interface.td"
|
||||||
|
#endif // SHAPE_INFERENCE_INTERFACE
|
||||||
|
|
||||||
def ONNX_Dialect : Dialect {
|
def ONNX_Dialect : Dialect {
|
||||||
let name = "onnx";
|
let name = "onnx";
|
||||||
let cppNamespace = "";
|
let cppNamespace = "";
|
||||||
|
@ -35,7 +40,8 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// We define an ONNX operation for adding two tensors elementwise.
|
// We define an ONNX operation for adding two tensors elementwise.
|
||||||
def ONNXAddOp: ONNX_Op<"add", [NoSideEffect]> {
|
def ONNXAddOp: ONNX_Op<"add",
|
||||||
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX add operation";
|
let summary = "ONNX add operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,12 @@ static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state,
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
void ONNXAddOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand(0)->getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -8,14 +8,15 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef MLIR_DIALECT_ONNX_ONNXOPS_H
|
#pragma once
|
||||||
#define MLIR_DIALECT_ONNX_ONNXOPS_H
|
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
#include "src/compiler/pass/shape_inference_interface.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
class ONNXOpsDialect : public Dialect {
|
class ONNXOpsDialect : public Dialect {
|
||||||
|
@ -35,5 +36,3 @@ class ONNXOpsDialect : public Dialect {
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
namespace onnf {}
|
namespace onnf {}
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_ONNX_ONNXOPS_H
|
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
//===- passes.hpp - ONNF Passes Definition --------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file exposes the entry points to create compiler passes for ONNF.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class Pass;
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
|
||||||
|
// TODO: Add pass for lowering to kernel IR.
|
||||||
|
|
||||||
|
// TODO: Add pass for lowering to LLVM IR.
|
||||||
|
|
||||||
|
} // end namespace mlir
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===- shape_inference_interface.hpp - Definition for ShapeInference --------=//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file contains the declarations of the shape inference interfaces defined
|
||||||
|
// in ShapeInferenceInterface.td.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
/// Include the auto-generated declarations.
|
||||||
|
#include "src/compiler/shape_inference.hpp.inc"
|
||||||
|
|
||||||
|
} // end namespace mlir
|
|
@ -0,0 +1,32 @@
|
||||||
|
//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Defines the operations of the Shape Inference Op Interface.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||||
|
#else
|
||||||
|
#define SHAPE_INFERENCE_INTERFACE
|
||||||
|
|
||||||
|
#ifdef OP_BASE
|
||||||
|
#else
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
||||||
|
let description = [{
|
||||||
|
Interface to access a registered method to infer the return types for an
|
||||||
|
operation that can be used during type inference.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||||
|
"void", "inferShapes">
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHAPE_INFERENCE_INTERFACE
|
|
@ -0,0 +1,98 @@
|
||||||
|
//===----- shape_inference_pass.cpp - Shape Inference ---------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements a Function level pass performing propagation of array
|
||||||
|
// shapes through function specialization.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
#include "shape_inference_interface.hpp"
|
||||||
|
|
||||||
|
#include "passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
// Include the auto-generated definitions for the shape inference interfaces.
|
||||||
|
#include "src/compiler/shape_inference.cpp.inc"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/*!
|
||||||
|
* FunctionPass that performs shape inference by iterating over a list of
|
||||||
|
* candidate operations and propagating the shape information until the list
|
||||||
|
* of operations is empty [credit MLIR authors].
|
||||||
|
*/
|
||||||
|
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto f = getFunction();
|
||||||
|
|
||||||
|
// Populate the worklist with the operations that need shape inference:
|
||||||
|
// these are operations that return a dynamic shape.
|
||||||
|
llvm::SmallPtrSet<mlir::Operation*, 16> op_worklist;
|
||||||
|
f.walk([&](mlir::Operation* op) {
|
||||||
|
if (returnsDynamicShape(op))
|
||||||
|
op_worklist.insert(op);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Iterate on the operations in the worklist until all operations have been
|
||||||
|
// inferred or no change happened (fix point).
|
||||||
|
while (!op_worklist.empty()) {
|
||||||
|
// Find the next operation ready for inference, that is an operation
|
||||||
|
// with all operands already resolved (non-generic).
|
||||||
|
auto nextop = llvm::find_if(op_worklist, returnsDynamicShape);
|
||||||
|
if (nextop == op_worklist.end())
|
||||||
|
break;
|
||||||
|
|
||||||
|
Operation* op = *nextop;
|
||||||
|
op_worklist.erase(op);
|
||||||
|
|
||||||
|
// Ask the operation to infer its output shapes.
|
||||||
|
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||||
|
shape_op.inferShapes();
|
||||||
|
} else {
|
||||||
|
op->emitError(
|
||||||
|
"unable to infer shape of operation without shape "
|
||||||
|
"inference interface");
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the operation worklist isn't empty, this indicates a failure.
|
||||||
|
if (!op_worklist.empty()) {
|
||||||
|
f.emitError("Shape inference failed, ")
|
||||||
|
<< op_worklist.size() << " operations couldn't be inferred\n";
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Check if the given operation has a dynamically shaped result.
|
||||||
|
*/
|
||||||
|
static bool returnsDynamicShape(Operation* op) {
|
||||||
|
// TODO: remove this check.
|
||||||
|
// Temporary fix until more ops are supported.
|
||||||
|
// All operations which do not return a ranked tensor type have dynamic
|
||||||
|
// shaped outputs. All those operation need to implement the inferShape()
|
||||||
|
// method.
|
||||||
|
if (op->getName().getStringRef() != "onnx.add")
|
||||||
|
return false;
|
||||||
|
return llvm::any_of(op->getResultTypes(),
|
||||||
|
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Create a Shape Inference pass.
|
||||||
|
*/
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
|
||||||
|
return std::make_unique<ShapeInferencePass>();
|
||||||
|
}
|
19
src/main.cpp
19
src/main.cpp
|
@ -22,8 +22,18 @@
|
||||||
|
|
||||||
#include "src/builder/frontend_dialect_transformer.hpp"
|
#include "src/builder/frontend_dialect_transformer.hpp"
|
||||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
#include "src/compiler/pass/passes.hpp"
|
||||||
|
|
||||||
|
#include "mlir/Analysis/Verifier.h"
|
||||||
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
|
#include "mlir/Parser.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Target/LLVMIR.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace onnf;
|
using namespace onnf;
|
||||||
|
@ -48,8 +58,15 @@ int main(int ac, char* av[]) {
|
||||||
|
|
||||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
|
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::OwningModuleRef module;
|
||||||
|
|
||||||
string model_filename = vm["onnx-model"].as<string>();
|
string model_filename = vm["onnx-model"].as<string>();
|
||||||
auto module = ImportFrontendModelFile(model_filename);
|
ImportFrontendModelFile(model_filename, context, module);
|
||||||
|
|
||||||
|
mlir::PassManager pm(&context);
|
||||||
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
|
pm.run(*module);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue