add integer attribute convert
This commit is contained in:
		
							parent
							
								
									9d5166684c
								
							
						
					
					
						commit
						95fc37ffa8
					
				
							
								
								
									
										2
									
								
								BUILD
								
								
								
								
							
							
						
						
									
										2
									
								
								BUILD
								
								
								
								
							|  | @ -129,6 +129,8 @@ cc_binary( | |||
|         "tools/mlir-tblgen-builder/*.cpp", | ||||
|         "tools/mlir-tblgen-builder/TableGen/*.h", | ||||
|         "tools/mlir-tblgen-builder/TableGen/*.cpp", | ||||
|         "tools/mlir-tblgen-builder/Builder/*.h", | ||||
|         "tools/mlir-tblgen-builder/Builder/*.cpp", | ||||
|     ]), | ||||
|     deps = [ | ||||
|         "@llvm-project//mlir:MlirTableGenMain", | ||||
|  |  | |||
|  | @ -14,8 +14,25 @@ limitations under the License. | |||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "iostream" | ||||
| 
 | ||||
| #include "mlir/Support/MlirOptMain.h" | ||||
| #include "llvm/Support/CommandLine.h" | ||||
| #include "llvm/Support/DynamicLibrary.h" | ||||
| #include "llvm/Support/FileUtilities.h" | ||||
| #include "llvm/Support/InitLLVM.h" | ||||
| #include "llvm/Support/Regex.h" | ||||
| #include "llvm/Support/SourceMgr.h" | ||||
| #include "llvm/Support/StringSaver.h" | ||||
| #include "llvm/Support/ToolOutputFile.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" | ||||
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||||
| #include "mlir/ExecutionEngine/ExecutionEngine.h" | ||||
| #include "mlir/ExecutionEngine/JitRunner.h" | ||||
| #include "mlir/ExecutionEngine/OptUtils.h" | ||||
| #include "mlir/IR/AsmState.h" | ||||
| #include "mlir/IR/Attributes.h" | ||||
| #include "mlir/IR/BuiltinOps.h" | ||||
|  | @ -23,62 +40,31 @@ limitations under the License. | |||
| #include "mlir/IR/Dialect.h" | ||||
| #include "mlir/IR/Location.h" | ||||
| #include "mlir/IR/MLIRContext.h" | ||||
| #include "mlir/InitAllDialects.h" | ||||
| #include "mlir/InitAllPasses.h" | ||||
| #include "mlir/Parser.h" | ||||
| #include "mlir/Pass/Pass.h" | ||||
| #include "mlir/Pass/PassManager.h" | ||||
| #include "mlir/Support/DebugCounter.h" | ||||
| #include "mlir/Support/FileUtilities.h" | ||||
| #include "mlir/Support/MlirOptMain.h" | ||||
| #include "mlir/Support/Timing.h" | ||||
| #include "mlir/Support/ToolUtilities.h" | ||||
| #include "llvm/Support/CommandLine.h" | ||||
| #include "llvm/Support/FileUtilities.h" | ||||
| #include "llvm/Support/InitLLVM.h" | ||||
| #include "llvm/Support/Regex.h" | ||||
| #include "llvm/Support/SourceMgr.h" | ||||
| #include "llvm/Support/StringSaver.h" | ||||
| #include "llvm/Support/ToolOutputFile.h" | ||||
| #include "llvm/Support/DynamicLibrary.h" | ||||
| 
 | ||||
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||||
| #include "mlir/ExecutionEngine/JitRunner.h" | ||||
| #include "mlir/ExecutionEngine/OptUtils.h" | ||||
| #include "mlir/Target/LLVMIR/Dialect/All.h" | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| #include "mlir/InitAllDialects.h" | ||||
| #include "mlir/InitAllPasses.h" | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" | ||||
| 
 | ||||
| 
 | ||||
| #include "mlir/ExecutionEngine/ExecutionEngine.h" | ||||
| // #include "mlir/IR/BuiltinTypes.h"
 | ||||
| 
 | ||||
| 
 | ||||
| #include "llvm/Support/TargetSelect.h" | ||||
| #include "llvm/ADT/STLExtras.h" | ||||
| #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" | ||||
| #include "llvm/ExecutionEngine/Orc/Mangling.h" | ||||
| 
 | ||||
| 
 | ||||
| #include "llvm/ADT/STLExtras.h" | ||||
| #include "llvm/IR/IRBuilder.h" | ||||
| #include "llvm/IR/LLVMContext.h" | ||||
| #include "llvm/IR/LegacyPassNameParser.h" | ||||
| 
 | ||||
| #include "llvm/Support/TargetSelect.h" | ||||
| 
 | ||||
| using namespace mlir; | ||||
| using namespace llvm; | ||||
| 
 | ||||
| namespace utils{ | ||||
| namespace utils { | ||||
| template <typename T, int N> | ||||
| struct MemRefDescriptor { | ||||
|   T *allocated; | ||||
|  | @ -87,11 +73,9 @@ struct MemRefDescriptor { | |||
|   int64_t sizes[N]; | ||||
|   int64_t strides[N]; | ||||
| }; | ||||
| } | ||||
| 
 | ||||
| }  // namespace utils
 | ||||
| 
 | ||||
| int main(int argc, char **argv) { | ||||
| 
 | ||||
|   llvm::InitLLVM y(argc, argv); | ||||
|   llvm::InitializeNativeTarget(); | ||||
|   llvm::InitializeNativeTargetAsmPrinter(); | ||||
|  | @ -105,7 +89,6 @@ int main(int argc, char **argv) { | |||
|   // registerDefaultTimingManagerCLOptions();
 | ||||
|   DebugCounter::registerCLOptions(); | ||||
| 
 | ||||
| 
 | ||||
|   mlir::registerAllPasses(); | ||||
|   mlir::mhlo::registerAllMhloPasses(); | ||||
|   mlir::lmhlo::registerAllLmhloPasses(); | ||||
|  | @ -120,24 +103,22 @@ int main(int argc, char **argv) { | |||
|   registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>(); | ||||
|   registry.insert<mlir::disc_ral::RalDialect>(); | ||||
| 
 | ||||
|    | ||||
|   // failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
 | ||||
|   //                                 registry,
 | ||||
|   //                                 /*preloadDialectsInContext=*/false));
 | ||||
|   // return 0;
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|   std::string errorMessage; | ||||
| 
 | ||||
|   // auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage);
 | ||||
|   auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage); | ||||
|   std::cout<<"errorMessage:" <<errorMessage <<std::endl; | ||||
|   // auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir",
 | ||||
|   // &errorMessage);
 | ||||
|   auto file = | ||||
|       mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage); | ||||
|   std::cout << "errorMessage:" << errorMessage << std::endl; | ||||
| 
 | ||||
|   SourceMgr sourceMgr; | ||||
|   sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); | ||||
| 
 | ||||
| 
 | ||||
|   SmallVector<const llvm::PassInfo *, 4> passes; | ||||
|   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); | ||||
|   auto tmOrError = tmBuilderOrError->createTargetMachine(); | ||||
|  | @ -145,8 +126,6 @@ int main(int argc, char **argv) { | |||
|   auto transformer = mlir::makeLLVMPassesTransformer( | ||||
|       passes, 0, /*targetMachine=*/tmOrError->get(), 0); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|   MLIRContext context(registry); | ||||
|   context.loadAllAvailableDialects(); | ||||
|   OwningModuleRef module(parseSourceFile(sourceMgr, &context)); | ||||
|  | @ -157,8 +136,7 @@ int main(int argc, char **argv) { | |||
|     return failure(); | ||||
|   }; | ||||
| 
 | ||||
| 
 | ||||
| // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \ | ||||
|   // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \ | ||||
| // RUN: -hlo-legalize-to-lhlo -buffer-hoisting \ | ||||
| // RUN: -buffer-deallocation -canonicalize -cse \ | ||||
| // RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ | ||||
|  | @ -192,9 +170,7 @@ int main(int argc, char **argv) { | |||
|   pm.run(*module); | ||||
|   module->dump(); | ||||
| 
 | ||||
| 
 | ||||
|    | ||||
|   std::cout<<"DEBUG load module success"<<std::endl; | ||||
|   std::cout << "DEBUG load module success" << std::endl; | ||||
| 
 | ||||
|   llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default; | ||||
| 
 | ||||
|  | @ -206,12 +182,10 @@ int main(int argc, char **argv) { | |||
|   // Use absolute library path so that gdb can find the symbol table.
 | ||||
| 
 | ||||
|   std::list<std::string> sharedlib; | ||||
|   sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git"); | ||||
|   sharedlib.push_back( | ||||
|       "/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git"); | ||||
| 
 | ||||
|   transform( | ||||
|       sharedlib, | ||||
|       std::back_inserter(libPaths), | ||||
|       [](std::string libPath) { | ||||
|   transform(sharedlib, std::back_inserter(libPaths), [](std::string libPath) { | ||||
|     SmallString<256> absPath(libPath.begin(), libPath.end()); | ||||
|     cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); | ||||
|     return absPath; | ||||
|  | @ -226,7 +200,6 @@ int main(int argc, char **argv) { | |||
|   llvm::StringMap<void *> exportSymbols; | ||||
|   SmallVector<MlirRunnerDestroyFn> destroyFns; | ||||
| 
 | ||||
| 
 | ||||
|   // Handle libraries that do support mlir-runner init/destroy callbacks.
 | ||||
|   for (auto &libPath : libPaths) { | ||||
|     auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str()); | ||||
|  | @ -255,17 +228,12 @@ int main(int argc, char **argv) { | |||
|     return symbolMap; | ||||
|   }; | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|   auto expectedEngine = mlir::ExecutionEngine::create( | ||||
|       module.get(), nullptr, transformer, jitCodeGenOptLevel, | ||||
|       executionEngineLibs); | ||||
|   auto expectedEngine = | ||||
|       mlir::ExecutionEngine::create(module.get(), nullptr, transformer, | ||||
|                                     jitCodeGenOptLevel, executionEngineLibs); | ||||
|   // if (!expectedEngine)
 | ||||
|   //   return expectedEngine.takeError();
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|   auto engine = std::move(*expectedEngine); | ||||
|   engine->registerSymbols(runtimeSymbolMap); | ||||
| 
 | ||||
|  | @ -277,16 +245,15 @@ int main(int argc, char **argv) { | |||
|   // if (options.dumpObjectFile)
 | ||||
|   // engine->dumpToObjectFile("a.o");
 | ||||
| 
 | ||||
| 
 | ||||
|   float rawdata[6] = {0,1,2,3,4,5}; | ||||
|   float rawdata[6] = {0, 1, 2, 3, 4, 5}; | ||||
|   int64_t dims = 1; | ||||
|   utils::MemRefDescriptor<float,1> a{rawdata,rawdata,0,{6},{1}}; | ||||
|   utils::MemRefDescriptor<float,1> b{rawdata,rawdata,0,{6},{1}}; | ||||
|   utils::MemRefDescriptor<float,1> result_memref; | ||||
|   utils::MemRefDescriptor<float, 1> a{rawdata, rawdata, 0, {6}, {1}}; | ||||
|   utils::MemRefDescriptor<float, 1> b{rawdata, rawdata, 0, {6}, {1}}; | ||||
|   utils::MemRefDescriptor<float, 1> result_memref; | ||||
| 
 | ||||
|   struct memref_type{ | ||||
|   struct memref_type { | ||||
|     int64_t res_size = 6; | ||||
|     utils::MemRefDescriptor<float,1> *memref; | ||||
|     utils::MemRefDescriptor<float, 1> *memref; | ||||
|   } result; | ||||
|   result.memref = &result_memref; | ||||
| 
 | ||||
|  | @ -299,23 +266,23 @@ int main(int argc, char **argv) { | |||
|   } data; | ||||
| 
 | ||||
|   data.data1_size = &dims; | ||||
|   void * a_ptr = &a; | ||||
|   void *a_ptr = &a; | ||||
|   data.data1 = &a_ptr; | ||||
|   data.data2_size = &dims; | ||||
|   void * b_ptr = &b; | ||||
|   void *b_ptr = &b; | ||||
|   data.data2 = &b_ptr; | ||||
|   void * result_ptr = &result; | ||||
|   void *result_ptr = &result; | ||||
|   data.res = &result; | ||||
| 
 | ||||
|   void (*fptr)(void **) = *expectedFPtr; | ||||
|   (*fptr)((void **)&data); | ||||
| 
 | ||||
|   std::cout<<"result: "<<result.memref->allocated[0]<<std::endl; | ||||
|   std::cout<<"result: "<<result.memref->allocated[1]<<std::endl; | ||||
|   std::cout<<"result: "<<result.memref->allocated[2]<<std::endl; | ||||
|   std::cout<<"result: "<<result.memref->allocated[3]<<std::endl; | ||||
|   std::cout<<"result: "<<result.memref->allocated[4]<<std::endl; | ||||
|   std::cout<<"result: "<<result.memref->allocated[5]<<std::endl; | ||||
|   std::cout << "result: " << result.memref->allocated[0] << std::endl; | ||||
|   std::cout << "result: " << result.memref->allocated[1] << std::endl; | ||||
|   std::cout << "result: " << result.memref->allocated[2] << std::endl; | ||||
|   std::cout << "result: " << result.memref->allocated[3] << std::endl; | ||||
|   std::cout << "result: " << result.memref->allocated[4] << std::endl; | ||||
|   std::cout << "result: " << result.memref->allocated[5] << std::endl; | ||||
| 
 | ||||
|   // Run all dynamic library destroy callbacks to prepare for the shutdown.
 | ||||
|   llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); }); | ||||
|  |  | |||
|  | @ -1,10 +0,0 @@ | |||
| #ifndef BUILDER_ARRAY_ | ||||
| #define BUILDER_ARRAY_ | ||||
| 
 | ||||
| #include "iostream" | ||||
| 
 | ||||
| namespace builder { | ||||
| class Array {} | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -0,0 +1,22 @@ | |||
| #include "Attribute.h" | ||||
| 
 | ||||
| #include "AttributeImpl.h" | ||||
| #include "Shape.h" | ||||
| #include "Tensor.h" | ||||
| #include "TensorImpl.h" | ||||
| #include "llvm/Support/Casting.h" | ||||
| #include "memory.h" | ||||
| // #include "mlir/Dialect/StandardOps/Ops.h"
 | ||||
| // #include "mlir/IR/Attributes.h"
 | ||||
| // #include "mlir/IR/Operation.h"
 | ||||
| // #include "mlir/IR/StandardTypes.h"
 | ||||
| // #include "mlir/IR/Types.h"
 | ||||
| // #include "mlir/IR/Value.h"
 | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {} | ||||
| Integer::Integer(int64_t value) | ||||
|     : impl_(std::make_shared<Integer::Impl>(value)) {} | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
|  | @ -0,0 +1,50 @@ | |||
| #ifndef BUILDER_ATTRIBUTE_H_ | ||||
| #define BUILDER_ATTRIBUTE_H_ | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <memory> | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| class Integer { | ||||
|  public: | ||||
|   Integer(int value); | ||||
|   Integer(int64_t value); | ||||
|   class Impl; | ||||
| 
 | ||||
|  private: | ||||
|   std::shared_ptr<Impl> impl_; | ||||
| }; | ||||
| 
 | ||||
| class Array { | ||||
|  public: | ||||
|   class Impl; | ||||
|   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||
| 
 | ||||
|  private: | ||||
|   std::shared_ptr<Impl> impl_; | ||||
| }; | ||||
| 
 | ||||
| class Type { | ||||
|  public: | ||||
|   class Impl; | ||||
|   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||
| 
 | ||||
|  private: | ||||
|   std::shared_ptr<Impl> impl_; | ||||
| }; | ||||
| 
 | ||||
| // template <typename T>
 | ||||
| // class DenseIntElementsAttr {
 | ||||
| //  public:
 | ||||
| //   explicit DenseIntElementsAttr(Tensor);
 | ||||
| //   class Impl;
 | ||||
| //   std::shared_ptr<Impl> GetImpl() { return impl_; }
 | ||||
| 
 | ||||
| //  private:
 | ||||
| //   std::shared_ptr<Impl> impl_;
 | ||||
| // };
 | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -0,0 +1,58 @@ | |||
| #ifndef BUILDER_TYPEIMPL_ | ||||
| #define BUILDER_TYPEIMPL_ | ||||
| 
 | ||||
| #include "Builder.h" | ||||
| #include "llvm/Support/Casting.h" | ||||
| #include "mlir/IR/Attributes.h" | ||||
| #include "mlir/IR/Builders.h" | ||||
| #include "mlir/IR/BuiltinTypes.h" | ||||
| #include "mlir/IR/MLIRContext.h" | ||||
| #include "mlir/IR/Operation.h" | ||||
| #include "mlir/IR/Types.h" | ||||
| #include "mlir/IR/Value.h" | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| class Integer::Impl { | ||||
|  public: | ||||
|   Impl(){}; | ||||
| 
 | ||||
|   Impl(int value) { | ||||
|     GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr { | ||||
|       auto int32_type = mlir::IntegerType::get(context, 32); | ||||
|       return mlir::IntegerAttr::get(int32_type, value); | ||||
|     }; | ||||
|   }; | ||||
|   Impl(int64_t value) | ||||
|       : GetAttr([=](mlir::MLIRContext *context) -> mlir::IntegerAttr { | ||||
|           auto int32_type = mlir::IntegerType::get(context, 64); | ||||
|           return mlir::IntegerAttr::get(int32_type, value); | ||||
|         }){}; | ||||
|   std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr; | ||||
| 
 | ||||
|  private: | ||||
| }; | ||||
| 
 | ||||
| class Array::Impl { | ||||
|  public: | ||||
|   Impl() = default; | ||||
| 
 | ||||
|  private: | ||||
| }; | ||||
| class Type::Impl { | ||||
|  public: | ||||
|   Impl() = default; | ||||
| 
 | ||||
|  private: | ||||
| }; | ||||
| 
 | ||||
| // class DenseIntElementsAttr::Impl {
 | ||||
| //  public:
 | ||||
| //   Impl() = default;
 | ||||
| 
 | ||||
| //  private:
 | ||||
| // };
 | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -11,7 +11,7 @@ | |||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| Builder::Builder() : _impl(std::make_shared<Impl>()) {} | ||||
| Builder::Builder() : impl_(std::make_shared<Impl>()) {} | ||||
| void Builder::DumpModule() {} | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
|  | @ -11,10 +11,10 @@ class Builder { | |||
| 
 | ||||
|   Builder(); | ||||
|   void DumpModule(); | ||||
|   std::shared_ptr<Impl> GetImpl() { return _impl; } | ||||
|   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||
| 
 | ||||
|  private: | ||||
|   std::shared_ptr<Impl> _impl; | ||||
|   std::shared_ptr<Impl> impl_; | ||||
| }; | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
|  | @ -0,0 +1,17 @@ | |||
| #ifndef BUILDER_OP_ | ||||
| #define BUILDER_OP_ | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <memory> | ||||
| 
 | ||||
| namespace builder { | ||||
| class Op { | ||||
|   class Impl; | ||||
|   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||
| 
 | ||||
|  private: | ||||
|   std::shared_ptr<Impl> impl_; | ||||
| } | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -0,0 +1,28 @@ | |||
| #include "PrimitiveType.h" | ||||
| 
 | ||||
| #include "Shape.h" | ||||
| #include "Tensor.h" | ||||
| #include "llvm/Support/Casting.h" | ||||
| // #include "mlir/Dialect/StandardOps/Ops.h"
 | ||||
| // #include "mlir/IR/Attributes.h"
 | ||||
| // #include "mlir/IR/Operation.h"
 | ||||
| // #include "mlir/IR/StandardTypes.h"
 | ||||
| #include "mlir/IR/Types.h" | ||||
| // #include "mlir/IR/Value.h"
 | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| // constexpr mlir::Type GetMlirType(const PrimitiveType primitiveType) {
 | ||||
| // switch (primitiveType)
 | ||||
| // {
 | ||||
| // case PrimitiveType::BF16 :
 | ||||
| //   return 
 | ||||
| //   break;
 | ||||
| 
 | ||||
| // default:
 | ||||
| //   break;
 | ||||
| // }
 | ||||
| 
 | ||||
| // }
 | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
|  | @ -0,0 +1,29 @@ | |||
| #ifndef BUILDER_PRIMITIVE_TYPE_ | ||||
| #define BUILDER_PRIMITIVE_TYPE_ | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| enum class PrimitiveType { | ||||
|   PRIMITIVE_TYPE_INVALID = 0, | ||||
|   PRED = 1, | ||||
|   S8 = 2, | ||||
|   S16 = 3, | ||||
|   S32 = 4, | ||||
|   S64 = 5, | ||||
|   U8 = 6, | ||||
|   U16 = 7, | ||||
|   U32 = 8, | ||||
|   U64 = 9, | ||||
|   F16 = 10, | ||||
|   F32 = 11, | ||||
|   BF16 = 16, | ||||
|   F64 = 12, | ||||
|   C64 = 15, | ||||
|   C128 = 18, | ||||
|   TUPLE = 13, | ||||
|   OPAQUE_TYPE = 14, | ||||
|   TOKEN = 17 | ||||
| }; | ||||
| 
 | ||||
| } | ||||
| #endif | ||||
|  | @ -0,0 +1,17 @@ | |||
| #ifndef BUILDER_SHAPE_ | ||||
| #define BUILDER_SHAPE_ | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <memory> | ||||
| 
 | ||||
| namespace builder { | ||||
| class Shape { | ||||
|   class Impl; | ||||
|   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||
| 
 | ||||
|  private: | ||||
|   std::shared_ptr<Impl> impl_; | ||||
| }; | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -0,0 +1,24 @@ | |||
| #ifndef BUILDER_SHAPEIMPL_ | ||||
| #define BUILDER_SHAPEIMPL_ | ||||
| 
 | ||||
| #include "Builder.h" | ||||
| #include "llvm/Support/Casting.h" | ||||
| #include "mlir/IR/Attributes.h" | ||||
| #include "mlir/IR/Builders.h" | ||||
| #include "mlir/IR/MLIRContext.h" | ||||
| #include "mlir/IR/Operation.h" | ||||
| #include "mlir/IR/Types.h" | ||||
| #include "mlir/IR/Value.h" | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| class Shape::Impl { | ||||
|  public: | ||||
|   Impl() = default; | ||||
| 
 | ||||
|  private: | ||||
| }; | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -0,0 +1,22 @@ | |||
| #include "Tensor.h" | ||||
| 
 | ||||
| #include "Shape.h" | ||||
| #include "TensorImpl.h" | ||||
| #include "llvm/Support/Casting.h" | ||||
| // #include "mlir/Dialect/StandardOps/Ops.h"
 | ||||
| // #include "mlir/IR/Attributes.h"
 | ||||
| // #include "mlir/IR/Operation.h"
 | ||||
| // #include "mlir/IR/StandardTypes.h"
 | ||||
| // #include "mlir/IR/Types.h"
 | ||||
| // #include "mlir/IR/Value.h"
 | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| // Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
 | ||||
| Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType) | ||||
|     : impl_(std::make_shared<Impl>(shape, primitiveType)) {} | ||||
| 
 | ||||
| Shape Tensor::GetShape() { return impl_->GetShape(); } | ||||
| PrimitiveType Tensor::GetType() { return impl_->GetType(); } | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
|  | @ -0,0 +1,25 @@ | |||
| #ifndef BUILDER_TENSOR_ | ||||
| #define BUILDER_TENSOR_ | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <memory> | ||||
| 
 | ||||
| #include "PrimitiveType.h" | ||||
| #include "Shape.h" | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| class Tensor { | ||||
|  public: | ||||
|   Tensor(Shape& shape, PrimitiveType& primitiveType); | ||||
|   class Impl; | ||||
|   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||
|   Shape GetShape(); | ||||
|   PrimitiveType GetType(); | ||||
| 
 | ||||
|  private: | ||||
|   std::shared_ptr<Impl> impl_; | ||||
| }; | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -0,0 +1,32 @@ | |||
| #ifndef BUILDER_TENSORIMPL_ | ||||
| #define BUILDER_TENSORIMPL_ | ||||
| 
 | ||||
| #include "Builder.h" | ||||
| #include "PrimitiveType.h" | ||||
| #include "Shape.h" | ||||
| #include "llvm/Support/Casting.h" | ||||
| #include "mlir/IR/Attributes.h" | ||||
| #include "mlir/IR/Builders.h" | ||||
| #include "mlir/IR/MLIRContext.h" | ||||
| #include "mlir/IR/Operation.h" | ||||
| #include "mlir/IR/Types.h" | ||||
| #include "mlir/IR/Value.h" | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| class Tensor::Impl { | ||||
|  public: | ||||
|   Impl() = default; | ||||
|   Impl(Shape& shape, PrimitiveType& primitiveType) | ||||
|       : shape_(shape), primitiveType_(primitiveType) {} | ||||
|   Shape GetShape() { return shape_; } | ||||
|   PrimitiveType GetType() { return primitiveType_; } | ||||
| 
 | ||||
|  private: | ||||
|   Shape shape_; | ||||
|   PrimitiveType primitiveType_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -29,7 +29,7 @@ | |||
| #include "llvm/TableGen/Record.h" | ||||
| #include "llvm/TableGen/TableGenBackend.h" | ||||
| 
 | ||||
| #include "iostream" | ||||
| #include <iostream> | ||||
| 
 | ||||
| #define DEBUG_TYPE "mlir-tblgen-opdefgen" | ||||
| 
 | ||||
|  | @ -42,28 +42,117 @@ static const char *const generatedArgName = "odsArg"; | |||
| static const char *const odsBuilder = "odsBuilder"; | ||||
| static const char *const builderOpState = "odsState"; | ||||
| 
 | ||||
| static const std::map<std::string, std::string> typeMapMLIR = { | ||||
|     {"::mlir::StringAttr", "std::string"}, | ||||
|     {"::mlir::IntegerAttr", "int"}, | ||||
|     {"::mlir::DenseIntElementsAttr", "std::vector<int>"}, | ||||
|     {"::mlir::mhlo::ChannelHandle", "ChannelHandle"}, | ||||
|     {"::mlir::FloatAttr", "float"}, | ||||
|     {"::mlir::BoolAttr", "bool"}, | ||||
|     {"::mlir::ElementsAttr", "::builder::Array"}, | ||||
|     {"::mlir::DenseElementsAttr", "::builder::Tensor"}, | ||||
|     // current only support string array.
 | ||||
|     {"::mlir::ArrayAttr", "std::vector<std::string>"}, | ||||
|     {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"}, | ||||
|     {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"}, | ||||
|     {"::mlir::mhlo::GatherDimensionNumbers", | ||||
|      "::builder::GatherDimensionNumbers"}, | ||||
|     {"::mlir::mhlo::ScatterDimensionNumbers", | ||||
|      "::builder::ScatterDimensionNumbers"}, | ||||
| namespace { | ||||
| struct mlirTypeWrap { | ||||
|   std::string Name; | ||||
|   std::string (*ConvertToMlir)(std::string &, OpMethodBody &); | ||||
| }; | ||||
| 
 | ||||
| static const std::map<std::string, mlirTypeWrap> typeMapMLIR = { | ||||
|     {"::mlir::StringAttr", | ||||
|      {"std::string", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         body << "  mlir::StringAttr " << var << "_mlir = mlir::StringAttr::get(" | ||||
|              << var << ", ctx);\n"; | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::IntegerAttr", | ||||
|      {"builder::Integer", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         // body << "  mlir::IntegerAttr " << var << "_mlir = mlir::IntegerAttr::get("
 | ||||
|         //      << var << ", ctx);\n";
 | ||||
|         body << "  mlir::IntegerAttr " << var << "_mlir = " << var | ||||
|              << ".GetAttr(ctx);\n"; | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::DenseIntElementsAttr", | ||||
|      {"std::vector<int>", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         body << "  mlir::DenseIntElementsAttr " << var | ||||
|              << "_mlir = mlir::DenseIntElementsAttr::get(" | ||||
|              << "mlir::VectorType::get(" << var | ||||
|              << ".size(), opBuilder->getIntegerType(32))," << var << ");\n"; | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::mhlo::ChannelHandle", | ||||
|      {"ChannelHandle", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::FloatAttr", | ||||
|      {"float", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         body << "  mlir::FloatAttr " << var << "_mlir = mlir::FloatAttr::get(" | ||||
|              << var << ", ctx);\n"; | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::BoolAttr", | ||||
|      {"bool", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         body << "  mlir::BoolAttr " << var << "_mlir = mlir::BoolAttr::get(" | ||||
|              << var << ", ctx);\n"; | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::ElementsAttr", | ||||
|      {"::builder::Array", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::DenseElementsAttr", | ||||
|      {"::builder::Tensor", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     // current only support string array.
 | ||||
|     {"::mlir::ArrayAttr", | ||||
|      {"std::vector<std::string>", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::mhlo::ConvDimensionNumbers", | ||||
|      {"::builder::ConvDimensionNumbers", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::mhlo::DotDimensionNumbers", | ||||
|      {"::builder::DotDimensionNumbers", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::mhlo::GatherDimensionNumbers", | ||||
|      {"::builder::GatherDimensionNumbers", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
|     {"::mlir::mhlo::ScatterDimensionNumbers", | ||||
|      {"::builder::ScatterDimensionNumbers", | ||||
|       [](std::string &var, OpMethodBody &body) -> std::string { | ||||
|         return var + "_mlir"; | ||||
|       }}}, | ||||
| }; | ||||
| 
 | ||||
| // static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
 | ||||
| //     {"::mlir::StringAttr", "std::string"},
 | ||||
| //     {"::mlir::IntegerAttr", "int"},
 | ||||
| //     {"::mlir::DenseIntElementsAttr", "std::vector<int>"},
 | ||||
| //     {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
 | ||||
| //     {"::mlir::FloatAttr", "float"},
 | ||||
| //     {"::mlir::BoolAttr", "bool"},
 | ||||
| //     {"::mlir::ElementsAttr", "::builder::Array"},
 | ||||
| //     {"::mlir::DenseElementsAttr", "::builder::Tensor"},
 | ||||
| //     // current only support string array.
 | ||||
| //     {"::mlir::ArrayAttr", "std::vector<std::string>"},
 | ||||
| //     {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
 | ||||
| //     {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
 | ||||
| //     {"::mlir::mhlo::GatherDimensionNumbers",
 | ||||
| //      "::builder::GatherDimensionNumbers"},
 | ||||
| //     {"::mlir::mhlo::ScatterDimensionNumbers",
 | ||||
| //      "::builder::ScatterDimensionNumbers"},
 | ||||
| // };
 | ||||
| 
 | ||||
| StringRef typeConvertFromMLIR(StringRef type) { | ||||
|   auto re = typeMapMLIR.find(type.str()); | ||||
|   if (re != typeMapMLIR.end()) return StringRef(re->second); | ||||
|   if (re != typeMapMLIR.end()) return StringRef(re->second.Name); | ||||
|   return type; | ||||
| } | ||||
| 
 | ||||
|  | @ -75,6 +164,7 @@ StringRef getReturnType(const Attribute &att) { | |||
|   auto type = att.getStorageType(); | ||||
|   return typeConvertFromMLIR(type); | ||||
| } | ||||
| }  // namespace
 | ||||
| 
 | ||||
| // The logic to calculate the actual value range for a declared operand/result
 | ||||
| // of an op with variadic operands/results. Note that this logic is not for
 | ||||
|  | @ -312,14 +402,6 @@ static std::string getArgumentName(const Operator &op, int index) { | |||
|     return std::string(formatv("{0}_{1}", generatedArgName, index)); | ||||
| } | ||||
| 
 | ||||
| // Returns true if we can use unwrapped value for the given `attr` in builders.
 | ||||
| static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { | ||||
|   return getReturnType(attr) != getStorageType(attr) && | ||||
|          // We need to wrap the raw value into an attribute in the builder impl
 | ||||
|          // so we need to make sure that the attribute specifies how to do that.
 | ||||
|          !attr.getConstBuilderTemplate().empty(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Op emitter
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | @ -427,7 +509,7 @@ private: | |||
|                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); | ||||
| 
 | ||||
|   // Adds op arguments and regions into operation state for build() methods.
 | ||||
|   void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,llvm::SmallVector<OpMethodParameter, 4> paramList, | ||||
|   void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, | ||||
|                                               bool isRawValueAttr = false); | ||||
| 
 | ||||
|   // Generates canonicalizer declaration for the operation.
 | ||||
|  | @ -479,7 +561,9 @@ private: | |||
|   // Generate the type inference interface methods.
 | ||||
|   void genTypeInterfaceMethods(); | ||||
| 
 | ||||
| private: | ||||
|   Operator GetOp() { return op; } | ||||
| 
 | ||||
|  private: | ||||
|   // The TableGen record for this op.
 | ||||
|   // TODO: OpEmitter should not have a Record directly,
 | ||||
|   // it should rather go through the Operator for better abstraction.
 | ||||
|  | @ -1077,27 +1161,6 @@ void OpEmitter::genNamedSuccessorGetters() { | |||
|   } | ||||
| } | ||||
| 
 | ||||
| static bool canGenerateUnwrappedBuilder(Operator &op) { | ||||
|   // If this op does not have native attributes at all, return directly to avoid
 | ||||
|   // redefining builders.
 | ||||
|   if (op.getNumNativeAttributes() == 0) | ||||
|     return false; | ||||
| 
 | ||||
|   bool canGenerate = false; | ||||
|   // We are generating builders that take raw values for attributes. We need to
 | ||||
|   // make sure the native attributes have a meaningful "unwrapped" value type
 | ||||
|   // different from the wrapped mlir::Attribute type to avoid redefining
 | ||||
|   // builders. This checks for the op has at least one such native attribute.
 | ||||
|   for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { | ||||
|     NamedAttribute &namedAttr = op.getAttribute(i); | ||||
|     if (canUseUnwrappedRawValue(namedAttr.attr)) { | ||||
|       canGenerate = true; | ||||
|       break; | ||||
|     } | ||||
|   } | ||||
|   return canGenerate; | ||||
| } | ||||
| 
 | ||||
| static bool canInferType(Operator &op) { | ||||
|   return op.getTrait("::mlir::InferTypeOpInterface::Trait") && | ||||
|          op.getNumRegions() == 0; | ||||
|  | @ -1106,18 +1169,14 @@ static bool canInferType(Operator &op) { | |||
| void OpEmitter::genSeparateArgParamBuilder() { | ||||
|   SmallVector<AttrParamKind, 2> attrBuilderType; | ||||
|   attrBuilderType.push_back(AttrParamKind::WrappedAttr); | ||||
|   // if (canGenerateUnwrappedBuilder(op))
 | ||||
|   //   attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
 | ||||
| 
 | ||||
|   // Emit with separate builders with or without unwrapped attributes and/or
 | ||||
|   // inferring result type.
 | ||||
|   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, | ||||
|                   bool inferType) { | ||||
|     llvm::SmallVector<OpMethodParameter, 4> paramList; | ||||
|     llvm::SmallVector<OpMethodParameter, 4> paramList2; | ||||
|     llvm::SmallVector<std::string, 4> resultNames; | ||||
|     buildParamList(paramList, resultNames, paramKind, attrType); | ||||
|     buildParamList(paramList2, resultNames, paramKind, attrType); | ||||
| 
 | ||||
|     auto *m = opClass.addMethodAndPrune( | ||||
|         "::builder::Op", "build", OpMethod::MP_Static, std::move(paramList)); | ||||
|  | @ -1126,10 +1185,10 @@ void OpEmitter::genSeparateArgParamBuilder() { | |||
|       return; | ||||
|     auto &body = m->body(); | ||||
|     genCodeForAddingArgAndRegionForBuilder( | ||||
|         body, paramList2, attrType == AttrParamKind::UnwrappedValue); | ||||
|         body, attrType == AttrParamKind::UnwrappedValue); | ||||
| 
 | ||||
|     // Push all result types to the operation state
 | ||||
| //"BBBBBBBBBBBB"
 | ||||
| 
 | ||||
|   //   if (inferType) {
 | ||||
|   //     // Generate builder that infers type too.
 | ||||
|   //     // TODO: Subsume this with general checking if type can be
 | ||||
|  | @ -1306,29 +1365,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList, | |||
|   int numAttrs = 0; | ||||
| 
 | ||||
|   int defaultValuedAttrStartIndex = op.getNumArgs(); | ||||
|   if (attrParamKind == AttrParamKind::UnwrappedValue) { | ||||
|     // Calculate the start index from which we can attach default values in the
 | ||||
|     // builder declaration.
 | ||||
|     for (int i = op.getNumArgs() - 1; i >= 0; --i) { | ||||
|       auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>(); | ||||
|       if (!namedAttr || !namedAttr->attr.hasDefaultValue()) | ||||
|         break; | ||||
| 
 | ||||
|       if (!canUseUnwrappedRawValue(namedAttr->attr)) | ||||
|         break; | ||||
| 
 | ||||
|       // Creating an APInt requires us to provide bitwidth, value, and
 | ||||
|       // signedness, which is complicated compared to others. Similarly
 | ||||
|       // for APFloat.
 | ||||
|       // TODO: Adjust the 'returnType' field of such attributes
 | ||||
|       // to support them.
 | ||||
|       StringRef retType = getReturnType(namedAttr->attr); | ||||
|       if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") | ||||
|         break; | ||||
| 
 | ||||
|       defaultValuedAttrStartIndex = i; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   for (int i = 0, e = op.getNumArgs(); i < e; ++i) { | ||||
|     auto argument = op.getArg(i); | ||||
|  | @ -1357,10 +1393,10 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList, | |||
|         type = getStorageType(attr); | ||||
|         break; | ||||
|       case AttrParamKind::UnwrappedValue: | ||||
|         if (canUseUnwrappedRawValue(attr)) | ||||
|           type = getReturnType(attr); | ||||
|         else | ||||
|           type = getStorageType(attr); | ||||
|         // if (canUseUnwrappedRawValue(attr))
 | ||||
|         //   type = getReturnType(attr);
 | ||||
|         // else
 | ||||
|         //   type = getStorageType(attr);
 | ||||
|         break; | ||||
|       } | ||||
| 
 | ||||
|  | @ -1394,41 +1430,72 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList, | |||
|                              llvm::formatv("{0}Count", region.name).str()); | ||||
| } | ||||
| 
 | ||||
| void OpEmitter::genCodeForAddingArgAndRegionForBuilder( | ||||
|     OpMethodBody &body, llvm::SmallVector<OpMethodParameter, 4> paramList, | ||||
| void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, | ||||
|                                                        bool isRawValueAttr) { | ||||
|   auto op = GetOp(); | ||||
|   auto operands = op.getOperands(); | ||||
|   auto attrs = op.getAttributes(); | ||||
|   SmallVector<std::string, 4> newAttrs; | ||||
| 
 | ||||
|   body << "//    AAAAAA  \n"; | ||||
|   // for(auto p : paramList)
 | ||||
|   // {
 | ||||
|   // body <<"====  type:"<< p.getType() << "  name:"<< p.getName() <<  "\n";
 | ||||
|   // for(auto o : operands){
 | ||||
|   //   body << "  operands name: "<<o.name<<"  getCPPClassName:"<<o.constraint.getCPPClassName()<<"\n";
 | ||||
|   // }
 | ||||
|   // for(auto a : attrs){
 | ||||
|   //   body << "  attrs name: "<<a.name<<"  type: "<< a.attr.getStorageType().str()<<"\n";
 | ||||
|   //   std::string attrType = a.attr.getStorageType().str();
 | ||||
|   //   auto ff = typeMapMLIR.find(attrType);
 | ||||
|   //   if(ff != typeMapMLIR.end()){
 | ||||
|   //     body << "//    BBBBBB  \n";
 | ||||
|   //     std::string attrName = a.name.str();
 | ||||
|   //     ff->second.ConvertToMlir(attrName,body);
 | ||||
|   //   }
 | ||||
|   // }
 | ||||
| 
 | ||||
|   body << "  auto builder = builder.GetImpl();\n"; | ||||
|   body << "  auto loc = builder->GetLoc();\n"; | ||||
|   body << "  auto opBuilder = builder->GetBuilder();\n"; | ||||
| 
 | ||||
|   // body << "//    AAAAAA  \n";
 | ||||
|   // for(auto p : paramList){
 | ||||
|   //   body << "  AAA type"<<p.getType()<<"  name"<<p.getName()<<"\n";
 | ||||
|   // }
 | ||||
| 
 | ||||
|   body << "  auto b = builder.GetImpl();\n"; | ||||
|   body << "  auto loc = b->GetLoc();\n"; | ||||
|   body << "  auto opBuilder = b->GetBuilder();\n"; | ||||
|   body << "  auto ctx = b->GetContext();\n"; | ||||
|   for (auto a : attrs) { | ||||
|     std::string attrType = a.attr.getStorageType().str(); | ||||
| 
 | ||||
| if(attrType == "::mlir::DenseIntElementsAttr") | ||||
| { | ||||
|   body << "  //  BBBBBBBB  getStorageType:"<< a.attr.getStorageType().str() <<"\n"; | ||||
|   body << "  //  BBBBBBBB  getReturnType:"<< a.attr.getReturnType().str() <<"\n"; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     auto typePair = typeMapMLIR.find(attrType); | ||||
|     if (typePair != typeMapMLIR.end()) { | ||||
|       std::string attrName = a.name.str(); | ||||
|       auto mlirName = typePair->second.ConvertToMlir(attrName, body); | ||||
|       newAttrs.emplace_back(mlirName); | ||||
|     } | ||||
|   } | ||||
|   body << "  mlir::" << op.getDialectName() << "::" << op.getCppClassName() | ||||
|        << " currentOp =\n"; | ||||
|   body << "    opBuilder.create<mlir::" << op.getDialectName() | ||||
|        << "::" << op.getDialectName() << ">(\n"; | ||||
|   if (paramList.size() > 1) { | ||||
|     body << "      loc,\n"; | ||||
|     std::for_each(paramList.begin() + 1, paramList.end() - 1, | ||||
|                   [&](OpMethodParameter &p) { | ||||
|                     body << "      " << p.getName() << ",\n"; | ||||
|   body << "      loc"; | ||||
|   std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) { | ||||
|     body << ",\n      " << v.name << ".getResult()"; | ||||
|   }); | ||||
|     body << "      " << paramList.back().getName() << "\n"; | ||||
|   } else { | ||||
|     body << "      loc\n"; | ||||
|   } | ||||
|   body << "    );\n"; | ||||
|   std::for_each(newAttrs.begin(), newAttrs.end(), | ||||
|                 [&](std::string &n) { body << ",\n      " << n; }); | ||||
|   body << "\n    );\n"; | ||||
|   body << "  builder::" << op.getCppClassName() << " builderOp;\n"; | ||||
|   body << "  auto opImpl = builderOp.GetImpl();\n"; | ||||
|   body << "  opImpl.SetOperation(currentOp.getOperation());\n"; | ||||
|   body << "  return builderOp;\n"; | ||||
| 
 | ||||
|    | ||||
| 
 | ||||
|   // // Push all operands to the result.
 | ||||
|   // for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
 | ||||
|   //   std::string argName = getArgumentName(op, i);
 | ||||
|  | @ -2079,184 +2146,6 @@ void OpEmitter::genOpAsmInterface() { | |||
|   } | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // OpOperandAdaptor emitter
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| namespace { | ||||
| // Helper class to emit Op operand adaptors to an output stream.  Operand
 | ||||
| // adaptors are wrappers around ArrayRef<Value> that provide named operand
 | ||||
| // getters identical to those defined in the Op.
 | ||||
| class OpOperandAdaptorEmitter { | ||||
| public: | ||||
|   static void emitDecl(const Operator &op, raw_ostream &os); | ||||
|   static void emitDef(const Operator &op, raw_ostream &os); | ||||
| 
 | ||||
| private: | ||||
|   explicit OpOperandAdaptorEmitter(const Operator &op); | ||||
| 
 | ||||
|   // Add verification function. This generates a verify method for the adaptor
 | ||||
|   // which verifies all the op-independent attribute constraints.
 | ||||
|   void addVerification(); | ||||
| 
 | ||||
|   const Operator &op; | ||||
|   Class adaptor; | ||||
| }; | ||||
| } // end namespace
 | ||||
| 
 | ||||
| OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) | ||||
|     : op(op), adaptor(op.getAdaptorName()) { | ||||
|   adaptor.newField("::mlir::ValueRange", "odsOperands"); | ||||
|   adaptor.newField("::mlir::DictionaryAttr", "odsAttrs"); | ||||
|   adaptor.newField("::mlir::RegionRange", "odsRegions"); | ||||
|   const auto *attrSizedOperands = | ||||
|       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); | ||||
|   { | ||||
|     SmallVector<OpMethodParameter, 2> paramList; | ||||
|     paramList.emplace_back("::mlir::ValueRange", "values"); | ||||
|     paramList.emplace_back("::mlir::DictionaryAttr", "attrs", | ||||
|                            attrSizedOperands ? "" : "nullptr"); | ||||
|     paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); | ||||
|     auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList)); | ||||
| 
 | ||||
|     constructor->addMemberInitializer("odsOperands", "values"); | ||||
|     constructor->addMemberInitializer("odsAttrs", "attrs"); | ||||
|     constructor->addMemberInitializer("odsRegions", "regions"); | ||||
|   } | ||||
| 
 | ||||
|   { | ||||
|     auto *constructor = adaptor.addConstructorAndPrune( | ||||
|         llvm::formatv("{0}&", op.getCppClassName()).str(), "op"); | ||||
|     constructor->addMemberInitializer("odsOperands", "op->getOperands()"); | ||||
|     constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); | ||||
|     constructor->addMemberInitializer("odsRegions", "op->getRegions()"); | ||||
|   } | ||||
| 
 | ||||
|   { | ||||
|     auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands"); | ||||
|     m->body() << "  return odsOperands;"; | ||||
|   } | ||||
|   std::string sizeAttrInit = | ||||
|       formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); | ||||
|   generateNamedOperandGetters(op, adaptor, sizeAttrInit, | ||||
|                               /*rangeType=*/"::mlir::ValueRange", | ||||
|                               /*rangeBeginCall=*/"odsOperands.begin()", | ||||
|                               /*rangeSizeCall=*/"odsOperands.size()", | ||||
|                               /*getOperandCallPattern=*/"odsOperands[{0}]"); | ||||
| 
 | ||||
|   FmtContext fctx; | ||||
|   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); | ||||
| 
 | ||||
|   auto emitAttr = [&](StringRef name, Attribute attr) { | ||||
|     auto &body = adaptor.addMethodAndPrune(getStorageType(attr), name)->body(); | ||||
|     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");" | ||||
|          << "\n  " << getStorageType(attr) << " attr = " | ||||
|          << "odsAttrs.get(\"" << name << "\")."; | ||||
|     if (attr.hasDefaultValue() || attr.isOptional()) | ||||
|       body << "dyn_cast_or_null<"; | ||||
|     else | ||||
|       body << "cast<"; | ||||
|     body << getStorageType(attr) << ">();\n"; | ||||
| 
 | ||||
|     if (attr.hasDefaultValue()) { | ||||
|       // Use the default value if attribute is not set.
 | ||||
|       // TODO: this is inefficient, we are recreating the attribute for every
 | ||||
|       // call. This should be set instead.
 | ||||
|       std::string defaultValue = std::string( | ||||
|           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); | ||||
|       body << "  if (!attr)\n    attr = " << defaultValue << ";\n"; | ||||
|     } | ||||
|     body << "  return attr;\n"; | ||||
|   }; | ||||
| 
 | ||||
|   { | ||||
|     auto *m = | ||||
|         adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes"); | ||||
|     m->body() << "  return odsAttrs;"; | ||||
|   } | ||||
|   for (auto &namedAttr : op.getAttributes()) { | ||||
|     const auto &name = namedAttr.name; | ||||
|     const auto &attr = namedAttr.attr; | ||||
|     if (!attr.isDerivedAttr()) | ||||
|       emitAttr(name, attr); | ||||
|   } | ||||
| 
 | ||||
|   unsigned numRegions = op.getNumRegions(); | ||||
|   if (numRegions > 0) { | ||||
|     auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions"); | ||||
|     m->body() << "  return odsRegions;"; | ||||
|   } | ||||
|   for (unsigned i = 0; i < numRegions; ++i) { | ||||
|     const auto ®ion = op.getRegion(i); | ||||
|     if (region.name.empty()) | ||||
|       continue; | ||||
| 
 | ||||
|     // Generate the accessors for a variadic region.
 | ||||
|     if (region.isVariadic()) { | ||||
|       auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name); | ||||
|       m->body() << formatv("  return odsRegions.drop_front({0});", i); | ||||
|       continue; | ||||
|     } | ||||
| 
 | ||||
|     auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name); | ||||
|     m->body() << formatv("  return *odsRegions[{0}];", i); | ||||
|   } | ||||
| 
 | ||||
|   // Add verification function.
 | ||||
|   addVerification(); | ||||
| } | ||||
| 
 | ||||
| void OpOperandAdaptorEmitter::addVerification() { | ||||
|   auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify", | ||||
|                                            "::mlir::Location", "loc"); | ||||
|   auto &body = method->body(); | ||||
| 
 | ||||
|   const char *checkAttrSizedValueSegmentsCode = R"( | ||||
|   { | ||||
|     auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); | ||||
|     auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); | ||||
|     if (numElements != {1}) | ||||
|       return emitError(loc, "'{0}' attribute for specifying {2} segments " | ||||
|                        "must have {1} elements, but got ") << numElements; | ||||
|   } | ||||
|   )"; | ||||
| 
 | ||||
|   // Verify a few traits first so that we can use
 | ||||
|   // getODSOperands()/getODSResults() in the rest of the verifier.
 | ||||
|   for (auto &trait : op.getTraits()) { | ||||
|     if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) { | ||||
|       if (t->getFullyQualifiedTraitName() == | ||||
|           "::mlir::OpTrait::AttrSizedOperandSegments") { | ||||
|         body << formatv(checkAttrSizedValueSegmentsCode, | ||||
|                         "operand_segment_sizes", op.getNumOperands(), | ||||
|                         "operand"); | ||||
|       } else if (t->getFullyQualifiedTraitName() == | ||||
|                  "::mlir::OpTrait::AttrSizedResultSegments") { | ||||
|         body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes", | ||||
|                         op.getNumResults(), "result"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   FmtContext verifyCtx; | ||||
|   populateSubstitutions(op, "odsAttrs.get", "getODSOperands", | ||||
|                         "<no results should be generated>", verifyCtx); | ||||
|   genAttributeVerifier(op, "odsAttrs.get", | ||||
|                        Twine("emitError(loc, \"'") + op.getOperationName() + | ||||
|                            "' op \"", | ||||
|                        /*emitVerificationRequiringOp*/ false, verifyCtx, body); | ||||
| 
 | ||||
|   body << "  return ::mlir::success();"; | ||||
| } | ||||
| 
 | ||||
| void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { | ||||
|   OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os); | ||||
| } | ||||
| 
 | ||||
| void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { | ||||
|   OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os); | ||||
| } | ||||
| 
 | ||||
| // Emits the opcode enum and op classes.
 | ||||
| static void emitOpClasses(const RecordKeeper &recordKeeper, | ||||
|                           const std::vector<Record *> &defs, raw_ostream &os, | ||||
|  | @ -2286,11 +2175,9 @@ static void emitOpClasses(const RecordKeeper &recordKeeper, | |||
|     NamespaceEmitter emitter(os, op.getCppNamespace()); | ||||
|     if (emitDecl) { | ||||
|       // os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
 | ||||
|       // OpOperandAdaptorEmitter::emitDecl(op, os);
 | ||||
|       OpEmitter::emitDecl(op, os, staticVerifierEmitter); | ||||
|     } else { | ||||
|       // os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
 | ||||
|       // OpOperandAdaptorEmitter::emitDef(op, os);
 | ||||
|       OpEmitter::emitDef(op, os, staticVerifierEmitter); | ||||
|     } | ||||
|   } | ||||
|  |  | |||
|  | @ -1,16 +0,0 @@ | |||
| #ifndef BUILDER_OP_ | ||||
| #define BUILDER_OP_ | ||||
| 
 | ||||
| #include "iostream" | ||||
| 
 | ||||
| namespace builder { | ||||
| class Op { | ||||
|   class Impl; | ||||
|   std::shared_ptr<Impl> GetImpl() { return _impl; } | ||||
| 
 | ||||
|  private: | ||||
|   std::shared_ptr<Impl> _impl; | ||||
| } | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -1,16 +0,0 @@ | |||
| #ifndef BUILDER_TENSOR_ | ||||
| #define BUILDER_TENSOR_ | ||||
| 
 | ||||
| 
 | ||||
| #include "iostream" | ||||
| 
 | ||||
| namespace builder{ | ||||
|   class Tensor{ | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| #endif | ||||
|  | @ -1,10 +0,0 @@ | |||
| #ifndef BUILDER_TYPE_ | ||||
| #define BUILDER_TYPE_ | ||||
| 
 | ||||
| #include "iostream" | ||||
| 
 | ||||
| namespace builder { | ||||
| class Type {} | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
| #endif | ||||
		Loading…
	
		Reference in New Issue