add integer attribute convert
This commit is contained in:
		
							parent
							
								
									9d5166684c
								
							
						
					
					
						commit
						95fc37ffa8
					
				
							
								
								
									
										2
									
								
								BUILD
								
								
								
								
							
							
						
						
									
										2
									
								
								BUILD
								
								
								
								
							|  | @ -129,6 +129,8 @@ cc_binary( | ||||||
|         "tools/mlir-tblgen-builder/*.cpp", |         "tools/mlir-tblgen-builder/*.cpp", | ||||||
|         "tools/mlir-tblgen-builder/TableGen/*.h", |         "tools/mlir-tblgen-builder/TableGen/*.h", | ||||||
|         "tools/mlir-tblgen-builder/TableGen/*.cpp", |         "tools/mlir-tblgen-builder/TableGen/*.cpp", | ||||||
|  |         "tools/mlir-tblgen-builder/Builder/*.h", | ||||||
|  |         "tools/mlir-tblgen-builder/Builder/*.cpp", | ||||||
|     ]), |     ]), | ||||||
|     deps = [ |     deps = [ | ||||||
|         "@llvm-project//mlir:MlirTableGenMain", |         "@llvm-project//mlir:MlirTableGenMain", | ||||||
|  |  | ||||||
|  | @ -14,8 +14,25 @@ limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "iostream" | #include "iostream" | ||||||
| 
 | #include "llvm/Support/CommandLine.h" | ||||||
| #include "mlir/Support/MlirOptMain.h" | #include "llvm/Support/DynamicLibrary.h" | ||||||
|  | #include "llvm/Support/FileUtilities.h" | ||||||
|  | #include "llvm/Support/InitLLVM.h" | ||||||
|  | #include "llvm/Support/Regex.h" | ||||||
|  | #include "llvm/Support/SourceMgr.h" | ||||||
|  | #include "llvm/Support/StringSaver.h" | ||||||
|  | #include "llvm/Support/ToolOutputFile.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" | ||||||
|  | #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" | ||||||
|  | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||||||
|  | #include "mlir/ExecutionEngine/ExecutionEngine.h" | ||||||
|  | #include "mlir/ExecutionEngine/JitRunner.h" | ||||||
|  | #include "mlir/ExecutionEngine/OptUtils.h" | ||||||
| #include "mlir/IR/AsmState.h" | #include "mlir/IR/AsmState.h" | ||||||
| #include "mlir/IR/Attributes.h" | #include "mlir/IR/Attributes.h" | ||||||
| #include "mlir/IR/BuiltinOps.h" | #include "mlir/IR/BuiltinOps.h" | ||||||
|  | @ -23,62 +40,31 @@ limitations under the License. | ||||||
| #include "mlir/IR/Dialect.h" | #include "mlir/IR/Dialect.h" | ||||||
| #include "mlir/IR/Location.h" | #include "mlir/IR/Location.h" | ||||||
| #include "mlir/IR/MLIRContext.h" | #include "mlir/IR/MLIRContext.h" | ||||||
|  | #include "mlir/InitAllDialects.h" | ||||||
|  | #include "mlir/InitAllPasses.h" | ||||||
| #include "mlir/Parser.h" | #include "mlir/Parser.h" | ||||||
| #include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||||
| #include "mlir/Pass/PassManager.h" | #include "mlir/Pass/PassManager.h" | ||||||
| #include "mlir/Support/DebugCounter.h" | #include "mlir/Support/DebugCounter.h" | ||||||
| #include "mlir/Support/FileUtilities.h" | #include "mlir/Support/FileUtilities.h" | ||||||
|  | #include "mlir/Support/MlirOptMain.h" | ||||||
| #include "mlir/Support/Timing.h" | #include "mlir/Support/Timing.h" | ||||||
| #include "mlir/Support/ToolUtilities.h" | #include "mlir/Support/ToolUtilities.h" | ||||||
| #include "llvm/Support/CommandLine.h" |  | ||||||
| #include "llvm/Support/FileUtilities.h" |  | ||||||
| #include "llvm/Support/InitLLVM.h" |  | ||||||
| #include "llvm/Support/Regex.h" |  | ||||||
| #include "llvm/Support/SourceMgr.h" |  | ||||||
| #include "llvm/Support/StringSaver.h" |  | ||||||
| #include "llvm/Support/ToolOutputFile.h" |  | ||||||
| #include "llvm/Support/DynamicLibrary.h" |  | ||||||
| 
 |  | ||||||
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |  | ||||||
| #include "mlir/ExecutionEngine/JitRunner.h" |  | ||||||
| #include "mlir/ExecutionEngine/OptUtils.h" |  | ||||||
| #include "mlir/Target/LLVMIR/Dialect/All.h" | #include "mlir/Target/LLVMIR/Dialect/All.h" | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| #include "mlir/InitAllDialects.h" |  | ||||||
| #include "mlir/InitAllPasses.h" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h" |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| #include "mlir/ExecutionEngine/ExecutionEngine.h" |  | ||||||
| // #include "mlir/IR/BuiltinTypes.h"
 | // #include "mlir/IR/BuiltinTypes.h"
 | ||||||
| 
 | 
 | ||||||
| 
 | #include "llvm/ADT/STLExtras.h" | ||||||
| #include "llvm/Support/TargetSelect.h" |  | ||||||
| #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" | #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" | ||||||
| #include "llvm/ExecutionEngine/Orc/Mangling.h" | #include "llvm/ExecutionEngine/Orc/Mangling.h" | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| #include "llvm/ADT/STLExtras.h" |  | ||||||
| #include "llvm/IR/IRBuilder.h" | #include "llvm/IR/IRBuilder.h" | ||||||
| #include "llvm/IR/LLVMContext.h" | #include "llvm/IR/LLVMContext.h" | ||||||
| #include "llvm/IR/LegacyPassNameParser.h" | #include "llvm/IR/LegacyPassNameParser.h" | ||||||
| 
 | #include "llvm/Support/TargetSelect.h" | ||||||
| 
 | 
 | ||||||
| using namespace mlir; | using namespace mlir; | ||||||
| using namespace llvm; | using namespace llvm; | ||||||
| 
 | 
 | ||||||
| namespace utils{ | namespace utils { | ||||||
| template <typename T, int N> | template <typename T, int N> | ||||||
| struct MemRefDescriptor { | struct MemRefDescriptor { | ||||||
|   T *allocated; |   T *allocated; | ||||||
|  | @ -87,11 +73,9 @@ struct MemRefDescriptor { | ||||||
|   int64_t sizes[N]; |   int64_t sizes[N]; | ||||||
|   int64_t strides[N]; |   int64_t strides[N]; | ||||||
| }; | }; | ||||||
| } | }  // namespace utils
 | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| int main(int argc, char **argv) { | int main(int argc, char **argv) { | ||||||
| 
 |  | ||||||
|   llvm::InitLLVM y(argc, argv); |   llvm::InitLLVM y(argc, argv); | ||||||
|   llvm::InitializeNativeTarget(); |   llvm::InitializeNativeTarget(); | ||||||
|   llvm::InitializeNativeTargetAsmPrinter(); |   llvm::InitializeNativeTargetAsmPrinter(); | ||||||
|  | @ -105,7 +89,6 @@ int main(int argc, char **argv) { | ||||||
|   // registerDefaultTimingManagerCLOptions();
 |   // registerDefaultTimingManagerCLOptions();
 | ||||||
|   DebugCounter::registerCLOptions(); |   DebugCounter::registerCLOptions(); | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|   mlir::registerAllPasses(); |   mlir::registerAllPasses(); | ||||||
|   mlir::mhlo::registerAllMhloPasses(); |   mlir::mhlo::registerAllMhloPasses(); | ||||||
|   mlir::lmhlo::registerAllLmhloPasses(); |   mlir::lmhlo::registerAllLmhloPasses(); | ||||||
|  | @ -120,24 +103,22 @@ int main(int argc, char **argv) { | ||||||
|   registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>(); |   registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>(); | ||||||
|   registry.insert<mlir::disc_ral::RalDialect>(); |   registry.insert<mlir::disc_ral::RalDialect>(); | ||||||
| 
 | 
 | ||||||
|    |  | ||||||
|   // failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
 |   // failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
 | ||||||
|   //                                 registry,
 |   //                                 registry,
 | ||||||
|   //                                 /*preloadDialectsInContext=*/false));
 |   //                                 /*preloadDialectsInContext=*/false));
 | ||||||
|   // return 0;
 |   // return 0;
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|   std::string errorMessage; |   std::string errorMessage; | ||||||
| 
 | 
 | ||||||
|   // auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage);
 |   // auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir",
 | ||||||
|   auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage); |   // &errorMessage);
 | ||||||
|   std::cout<<"errorMessage:" <<errorMessage <<std::endl; |   auto file = | ||||||
|  |       mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage); | ||||||
|  |   std::cout << "errorMessage:" << errorMessage << std::endl; | ||||||
| 
 | 
 | ||||||
|   SourceMgr sourceMgr; |   SourceMgr sourceMgr; | ||||||
|   sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); |   sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|   SmallVector<const llvm::PassInfo *, 4> passes; |   SmallVector<const llvm::PassInfo *, 4> passes; | ||||||
|   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); |   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); | ||||||
|   auto tmOrError = tmBuilderOrError->createTargetMachine(); |   auto tmOrError = tmBuilderOrError->createTargetMachine(); | ||||||
|  | @ -145,8 +126,6 @@ int main(int argc, char **argv) { | ||||||
|   auto transformer = mlir::makeLLVMPassesTransformer( |   auto transformer = mlir::makeLLVMPassesTransformer( | ||||||
|       passes, 0, /*targetMachine=*/tmOrError->get(), 0); |       passes, 0, /*targetMachine=*/tmOrError->get(), 0); | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|   MLIRContext context(registry); |   MLIRContext context(registry); | ||||||
|   context.loadAllAvailableDialects(); |   context.loadAllAvailableDialects(); | ||||||
|   OwningModuleRef module(parseSourceFile(sourceMgr, &context)); |   OwningModuleRef module(parseSourceFile(sourceMgr, &context)); | ||||||
|  | @ -157,8 +136,7 @@ int main(int argc, char **argv) { | ||||||
|     return failure(); |     return failure(); | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
| 
 |   // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \ | ||||||
| // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \ |  | ||||||
| // RUN: -hlo-legalize-to-lhlo -buffer-hoisting \ | // RUN: -hlo-legalize-to-lhlo -buffer-hoisting \ | ||||||
| // RUN: -buffer-deallocation -canonicalize -cse \ | // RUN: -buffer-deallocation -canonicalize -cse \ | ||||||
| // RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ | // RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ | ||||||
|  | @ -172,7 +150,7 @@ int main(int argc, char **argv) { | ||||||
| 
 | 
 | ||||||
|   pm.addPass(mlir::mhlo::createChloLegalizeToHloPass()); |   pm.addPass(mlir::mhlo::createChloLegalizeToHloPass()); | ||||||
|   pm.addPass(mlir::mhlo::createLegalizeToLhloPass()); |   pm.addPass(mlir::mhlo::createLegalizeToLhloPass()); | ||||||
|    | 
 | ||||||
|   pm.addPass(mlir::createBufferHoistingPass()); |   pm.addPass(mlir::createBufferHoistingPass()); | ||||||
|   pm.addPass(mlir::createBufferDeallocationPass()); |   pm.addPass(mlir::createBufferDeallocationPass()); | ||||||
| 
 | 
 | ||||||
|  | @ -188,13 +166,11 @@ int main(int argc, char **argv) { | ||||||
|   pm.addPass(mlir::createCanonicalizerPass()); |   pm.addPass(mlir::createCanonicalizerPass()); | ||||||
|   pm.addPass(mlir::createCSEPass()); |   pm.addPass(mlir::createCSEPass()); | ||||||
|   pm.addPass(mlir::createLowerToLLVMPass()); |   pm.addPass(mlir::createLowerToLLVMPass()); | ||||||
|    | 
 | ||||||
|   pm.run(*module); |   pm.run(*module); | ||||||
|   module->dump(); |   module->dump(); | ||||||
| 
 | 
 | ||||||
| 
 |   std::cout << "DEBUG load module success" << std::endl; | ||||||
|    |  | ||||||
|   std::cout<<"DEBUG load module success"<<std::endl; |  | ||||||
| 
 | 
 | ||||||
|   llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default; |   llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default; | ||||||
| 
 | 
 | ||||||
|  | @ -206,16 +182,14 @@ int main(int argc, char **argv) { | ||||||
|   // Use absolute library path so that gdb can find the symbol table.
 |   // Use absolute library path so that gdb can find the symbol table.
 | ||||||
| 
 | 
 | ||||||
|   std::list<std::string> sharedlib; |   std::list<std::string> sharedlib; | ||||||
|   sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git"); |   sharedlib.push_back( | ||||||
|  |       "/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git"); | ||||||
| 
 | 
 | ||||||
|   transform( |   transform(sharedlib, std::back_inserter(libPaths), [](std::string libPath) { | ||||||
|       sharedlib, |     SmallString<256> absPath(libPath.begin(), libPath.end()); | ||||||
|       std::back_inserter(libPaths), |     cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); | ||||||
|       [](std::string libPath) { |     return absPath; | ||||||
|         SmallString<256> absPath(libPath.begin(), libPath.end()); |   }); | ||||||
|         cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); |  | ||||||
|         return absPath; |  | ||||||
|       }); |  | ||||||
| 
 | 
 | ||||||
|   // Libraries that we'll pass to the ExecutionEngine for loading.
 |   // Libraries that we'll pass to the ExecutionEngine for loading.
 | ||||||
|   SmallVector<StringRef, 4> executionEngineLibs; |   SmallVector<StringRef, 4> executionEngineLibs; | ||||||
|  | @ -226,7 +200,6 @@ int main(int argc, char **argv) { | ||||||
|   llvm::StringMap<void *> exportSymbols; |   llvm::StringMap<void *> exportSymbols; | ||||||
|   SmallVector<MlirRunnerDestroyFn> destroyFns; |   SmallVector<MlirRunnerDestroyFn> destroyFns; | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|   // Handle libraries that do support mlir-runner init/destroy callbacks.
 |   // Handle libraries that do support mlir-runner init/destroy callbacks.
 | ||||||
|   for (auto &libPath : libPaths) { |   for (auto &libPath : libPaths) { | ||||||
|     auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str()); |     auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str()); | ||||||
|  | @ -255,17 +228,12 @@ int main(int argc, char **argv) { | ||||||
|     return symbolMap; |     return symbolMap; | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
| 
 |   auto expectedEngine = | ||||||
| 
 |       mlir::ExecutionEngine::create(module.get(), nullptr, transformer, | ||||||
| 
 |                                     jitCodeGenOptLevel, executionEngineLibs); | ||||||
|   auto expectedEngine = mlir::ExecutionEngine::create( |  | ||||||
|       module.get(), nullptr, transformer, jitCodeGenOptLevel, |  | ||||||
|       executionEngineLibs); |  | ||||||
|   // if (!expectedEngine)
 |   // if (!expectedEngine)
 | ||||||
|   //   return expectedEngine.takeError();
 |   //   return expectedEngine.takeError();
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|   auto engine = std::move(*expectedEngine); |   auto engine = std::move(*expectedEngine); | ||||||
|   engine->registerSymbols(runtimeSymbolMap); |   engine->registerSymbols(runtimeSymbolMap); | ||||||
| 
 | 
 | ||||||
|  | @ -277,16 +245,15 @@ int main(int argc, char **argv) { | ||||||
|   // if (options.dumpObjectFile)
 |   // if (options.dumpObjectFile)
 | ||||||
|   // engine->dumpToObjectFile("a.o");
 |   // engine->dumpToObjectFile("a.o");
 | ||||||
| 
 | 
 | ||||||
| 
 |   float rawdata[6] = {0, 1, 2, 3, 4, 5}; | ||||||
|   float rawdata[6] = {0,1,2,3,4,5}; |  | ||||||
|   int64_t dims = 1; |   int64_t dims = 1; | ||||||
|   utils::MemRefDescriptor<float,1> a{rawdata,rawdata,0,{6},{1}}; |   utils::MemRefDescriptor<float, 1> a{rawdata, rawdata, 0, {6}, {1}}; | ||||||
|   utils::MemRefDescriptor<float,1> b{rawdata,rawdata,0,{6},{1}}; |   utils::MemRefDescriptor<float, 1> b{rawdata, rawdata, 0, {6}, {1}}; | ||||||
|   utils::MemRefDescriptor<float,1> result_memref; |   utils::MemRefDescriptor<float, 1> result_memref; | ||||||
| 
 | 
 | ||||||
|   struct memref_type{ |   struct memref_type { | ||||||
|     int64_t res_size = 6; |     int64_t res_size = 6; | ||||||
|     utils::MemRefDescriptor<float,1> *memref; |     utils::MemRefDescriptor<float, 1> *memref; | ||||||
|   } result; |   } result; | ||||||
|   result.memref = &result_memref; |   result.memref = &result_memref; | ||||||
| 
 | 
 | ||||||
|  | @ -299,23 +266,23 @@ int main(int argc, char **argv) { | ||||||
|   } data; |   } data; | ||||||
| 
 | 
 | ||||||
|   data.data1_size = &dims; |   data.data1_size = &dims; | ||||||
|   void * a_ptr = &a; |   void *a_ptr = &a; | ||||||
|   data.data1 = &a_ptr; |   data.data1 = &a_ptr; | ||||||
|   data.data2_size = &dims; |   data.data2_size = &dims; | ||||||
|   void * b_ptr = &b; |   void *b_ptr = &b; | ||||||
|   data.data2 = &b_ptr; |   data.data2 = &b_ptr; | ||||||
|   void * result_ptr = &result; |   void *result_ptr = &result; | ||||||
|   data.res = &result; |   data.res = &result; | ||||||
| 
 | 
 | ||||||
|   void (*fptr)(void **) = *expectedFPtr; |   void (*fptr)(void **) = *expectedFPtr; | ||||||
|   (*fptr)((void **)&data); |   (*fptr)((void **)&data); | ||||||
| 
 | 
 | ||||||
|   std::cout<<"result: "<<result.memref->allocated[0]<<std::endl; |   std::cout << "result: " << result.memref->allocated[0] << std::endl; | ||||||
|   std::cout<<"result: "<<result.memref->allocated[1]<<std::endl; |   std::cout << "result: " << result.memref->allocated[1] << std::endl; | ||||||
|   std::cout<<"result: "<<result.memref->allocated[2]<<std::endl; |   std::cout << "result: " << result.memref->allocated[2] << std::endl; | ||||||
|   std::cout<<"result: "<<result.memref->allocated[3]<<std::endl; |   std::cout << "result: " << result.memref->allocated[3] << std::endl; | ||||||
|   std::cout<<"result: "<<result.memref->allocated[4]<<std::endl; |   std::cout << "result: " << result.memref->allocated[4] << std::endl; | ||||||
|   std::cout<<"result: "<<result.memref->allocated[5]<<std::endl; |   std::cout << "result: " << result.memref->allocated[5] << std::endl; | ||||||
| 
 | 
 | ||||||
|   // Run all dynamic library destroy callbacks to prepare for the shutdown.
 |   // Run all dynamic library destroy callbacks to prepare for the shutdown.
 | ||||||
|   llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); }); |   llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); }); | ||||||
|  |  | ||||||
|  | @ -1,10 +0,0 @@ | ||||||
| #ifndef BUILDER_ARRAY_ |  | ||||||
| #define BUILDER_ARRAY_ |  | ||||||
| 
 |  | ||||||
