ONNX Dialect creation and registration with MLIR (#358)

* Create and register ONNX Dialect with one Add operation.

* Fix file formatting.

* Change name from ONNX to SGIR.

* Use ONNX dialect. Change SGIR to frontend placeholder dialect.

* Add comments.

* Type clean-up.
This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-11-01 17:09:48 -04:00 committed by Doru Bercea
parent b5a35c9138
commit 958a2fbffc
12 changed files with 276 additions and 79 deletions

View File

@ -1,10 +1,5 @@
add_executable(onnf main.cpp) add_executable(onnf main.cpp)
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(onnf PRIVATE ..) target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
target_link_libraries(onnf builder compiler ${Boost_LIBRARIES})
target_link_libraries(onnf
builder
${Boost_LIBRARIES}
)

View File

@ -1,8 +1,10 @@
add_library(builder add_library(builder
sgir.cpp frontend_dialect_transformer.cpp
) )
target_link_libraries(builder onnx ${MLIRLIBS} curses) target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR})
target_link_libraries(builder compiler onnx ${MLIRLIBS} curses)
target_include_directories(builder target_include_directories(builder
PRIVATE PRIVATE
${CMAKE_SOURCE_DIR}/third_party/onnx ${CMAKE_SOURCE_DIR}/third_party/onnx

View File

@ -1,9 +1,17 @@
//===----------------------------------------------------------------------===// //===- frontend_dialect_transformer.cpp - MLIR Operations -----------------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
// ============================================================================= // =============================================================================
// //
// This file transforms the input to available MLIR dialects that can represent
// the operations of the model. Models use the ONNX dialect and any other
// extension dialects that comprise the the operations not supported or covered
// by the ONNX specification.
//
// A `frontend` placeholder dialect is used to encode operations that are not
// covered by any existing dialects.
//
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <numeric> #include <numeric>
@ -26,7 +34,8 @@
#include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "sgir.hpp" #include "frontend_dialect_transformer.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
namespace onnf { namespace onnf {
namespace { namespace {
@ -84,14 +93,14 @@ struct OnnxOnnfSymbolMapping {
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor; std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
}; };
class SGIRGenImpl { class FrontendGenImpl {
public: public:
SGIRGenImpl(mlir::MLIRContext& context) FrontendGenImpl(mlir::MLIRContext& context)
: context_(context), builder_(&context) { : context_(context), builder_(&context) {
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
} }
mlir::ModuleOp ImportModel(onnx::ModelProto model) { mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) {
ImportGraph(model.graph()); ImportGraph(model.graph());
return module_; return module_;
} }
@ -101,7 +110,7 @@ class SGIRGenImpl {
mlir::ModuleOp module_; mlir::ModuleOp module_;
mlir::OpBuilder builder_; mlir::OpBuilder builder_;
// mapping between string name and symbol // mapping between string name and symbol
OnnxOnnfSymbolMapping sgir_symbols_; OnnxOnnfSymbolMapping frontend_symbols_;
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
@ -127,17 +136,17 @@ class SGIRGenImpl {
dims.push_back(-1); dims.push_back(-1);
} }
} }
if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) { if (!frontend_symbols_.ContainKey(input_tensor_legalized_name)) {
mlir::Type elementType = mlir::Type elementType =
TypeConvert(input.type().tensor_type().elem_type()); TypeConvert(input.type().tensor_type().elem_type());
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size()); llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
auto dataType = mlir::RankedTensorType::get(llvmdimsAR, elementType); auto dataType = mlir::RankedTensorType::get(llvmdimsAR, elementType);
mlir::OperationState result( mlir::OperationState result(
UnknownLoc(), "sgir.input " + input_tensor_legalized_name); UnknownLoc(), "frontend.input " + input_tensor_legalized_name);
result.addTypes(dataType); result.addTypes(dataType);
auto op = builder_.createOperation(result); auto op = builder_.createOperation(result);
auto value = op->getResult(0); auto value = op->getResult(0);
sgir_symbols_.AddMapping(input_tensor_legalized_name, value); frontend_symbols_.AddMapping(input_tensor_legalized_name, value);
} else { } else {
// TODO Should not happen // TODO Should not happen
} }
@ -146,11 +155,23 @@ class SGIRGenImpl {
void ImportNode(onnx::NodeProto node) { void ImportNode(onnx::NodeProto node) {
std::vector<mlir::Value*> inputs; std::vector<mlir::Value*> inputs;
for (auto item : node.input()) { for (auto item : node.input()) {
if (sgir_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
} }
mlir::OperationState result(UnknownLoc(), "SGIR." + node.op_type());
// Handle ONNX Add Operation by using its representation in the
// ONNX Dialect.
llvm::StringRef OpName = node.op_type();
if (OpName == "Add") {
auto op =
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult());
return;
}
// Old way of doing things.
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
for (auto item : node.output()) { for (auto item : node.output()) {
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
} }
@ -158,17 +179,17 @@ class SGIRGenImpl {
auto op = builder_.createOperation(result); auto op = builder_.createOperation(result);
for (int i = 0; i < node.output().size(); i++) { for (int i = 0; i < node.output().size(); i++) {
auto r = op->getResult(i); auto r = op->getResult(i);
sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r); frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
} }
// TODO more info from node: attributes // TODO more info from node: attributes
} }
void ImportOutputTensor(onnx::ValueInfoProto& output) { void ImportOutputTensor(onnx::ValueInfoProto& output) {
if (sgir_symbols_.ContainKey(legalize_name(output.name()))) { if (frontend_symbols_.ContainKey(legalize_name(output.name()))) {
mlir::OperationState result(UnknownLoc(), "sgir.output " + output.name()); mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name());
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name())); result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name()));
builder_.createOperation(result); builder_.createOperation(result);
} else { } else {
// TODO: Why not in the symbol table? something is wrong // TODO: Why not in the symbol table? something is wrong
@ -209,27 +230,28 @@ class SGIRGenImpl {
} }
} }
}; // SGIRGenImpl class }; // FrontendGenImpl class
} // namespace } // namespace
} // namespace onnf } // namespace onnf
namespace onnf { namespace onnf {
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) { mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
mlir::MLIRContext context; mlir::MLIRContext context;
SGIRGenImpl mySGIRGen(context); FrontendGenImpl myONNXGen(context);
auto module = mySGIRGen.ImportModel(model); auto module = myONNXGen.ImportONNXModel(model);
module.dump(); module.dump();
return module; return module;
} }
mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname) { mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname) {
onnx::ModelProto model; onnx::ModelProto model;
std::fstream input(model_fname, std::ios::in | std::ios::binary); std::fstream input(model_fname, std::ios::in | std::ios::binary);
auto parse_success = model.ParseFromIstream(&input); auto parse_success = model.ParseFromIstream(&input);
return SGIRImportModel(model);
return ImportFrontendModel(model);
} }
} // namespace onnf } // namespace onnf

View File

@ -0,0 +1,49 @@
//===- frontend_dialect_transformer.hpp - MLIR Operations -----------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#pragma once
#include <fstream>
#include <functional>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "onnx/onnx_pb.h"
namespace mlir {
class MLIRContext;
class OwningModuleRef;
} // namespace mlir
//===----------------------------------------------------------------------===//
// Import a model into one of ONNF's frontend models.
//===----------------------------------------------------------------------===//
namespace onnf {
/*!
* Import an ONNX model into ONNF's ONNX Dialect.
* @param model onnx model.
* @return MLIR::module generated for the ONNX model.
*/
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
/*!
* Import an ONNX model file into ONNF's ONNX Dialect.
* @param model_fname file name pointing to the onnx model protobuf.
* @return MLIR::module generated for the ONNX model.
*/
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname);
/*!
* TODO: Import models into other extension dialects that cover the
* operations specific to other frameworks such as Tensorflow or Pytorch.
*/
} // namespace onnf

View File

@ -1,40 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
//===----------------------------------------------------------------------===//
#pragma once
#include <fstream>
#include <functional>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "onnx/onnx_pb.h"
namespace mlir {
class MLIRContext;
class OwningModuleRef;
} // namespace mlir
namespace onnf {
/*!
* Import an ONNX Model into SGIR.
* @param model onnx model.
* @return MLIR::module generated for the ONNX model.
*/
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
/*!
* Import an ONNX Model file into SGIR.
* @param model_fname file name pointing to the onnx model protobuf.
* @return MLIR::module generated for the ONNX model.
*/
mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname);
} // namespace onnf

View File

@ -1,11 +1,14 @@
add_library( add_library(
compiler compiler
ir/knl/knl_ops.cpp ir/knl/knl_ops.cpp
ir/knl/knl_ops.hpp) ir/knl/knl_ops.hpp
dialect/onnx/onnx_ops.cpp
dialect/onnx/onnx_ops.hpp)
# Include root src directory. # Include root src directory.
target_include_directories(compiler PRIVATE ../..) target_include_directories(compiler PRIVATE ../..)
target_include_directories(compiler PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) target_include_directories(compiler PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(compiler PRIVATE ${CMAKE_BINARY_DIR})
find_package(Boost 1.54.0 find_package(Boost 1.54.0
COMPONENTS graph COMPONENTS graph
@ -26,3 +29,9 @@ onnf_tablegen(knl.hpp.inc -gen-op-decls)
onnf_tablegen(knl.cpp.inc -gen-op-defs) onnf_tablegen(knl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_kir) add_public_tablegen_target(gen_kir)
add_dependencies(compiler gen_kir) add_dependencies(compiler gen_kir)
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
onnf_tablegen(onnx.hpp.inc -gen-op-decls)
onnf_tablegen(onnx.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_onnx)
add_dependencies(compiler gen_onnx)

View File

@ -0,0 +1,64 @@
//===- ONNXOps.td -- ONNX operation definitions ---------*- tablegen -*----===//
//
// Copyright 2019 The IBM Research Authors
//
// =============================================================================
//
// Defines MLIR ONNX operations.
//
//===----------------------------------------------------------------------===//
#ifdef ONNX_OPS
#else
#define ONNX_OPS
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def ONNX_Dialect : Dialect {
let name = "onnx";
let cppNamespace = "";
}
// Base class for ONNX dialect operations. This operation inherits from the base
// `Op` class in OpBase.td, and provides:
// * The parent dialect of the operation.
// * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation.
class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
Op<ONNX_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// ONNX Operations
//===----------------------------------------------------------------------===//
// We define an ONNX operation for adding two tensors elementwise.
def ONNXAddOp: ONNX_Op<"add", [NoSideEffect]> {
let summary = "ONNX add operation";
let description = [{
The "onnx.add" adds two tensors element-wise.
}];
// TODO: AnyTensor might be too wide for ONNX and may need to be constrained
// to fewer valid types.
// In the ONNX spec:
// T : tensor(uint32), tensor(uint64),
// tensor(int32), tensor(int64),
// tensor(float16), tensor(float), tensor(double)
//
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in);
let results = (outs AnyTensor);
// Build an ONNX Add operation using two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildONNXAddOp(b, result, lhs, rhs);
}]
>];
}
#endif // ONNX_OPS

View File

@ -0,0 +1,54 @@
//===- onnx_ops.cpp - MLIR ONNX Operations --------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file defines ONNX operations in the MLIR operation set.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "onnx_ops.hpp"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ONNXOpsDialect
//===----------------------------------------------------------------------===//
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
: mlir::Dialect(getDialectNamespace(), ctx) {
addOperations<
#define GET_OP_LIST
#include "src/compiler/onnx.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
// ONNX Operations
//===----------------------------------------------------------------------===//
static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state,
mlir::Value* lhs, mlir::Value* rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF32Type()));
state.addOperands({lhs, rhs});
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/compiler/onnx.cpp.inc"

View File

@ -0,0 +1,39 @@
//===- onnx_ops.hpp - MLIR ONNX Operations --------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file defines ONNX operations in the MLIR operation set.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ONNX_ONNXOPS_H
#define MLIR_DIALECT_ONNX_ONNXOPS_H
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class ONNXOpsDialect : public Dialect {
public:
ONNXOpsDialect(MLIRContext* context);
/// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects.
static StringRef getDialectNamespace() { return "onnx"; }
};
/// Include the auto-generated header file containing the declarations of the
/// ONNX operations.
#define GET_OP_CLASSES
#include "src/compiler/onnx.hpp.inc"
} // end namespace mlir
namespace onnf {}
#endif // MLIR_DIALECT_ONNX_ONNXOPS_H

View File

@ -15,7 +15,7 @@ KnlOpsDialect::KnlOpsDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) { : Dialect(getDialectNamespace(), context) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "knl.cpp.inc" #include "src/compiler/knl.cpp.inc"
>(); >();
} }
} // namespace mlir } // namespace mlir

View File

@ -13,7 +13,7 @@ class KnlOpsDialect : public Dialect {
}; };
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "knl.hpp.inc" #include "src/compiler/knl.hpp.inc"
} // namespace mlir } // namespace mlir
namespace onnf {} namespace onnf {}

View File

@ -20,7 +20,8 @@
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include "src/builder/sgir.hpp" #include "src/builder/frontend_dialect_transformer.hpp"
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
#include "mlir/IR/Module.h" #include "mlir/IR/Module.h"
@ -45,8 +46,10 @@ int main(int ac, char* av[]) {
return 0; return 0;
} }
mlir::registerDialect<mlir::ONNXOpsDialect>();
string model_filename = vm["onnx-model"].as<string>(); string model_filename = vm["onnx-model"].as<string>();
auto module = SGIRImportModelFile(model_filename); auto module = ImportFrontendModelFile(model_filename);
return 0; return 0;
} }