diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7c16c0b..7085399 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,10 +1,5 @@ add_executable(onnf main.cpp) - -target_include_directories(onnf PRIVATE ..) - -target_link_libraries(onnf - builder - ${Boost_LIBRARIES} - ) - +target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) +target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) +target_link_libraries(onnf builder compiler ${Boost_LIBRARIES}) diff --git a/src/builder/CMakeLists.txt b/src/builder/CMakeLists.txt index 8627ae1..a014dd7 100644 --- a/src/builder/CMakeLists.txt +++ b/src/builder/CMakeLists.txt @@ -1,8 +1,10 @@ add_library(builder - sgir.cpp + frontend_dialect_transformer.cpp ) -target_link_libraries(builder onnx ${MLIRLIBS} curses) +target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}) +target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR}) +target_link_libraries(builder compiler onnx ${MLIRLIBS} curses) target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}/third_party/onnx diff --git a/src/builder/sgir.cpp b/src/builder/frontend_dialect_transformer.cpp similarity index 73% rename from src/builder/sgir.cpp rename to src/builder/frontend_dialect_transformer.cpp index 2a63548..a957897 100644 --- a/src/builder/sgir.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -1,9 +1,17 @@ -//===----------------------------------------------------------------------===// +//===- frontend_dialect_transformer.cpp - MLIR Operations -----------------===// // -// Copyright 2019 The IBM Research Authors. +// Copyright 2019 The IBM Research Authors. // // ============================================================================= // +// This file transforms the input to available MLIR dialects that can represent +// the operations of the model. Models use the ONNX dialect and any other +// extension dialects that comprise the the operations not supported or covered +// by the ONNX specification. +// +// A `frontend` placeholder dialect is used to encode operations that are not +// covered by any existing dialects. +// //===----------------------------------------------------------------------===// #include @@ -26,7 +34,8 @@ #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" -#include "sgir.hpp" +#include "frontend_dialect_transformer.hpp" +#include "src/compiler/dialect/onnx/onnx_ops.hpp" namespace onnf { namespace { @@ -84,14 +93,14 @@ struct OnnxOnnfSymbolMapping { std::map onnx_name2onnf_tensor; }; -class SGIRGenImpl { +class FrontendGenImpl { public: - SGIRGenImpl(mlir::MLIRContext& context) + FrontendGenImpl(mlir::MLIRContext& context) : context_(context), builder_(&context) { module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); } - mlir::ModuleOp ImportModel(onnx::ModelProto model) { + mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) { ImportGraph(model.graph()); return module_; } @@ -101,7 +110,7 @@ class SGIRGenImpl { mlir::ModuleOp module_; mlir::OpBuilder builder_; // mapping between string name and symbol - OnnxOnnfSymbolMapping sgir_symbols_; + OnnxOnnfSymbolMapping frontend_symbols_; mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } @@ -127,17 +136,17 @@ class SGIRGenImpl { dims.push_back(-1); } } - if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) { + if (!frontend_symbols_.ContainKey(input_tensor_legalized_name)) { mlir::Type elementType = TypeConvert(input.type().tensor_type().elem_type()); llvm::ArrayRef llvmdimsAR(dims.data(), dims.size()); auto dataType = mlir::RankedTensorType::get(llvmdimsAR, elementType); mlir::OperationState result( - UnknownLoc(), "sgir.input " + input_tensor_legalized_name); + UnknownLoc(), "frontend.input " + input_tensor_legalized_name); result.addTypes(dataType); auto op = builder_.createOperation(result); auto value = op->getResult(0); - sgir_symbols_.AddMapping(input_tensor_legalized_name, value); + frontend_symbols_.AddMapping(input_tensor_legalized_name, value); } else { // TODO Should not happen } @@ -146,11 +155,23 @@ class SGIRGenImpl { void ImportNode(onnx::NodeProto node) { std::vector inputs; for (auto item : node.input()) { - if (sgir_symbols_.ContainKey(legalize_name(item))) { - inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item)); + if (frontend_symbols_.ContainKey(legalize_name(item))) { + inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } } - mlir::OperationState result(UnknownLoc(), "SGIR." + node.op_type()); + + // Handle ONNX Add Operation by using its representation in the + // ONNX Dialect. + llvm::StringRef OpName = node.op_type(); + if (OpName == "Add") { + auto op = + builder_.create(UnknownLoc(), inputs[0], inputs[1]); + frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult()); + return; + } + + // Old way of doing things. + mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type()); for (auto item : node.output()) { result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); } @@ -158,17 +179,17 @@ class SGIRGenImpl { auto op = builder_.createOperation(result); for (int i = 0; i < node.output().size(); i++) { auto r = op->getResult(i); - sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r); + frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r); } // TODO more info from node: attributes } void ImportOutputTensor(onnx::ValueInfoProto& output) { - if (sgir_symbols_.ContainKey(legalize_name(output.name()))) { - mlir::OperationState result(UnknownLoc(), "sgir.output " + output.name()); + if (frontend_symbols_.ContainKey(legalize_name(output.name()))) { + mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name()); result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); - result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name())); + result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name())); builder_.createOperation(result); } else { // TODO: Why not in the symbol table? something is wrong @@ -209,27 +230,28 @@ class SGIRGenImpl { } } -}; // SGIRGenImpl class +}; // FrontendGenImpl class } // namespace } // namespace onnf namespace onnf { -mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) { +mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) { mlir::MLIRContext context; - SGIRGenImpl mySGIRGen(context); - auto module = mySGIRGen.ImportModel(model); + FrontendGenImpl myONNXGen(context); + auto module = myONNXGen.ImportONNXModel(model); module.dump(); return module; } -mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname) { +mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname) { onnx::ModelProto model; std::fstream input(model_fname, std::ios::in | std::ios::binary); auto parse_success = model.ParseFromIstream(&input); - return SGIRImportModel(model); + + return ImportFrontendModel(model); } } // namespace onnf diff --git a/src/builder/frontend_dialect_transformer.hpp b/src/builder/frontend_dialect_transformer.hpp new file mode 100644 index 0000000..bc59708 --- /dev/null +++ b/src/builder/frontend_dialect_transformer.hpp @@ -0,0 +1,49 @@ +//===- frontend_dialect_transformer.hpp - MLIR Operations -----------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "onnx/onnx_pb.h" + +namespace mlir { +class MLIRContext; +class OwningModuleRef; +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Import a model into one of ONNF's frontend models. +//===----------------------------------------------------------------------===// + +namespace onnf { +/*! + * Import an ONNX model into ONNF's ONNX Dialect. + * @param model onnx model. + * @return MLIR::module generated for the ONNX model. + */ +mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model); + +/*! + * Import an ONNX model file into ONNF's ONNX Dialect. + * @param model_fname file name pointing to the onnx model protobuf. + * @return MLIR::module generated for the ONNX model. + */ +mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname); + +/*! + * TODO: Import models into other extension dialects that cover the + * operations specific to other frameworks such as Tensorflow or Pytorch. + */ +} // namespace onnf diff --git a/src/builder/sgir.hpp b/src/builder/sgir.hpp deleted file mode 100644 index 382b3d2..0000000 --- a/src/builder/sgir.hpp +++ /dev/null @@ -1,40 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright 2019 The IBM Research Authors. -// -// ============================================================================= -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "onnx/onnx_pb.h" - -namespace mlir { -class MLIRContext; -class OwningModuleRef; -} // namespace mlir - -namespace onnf { -/*! - * Import an ONNX Model into SGIR. - * @param model onnx model. - * @return MLIR::module generated for the ONNX model. - */ -mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model); - -/*! - * Import an ONNX Model file into SGIR. - * @param model_fname file name pointing to the onnx model protobuf. - * @return MLIR::module generated for the ONNX model. - */ -mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname); -} // namespace onnf diff --git a/src/compiler/CMakeLists.txt b/src/compiler/CMakeLists.txt index 2ab4e9b..26a2c30 100644 --- a/src/compiler/CMakeLists.txt +++ b/src/compiler/CMakeLists.txt @@ -1,11 +1,14 @@ add_library( compiler ir/knl/knl_ops.cpp - ir/knl/knl_ops.hpp) + ir/knl/knl_ops.hpp + dialect/onnx/onnx_ops.cpp + dialect/onnx/onnx_ops.hpp) # Include root src directory. target_include_directories(compiler PRIVATE ../..) -target_include_directories(compiler PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +target_include_directories(compiler PRIVATE ${CMAKE_SOURCE_DIR}) +target_include_directories(compiler PRIVATE ${CMAKE_BINARY_DIR}) find_package(Boost 1.54.0 COMPONENTS graph @@ -26,3 +29,9 @@ onnf_tablegen(knl.hpp.inc -gen-op-decls) onnf_tablegen(knl.cpp.inc -gen-op-defs) add_public_tablegen_target(gen_kir) add_dependencies(compiler gen_kir) + +set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) +onnf_tablegen(onnx.hpp.inc -gen-op-decls) +onnf_tablegen(onnx.cpp.inc -gen-op-defs) +add_public_tablegen_target(gen_onnx) +add_dependencies(compiler gen_onnx) diff --git a/src/compiler/dialect/onnx/onnx.td b/src/compiler/dialect/onnx/onnx.td new file mode 100644 index 0000000..b005122 --- /dev/null +++ b/src/compiler/dialect/onnx/onnx.td @@ -0,0 +1,64 @@ +//===- ONNXOps.td -- ONNX operation definitions ---------*- tablegen -*----===// +// +// Copyright 2019 The IBM Research Authors +// +// ============================================================================= +// +// Defines MLIR ONNX operations. +// +//===----------------------------------------------------------------------===// + +#ifdef ONNX_OPS +#else +#define ONNX_OPS + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def ONNX_Dialect : Dialect { + let name = "onnx"; + let cppNamespace = ""; +} + +// Base class for ONNX dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class ONNX_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// ONNX Operations +//===----------------------------------------------------------------------===// + +// We define an ONNX operation for adding two tensors elementwise. +def ONNXAddOp: ONNX_Op<"add", [NoSideEffect]> { + let summary = "ONNX add operation"; + let description = [{ + + The "onnx.add" adds two tensors element-wise. + + }]; + + // TODO: AnyTensor might be too wide for ONNX and may need to be constrained + // to fewer valid types. + // In the ONNX spec: + // T : tensor(uint32), tensor(uint64), + // tensor(int32), tensor(int64), + // tensor(float16), tensor(float), tensor(double) + // + let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in); + let results = (outs AnyTensor); + + // Build an ONNX Add operation using two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{ + buildONNXAddOp(b, result, lhs, rhs); + }] + >]; +} + +#endif // ONNX_OPS diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp new file mode 100644 index 0000000..8488455 --- /dev/null +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -0,0 +1,54 @@ +//===- onnx_ops.cpp - MLIR ONNX Operations --------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file defines ONNX operations in the MLIR operation set. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" + +#include "onnx_ops.hpp" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ONNXOpsDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx) + : mlir::Dialect(getDialectNamespace(), ctx) { + addOperations< +#define GET_OP_LIST +#include "src/compiler/onnx.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// ONNX Operations +//===----------------------------------------------------------------------===// + +static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state, + mlir::Value* lhs, mlir::Value* rhs) { + state.addTypes(UnrankedTensorType::get(builder->getF32Type())); + state.addOperands({lhs, rhs}); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "src/compiler/onnx.cpp.inc" diff --git a/src/compiler/dialect/onnx/onnx_ops.hpp b/src/compiler/dialect/onnx/onnx_ops.hpp new file mode 100644 index 0000000..8d12280 --- /dev/null +++ b/src/compiler/dialect/onnx/onnx_ops.hpp @@ -0,0 +1,39 @@ +//===- onnx_ops.hpp - MLIR ONNX Operations --------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file defines ONNX operations in the MLIR operation set. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ONNX_ONNXOPS_H +#define MLIR_DIALECT_ONNX_ONNXOPS_H + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { + +class ONNXOpsDialect : public Dialect { + public: + ONNXOpsDialect(MLIRContext* context); + + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static StringRef getDialectNamespace() { return "onnx"; } +}; + +/// Include the auto-generated header file containing the declarations of the +/// ONNX operations. +#define GET_OP_CLASSES +#include "src/compiler/onnx.hpp.inc" + +} // end namespace mlir + +namespace onnf {} + +#endif // MLIR_DIALECT_ONNX_ONNXOPS_H diff --git a/src/compiler/ir/knl/knl_ops.cpp b/src/compiler/ir/knl/knl_ops.cpp index 1d865c8..10ad6eb 100644 --- a/src/compiler/ir/knl/knl_ops.cpp +++ b/src/compiler/ir/knl/knl_ops.cpp @@ -15,7 +15,7 @@ KnlOpsDialect::KnlOpsDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST -#include "knl.cpp.inc" +#include "src/compiler/knl.cpp.inc" >(); } } // namespace mlir diff --git a/src/compiler/ir/knl/knl_ops.hpp b/src/compiler/ir/knl/knl_ops.hpp index 6ec7a86..aaa55be 100644 --- a/src/compiler/ir/knl/knl_ops.hpp +++ b/src/compiler/ir/knl/knl_ops.hpp @@ -13,7 +13,7 @@ class KnlOpsDialect : public Dialect { }; #define GET_OP_CLASSES -#include "knl.hpp.inc" +#include "src/compiler/knl.hpp.inc" } // namespace mlir namespace onnf {} diff --git a/src/main.cpp b/src/main.cpp index ca92bc0..0f4b4f0 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -20,7 +20,8 @@ #include -#include "src/builder/sgir.hpp" +#include "src/builder/frontend_dialect_transformer.hpp" +#include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "mlir/IR/Module.h" @@ -45,8 +46,10 @@ int main(int ac, char* av[]) { return 0; } + mlir::registerDialect(); + string model_filename = vm["onnx-model"].as(); - auto module = SGIRImportModelFile(model_filename); + auto module = ImportFrontendModelFile(model_filename); return 0; }