| #include "iostream" |  | ||||||
| 
 |  | ||||||
| namespace builder { |  | ||||||
| class Array {} |  | ||||||
| }  // namespace builder
 |  | ||||||
| 
 |  | ||||||
| #endif |  | ||||||
|  | @ -0,0 +1,22 @@ | ||||||
|  | #include "Attribute.h" | ||||||
|  | 
 | ||||||
|  | #include "AttributeImpl.h" | ||||||
|  | #include "Shape.h" | ||||||
|  | #include "Tensor.h" | ||||||
|  | #include "TensorImpl.h" | ||||||
|  | #include "llvm/Support/Casting.h" | ||||||
|  | #include "memory.h" | ||||||
|  | // #include "mlir/Dialect/StandardOps/Ops.h"
 | ||||||
|  | // #include "mlir/IR/Attributes.h"
 | ||||||
|  | // #include "mlir/IR/Operation.h"
 | ||||||
|  | // #include "mlir/IR/StandardTypes.h"
 | ||||||
|  | // #include "mlir/IR/Types.h"
 | ||||||
|  | // #include "mlir/IR/Value.h"
 | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | 
 | ||||||
|  | Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {} | ||||||
|  | Integer::Integer(int64_t value) | ||||||
|  |     : impl_(std::make_shared<Integer::Impl>(value)) {} | ||||||
|  | 
 | ||||||
|  | }  // namespace builder
 | ||||||
|  | @ -0,0 +1,50 @@ | ||||||
|  | #ifndef BUILDER_ATTRIBUTE_H_ | ||||||
|  | #define BUILDER_ATTRIBUTE_H_ | ||||||
|  | 
 | ||||||
|  | #include <iostream> | ||||||
|  | #include <memory> | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | 
 | ||||||
|  | class Integer { | ||||||
|  |  public: | ||||||
|  |   Integer(int value); | ||||||
|  |   Integer(int64_t value); | ||||||
|  |   class Impl; | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   std::shared_ptr<Impl> impl_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | class Array { | ||||||
|  |  public: | ||||||
|  |   class Impl; | ||||||
|  |   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   std::shared_ptr<Impl> impl_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | class Type { | ||||||
|  |  public: | ||||||
|  |   class Impl; | ||||||
|  |   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   std::shared_ptr<Impl> impl_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // template <typename T>
 | ||||||
|  | // class DenseIntElementsAttr {
 | ||||||
|  | //  public:
 | ||||||
|  | //   explicit DenseIntElementsAttr(Tensor);
 | ||||||
|  | //   class Impl;
 | ||||||
|  | //   std::shared_ptr<Impl> GetImpl() { return impl_; }
 | ||||||
|  | 
 | ||||||
|  | //  private:
 | ||||||
|  | //   std::shared_ptr<Impl> impl_;
 | ||||||
|  | // };
 | ||||||
|  | 
 | ||||||
|  | }  // namespace builder
 | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
|  | @ -0,0 +1,58 @@ | ||||||
|  | #ifndef BUILDER_TYPEIMPL_ | ||||||
|  | #define BUILDER_TYPEIMPL_ | ||||||
|  | 
 | ||||||
|  | #include "Builder.h" | ||||||
|  | #include "llvm/Support/Casting.h" | ||||||
|  | #include "mlir/IR/Attributes.h" | ||||||
|  | #include "mlir/IR/Builders.h" | ||||||
|  | #include "mlir/IR/BuiltinTypes.h" | ||||||
|  | #include "mlir/IR/MLIRContext.h" | ||||||
|  | #include "mlir/IR/Operation.h" | ||||||
|  | #include "mlir/IR/Types.h" | ||||||
|  | #include "mlir/IR/Value.h" | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | 
 | ||||||
|  | class Integer::Impl { | ||||||
|  |  public: | ||||||
|  |   Impl(){}; | ||||||
|  | 
 | ||||||
|  |   Impl(int value) { | ||||||
|  |     GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr { | ||||||
|  |       auto int32_type = mlir::IntegerType::get(context, 32); | ||||||
|  |       return mlir::IntegerAttr::get(int32_type, value); | ||||||
|  |     }; | ||||||
|  |   }; | ||||||
|  |   Impl(int64_t value) | ||||||
|  |       : GetAttr([=](mlir::MLIRContext *context) -> mlir::IntegerAttr { | ||||||
|  |           auto int32_type = mlir::IntegerType::get(context, 64); | ||||||
|  |           return mlir::IntegerAttr::get(int32_type, value); | ||||||
|  |         }){}; | ||||||
|  |   std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr; | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | class Array::Impl { | ||||||
|  |  public: | ||||||
|  |   Impl() = default; | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  | }; | ||||||
|  | class Type::Impl { | ||||||
|  |  public: | ||||||
|  |   Impl() = default; | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // class DenseIntElementsAttr::Impl {
 | ||||||
|  | //  public:
 | ||||||
|  | //   Impl() = default;
 | ||||||
|  | 
 | ||||||
|  | //  private:
 | ||||||
|  | // };
 | ||||||
|  | 
 | ||||||
|  | }  // namespace builder
 | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
|  | @ -11,7 +11,7 @@ | ||||||
| 
 | 
 | ||||||
| namespace builder { | namespace builder { | ||||||
| 
 | 
 | ||||||
| Builder::Builder() : _impl(std::make_shared<Impl>()) {} | Builder::Builder() : impl_(std::make_shared<Impl>()) {} | ||||||
| void Builder::DumpModule() {} | void Builder::DumpModule() {} | ||||||
| 
 | 
 | ||||||
| }  // namespace builder
 | }  // namespace builder
 | ||||||
|  | @ -11,10 +11,10 @@ class Builder { | ||||||
| 
 | 
 | ||||||
|   Builder(); |   Builder(); | ||||||
|   void DumpModule(); |   void DumpModule(); | ||||||
|   std::shared_ptr<Impl> GetImpl() { return _impl; } |   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   std::shared_ptr<Impl> _impl; |   std::shared_ptr<Impl> impl_; | ||||||
| }; | }; | ||||||
| }  // namespace builder
 | }  // namespace builder
 | ||||||
| 
 | 
 | ||||||
|  | @ -0,0 +1,17 @@ | ||||||
|  | #ifndef BUILDER_OP_ | ||||||
|  | #define BUILDER_OP_ | ||||||
|  | 
 | ||||||
|  | #include <iostream> | ||||||
|  | #include <memory> | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | class Op { | ||||||
|  |   class Impl; | ||||||
|  |   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   std::shared_ptr<Impl> impl_; | ||||||
|  | } | ||||||
|  | }  // namespace builder
 | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
|  | @ -0,0 +1,28 @@ | ||||||
|  | #include "PrimitiveType.h" | ||||||
|  | 
 | ||||||
|  | #include "Shape.h" | ||||||
|  | #include "Tensor.h" | ||||||
|  | #include "llvm/Support/Casting.h" | ||||||
|  | // #include "mlir/Dialect/StandardOps/Ops.h"
 | ||||||
|  | // #include "mlir/IR/Attributes.h"
 | ||||||
|  | // #include "mlir/IR/Operation.h"
 | ||||||
|  | // #include "mlir/IR/StandardTypes.h"
 | ||||||
|  | #include "mlir/IR/Types.h" | ||||||
|  | // #include "mlir/IR/Value.h"
 | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | 
 | ||||||
|  | // constexpr mlir::Type GetMlirType(const PrimitiveType primitiveType) {
 | ||||||
|  | // switch (primitiveType)
 | ||||||
|  | // {
 | ||||||
|  | // case PrimitiveType::BF16 :
 | ||||||
|  | //   return 
 | ||||||
|  | //   break;
 | ||||||
|  | 
 | ||||||
|  | // default:
 | ||||||
|  | //   break;
 | ||||||
|  | // }
 | ||||||
|  | 
 | ||||||
|  | // }
 | ||||||
|  | 
 | ||||||
|  | }  // namespace builder
 | ||||||
|  | @ -0,0 +1,29 @@ | ||||||
|  | #ifndef BUILDER_PRIMITIVE_TYPE_ | ||||||
|  | #define BUILDER_PRIMITIVE_TYPE_ | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | 
 | ||||||
|  | enum class PrimitiveType { | ||||||
|  |   PRIMITIVE_TYPE_INVALID = 0, | ||||||
|  |   PRED = 1, | ||||||
|  |   S8 = 2, | ||||||
|  |   S16 = 3, | ||||||
|  |   S32 = 4, | ||||||
|  |   S64 = 5, | ||||||
|  |   U8 = 6, | ||||||
|  |   U16 = 7, | ||||||
|  |   U32 = 8, | ||||||
|  |   U64 = 9, | ||||||
|  |   F16 = 10, | ||||||
|  |   F32 = 11, | ||||||
|  |   BF16 = 16, | ||||||
|  |   F64 = 12, | ||||||
|  |   C64 = 15, | ||||||
|  |   C128 = 18, | ||||||
|  |   TUPLE = 13, | ||||||
|  |   OPAQUE_TYPE = 14, | ||||||
|  |   TOKEN = 17 | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  | @ -0,0 +1,17 @@ | ||||||
|  | #ifndef BUILDER_SHAPE_ | ||||||
|  | #define BUILDER_SHAPE_ | ||||||
|  | 
 | ||||||
|  | #include <iostream> | ||||||
|  | #include <memory> | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | class Shape { | ||||||
|  |   class Impl; | ||||||
|  |   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   std::shared_ptr<Impl> impl_; | ||||||
|  | }; | ||||||
|  | }  // namespace builder
 | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
|  | @ -0,0 +1,24 @@ | ||||||
|  | #ifndef BUILDER_SHAPEIMPL_ | ||||||
|  | #define BUILDER_SHAPEIMPL_ | ||||||
|  | 
 | ||||||
|  | #include "Builder.h" | ||||||
|  | #include "llvm/Support/Casting.h" | ||||||
|  | #include "mlir/IR/Attributes.h" | ||||||
|  | #include "mlir/IR/Builders.h" | ||||||
|  | #include "mlir/IR/MLIRContext.h" | ||||||
|  | #include "mlir/IR/Operation.h" | ||||||
|  | #include "mlir/IR/Types.h" | ||||||
|  | #include "mlir/IR/Value.h" | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | 
 | ||||||
|  | class Shape::Impl { | ||||||
|  |  public: | ||||||
|  |   Impl() = default; | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace builder
 | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
|  | @ -0,0 +1,22 @@ | ||||||
|  | #include "Tensor.h" | ||||||
|  | 
 | ||||||
|  | #include "Shape.h" | ||||||
|  | #include "TensorImpl.h" | ||||||
|  | #include "llvm/Support/Casting.h" | ||||||
|  | // #include "mlir/Dialect/StandardOps/Ops.h"
 | ||||||
|  | // #include "mlir/IR/Attributes.h"
 | ||||||
|  | // #include "mlir/IR/Operation.h"
 | ||||||
|  | // #include "mlir/IR/StandardTypes.h"
 | ||||||
|  | // #include "mlir/IR/Types.h"
 | ||||||
|  | // #include "mlir/IR/Value.h"
 | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | 
 | ||||||
|  | // Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
 | ||||||
|  | Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType) | ||||||
|  |     : impl_(std::make_shared<Impl>(shape, primitiveType)) {} | ||||||
|  | 
 | ||||||
|  | Shape Tensor::GetShape() { return impl_->GetShape(); } | ||||||
|  | PrimitiveType Tensor::GetType() { return impl_->GetType(); } | ||||||
|  | 
 | ||||||
|  | }  // namespace builder
 | ||||||
|  | @ -0,0 +1,25 @@ | ||||||
|  | #ifndef BUILDER_TENSOR_ | ||||||
|  | #define BUILDER_TENSOR_ | ||||||
|  | 
 | ||||||
|  | #include <iostream> | ||||||
|  | #include <memory> | ||||||
|  | 
 | ||||||
|  | #include "PrimitiveType.h" | ||||||
|  | #include "Shape.h" | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | 
 | ||||||
|  | class Tensor { | ||||||
|  |  public: | ||||||
|  |   Tensor(Shape& shape, PrimitiveType& primitiveType); | ||||||
|  |   class Impl; | ||||||
|  |   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||||
|  |   Shape GetShape(); | ||||||
|  |   PrimitiveType GetType(); | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   std::shared_ptr<Impl> impl_; | ||||||
|  | }; | ||||||
|  | }  // namespace builder
 | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
|  | @ -0,0 +1,32 @@ | ||||||
|  | #ifndef BUILDER_TENSORIMPL_ | ||||||
|  | #define BUILDER_TENSORIMPL_ | ||||||
|  | 
 | ||||||
|  | #include "Builder.h" | ||||||
|  | #include "PrimitiveType.h" | ||||||
|  | #include "Shape.h" | ||||||
|  | #include "llvm/Support/Casting.h" | ||||||
|  | #include "mlir/IR/Attributes.h" | ||||||
|  | #include "mlir/IR/Builders.h" | ||||||
|  | #include "mlir/IR/MLIRContext.h" | ||||||
|  | #include "mlir/IR/Operation.h" | ||||||
|  | #include "mlir/IR/Types.h" | ||||||
|  | #include "mlir/IR/Value.h" | ||||||
|  | 
 | ||||||
