ONNX Dialect creation and registration with MLIR (#358)
* Create and register ONNX Dialect with one Add operation. * Fix file formatting. * Change name from ONNX to SGIR. * Use ONNX dialect. Change SGIR to frontend placeholder dialect. * Add comments. * Type clean-up.
This commit is contained in:
parent
b5a35c9138
commit
958a2fbffc
|
@ -1,10 +1,5 @@
|
||||||
|
|
||||||
add_executable(onnf main.cpp)
|
add_executable(onnf main.cpp)
|
||||||
|
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||||
target_include_directories(onnf PRIVATE ..)
|
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
|
target_link_libraries(onnf builder compiler ${Boost_LIBRARIES})
|
||||||
target_link_libraries(onnf
|
|
||||||
builder
|
|
||||||
${Boost_LIBRARIES}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
add_library(builder
|
add_library(builder
|
||||||
sgir.cpp
|
frontend_dialect_transformer.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(builder onnx ${MLIRLIBS} curses)
|
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
|
||||||
|
target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
|
target_link_libraries(builder compiler onnx ${MLIRLIBS} curses)
|
||||||
target_include_directories(builder
|
target_include_directories(builder
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_SOURCE_DIR}/third_party/onnx
|
${CMAKE_SOURCE_DIR}/third_party/onnx
|
||||||
|
|
|
@ -1,9 +1,17 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===- frontend_dialect_transformer.cpp - MLIR Operations -----------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
//
|
//
|
||||||
|
// This file transforms the input to available MLIR dialects that can represent
|
||||||
|
// the operations of the model. Models use the ONNX dialect and any other
|
||||||
|
// extension dialects that comprise the the operations not supported or covered
|
||||||
|
// by the ONNX specification.
|
||||||
|
//
|
||||||
|
// A `frontend` placeholder dialect is used to encode operations that are not
|
||||||
|
// covered by any existing dialects.
|
||||||
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
@ -26,7 +34,8 @@
|
||||||
#include "llvm/ADT/ScopedHashTable.h"
|
#include "llvm/ADT/ScopedHashTable.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
#include "sgir.hpp"
|
#include "frontend_dialect_transformer.hpp"
|
||||||
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -84,14 +93,14 @@ struct OnnxOnnfSymbolMapping {
|
||||||
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
|
std::map<std::string, mlir::Value*> onnx_name2onnf_tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SGIRGenImpl {
|
class FrontendGenImpl {
|
||||||
public:
|
public:
|
||||||
SGIRGenImpl(mlir::MLIRContext& context)
|
FrontendGenImpl(mlir::MLIRContext& context)
|
||||||
: context_(context), builder_(&context) {
|
: context_(context), builder_(&context) {
|
||||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::ModuleOp ImportModel(onnx::ModelProto model) {
|
mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) {
|
||||||
ImportGraph(model.graph());
|
ImportGraph(model.graph());
|
||||||
return module_;
|
return module_;
|
||||||
}
|
}
|
||||||
|
@ -101,7 +110,7 @@ class SGIRGenImpl {
|
||||||
mlir::ModuleOp module_;
|
mlir::ModuleOp module_;
|
||||||
mlir::OpBuilder builder_;
|
mlir::OpBuilder builder_;
|
||||||
// mapping between string name and symbol
|
// mapping between string name and symbol
|
||||||
OnnxOnnfSymbolMapping sgir_symbols_;
|
OnnxOnnfSymbolMapping frontend_symbols_;
|
||||||
|
|
||||||
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
||||||
|
|
||||||
|
@ -127,17 +136,17 @@ class SGIRGenImpl {
|
||||||
dims.push_back(-1);
|
dims.push_back(-1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!sgir_symbols_.ContainKey(input_tensor_legalized_name)) {
|
if (!frontend_symbols_.ContainKey(input_tensor_legalized_name)) {
|
||||||
mlir::Type elementType =
|
mlir::Type elementType =
|
||||||
TypeConvert(input.type().tensor_type().elem_type());
|
TypeConvert(input.type().tensor_type().elem_type());
|
||||||
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
|
llvm::ArrayRef<int64_t> llvmdimsAR(dims.data(), dims.size());
|
||||||
auto dataType = mlir::RankedTensorType::get(llvmdimsAR, elementType);
|
auto dataType = mlir::RankedTensorType::get(llvmdimsAR, elementType);
|
||||||
mlir::OperationState result(
|
mlir::OperationState result(
|
||||||
UnknownLoc(), "sgir.input " + input_tensor_legalized_name);
|
UnknownLoc(), "frontend.input " + input_tensor_legalized_name);
|
||||||
result.addTypes(dataType);
|
result.addTypes(dataType);
|
||||||
auto op = builder_.createOperation(result);
|
auto op = builder_.createOperation(result);
|
||||||
auto value = op->getResult(0);
|
auto value = op->getResult(0);
|
||||||
sgir_symbols_.AddMapping(input_tensor_legalized_name, value);
|
frontend_symbols_.AddMapping(input_tensor_legalized_name, value);
|
||||||
} else {
|
} else {
|
||||||
// TODO Should not happen
|
// TODO Should not happen
|
||||||
}
|
}
|
||||||
|
@ -146,11 +155,23 @@ class SGIRGenImpl {
|
||||||
void ImportNode(onnx::NodeProto node) {
|
void ImportNode(onnx::NodeProto node) {
|
||||||
std::vector<mlir::Value*> inputs;
|
std::vector<mlir::Value*> inputs;
|
||||||
for (auto item : node.input()) {
|
for (auto item : node.input()) {
|
||||||
if (sgir_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(sgir_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mlir::OperationState result(UnknownLoc(), "SGIR." + node.op_type());
|
|
||||||
|
// Handle ONNX Add Operation by using its representation in the
|
||||||
|
// ONNX Dialect.
|
||||||
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
if (OpName == "Add") {
|
||||||
|
auto op =
|
||||||
|
builder_.create<mlir::ONNXAddOp>(UnknownLoc(), inputs[0], inputs[1]);
|
||||||
|
frontend_symbols_.AddMapping(legalize_name(node.output()[0]), op.getResult());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Old way of doing things.
|
||||||
|
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
|
||||||
for (auto item : node.output()) {
|
for (auto item : node.output()) {
|
||||||
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
}
|
}
|
||||||
|
@ -158,17 +179,17 @@ class SGIRGenImpl {
|
||||||
auto op = builder_.createOperation(result);
|
auto op = builder_.createOperation(result);
|
||||||
for (int i = 0; i < node.output().size(); i++) {
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
auto r = op->getResult(i);
|
auto r = op->getResult(i);
|
||||||
sgir_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO more info from node: attributes
|
// TODO more info from node: attributes
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportOutputTensor(onnx::ValueInfoProto& output) {
|
void ImportOutputTensor(onnx::ValueInfoProto& output) {
|
||||||
if (sgir_symbols_.ContainKey(legalize_name(output.name()))) {
|
if (frontend_symbols_.ContainKey(legalize_name(output.name()))) {
|
||||||
mlir::OperationState result(UnknownLoc(), "sgir.output " + output.name());
|
mlir::OperationState result(UnknownLoc(), "frontend.output " + output.name());
|
||||||
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
result.addOperands(sgir_symbols_.GetTensorByOnnxName(output.name()));
|
result.addOperands(frontend_symbols_.GetTensorByOnnxName(output.name()));
|
||||||
builder_.createOperation(result);
|
builder_.createOperation(result);
|
||||||
} else {
|
} else {
|
||||||
// TODO: Why not in the symbol table? something is wrong
|
// TODO: Why not in the symbol table? something is wrong
|
||||||
|
@ -209,27 +230,28 @@ class SGIRGenImpl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // SGIRGenImpl class
|
}; // FrontendGenImpl class
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
|
|
||||||
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model) {
|
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
SGIRGenImpl mySGIRGen(context);
|
FrontendGenImpl myONNXGen(context);
|
||||||
auto module = mySGIRGen.ImportModel(model);
|
auto module = myONNXGen.ImportONNXModel(model);
|
||||||
module.dump();
|
module.dump();
|
||||||
|
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname) {
|
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname) {
|
||||||
onnx::ModelProto model;
|
onnx::ModelProto model;
|
||||||
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
||||||
|
|
||||||
auto parse_success = model.ParseFromIstream(&input);
|
auto parse_success = model.ParseFromIstream(&input);
|
||||||
return SGIRImportModel(model);
|
|
||||||
|
return ImportFrontendModel(model);
|
||||||
}
|
}
|
||||||
} // namespace onnf
|
} // namespace onnf
|
|
@ -0,0 +1,49 @@
|
||||||
|
//===- frontend_dialect_transformer.hpp - MLIR Operations -----------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnx/onnx_pb.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class MLIRContext;
|
||||||
|
class OwningModuleRef;
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Import a model into one of ONNF's frontend models.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace onnf {
|
||||||
|
/*!
|
||||||
|
* Import an ONNX model into ONNF's ONNX Dialect.
|
||||||
|
* @param model onnx model.
|
||||||
|
* @return MLIR::module generated for the ONNX model.
|
||||||
|
*/
|
||||||
|
mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Import an ONNX model file into ONNF's ONNX Dialect.
|
||||||
|
* @param model_fname file name pointing to the onnx model protobuf.
|
||||||
|
* @return MLIR::module generated for the ONNX model.
|
||||||
|
*/
|
||||||
|
mlir::OwningModuleRef ImportFrontendModelFile(std::string model_fname);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* TODO: Import models into other extension dialects that cover the
|
||||||
|
* operations specific to other frameworks such as Tensorflow or Pytorch.
|
||||||
|
*/
|
||||||
|
} // namespace onnf
|
|
@ -1,40 +0,0 @@
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Copyright 2019 The IBM Research Authors.
|
|
||||||
//
|
|
||||||
// =============================================================================
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <functional>
|
|
||||||
#include <map>
|
|
||||||
#include <memory>
|
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "onnx/onnx_pb.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
class MLIRContext;
|
|
||||||
class OwningModuleRef;
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
namespace onnf {
|
|
||||||
/*!
|
|
||||||
* Import an ONNX Model into SGIR.
|
|
||||||
* @param model onnx model.
|
|
||||||
* @return MLIR::module generated for the ONNX model.
|
|
||||||
*/
|
|
||||||
mlir::OwningModuleRef SGIRImportModel(onnx::ModelProto model);
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Import an ONNX Model file into SGIR.
|
|
||||||
* @param model_fname file name pointing to the onnx model protobuf.
|
|
||||||
* @return MLIR::module generated for the ONNX model.
|
|
||||||
*/
|
|
||||||
mlir::OwningModuleRef SGIRImportModelFile(std::string model_fname);
|
|
||||||
} // namespace onnf
|
|
|
@ -1,11 +1,14 @@
|
||||||
add_library(
|
add_library(
|
||||||
compiler
|
compiler
|
||||||
ir/knl/knl_ops.cpp
|
ir/knl/knl_ops.cpp
|
||||||
ir/knl/knl_ops.hpp)
|
ir/knl/knl_ops.hpp
|
||||||
|
dialect/onnx/onnx_ops.cpp
|
||||||
|
dialect/onnx/onnx_ops.hpp)
|
||||||
|
|
||||||
# Include root src directory.
|
# Include root src directory.
|
||||||
target_include_directories(compiler PRIVATE ../..)
|
target_include_directories(compiler PRIVATE ../..)
|
||||||
target_include_directories(compiler PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
target_include_directories(compiler PRIVATE ${CMAKE_SOURCE_DIR})
|
||||||
|
target_include_directories(compiler PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
|
|
||||||
find_package(Boost 1.54.0
|
find_package(Boost 1.54.0
|
||||||
COMPONENTS graph
|
COMPONENTS graph
|
||||||
|
@ -26,3 +29,9 @@ onnf_tablegen(knl.hpp.inc -gen-op-decls)
|
||||||
onnf_tablegen(knl.cpp.inc -gen-op-defs)
|
onnf_tablegen(knl.cpp.inc -gen-op-defs)
|
||||||
add_public_tablegen_target(gen_kir)
|
add_public_tablegen_target(gen_kir)
|
||||||
add_dependencies(compiler gen_kir)
|
add_dependencies(compiler gen_kir)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||||
|
onnf_tablegen(onnx.hpp.inc -gen-op-decls)
|
||||||
|
onnf_tablegen(onnx.cpp.inc -gen-op-defs)
|
||||||
|
add_public_tablegen_target(gen_onnx)
|
||||||
|
add_dependencies(compiler gen_onnx)
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
//===- ONNXOps.td -- ONNX operation definitions ---------*- tablegen -*----===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Defines MLIR ONNX operations.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifdef ONNX_OPS
|
||||||
|
#else
|
||||||
|
#define ONNX_OPS
|
||||||
|
|
||||||
|
#ifdef OP_BASE
|
||||||
|
#else
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def ONNX_Dialect : Dialect {
|
||||||
|
let name = "onnx";
|
||||||
|
let cppNamespace = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base class for ONNX dialect operations. This operation inherits from the base
|
||||||
|
// `Op` class in OpBase.td, and provides:
|
||||||
|
// * The parent dialect of the operation.
|
||||||
|
// * The mnemonic for the operation, or the name without the dialect prefix.
|
||||||
|
// * A list of traits for the operation.
|
||||||
|
class ONNX_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
Op<ONNX_Dialect, mnemonic, traits>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNX Operations
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// We define an ONNX operation for adding two tensors elementwise.
|
||||||
|
def ONNXAddOp: ONNX_Op<"add", [NoSideEffect]> {
|
||||||
|
let summary = "ONNX add operation";
|
||||||
|
let description = [{
|
||||||
|
|
||||||
|
The "onnx.add" adds two tensors element-wise.
|
||||||
|
|
||||||
|
}];
|
||||||
|
|
||||||
|
// TODO: AnyTensor might be too wide for ONNX and may need to be constrained
|
||||||
|
// to fewer valid types.
|
||||||
|
// In the ONNX spec:
|
||||||
|
// T : tensor(uint32), tensor(uint64),
|
||||||
|
// tensor(int32), tensor(int64),
|
||||||
|
// tensor(float16), tensor(float), tensor(double)
|
||||||
|
//
|
||||||
|
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in);
|
||||||
|
let results = (outs AnyTensor);
|
||||||
|
|
||||||
|
// Build an ONNX Add operation using two input operands.
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
||||||
|
buildONNXAddOp(b, result, lhs, rhs);
|
||||||
|
}]
|
||||||
|
>];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // ONNX_OPS
|
|
@ -0,0 +1,54 @@
|
||||||
|
//===- onnx_ops.cpp - MLIR ONNX Operations --------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file defines ONNX operations in the MLIR operation set.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Function.h"
|
||||||
|
#include "mlir/IR/IntegerSet.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "onnx_ops.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNXOpsDialect
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Dialect creation, the instance will be owned by the context. This is the
|
||||||
|
/// point of registration of custom types and operations for the dialect.
|
||||||
|
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
|
||||||
|
: mlir::Dialect(getDialectNamespace(), ctx) {
|
||||||
|
addOperations<
|
||||||
|
#define GET_OP_LIST
|
||||||
|
#include "src/compiler/onnx.cpp.inc"
|
||||||
|
>();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ONNX Operations
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static void buildONNXAddOp(mlir::Builder* builder, mlir::OperationState& state,
|
||||||
|
mlir::Value* lhs, mlir::Value* rhs) {
|
||||||
|
state.addTypes(UnrankedTensorType::get(builder->getF32Type()));
|
||||||
|
state.addOperands({lhs, rhs});
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TableGen'd op method definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "src/compiler/onnx.cpp.inc"
|
|
@ -0,0 +1,39 @@
|
||||||
|
//===- onnx_ops.hpp - MLIR ONNX Operations --------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file defines ONNX operations in the MLIR operation set.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_ONNX_ONNXOPS_H
|
||||||
|
#define MLIR_DIALECT_ONNX_ONNXOPS_H
|
||||||
|
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class ONNXOpsDialect : public Dialect {
|
||||||
|
public:
|
||||||
|
ONNXOpsDialect(MLIRContext* context);
|
||||||
|
|
||||||
|
/// Provide a utility accessor to the dialect namespace. This is used by
|
||||||
|
/// several utilities for casting between dialects.
|
||||||
|
static StringRef getDialectNamespace() { return "onnx"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Include the auto-generated header file containing the declarations of the
|
||||||
|
/// ONNX operations.
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "src/compiler/onnx.hpp.inc"
|
||||||
|
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
namespace onnf {}
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_ONNX_ONNXOPS_H
|
|
@ -15,7 +15,7 @@ KnlOpsDialect::KnlOpsDialect(MLIRContext* context)
|
||||||
: Dialect(getDialectNamespace(), context) {
|
: Dialect(getDialectNamespace(), context) {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "knl.cpp.inc"
|
#include "src/compiler/knl.cpp.inc"
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -13,7 +13,7 @@ class KnlOpsDialect : public Dialect {
|
||||||
};
|
};
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "knl.hpp.inc"
|
#include "src/compiler/knl.hpp.inc"
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
namespace onnf {}
|
namespace onnf {}
|
||||||
|
|
|
@ -20,7 +20,8 @@
|
||||||
|
|
||||||
#include <boost/program_options.hpp>
|
#include <boost/program_options.hpp>
|
||||||
|
|
||||||
#include "src/builder/sgir.hpp"
|
#include "src/builder/frontend_dialect_transformer.hpp"
|
||||||
|
#include "src/compiler/dialect/onnx/onnx_ops.hpp"
|
||||||
|
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
|
|
||||||
|
@ -45,8 +46,10 @@ int main(int ac, char* av[]) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||||
|
|
||||||
string model_filename = vm["onnx-model"].as<string>();
|
string model_filename = vm["onnx-model"].as<string>();
|
||||||
auto module = SGIRImportModelFile(model_filename);
|
auto module = ImportFrontendModelFile(model_filename);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue