diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 42e7085..7413768 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -279,12 +279,15 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) { return module; } -mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname) { +void ImportFrontendModelFile(std::string model_fname, + mlir::MLIRContext& context, mlir::OwningModuleRef& module) { onnx::ModelProto model; std::fstream input(model_fname, std::ios::in | std::ios::binary); auto parse_success = model.ParseFromIstream(&input); - return ImportFrontendModel(model); + FrontendGenImpl myONNXGen(context); + module = myONNXGen.ImportONNXModel(model); + module->dump(); } } // namespace onnf diff --git a/src/builder/frontend_dialect_transformer.hpp b/src/builder/frontend_dialect_transformer.hpp index bc59708..f12512c 100644 --- a/src/builder/frontend_dialect_transformer.hpp +++ b/src/builder/frontend_dialect_transformer.hpp @@ -40,7 +40,8 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model); * @param model_fname file name pointing to the onnx model protobuf. * @return MLIR::module generated for the ONNX model. */ -mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname); +void ImportFrontendModelFile(std::string model_fname, + mlir::MLIRContext& context, mlir::OwningModuleRef& module); /*! * TODO: Import models into other extension dialects that cover the diff --git a/src/compiler/CMakeLists.txt b/src/compiler/CMakeLists.txt index b7e0aae..e7e069c 100644 --- a/src/compiler/CMakeLists.txt +++ b/src/compiler/CMakeLists.txt @@ -3,7 +3,10 @@ add_library( ir/knl/knl_ops.cpp ir/knl/knl_ops.hpp dialect/onnx/onnx_ops.cpp - dialect/onnx/onnx_ops.hpp) + dialect/onnx/onnx_ops.hpp + pass/shape_inference_pass.cpp + pass/shape_inference_interface.hpp + pass/passes.hpp) # Include root src directory. target_include_directories(compiler PRIVATE ../..) @@ -37,8 +40,14 @@ onnf_tablegen(knl.cpp.inc -gen-op-defs) add_public_tablegen_target(gen_kir) add_dependencies(compiler gen_kir) +set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td) +onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls) +onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(gen_shape_inference) +add_dependencies(compiler gen_shape_inference) + set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) -onnf_tablegen(onnx.hpp.inc -gen-op-decls) -onnf_tablegen(onnx.cpp.inc -gen-op-defs) +onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") +onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") add_public_tablegen_target(gen_onnx) add_dependencies(compiler gen_onnx) diff --git a/src/compiler/dialect/onnx/onnx.td b/src/compiler/dialect/onnx/onnx.td index b005122..c4ee81b 100644 --- a/src/compiler/dialect/onnx/onnx.td +++ b/src/compiler/dialect/onnx/onnx.td @@ -17,6 +17,11 @@ include "mlir/IR/OpBase.td" #endif // OP_BASE +#ifdef SHAPE_INFERENCE_INTERFACE +#else +include "pass/shape_inference_interface.td" +#endif // SHAPE_INFERENCE_INTERFACE + def ONNX_Dialect : Dialect { let name = "onnx"; let cppNamespace = ""; @@ -35,7 +40,8 @@ class ONNX_Op traits = []> : //===----------------------------------------------------------------------===// // We define an ONNX operation for adding two tensors elementwise. -def ONNXAddOp: ONNX_Op<"add", [NoSideEffect]> { +def ONNXAddOp: ONNX_Op<"add", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX add operation"; let description = [{ diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index 8488455..627f804 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -46,6 +46,12 @@ static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state, state.addOperands({lhs, rhs}); } +/// Infer the output shape of the ONNXAddOp. This method is required by the +/// shape inference interface. +void ONNXAddOp::inferShapes() { + getResult()->setType(getOperand(0)->getType()); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/compiler/dialect/onnx/onnx_ops.hpp b/src/compiler/dialect/onnx/onnx_ops.hpp index 8d12280..deab78a 100644 --- a/src/compiler/dialect/onnx/onnx_ops.hpp +++ b/src/compiler/dialect/onnx/onnx_ops.hpp @@ -8,14 +8,15 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_ONNX_ONNXOPS_H -#define MLIR_DIALECT_ONNX_ONNXOPS_H +#pragma once #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" +#include "src/compiler/pass/shape_inference_interface.hpp" + namespace mlir { class ONNXOpsDialect : public Dialect { @@ -35,5 +36,3 @@ class ONNXOpsDialect : public Dialect { } // end namespace mlir namespace onnf {} - -#endif // MLIR_DIALECT_ONNX_ONNXOPS_H diff --git a/src/compiler/pass/passes.hpp b/src/compiler/pass/passes.hpp new file mode 100644 index 0000000..995da61 --- /dev/null +++ b/src/compiler/pass/passes.hpp @@ -0,0 +1,24 @@ +//===- passes.hpp - ONNF Passes Definition --------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file exposes the entry points to create compiler passes for ONNF. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace mlir { +class Pass; + +std::unique_ptr createShapeInferencePass(); + +// TODO: Add pass for lowering to kernel IR. + +// TODO: Add pass for lowering to LLVM IR. + +} // end namespace mlir diff --git a/src/compiler/pass/shape_inference_interface.hpp b/src/compiler/pass/shape_inference_interface.hpp new file mode 100644 index 0000000..8f4feb1 --- /dev/null +++ b/src/compiler/pass/shape_inference_interface.hpp @@ -0,0 +1,21 @@ +//===- shape_inference_interface.hpp - Definition for ShapeInference --------=// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { + +/// Include the auto-generated declarations. +#include "src/compiler/shape_inference.hpp.inc" + +} // end namespace mlir diff --git a/src/compiler/pass/shape_inference_interface.td b/src/compiler/pass/shape_inference_interface.td new file mode 100644 index 0000000..e191c26 --- /dev/null +++ b/src/compiler/pass/shape_inference_interface.td @@ -0,0 +1,32 @@ +//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifdef SHAPE_INFERENCE_INTERFACE +#else +#define SHAPE_INFERENCE_INTERFACE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp new file mode 100644 index 0000000..cf556df --- /dev/null +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -0,0 +1,98 @@ +//===----- shape_inference_pass.cpp - Shape Inference ---------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements a Function level pass performing propagation of array +// shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Pass/Pass.h" + +#include "src/compiler/dialect/onnx/onnx_ops.hpp" +#include "shape_inference_interface.hpp" + +#include "passes.hpp" + +using namespace mlir; + +// Include the auto-generated definitions for the shape inference interfaces. +#include "src/compiler/shape_inference.cpp.inc" + +namespace { +/*! + * FunctionPass that performs shape inference by iterating over a list of + * candidate operations and propagating the shape information until the list + * of operations is empty [credit MLIR authors]. + */ +class ShapeInferencePass : public mlir::FunctionPass { + public: + void runOnFunction() override { + auto f = getFunction(); + + // Populate the worklist with the operations that need shape inference: + // these are operations that return a dynamic shape. + llvm::SmallPtrSet op_worklist; + f.walk([&](mlir::Operation* op) { + if (returnsDynamicShape(op)) + op_worklist.insert(op); + }); + + // Iterate on the operations in the worklist until all operations have been + // inferred or no change happened (fix point). + while (!op_worklist.empty()) { + // Find the next operation ready for inference, that is an operation + // with all operands already resolved (non-generic). + auto nextop = llvm::find_if(op_worklist, returnsDynamicShape); + if (nextop == op_worklist.end()) + break; + + Operation* op = *nextop; + op_worklist.erase(op); + + // Ask the operation to infer its output shapes. + if (auto shape_op = dyn_cast(op)) { + shape_op.inferShapes(); + } else { + op->emitError( + "unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } + + // If the operation worklist isn't empty, this indicates a failure. + if (!op_worklist.empty()) { + f.emitError("Shape inference failed, ") + << op_worklist.size() << " operations couldn't be inferred\n"; + signalPassFailure(); + } + } + + /*! + * Check if the given operation has a dynamically shaped result. + */ + static bool returnsDynamicShape(Operation* op) { + // TODO: remove this check. + // Temporary fix until more ops are supported. + // All operations which do not return a ranked tensor type have dynamic + // shaped outputs. All those operation need to implement the inferShape() + // method. + if (op->getName().getStringRef() != "onnx.add") + return false; + return llvm::any_of(op->getResultTypes(), + [](Type result_type) { return !result_type.isa(); }); + } +}; +} // end anonymous namespace + +/*! + * Create a Shape Inference pass. + */ +std::unique_ptr mlir::createShapeInferencePass() { + return std::make_unique(); +} diff --git a/src/main.cpp b/src/main.cpp index 0f4b4f0..bcbcd4b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -22,8 +22,18 @@ #include "src/builder/frontend_dialect_transformer.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp" +#include "src/compiler/pass/passes.hpp" +#include "mlir/Analysis/Verifier.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" using namespace std; using namespace onnf; @@ -48,8 +58,15 @@ int main(int ac, char* av[]) { mlir::registerDialect(); + mlir::MLIRContext context; + mlir::OwningModuleRef module; + string model_filename = vm["onnx-model"].as(); - auto module = ImportFrontendModelFile(model_filename); + ImportFrontendModelFile(model_filename, context, module); + + mlir::PassManager pm(&context); + pm.addPass(mlir::createShapeInferencePass()); + pm.run(*module); return 0; }