|  | namespace builder { | ||||||
|  | 
 | ||||||
|  | class Tensor::Impl { | ||||||
|  |  public: | ||||||
|  |   Impl() = default; | ||||||
|  |   Impl(Shape& shape, PrimitiveType& primitiveType) | ||||||
|  |       : shape_(shape), primitiveType_(primitiveType) {} | ||||||
|  |   Shape GetShape() { return shape_; } | ||||||
|  |   PrimitiveType GetType() { return primitiveType_; } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   Shape shape_; | ||||||
|  |   PrimitiveType primitiveType_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace builder
 | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
|  | @ -29,7 +29,7 @@ | ||||||
| #include "llvm/TableGen/Record.h" | #include "llvm/TableGen/Record.h" | ||||||
| #include "llvm/TableGen/TableGenBackend.h" | #include "llvm/TableGen/TableGenBackend.h" | ||||||
| 
 | 
 | ||||||
| #include "iostream" | #include <iostream> | ||||||
| 
 | 
 | ||||||
| #define DEBUG_TYPE "mlir-tblgen-opdefgen" | #define DEBUG_TYPE "mlir-tblgen-opdefgen" | ||||||
| 
 | 
 | ||||||
|  | @ -42,28 +42,117 @@ static const char *const generatedArgName = "odsArg"; | ||||||
| static const char *const odsBuilder = "odsBuilder"; | static const char *const odsBuilder = "odsBuilder"; | ||||||
| static const char *const builderOpState = "odsState"; | static const char *const builderOpState = "odsState"; | ||||||
| 
 | 
 | ||||||
| static const std::map<std::string, std::string> typeMapMLIR = { | namespace { | ||||||
|     {"::mlir::StringAttr", "std::string"}, | struct mlirTypeWrap { | ||||||
|     {"::mlir::IntegerAttr", "int"}, |   std::string Name; | ||||||
|     {"::mlir::DenseIntElementsAttr", "std::vector<int>"}, |   std::string (*ConvertToMlir)(std::string &, OpMethodBody &); | ||||||
|     {"::mlir::mhlo::ChannelHandle", "ChannelHandle"}, |  | ||||||
|     {"::mlir::FloatAttr", "float"}, |  | ||||||
|     {"::mlir::BoolAttr", "bool"}, |  | ||||||
|     {"::mlir::ElementsAttr", "::builder::Array"}, |  | ||||||
|     {"::mlir::DenseElementsAttr", "::builder::Tensor"}, |  | ||||||
|     // current only support string array.
 |  | ||||||
|     {"::mlir::ArrayAttr", "std::vector<std::string>"}, |  | ||||||
|     {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"}, |  | ||||||
|     {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"}, |  | ||||||
|     {"::mlir::mhlo::GatherDimensionNumbers", |  | ||||||
|      "::builder::GatherDimensionNumbers"}, |  | ||||||
|     {"::mlir::mhlo::ScatterDimensionNumbers", |  | ||||||
|      "::builder::ScatterDimensionNumbers"}, |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | static const std::map<std::string, mlirTypeWrap> typeMapMLIR = { | ||||||
|  |     {"::mlir::StringAttr", | ||||||
|  |      {"std::string", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         body << "  mlir::StringAttr " << var << "_mlir = mlir::StringAttr::get(" | ||||||
|  |              << var << ", ctx);\n"; | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::IntegerAttr", | ||||||
|  |      {"builder::Integer", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         // body << "  mlir::IntegerAttr " << var << "_mlir = mlir::IntegerAttr::get("
 | ||||||
|  |         //      << var << ", ctx);\n";
 | ||||||
|  |         body << "  mlir::IntegerAttr " << var << "_mlir = " << var | ||||||
|  |              << ".GetAttr(ctx);\n"; | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::DenseIntElementsAttr", | ||||||
|  |      {"std::vector<int>", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         body << "  mlir::DenseIntElementsAttr " << var | ||||||
|  |              << "_mlir = mlir::DenseIntElementsAttr::get(" | ||||||
|  |              << "mlir::VectorType::get(" << var | ||||||
|  |              << ".size(), opBuilder->getIntegerType(32))," << var << ");\n"; | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::mhlo::ChannelHandle", | ||||||
|  |      {"ChannelHandle", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::FloatAttr", | ||||||
|  |      {"float", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         body << "  mlir::FloatAttr " << var << "_mlir = mlir::FloatAttr::get(" | ||||||
|  |              << var << ", ctx);\n"; | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::BoolAttr", | ||||||
|  |      {"bool", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         body << "  mlir::BoolAttr " << var << "_mlir = mlir::BoolAttr::get(" | ||||||
|  |              << var << ", ctx);\n"; | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::ElementsAttr", | ||||||
|  |      {"::builder::Array", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::DenseElementsAttr", | ||||||
|  |      {"::builder::Tensor", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     // current only support string array.
 | ||||||
|  |     {"::mlir::ArrayAttr", | ||||||
|  |      {"std::vector<std::string>", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::mhlo::ConvDimensionNumbers", | ||||||
|  |      {"::builder::ConvDimensionNumbers", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::mhlo::DotDimensionNumbers", | ||||||
|  |      {"::builder::DotDimensionNumbers", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::mhlo::GatherDimensionNumbers", | ||||||
|  |      {"::builder::GatherDimensionNumbers", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  |     {"::mlir::mhlo::ScatterDimensionNumbers", | ||||||
|  |      {"::builder::ScatterDimensionNumbers", | ||||||
|  |       [](std::string &var, OpMethodBody &body) -> std::string { | ||||||
|  |         return var + "_mlir"; | ||||||
|  |       }}}, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
 | ||||||
|  | //     {"::mlir::StringAttr", "std::string"},
 | ||||||
|  | //     {"::mlir::IntegerAttr", "int"},
 | ||||||
|  | //     {"::mlir::DenseIntElementsAttr", "std::vector<int>"},
 | ||||||
|  | //     {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
 | ||||||
|  | //     {"::mlir::FloatAttr", "float"},
 | ||||||
|  | //     {"::mlir::BoolAttr", "bool"},
 | ||||||
|  | //     {"::mlir::ElementsAttr", "::builder::Array"},
 | ||||||
|  | //     {"::mlir::DenseElementsAttr", "::builder::Tensor"},
 | ||||||
|  | //     // current only support string array.
 | ||||||
|  | //     {"::mlir::ArrayAttr", "std::vector<std::string>"},
 | ||||||
|  | //     {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
 | ||||||
|  | //     {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
 | ||||||
|  | //     {"::mlir::mhlo::GatherDimensionNumbers",
 | ||||||
|  | //      "::builder::GatherDimensionNumbers"},
 | ||||||
|  | //     {"::mlir::mhlo::ScatterDimensionNumbers",
 | ||||||
|  | //      "::builder::ScatterDimensionNumbers"},
 | ||||||
|  | // };
 | ||||||
|  | 
 | ||||||
| StringRef typeConvertFromMLIR(StringRef type) { | StringRef typeConvertFromMLIR(StringRef type) { | ||||||
|   auto re = typeMapMLIR.find(type.str()); |   auto re = typeMapMLIR.find(type.str()); | ||||||
|   if (re != typeMapMLIR.end()) return StringRef(re->second); |   if (re != typeMapMLIR.end()) return StringRef(re->second.Name); | ||||||
|   return type; |   return type; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -75,6 +164,7 @@ StringRef getReturnType(const Attribute &att) { | ||||||
|   auto type = att.getStorageType(); |   auto type = att.getStorageType(); | ||||||
|   return typeConvertFromMLIR(type); |   return typeConvertFromMLIR(type); | ||||||
| } | } | ||||||
|  | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| // The logic to calculate the actual value range for a declared operand/result
 | // The logic to calculate the actual value range for a declared operand/result
 | ||||||
| // of an op with variadic operands/results. Note that this logic is not for
 | // of an op with variadic operands/results. Note that this logic is not for
 | ||||||
|  | @ -312,14 +402,6 @@ static std::string getArgumentName(const Operator &op, int index) { | ||||||
|     return std::string(formatv("{0}_{1}", generatedArgName, index)); |     return std::string(formatv("{0}_{1}", generatedArgName, index)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Returns true if we can use unwrapped value for the given `attr` in builders.
 |  | ||||||
| static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { |  | ||||||
|   return getReturnType(attr) != getStorageType(attr) && |  | ||||||
|          // We need to wrap the raw value into an attribute in the builder impl
 |  | ||||||
|          // so we need to make sure that the attribute specifies how to do that.
 |  | ||||||
|          !attr.getConstBuilderTemplate().empty(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| // Op emitter
 | // Op emitter
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
|  | @ -427,7 +509,7 @@ private: | ||||||
|                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); |                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); | ||||||
| 
 | 
 | ||||||
|   // Adds op arguments and regions into operation state for build() methods.
 |   // Adds op arguments and regions into operation state for build() methods.
 | ||||||
|   void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,llvm::SmallVector<OpMethodParameter, 4> paramList, |   void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, | ||||||
|                                               bool isRawValueAttr = false); |                                               bool isRawValueAttr = false); | ||||||
| 
 | 
 | ||||||
|   // Generates canonicalizer declaration for the operation.
 |   // Generates canonicalizer declaration for the operation.
 | ||||||
|  | @ -479,7 +561,9 @@ private: | ||||||
|   // Generate the type inference interface methods.
 |   // Generate the type inference interface methods.
 | ||||||
|   void genTypeInterfaceMethods(); |   void genTypeInterfaceMethods(); | ||||||
| 
 | 
 | ||||||
| private: |   Operator GetOp() { return op; } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|   // The TableGen record for this op.
 |   // The TableGen record for this op.
 | ||||||
|   // TODO: OpEmitter should not have a Record directly,
 |   // TODO: OpEmitter should not have a Record directly,
 | ||||||
|   // it should rather go through the Operator for better abstraction.
 |   // it should rather go through the Operator for better abstraction.
 | ||||||
|  | @ -1077,27 +1161,6 @@ void OpEmitter::genNamedSuccessorGetters() { | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| static bool canGenerateUnwrappedBuilder(Operator &op) { |  | ||||||
|   // If this op does not have native attributes at all, return directly to avoid
 |  | ||||||
|   // redefining builders.
 |  | ||||||
|   if (op.getNumNativeAttributes() == 0) |  | ||||||
|     return false; |  | ||||||
| 
 |  | ||||||
|   bool canGenerate = false; |  | ||||||
|   // We are generating builders that take raw values for attributes. We need to
 |  | ||||||
|   // make sure the native attributes have a meaningful "unwrapped" value type
 |  | ||||||
|   // different from the wrapped mlir::Attribute type to avoid redefining
 |  | ||||||
|   // builders. This checks for the op has at least one such native attribute.
 |  | ||||||
|   for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { |  | ||||||
|     NamedAttribute &namedAttr = op.getAttribute(i); |  | ||||||
|     if (canUseUnwrappedRawValue(namedAttr.attr)) { |  | ||||||
|       canGenerate = true; |  | ||||||
|       break; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|   return canGenerate; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| static bool canInferType(Operator &op) { | static bool canInferType(Operator &op) { | ||||||
|   return op.getTrait("::mlir::InferTypeOpInterface::Trait") && |   return op.getTrait("::mlir::InferTypeOpInterface::Trait") && | ||||||
|          op.getNumRegions() == 0; |          op.getNumRegions() == 0; | ||||||
|  | @ -1106,18 +1169,14 @@ static bool canInferType(Operator &op) { | ||||||
| void OpEmitter::genSeparateArgParamBuilder() { | void OpEmitter::genSeparateArgParamBuilder() { | ||||||
|   SmallVector<AttrParamKind, 2> attrBuilderType; |   SmallVector<AttrParamKind, 2> attrBuilderType; | ||||||
|   attrBuilderType.push_back(AttrParamKind::WrappedAttr); |   attrBuilderType.push_back(AttrParamKind::WrappedAttr); | ||||||
|   // if (canGenerateUnwrappedBuilder(op))
 |  | ||||||
|   //   attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
 |  | ||||||
| 
 | 
 | ||||||
|   // Emit with separate builders with or without unwrapped attributes and/or
 |   // Emit with separate builders with or without unwrapped attributes and/or
 | ||||||
|   // inferring result type.
 |   // inferring result type.
 | ||||||
|   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, |   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, | ||||||
|                   bool inferType) { |                   bool inferType) { | ||||||
|     llvm::SmallVector<OpMethodParameter, 4> paramList; |     llvm::SmallVector<OpMethodParameter, 4> paramList; | ||||||
|     llvm::SmallVector<OpMethodParameter, 4> paramList2; |  | ||||||
|     llvm::SmallVector<std::string, 4> resultNames; |     llvm::SmallVector<std::string, 4> resultNames; | ||||||
|     buildParamList(paramList, resultNames, paramKind, attrType); |     buildParamList(paramList, resultNames, paramKind, attrType); | ||||||
|     buildParamList(paramList2, resultNames, paramKind, attrType); |  | ||||||
| 
 | 
 | ||||||
|     auto *m = opClass.addMethodAndPrune( |     auto *m = opClass.addMethodAndPrune( | ||||||
|         "::builder::Op", "build", OpMethod::MP_Static, std::move(paramList)); |         "::builder::Op", "build", OpMethod::MP_Static, std::move(paramList)); | ||||||
|  | @ -1126,10 +1185,10 @@ void OpEmitter::genSeparateArgParamBuilder() { | ||||||
|       return; |       return; | ||||||
|     auto &body = m->body(); |     auto &body = m->body(); | ||||||
|     genCodeForAddingArgAndRegionForBuilder( |     genCodeForAddingArgAndRegionForBuilder( | ||||||
|         body, paramList2, attrType == AttrParamKind::UnwrappedValue); |         body, attrType == AttrParamKind::UnwrappedValue); | ||||||
| 
 | 
 | ||||||
|     // Push all result types to the operation state
 |     // Push all result types to the operation state
 | ||||||
| //"BBBBBBBBBBBB"
 | 
 | ||||||
|   //   if (inferType) {
 |   //   if (inferType) {
 | ||||||
|   //     // Generate builder that infers type too.
 |   //     // Generate builder that infers type too.
 | ||||||
|   //     // TODO: Subsume this with general checking if type can be
 |   //     // TODO: Subsume this with general checking if type can be
 | ||||||
|  | @ -1306,29 +1365,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList, | ||||||
|   int numAttrs = 0; |   int numAttrs = 0; | ||||||
| 
 | 
 | ||||||
|   int defaultValuedAttrStartIndex = op.getNumArgs(); |   int defaultValuedAttrStartIndex = op.getNumArgs(); | ||||||
|   if (attrParamKind == AttrParamKind::UnwrappedValue) { |  | ||||||
|     // Calculate the start index from which we can attach default values in the
 |  | ||||||
|     // builder declaration.
 |  | ||||||
|     for (int i = op.getNumArgs() - 1; i >= 0; --i) { |  | ||||||
|       auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>(); |  | ||||||
|       if (!namedAttr || !namedAttr->attr.hasDefaultValue()) |  | ||||||
|         break; |  | ||||||
| 
 |  | ||||||
|       if (!canUseUnwrappedRawValue(namedAttr->attr)) |  | ||||||
|         break; |  | ||||||
| 
 |  | ||||||
|       // Creating an APInt requires us to provide bitwidth, value, and
 |  | ||||||
|       // signedness, which is complicated compared to others. Similarly
 |  | ||||||
|       // for APFloat.
 |  | ||||||
|       // TODO: Adjust the 'returnType' field of such attributes
 |  | ||||||
|       // to support them.
 |  | ||||||
|       StringRef retType = getReturnType(namedAttr->attr); |  | ||||||
|       if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") |  | ||||||
|         break; |  | ||||||
| 
 |  | ||||||
|       defaultValuedAttrStartIndex = i; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |   for (int i = 0, e = op.getNumArgs(); i < e; ++i) { | ||||||
|     auto argument = op.getArg(i); |     auto argument = op.getArg(i); | ||||||
|  | @ -1357,10 +1393,10 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList, | ||||||
|         type = getStorageType(attr); |         type = getStorageType(attr); | ||||||
|         break; |         break; | ||||||
|       case AttrParamKind::UnwrappedValue: |       case AttrParamKind::UnwrappedValue: | ||||||
|         if (canUseUnwrappedRawValue(attr)) |         // if (canUseUnwrappedRawValue(attr))
 | ||||||
|           type = getReturnType(attr); |         //   type = getReturnType(attr);
 | ||||||
|         else |         // else
 | ||||||
|           type = getStorageType(attr); |         //   type = getStorageType(attr);
 | ||||||
|         break; |         break; | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|  | @ -1394,41 +1430,72 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList, | ||||||
|                              llvm::formatv("{0}Count", region.name).str()); |                              llvm::formatv("{0}Count", region.name).str()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void OpEmitter::genCodeForAddingArgAndRegionForBuilder( | void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, | ||||||
|     OpMethodBody &body, llvm::SmallVector<OpMethodParameter, 4> paramList, |                                                        bool isRawValueAttr) { | ||||||
|     bool isRawValueAttr) { |   auto op = GetOp(); | ||||||
|  |   auto operands = op.getOperands(); | ||||||
|  |   auto attrs = op.getAttributes(); | ||||||
|  |   SmallVector<std::string, 4> newAttrs; | ||||||
| 
 | 
 | ||||||
|   body << "//    AAAAAA  \n"; |   // for(auto o : operands){
 | ||||||
|   // for(auto p : paramList)
 |   //   body << "  operands name: "<<o.name<<"  getCPPClassName:"<<o.constraint.getCPPClassName()<<"\n";
 | ||||||
|   // {
 |   // }
 | ||||||
|   // body <<"====  type:"<< p.getType() << "  name:"<< p.getName() <<  "\n";
 |   // for(auto a : attrs){
 | ||||||
|  |   //   body << "  attrs name: "<<a.name<<"  type: "<< a.attr.getStorageType().str()<<"\n";
 | ||||||
|  |   //   std::string attrType = a.attr.getStorageType().str();
 | ||||||
|  |   //   auto ff = typeMapMLIR.find(attrType);
 | ||||||
|  |   //   if(ff != typeMapMLIR.end()){
 | ||||||
|  |   //     body << "//    BBBBBB  \n";
 | ||||||
|  |   //     std::string attrName = a.name.str();
 | ||||||
|  |   //     ff->second.ConvertToMlir(attrName,body);
 | ||||||
|  |   //   }
 | ||||||
|   // }
 |   // }
 | ||||||
| 
 | 
 | ||||||
|   body << "  auto builder = builder.GetImpl();\n"; | 
 | ||||||
|   body << "  auto loc = builder->GetLoc();\n"; |   // body << "//    AAAAAA  \n";
 | ||||||
|   body << "  auto opBuilder = builder->GetBuilder();\n"; |   // for(auto p : paramList){
 | ||||||
|  |   //   body << "  AAA type"<<p.getType()<<"  name"<<p.getName()<<"\n";
 | ||||||
|  |   // }
 | ||||||
|  | 
 | ||||||
|  |   body << "  auto b = builder.GetImpl();\n"; | ||||||
|  |   body << "  auto loc = b->GetLoc();\n"; | ||||||
|  |   body << "  auto opBuilder = b->GetBuilder();\n"; | ||||||
|  |   body << "  auto ctx = b->GetContext();\n"; | ||||||
|  |   for (auto a : attrs) { | ||||||
|  |     std::string attrType = a.attr.getStorageType().str(); | ||||||
|  | 
 | ||||||
|  | if(attrType == "::mlir::DenseIntElementsAttr") | ||||||
|  | { | ||||||
|  |   body << "  //  BBBBBBBB  getStorageType:"<< a.attr.getStorageType().str() <<"\n"; | ||||||
|  |   body << "  //  BBBBBBBB  getReturnType:"<< a.attr.getReturnType().str() <<"\n"; | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     auto typePair = typeMapMLIR.find(attrType); | ||||||
|  |     if (typePair != typeMapMLIR.end()) { | ||||||
|  |       std::string attrName = a.name.str(); | ||||||
|  |       auto mlirName = typePair->second.ConvertToMlir(attrName, body); | ||||||
|  |       newAttrs.emplace_back(mlirName); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|   body << "  mlir::" << op.getDialectName() << "::" << op.getCppClassName() |   body << "  mlir::" << op.getDialectName() << "::" << op.getCppClassName() | ||||||
|        << " currentOp =\n"; |        << " currentOp =\n"; | ||||||
|   body << "    opBuilder.create<mlir::" << op.getDialectName() |   body << "    opBuilder.create<mlir::" << op.getDialectName() | ||||||
|        << "::" << op.getDialectName() << ">(\n"; |        << "::" << op.getDialectName() << ">(\n"; | ||||||
|   if (paramList.size() > 1) { |   body << "      loc"; | ||||||
|     body << "      loc,\n"; |   std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) { | ||||||
|     std::for_each(paramList.begin() + 1, paramList.end() - 1, |     body << ",\n      " << v.name << ".getResult()"; | ||||||
|                   [&](OpMethodParameter &p) { |   }); | ||||||
|                     body << "      " << p.getName() << ",\n"; |   std::for_each(newAttrs.begin(), newAttrs.end(), | ||||||
|                   }); |                 [&](std::string &n) { body << ",\n      " << n; }); | ||||||
|     body << "      " << paramList.back().getName() << "\n"; |   body << "\n    );\n"; | ||||||
|   } else { |  | ||||||
|     body << "      loc\n"; |  | ||||||
|   } |  | ||||||
|   body << "    );\n"; |  | ||||||
|   body << "  builder::" << op.getCppClassName() << " builderOp;\n"; |   body << "  builder::" << op.getCppClassName() << " builderOp;\n"; | ||||||
|   body << "  auto opImpl = builderOp.GetImpl();\n"; |   body << "  auto opImpl = builderOp.GetImpl();\n"; | ||||||
|   body << "  opImpl.SetOperation(currentOp.getOperation());\n"; |   body << "  opImpl.SetOperation(currentOp.getOperation());\n"; | ||||||
|   body << "  return builderOp;\n"; |   body << "  return builderOp;\n"; | ||||||
| 
 | 
 | ||||||
|    |  | ||||||
| 
 |  | ||||||
|   // // Push all operands to the result.
 |   // // Push all operands to the result.
 | ||||||
|   // for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
 |   // for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
 | ||||||
|   //   std::string argName = getArgumentName(op, i);
 |   //   std::string argName = getArgumentName(op, i);
 | ||||||
|  | @ -2079,184 +2146,6 @@ void OpEmitter::genOpAsmInterface() { | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| // OpOperandAdaptor emitter
 |  | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| 
 |  | ||||||
| namespace { |  | ||||||
| // Helper class to emit Op operand adaptors to an output stream.  Operand
 |  | ||||||
| // adaptors are wrappers around ArrayRef<Value> that provide named operand
 |  | ||||||
| // getters identical to those defined in the Op.
 |  | ||||||
| class OpOperandAdaptorEmitter { |  | ||||||
| public: |  | ||||||
|   static void emitDecl(const Operator &op, raw_ostream &os); |  | ||||||
|   static void emitDef(const Operator &op, raw_ostream &os); |  | ||||||
| 
 |  | ||||||
| private: |  | ||||||
|   explicit OpOperandAdaptorEmitter(const Operator &op); |  | ||||||
| 
 |  | ||||||
|   // Add verification function. This generates a verify method for the adaptor
 |  | ||||||
|   // which verifies all the op-independent attribute constraints.
 |  | ||||||
|   void addVerification(); |  | ||||||
| 
 |  | ||||||
|   const Operator &op; |  | ||||||
|   Class adaptor; |  | ||||||
| }; |  | ||||||
| } // end namespace
 |  | ||||||
| 
 |  | ||||||
| OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) |  | ||||||
|     : op(op), adaptor(op.getAdaptorName()) { |  | ||||||
|   adaptor.newField("::mlir::ValueRange", "odsOperands"); |  | ||||||
|   adaptor.newField("::mlir::DictionaryAttr", "odsAttrs"); |  | ||||||
|   adaptor.newField("::mlir::RegionRange", "odsRegions"); |  | ||||||
|   const auto *attrSizedOperands = |  | ||||||
|       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); |  | ||||||
|   { |  | ||||||
|     SmallVector<OpMethodParameter, 2> paramList; |  | ||||||
|     paramList.emplace_back("::mlir::ValueRange", "values"); |  | ||||||
|     paramList.emplace_back("::mlir::DictionaryAttr", "attrs", |  | ||||||
|                            attrSizedOperands ? "" : "nullptr"); |  | ||||||
|     paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); |  | ||||||
|     auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList)); |  | ||||||
| 
 |  | ||||||
|     constructor->addMemberInitializer("odsOperands", "values"); |  | ||||||
|     constructor->addMemberInitializer("odsAttrs", "attrs"); |  | ||||||
|     constructor->addMemberInitializer("odsRegions", "regions"); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   { |  | ||||||
|     auto *constructor = adaptor.addConstructorAndPrune( |  | ||||||
|         llvm::formatv("{0}&", op.getCppClassName()).str(), "op"); |  | ||||||
|     constructor->addMemberInitializer("odsOperands", "op->getOperands()"); |  | ||||||
|     constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); |  | ||||||
|     constructor->addMemberInitializer("odsRegions", "op->getRegions()"); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   { |  | ||||||
|     auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands"); |  | ||||||
|     m->body() << "  return odsOperands;"; |  | ||||||
|   } |  | ||||||
|   std::string sizeAttrInit = |  | ||||||
|       formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); |  | ||||||
|   generateNamedOperandGetters(op, adaptor, sizeAttrInit, |  | ||||||
|                               /*rangeType=*/"::mlir::ValueRange", |  | ||||||
|                               /*rangeBeginCall=*/"odsOperands.begin()", |  | ||||||
|                               /*rangeSizeCall=*/"odsOperands.size()", |  | ||||||
|                               /*getOperandCallPattern=*/"odsOperands[{0}]"); |  | ||||||
| 
 |  | ||||||
|   FmtContext fctx; |  | ||||||
|   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); |  | ||||||
| 
 |  | ||||||
|   auto emitAttr = [&](StringRef name, Attribute attr) { |  | ||||||
|     auto &body = adaptor.addMethodAndPrune(getStorageType(attr), name)->body(); |  | ||||||
|     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");" |  | ||||||
|          << "\n  " << getStorageType(attr) << " attr = " |  | ||||||
|          << "odsAttrs.get(\"" << name << "\")."; |  | ||||||
|     if (attr.hasDefaultValue() || attr.isOptional()) |  | ||||||
|       body << "dyn_cast_or_null<"; |  | ||||||
|     else |  | ||||||
|       body << "cast<"; |  | ||||||
|     body << getStorageType(attr) << ">();\n"; |  | ||||||
| 
 |  | ||||||
|     if (attr.hasDefaultValue()) { |  | ||||||
|       // Use the default value if attribute is not set.
 |  | ||||||
|       // TODO: this is inefficient, we are recreating the attribute for every
 |  | ||||||
|       // call. This should be set instead.
 |  | ||||||
|       std::string defaultValue = std::string( |  | ||||||
|           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); |  | ||||||
|       body << "  if (!attr)\n    attr = " << defaultValue << ";\n"; |  | ||||||
|     } |  | ||||||
|     body << "  return attr;\n"; |  | ||||||
|   }; |  | ||||||
| 
 |  | ||||||
|   { |  | ||||||
|     auto *m = |  | ||||||
|         adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes"); |  | ||||||
|     m->body() << "  return odsAttrs;"; |  | ||||||
|   } |  | ||||||
|   for (auto &namedAttr : op.getAttributes()) { |  | ||||||
|     const auto &name = namedAttr.name; |  | ||||||
|     const auto &attr = namedAttr.attr; |  | ||||||
|     if (!attr.isDerivedAttr()) |  | ||||||
|       emitAttr(name, attr); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   unsigned numRegions = op.getNumRegions(); |  | ||||||
|   if (numRegions > 0) { |  | ||||||
|     auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions"); |  | ||||||
|     m->body() << "  return odsRegions;"; |  | ||||||
|   } |  | ||||||
|   for (unsigned i = 0; i < numRegions; ++i) { |  | ||||||
|     const auto ®ion = op.getRegion(i); |  | ||||||
|     if (region.name.empty()) |  | ||||||
|       continue; |  | ||||||
| 
 |  | ||||||
|     // Generate the accessors for a variadic region.
 |  | ||||||
|     if (region.isVariadic()) { |  | ||||||
|       auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name); |  | ||||||
|       m->body() << formatv("  return odsRegions.drop_front({0});", i); |  | ||||||
|       continue; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name); |  | ||||||
|     m->body() << formatv("  return *odsRegions[{0}];", i); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   // Add verification function.
 |  | ||||||
|   addVerification(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void OpOperandAdaptorEmitter::addVerification() { |  | ||||||
|   auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify", |  | ||||||
|                                            "::mlir::Location", "loc"); |  | ||||||
|   auto &body = method->body(); |  | ||||||
| 
 |  | ||||||
|   const char *checkAttrSizedValueSegmentsCode = R"( |  | ||||||
|   { |  | ||||||
|     auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); |  | ||||||
|     auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); |  | ||||||
|     if (numElements != {1}) |  | ||||||
|       return emitError(loc, "'{0}' attribute for specifying {2} segments " |  | ||||||
|                        "must have {1} elements, but got ") << numElements; |  | ||||||
|   } |  | ||||||
|   )"; |  | ||||||
| 
 |  | ||||||
|   // Verify a few traits first so that we can use
 |  | ||||||
|   // getODSOperands()/getODSResults() in the rest of the verifier.
 |  | ||||||
|   for (auto &trait : op.getTraits()) { |  | ||||||
|     if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) { |  | ||||||
|       if (t->getFullyQualifiedTraitName() == |  | ||||||
|           "::mlir::OpTrait::AttrSizedOperandSegments") { |  | ||||||
|         body << formatv(checkAttrSizedValueSegmentsCode, |  | ||||||
|                         "operand_segment_sizes", op.getNumOperands(), |  | ||||||
|                         "operand"); |  | ||||||
|       } else if (t->getFullyQualifiedTraitName() == |  | ||||||
|                  "::mlir::OpTrait::AttrSizedResultSegments") { |  | ||||||
|         body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes", |  | ||||||
|                         op.getNumResults(), "result"); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   FmtContext verifyCtx; |  | ||||||
|   populateSubstitutions(op, "odsAttrs.get", "getODSOperands", |  | ||||||
|                         "<no results should be generated>", verifyCtx); |  | ||||||
|   genAttributeVerifier(op, "odsAttrs.get", |  | ||||||
|                        Twine("emitError(loc, \"'") + op.getOperationName() + |  | ||||||
|                            "' op \"", |  | ||||||
|                        /*emitVerificationRequiringOp*/ false, verifyCtx, body); |  | ||||||
| 
 |  | ||||||
|   body << "  return ::mlir::success();"; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { |  | ||||||
|   OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { |  | ||||||
|   OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Emits the opcode enum and op classes.
 | // Emits the opcode enum and op classes.
 | ||||||
| static void emitOpClasses(const RecordKeeper &recordKeeper, | static void emitOpClasses(const RecordKeeper &recordKeeper, | ||||||
|                           const std::vector<Record *> &defs, raw_ostream &os, |                           const std::vector<Record *> &defs, raw_ostream &os, | ||||||
|  | @ -2286,11 +2175,9 @@ static void emitOpClasses(const RecordKeeper &recordKeeper, | ||||||
|     NamespaceEmitter emitter(os, op.getCppNamespace()); |     NamespaceEmitter emitter(os, op.getCppNamespace()); | ||||||
|     if (emitDecl) { |     if (emitDecl) { | ||||||
|       // os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
 |       // os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
 | ||||||
|       // OpOperandAdaptorEmitter::emitDecl(op, os);
 |  | ||||||
|       OpEmitter::emitDecl(op, os, staticVerifierEmitter); |       OpEmitter::emitDecl(op, os, staticVerifierEmitter); | ||||||
|     } else { |     } else { | ||||||
|       // os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
 |       // os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
 | ||||||
|       // OpOperandAdaptorEmitter::emitDef(op, os);
 |  | ||||||
|       OpEmitter::emitDef(op, os, staticVerifierEmitter); |       OpEmitter::emitDef(op, os, staticVerifierEmitter); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -1,16 +0,0 @@ | ||||||
| #ifndef BUILDER_OP_ |  | ||||||
| #define BUILDER_OP_ |  | ||||||
| 
 |  | ||||||
| #include "iostream" |  | ||||||
| 
 |  | ||||||
| namespace builder { |  | ||||||
| class Op { |  | ||||||
|   class Impl; |  | ||||||
|   std::shared_ptr<Impl> GetImpl() { return _impl; } |  | ||||||
| 
 |  | ||||||
|  private: |  | ||||||
|   std::shared_ptr<Impl> _impl; |  | ||||||
| } |  | ||||||
| }  // namespace builder
 |  | ||||||
| 
 |  | ||||||
| #endif |  | ||||||
|  | @ -1,16 +0,0 @@ | ||||||
| #ifndef BUILDER_TENSOR_ |  | ||||||
| #define BUILDER_TENSOR_ |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| #include "iostream" |  | ||||||
| 
 |  | ||||||
| namespace builder{ |  | ||||||
|   class Tensor{ |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| #endif |  | ||||||
|  | @ -1,10 +0,0 @@ | ||||||
| #ifndef BUILDER_TYPE_ |  | ||||||
| #define BUILDER_TYPE_ |  | ||||||
| 
 |  | ||||||
| #include "iostream" |  | ||||||
| 
 |  | ||||||
| namespace builder { |  | ||||||
| class Type {} |  | ||||||
| }  // namespace builder
 |  | ||||||
| 
 |  | ||||||
| #endif |  | ||||||
		Loading…
	
		Reference in New Issue