#ifndef BUILDER_BUILDERIMPL_ #define BUILDER_BUILDERIMPL_ #include "Attribute.h" #include "AttributeImpl.h" #include "Builder.h" #include "OpImpl.h" #include "llvm/Support/Casting.h" // #include "llvm/Support/InitLLVM.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" namespace builder { class Builder::Impl { public: Impl() : builder_(&context_) { // llvm::InitLLVM y(argc, argv); // llvm::InitializeNativeTarget(); // llvm::InitializeNativeTargetAsmPrinter(); // llvm::InitializeNativeTargetAsmParser(); // mlir::initializeLLVMPasses(); // Register any command line options. // registerAsmPrinterCLOptions(); // registerMLIRContextCLOptions(); // registerPassManagerCLOptions(); // registerDefaultTimingManagerCLOptions(); // DebugCounter::registerCLOptions(); mlir::registerAllPasses(); mlir::mhlo::registerAllMhloPasses(); // mlir::lmhlo::registerAllLmhloPasses(); // mlir::disc_ral::registerAllDiscRalPasses(); mlir::DialectRegistry registry; // mlir::registerAllToLLVMIRTranslations(registry); mlir::registerAllDialects(registry); registry.insert(); // registry.insert(); // registry.insert(); // registry.insert(); // registry.insert(); context_.appendDialectRegistry(registry); context_.loadAllAvailableDialects(); module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_)); llvm::SmallVector arg_types; // Create the main function. mlir::FunctionType funcType = builder_.getFunctionType(arg_types, {}); main_func_ = mlir::FuncOp::create(builder_.getUnknownLoc(), "main", funcType, /* attrs = */ {}); entry_block_ = main_func_.addEntryBlock(); builder_.setInsertionPointToStart(entry_block_); module_.push_back(main_func_); } mlir::Location GetLoc() { return builder_.getUnknownLoc(); } mlir::OpBuilder GetBuilder() { return builder_; } mlir::MLIRContext* GetContext() { return &context_; } void DumpModule() { module_.dump(); } builder::Op CreateInput(const builder::Type& type) { mlir::BlockArgument arg = entry_block_->addArgument(type.GetImpl()->GetMlirType(&context_)); builder::Op op; op.GetImpl()->SetValue(arg); return op; } void SetOutput(const std::vector& outputs) { llvm::SmallVector arg_types; int arg_num = entry_block_->getNumArguments(); for (int i = 0; i < arg_num; ++i) { arg_types.push_back(entry_block_->getArgument(i).getType()); } llvm::SmallVector ret_types; llvm::SmallVector ret_vals; for (auto& out : outputs) { mlir::Value v = out.GetImpl()->GetResult(); ret_types.push_back(v.getType()); ret_vals.push_back(v); } // return all output tensors. builder_.create(builder_.getUnknownLoc(), ret_vals); // Update main function input/output type mlir::FunctionType funcType = builder_.getFunctionType(arg_types, ret_types); main_func_.setType(funcType); } private: mlir::MLIRContext context_; mlir::ModuleOp module_; mlir::OpBuilder builder_; mlir::FuncOp main_func_; mlir::Block* entry_block_; }; } // namespace builder #endif