ONNX Dialect creation and registration with MLIR (#358)
* Create and register ONNX Dialect with one Add operation. * Fix file formatting. * Change name from ONNX to SGIR. * Use ONNX dialect. Change SGIR to frontend placeholder dialect. * Add comments. * Type clean-up.
This commit is contained in:
parent
b5a35c9138
commit
958a2fbffc
|
@ -1,10 +1,5 @@
|
|||
|
||||
add_executable(onnf main.cpp)
|
||||
|
||||
target_include_directories(onnf PRIVATE ..)
|
||||
|
||||
target_link_libraries(onnf
|
||||
builder
|
||||
${Boost_LIBRARIES}
|
||||
)
|
||||
|
||||
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||
target_link_libraries(onnf builder compiler ${Boost_LIBRARIES})
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
add_library(builder
|
||||
sgir.cpp
|
||||
frontend_dialect_transformer.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(builder onnx ${MLIRLIBS} curses)
|
||||
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR})
|
||||
target_link_libraries(builder compiler onnx ${MLIRLIBS} curses)
|
||||
target_include_directories(builder
|
||||
PRIVATE
|
||||
${CMAKE_SOURCE_DIR}/third_party/onnx
|
||||
|
|
|
@ -1,9 +1,17 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//===- frontend_dialect_transformer.cpp - MLIR Operations -----------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file transforms the input to available MLIR dialects that can represent
|
||||
// the operations of the model. Models use the ONNX dialect and any other
|
||||
// extension dialects that comprise the the operations not supported or covered
|
||||
// by the ONNX specification.
|
||||
//
|
||||
// A `frontend` placeholder dialect is used to encode operations that are not
|
||||
// covered by any existing dialects.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <numeric>
|
||||
|
@ -26,7 +34,8 @@
|
|||
#include "llvm/ADT/ScopedHashTable.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "sgir.hpp"
|
||||
#include "frontend_dialect_transformer.hpp"
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
|
||||
namespace onnf {
|
||||
namespace {
|
||||
|
@ -84,14 +93,14 @@ struct OnnxOnnfSymbolMapping {
|
|||
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
|
||||
};
|
||||
|
||||
class SGIRGenImpl {
|
||||
class FrontendGenImpl {
|
||||
public:
|
||||
SGIRGenImpl(mlir::MLIRContext& context)
|
||||
FrontendGenImpl(mlir::MLIRContext& context)
|
||||
: context_(context), builder_(&context) {
|
||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
}
|
||||
|
||||
mlir::ModuleOp ImportModel(onnx::ModelProto model) {
|
||||
mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) {
|
||||
ImportGraph(model.graph());
|
||||
return module_;
|
||||
}
|
||||
|
@ -101,7 +110,7 @@ class SGIRGenImpl {
|
|||
mlir::ModuleOp module_;
|
||||
mlir::OpBuilder builder_;
|
||||
// mapping between string name and symbol
|
||||
OnnxOnnfSymbolMapping sgir_symbols_;
|
||||
OnnxOnnfSymbolMapping frontend_symbols_;
|
||||
|
||||
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
||||
|
||||
|
@ -127,17 +136,17 @@ class SGIRGenImpl {
|
|||
dims.push_back(-1);
|
||||
}
|
||||
}
|
||||
if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) {
|
||||
if (!frontend_symbols_.ContainKey(input_tensor_legalized_name)) {
|
||||
mlir::Type elementType =
|
||||
TypeConvert(input.type().tensor_type().elem_type());
|
||||
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
|
||||
auto dataType = mlir::RankedTensorType::get(llvmdimsAR, elementType);
|
||||
mlir::OperationState result(
|
||||
UnknownLoc(), "sgir.input " + input_tensor_legalized_name);
|
||||
UnknownLoc(), "frontend.input " + input_tensor_legalized_name);
|
||||
result.addTypes(dataType);
|
||||
auto op = builder_.createOperation(result);
|
||||
auto value = op->getResult(0);
|
||||
sgir_symbols_.AddMapping(input_tensor_legalized_name, value);
|
||||
frontend_symbols_.AddMapping(input_tensor_legalized_name, value);
|
||||
} else {
|
||||
// TODO Should not happen
|
||||
}
|
||||
|
@ -146,11 +155,23 @@ class SGIRGenImpl {
|
|||
void ImportNode(onnx::NodeProto node) {
|
||||
std::vector<mlir::Value*> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (sgir_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item));
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
}
|
||||
}
|
||||
mlir::OperationState result(UnknownLoc(), "SGIR." + node.op_type());
|
||||
|
||||
// Handle ONNX Add Operation by using its representation in the
|
||||
// ONNX Dialect.
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
if (OpName == "Add") {
|
||||
auto op =
|
||||
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
|
||||
frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult());
|
||||
return;
|
||||
}
|
||||
|
||||
// Old way of doing things.
|
||||
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
|
||||
for (auto item : node.output()) {
|
||||
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
}
|
||||
|
@ -158,17 +179,17 @@ class SGIRGenImpl {
|
|||
auto op = builder_.createOperation(result);
|
||||
for (int i = 0; i < node.output().size(); i++) {
|
||||
auto r = op->getResult(i);
|
||||
sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
||||
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
||||
}
|
||||
|
||||
// TODO more info from node: attributes
|
||||
}
|
||||
|
||||
void ImportOutputTensor(onnx::ValueInfoProto& output) {
|
||||
if (sgir_symbols_.ContainKey(legalize_name(output.name()))) {
|
||||
mlir::OperationState result(UnknownLoc(), "sgir.output " + output.name());
|
||||
if (frontend_symbols_.ContainKey(legalize_name(output.name()))) {
|
||||
mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name());
|
||||
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name()));
|
||||
result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name()));
|
||||
builder_.createOperation(result);
|
||||
} else {
|
||||
// TODO: Why not in the symbol table? something is wrong
|
||||
|
@ -209,27 +230,28 @@ class SGIRGenImpl {
|
|||
}
|
||||
}
|
||||
|
||||
}; // SGIRGenImpl class
|
||||
}; // FrontendGenImpl class
|
||||
|
||||
} // namespace
|
||||
} // namespace onnf
|
||||
|
||||
namespace onnf {
|
||||
|
||||
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) {
|
||||
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
|
||||
mlir::MLIRContext context;
|
||||
SGIRGenImpl mySGIRGen(context);
|
||||
auto module = mySGIRGen.ImportModel(model);
|
||||
FrontendGenImpl myONNXGen(context);
|
||||
auto module = myONNXGen.ImportONNXModel(model);
|
||||
module.dump();
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname) {
|
||||
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname) {
|
||||
onnx::ModelProto model;
|
||||
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
||||
|
||||
auto parse_success = model.ParseFromIstream(&input);
|
||||
return SGIRImportModel(model);
|
||||
|
||||
return ImportFrontendModel(model);
|
||||
}
|
||||
} // namespace onnf
|
|
@ -0,0 +1,49 @@
|
|||
//===- frontend_dialect_transformer.hpp - MLIR Operations -----------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnx/onnx_pb.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class OwningModuleRef;
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Import a model into one of ONNF's frontend models.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace onnf {
|
||||
/*!
|
||||
* Import an ONNX model into ONNF's ONNX Dialect.
|
||||
* @param model onnx model.
|
||||
* @return MLIR::module generated for the ONNX model.
|
||||
*/
|
||||
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
|
||||
|
||||
/*!
|
||||
* Import an ONNX model file into ONNF's ONNX Dialect.
|
||||
* @param model_fname file name pointing to the onnx model protobuf.
|
||||
* @return MLIR::module generated for the ONNX model.
|
||||
*/
|
||||
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname);
|
||||
|
||||
/*!
|
||||
* TODO: Import models into other extension dialects that cover the
|
||||
* operations specific to other frameworks such as Tensorflow or Pytorch.
|
||||
*/
|
||||
} // namespace onnf
|
|
@ -1,40 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnx/onnx_pb.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class OwningModuleRef;
|
||||
} // namespace mlir
|
||||
|
||||
namespace onnf {
|
||||
/*!
|
||||
* Import an ONNX Model into SGIR.
|
||||
* @param model onnx model.
|
||||
* @return MLIR::module generated for the ONNX model.
|
||||
*/
|
||||
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
|
||||
|
||||
/*!
|
||||
* Import an ONNX Model file into SGIR.
|
||||
* @param model_fname file name pointing to the onnx model protobuf.
|
||||
* @return MLIR::module generated for the ONNX model.
|
||||
*/
|
||||
mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname);
|
||||
} // namespace onnf
|
|
@ -1,11 +1,14 @@
|
|||
add_library(
|
||||
compiler
|
||||
ir/knl/knl_ops.cpp
|
||||
ir/knl/knl_ops.hpp)
|
||||
ir/knl/knl_ops.hpp
|
||||
dialect/onnx/onnx_ops.cpp
|
||||
dialect/onnx/onnx_ops.hpp)
|
||||
|
||||
# Include root src directory.
|
||||
target_include_directories(compiler PRIVATE ../..)
|
||||
target_include_directories(compiler PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
||||
target_include_directories(compiler PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_include_directories(compiler PRIVATE ${CMAKE_BINARY_DIR})
|
||||
|
||||
find_package(Boost 1.54.0
|
||||
COMPONENTS graph
|
||||
|
@ -26,3 +29,9 @@ onnf_tablegen(knl.hpp.inc -gen-op-decls)
|
|||
onnf_tablegen(knl.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(gen_kir)
|
||||
add_dependencies(compiler gen_kir)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls)
|
||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(gen_onnx)
|
||||
add_dependencies(compiler gen_onnx)
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
//===- ONNXOps.td -- ONNX operation definitions ---------*- tablegen -*----===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines MLIR ONNX operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef ONNX_OPS
|
||||
#else
|
||||
#define ONNX_OPS
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def ONNX_Dialect : Dialect {
|
||||
let name = "onnx";
|
||||
let cppNamespace = "";
|
||||
}
|
||||
|
||||
// Base class for ONNX dialect operations. This operation inherits from the base
|
||||
// `Op` class in OpBase.td, and provides:
|
||||
// * The parent dialect of the operation.
|
||||
// * The mnemonic for the operation, or the name without the dialect prefix.
|
||||
// * A list of traits for the operation.
|
||||
class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<ONNX_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNX Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// We define an ONNX operation for adding two tensors elementwise.
|
||||
def ONNXAddOp: ONNX_Op<"add", [NoSideEffect]> {
|
||||
let summary = "ONNX add operation";
|
||||
let description = [{
|
||||
|
||||
The "onnx.add" adds two tensors element-wise.
|
||||
|
||||
}];
|
||||
|
||||
// TODO: AnyTensor might be too wide for ONNX and may need to be constrained
|
||||
// to fewer valid types.
|
||||
// In the ONNX spec:
|
||||
// T : tensor(uint32), tensor(uint64),
|
||||
// tensor(int32), tensor(int64),
|
||||
// tensor(float16), tensor(float), tensor(double)
|
||||
//
|
||||
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in);
|
||||
let results = (outs AnyTensor);
|
||||
|
||||
// Build an ONNX Add operation using two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
||||
buildONNXAddOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
}
|
||||
|
||||
#endif // ONNX_OPS
|
|
@ -0,0 +1,54 @@
|
|||
//===- onnx_ops.cpp - MLIR ONNX Operations --------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines ONNX operations in the MLIR operation set.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "onnx_ops.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXOpsDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Dialect creation, the instance will be owned by the context. This is the
|
||||
/// point of registration of custom types and operations for the dialect.
|
||||
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
|
||||
: mlir::Dialect(getDialectNamespace(), ctx) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "src/compiler/onnx.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNX Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state,
|
||||
mlir::Value* lhs, mlir::Value* rhs) {
|
||||
state.addTypes(UnrankedTensorType::get(builder->getF32Type()));
|
||||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/compiler/onnx.cpp.inc"
|
|
@ -0,0 +1,39 @@
|
|||
//===- onnx_ops.hpp - MLIR ONNX Operations --------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines ONNX operations in the MLIR operation set.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_ONNX_ONNXOPS_H
|
||||
#define MLIR_DIALECT_ONNX_ONNXOPS_H
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ONNXOpsDialect : public Dialect {
|
||||
public:
|
||||
ONNXOpsDialect(MLIRContext* context);
|
||||
|
||||
/// Provide a utility accessor to the dialect namespace. This is used by
|
||||
/// several utilities for casting between dialects.
|
||||
static StringRef getDialectNamespace() { return "onnx"; }
|
||||
};
|
||||
|
||||
/// Include the auto-generated header file containing the declarations of the
|
||||
/// ONNX operations.
|
||||
#define GET_OP_CLASSES
|
||||
#include "src/compiler/onnx.hpp.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
namespace onnf {}
|
||||
|
||||
#endif // MLIR_DIALECT_ONNX_ONNXOPS_H
|
|
@ -15,7 +15,7 @@ KnlOpsDialect::KnlOpsDialect(MLIRContext* context)
|
|||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "knl.cpp.inc"
|
||||
#include "src/compiler/knl.cpp.inc"
|
||||
>();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -13,7 +13,7 @@ class KnlOpsDialect : public Dialect {
|
|||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "knl.hpp.inc"
|
||||
#include "src/compiler/knl.hpp.inc"
|
||||
} // namespace mlir
|
||||
|
||||
namespace onnf {}
|
||||
|
|
|
@ -20,7 +20,8 @@
|
|||
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "src/builder/sgir.hpp"
|
||||
#include "src/builder/frontend_dialect_transformer.hpp"
|
||||
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||
|
||||
#include "mlir/IR/Module.h"
|
||||
|
||||
|
@ -45,8 +46,10 @@ int main(int ac, char* av[]) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||
|
||||
string model_filename = vm["onnx-model"].as<string>();
|
||||
auto module = SGIRImportModelFile(model_filename);
|
||||
auto module = ImportFrontendModelFile(model_filename);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue