[MLIR] Add shape inference pass (#361)

* Add shape inference pass.

* Format code.

* Enable shape inference pass.

* Rename new files to use ONNF convention.

* Use include path from src folder.

* Replace guards with pragma once.

* Format variable names. Shuffle headers.

* Fix comment.

* Fix comments.
This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-11-07 11:42:40 -05:00 committed by Doru Bercea
parent a6cca3cbb7
commit 03be41f7df
11 changed files with 228 additions and 12 deletions

View File

@ -279,12 +279,15 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
return module;
}
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname) {
void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext& context, mlir::OwningModuleRef& module) {
onnx::ModelProto model;
std::fstream input(model_fname, std::ios::in | std::ios::binary);
auto parse_success = model.ParseFromIstream(&input);
return ImportFrontendModel(model);
FrontendGenImpl myONNXGen(context);
module = myONNXGen.ImportONNXModel(model);
module->dump();
}
} // namespace onnf

View File

@ -40,7 +40,8 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
* @param model_fname file name pointing to the onnx model protobuf.
* @return MLIR::module generated for the ONNX model.
*/
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname);
void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext& context, mlir::OwningModuleRef& module);
/*!
* TODO: Import models into other extension dialects that cover the

View File

@ -3,7 +3,10 @@ add_library(
ir/knl/knl_ops.cpp
ir/knl/knl_ops.hpp
dialect/onnx/onnx_ops.cpp
dialect/onnx/onnx_ops.hpp)
dialect/onnx/onnx_ops.hpp
pass/shape_inference_pass.cpp
pass/shape_inference_interface.hpp
pass/passes.hpp)
# Include root src directory.
target_include_directories(compiler PRIVATE ../..)
@ -37,8 +40,14 @@ onnf_tablegen(knl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_kir)
add_dependencies(compiler gen_kir)
set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(gen_shape_inference)
add_dependencies(compiler gen_shape_inference)
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
onnf_tablegen(onnx.hpp.inc -gen-op-decls)
onnf_tablegen(onnx.cpp.inc -gen-op-defs)
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
add_public_tablegen_target(gen_onnx)
add_dependencies(compiler gen_onnx)

View File

@ -17,6 +17,11 @@
include "mlir/IR/OpBase.td"
#endif // OP_BASE
#ifdef SHAPE_INFERENCE_INTERFACE
#else
include "pass/shape_inference_interface.td"
#endif // SHAPE_INFERENCE_INTERFACE
def ONNX_Dialect : Dialect {
let name = "onnx";
let cppNamespace = "";
@ -35,7 +40,8 @@ class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
//===----------------------------------------------------------------------===//
// We define an ONNX operation for adding two tensors elementwise.
def ONNXAddOp: ONNX_Op<"add", [NoSideEffect]> {
def ONNXAddOp: ONNX_Op<"add",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX add operation";
let description = [{

View File

@ -46,6 +46,12 @@ static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state,
state.addOperands({lhs, rhs});
}
/// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface.
void ONNXAddOp::inferShapes() {
getResult()->setType(getOperand(0)->getType());
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -8,14 +8,15 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ONNX_ONNXOPS_H
#define MLIR_DIALECT_ONNX_ONNXOPS_H
#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "src/compiler/pass/shape_inference_interface.hpp"
namespace mlir {
class ONNXOpsDialect : public Dialect {
@ -35,5 +36,3 @@ class ONNXOpsDialect : public Dialect {
} // end namespace mlir
namespace onnf {}
#endif // MLIR_DIALECT_ONNX_ONNXOPS_H

View File

@ -0,0 +1,24 @@
//===- passes.hpp - ONNF Passes Definition --------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file exposes the entry points to create compiler passes for ONNF.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <memory>
namespace mlir {
class Pass;
std::unique_ptr<Pass> createShapeInferencePass();
// TODO: Add pass for lowering to kernel IR.
// TODO: Add pass for lowering to LLVM IR.
} // end namespace mlir

View File

@ -0,0 +1,21 @@
//===- shape_inference_interface.hpp - Definition for ShapeInference --------=//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file contains the declarations of the shape inference interfaces defined
// in ShapeInferenceInterface.td.
//
//===----------------------------------------------------------------------===//
#pragma once
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// Include the auto-generated declarations.
#include "src/compiler/shape_inference.hpp.inc"
} // end namespace mlir

View File

@ -0,0 +1,32 @@
//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Defines the operations of the Shape Inference Op Interface.
//
//===----------------------------------------------------------------------===//
#ifdef SHAPE_INFERENCE_INTERFACE
#else
#define SHAPE_INFERENCE_INTERFACE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "inferShapes">
];
}
#endif // SHAPE_INFERENCE_INTERFACE

View File

@ -0,0 +1,98 @@
//===----- shape_inference_pass.cpp - Shape Inference ---------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a Function level pass performing propagation of array
// shapes through function specialization.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Pass/Pass.h"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "shape_inference_interface.hpp"
#include "passes.hpp"
using namespace mlir;
// Include the auto-generated definitions for the shape inference interfaces.
#include "src/compiler/shape_inference.cpp.inc"
namespace {
/*!
* FunctionPass that performs shape inference by iterating over a list of
* candidate operations and propagating the shape information until the list
* of operations is empty [credit MLIR authors].
*/
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
public:
void runOnFunction() override {
auto f = getFunction();
// Populate the worklist with the operations that need shape inference:
// these are operations that return a dynamic shape.
llvm::SmallPtrSet<mlir::Operation*, 16> op_worklist;
f.walk([&](mlir::Operation* op) {
if (returnsDynamicShape(op))
op_worklist.insert(op);
});
// Iterate on the operations in the worklist until all operations have been
// inferred or no change happened (fix point).
while (!op_worklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(op_worklist, returnsDynamicShape);
if (nextop == op_worklist.end())
break;
Operation* op = *nextop;
op_worklist.erase(op);
// Ask the operation to infer its output shapes.
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
shape_op.inferShapes();
} else {
op->emitError(
"unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
}
// If the operation worklist isn't empty, this indicates a failure.
if (!op_worklist.empty()) {
f.emitError("Shape inference failed, ")
<< op_worklist.size() << " operations couldn't be inferred\n";
signalPassFailure();
}
}
/*!
* Check if the given operation has a dynamically shaped result.
*/
static bool returnsDynamicShape(Operation* op) {
// TODO: remove this check.
// Temporary fix until more ops are supported.
// All operations which do not return a ranked tensor type have dynamic
// shaped outputs. All those operation need to implement the inferShape()
// method.
if (op->getName().getStringRef() != "onnx.add")
return false;
return llvm::any_of(op->getResultTypes(),
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
}
};
} // end anonymous namespace
/*!
* Create a Shape Inference pass.
*/
std::unique_ptr<mlir::Pass> mlir::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>();
}

View File

@ -22,8 +22,18 @@
#include "src/builder/frontend_dialect_transformer.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "src/compiler/pass/passes.hpp"
#include "mlir/Analysis/Verifier.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/Passes.h"
using namespace std;
using namespace onnf;
@ -48,8 +58,15 @@ int main(int ac, char* av[]) {
mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
string model_filename = vm["onnx-model"].as<string>();
auto module = ImportFrontendModelFile(model_filename);
ImportFrontendModelFile(model_filename, context, module);
mlir::PassManager pm(&context);
pm.addPass(mlir::createShapeInferencePass());
pm.run(*module);
return 0;
}