add integer attribute convert
This commit is contained in:
parent
9d5166684c
commit
95fc37ffa8
2
BUILD
2
BUILD
|
@ -129,6 +129,8 @@ cc_binary(
|
||||||
"tools/mlir-tblgen-builder/*.cpp",
|
"tools/mlir-tblgen-builder/*.cpp",
|
||||||
"tools/mlir-tblgen-builder/TableGen/*.h",
|
"tools/mlir-tblgen-builder/TableGen/*.h",
|
||||||
"tools/mlir-tblgen-builder/TableGen/*.cpp",
|
"tools/mlir-tblgen-builder/TableGen/*.cpp",
|
||||||
|
"tools/mlir-tblgen-builder/Builder/*.h",
|
||||||
|
"tools/mlir-tblgen-builder/Builder/*.cpp",
|
||||||
]),
|
]),
|
||||||
deps = [
|
deps = [
|
||||||
"@llvm-project//mlir:MlirTableGenMain",
|
"@llvm-project//mlir:MlirTableGenMain",
|
||||||
|
|
|
@ -14,8 +14,25 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "iostream"
|
#include "iostream"
|
||||||
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "mlir/Support/MlirOptMain.h"
|
#include "llvm/Support/DynamicLibrary.h"
|
||||||
|
#include "llvm/Support/FileUtilities.h"
|
||||||
|
#include "llvm/Support/InitLLVM.h"
|
||||||
|
#include "llvm/Support/Regex.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
#include "llvm/Support/StringSaver.h"
|
||||||
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
|
#include "mlir/ExecutionEngine/JitRunner.h"
|
||||||
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
#include "mlir/IR/AsmState.h"
|
#include "mlir/IR/AsmState.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
@ -23,57 +40,26 @@ limitations under the License.
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/InitAllDialects.h"
|
||||||
|
#include "mlir/InitAllPasses.h"
|
||||||
#include "mlir/Parser.h"
|
#include "mlir/Parser.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Support/DebugCounter.h"
|
#include "mlir/Support/DebugCounter.h"
|
||||||
#include "mlir/Support/FileUtilities.h"
|
#include "mlir/Support/FileUtilities.h"
|
||||||
|
#include "mlir/Support/MlirOptMain.h"
|
||||||
#include "mlir/Support/Timing.h"
|
#include "mlir/Support/Timing.h"
|
||||||
#include "mlir/Support/ToolUtilities.h"
|
#include "mlir/Support/ToolUtilities.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
|
||||||
#include "llvm/Support/FileUtilities.h"
|
|
||||||
#include "llvm/Support/InitLLVM.h"
|
|
||||||
#include "llvm/Support/Regex.h"
|
|
||||||
#include "llvm/Support/SourceMgr.h"
|
|
||||||
#include "llvm/Support/StringSaver.h"
|
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
|
||||||
#include "llvm/Support/DynamicLibrary.h"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
||||||
#include "mlir/ExecutionEngine/JitRunner.h"
|
|
||||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
|
||||||
#include "mlir/Target/LLVMIR/Dialect/All.h"
|
#include "mlir/Target/LLVMIR/Dialect/All.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#include "mlir/InitAllDialects.h"
|
|
||||||
#include "mlir/InitAllPasses.h"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
|
||||||
|
|
||||||
|
|
||||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
|
||||||
// #include "mlir/IR/BuiltinTypes.h"
|
// #include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/TargetSelect.h"
|
|
||||||
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
|
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
|
||||||
#include "llvm/ExecutionEngine/Orc/Mangling.h"
|
#include "llvm/ExecutionEngine/Orc/Mangling.h"
|
||||||
|
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
#include "llvm/IR/LLVMContext.h"
|
#include "llvm/IR/LLVMContext.h"
|
||||||
#include "llvm/IR/LegacyPassNameParser.h"
|
#include "llvm/IR/LegacyPassNameParser.h"
|
||||||
|
#include "llvm/Support/TargetSelect.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
|
@ -87,11 +73,9 @@ struct MemRefDescriptor {
|
||||||
int64_t sizes[N];
|
int64_t sizes[N];
|
||||||
int64_t strides[N];
|
int64_t strides[N];
|
||||||
};
|
};
|
||||||
}
|
} // namespace utils
|
||||||
|
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
|
|
||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
llvm::InitializeNativeTarget();
|
llvm::InitializeNativeTarget();
|
||||||
llvm::InitializeNativeTargetAsmPrinter();
|
llvm::InitializeNativeTargetAsmPrinter();
|
||||||
|
@ -105,7 +89,6 @@ int main(int argc, char **argv) {
|
||||||
// registerDefaultTimingManagerCLOptions();
|
// registerDefaultTimingManagerCLOptions();
|
||||||
DebugCounter::registerCLOptions();
|
DebugCounter::registerCLOptions();
|
||||||
|
|
||||||
|
|
||||||
mlir::registerAllPasses();
|
mlir::registerAllPasses();
|
||||||
mlir::mhlo::registerAllMhloPasses();
|
mlir::mhlo::registerAllMhloPasses();
|
||||||
mlir::lmhlo::registerAllLmhloPasses();
|
mlir::lmhlo::registerAllLmhloPasses();
|
||||||
|
@ -120,24 +103,22 @@ int main(int argc, char **argv) {
|
||||||
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||||
registry.insert<mlir::disc_ral::RalDialect>();
|
registry.insert<mlir::disc_ral::RalDialect>();
|
||||||
|
|
||||||
|
|
||||||
// failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
|
// failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
|
||||||
// registry,
|
// registry,
|
||||||
// /*preloadDialectsInContext=*/false));
|
// /*preloadDialectsInContext=*/false));
|
||||||
// return 0;
|
// return 0;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::string errorMessage;
|
std::string errorMessage;
|
||||||
|
|
||||||
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage);
|
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir",
|
||||||
auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
|
// &errorMessage);
|
||||||
|
auto file =
|
||||||
|
mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
|
||||||
std::cout << "errorMessage:" << errorMessage << std::endl;
|
std::cout << "errorMessage:" << errorMessage << std::endl;
|
||||||
|
|
||||||
SourceMgr sourceMgr;
|
SourceMgr sourceMgr;
|
||||||
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
|
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
|
||||||
|
|
||||||
|
|
||||||
SmallVector<const llvm::PassInfo *, 4> passes;
|
SmallVector<const llvm::PassInfo *, 4> passes;
|
||||||
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
|
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||||
auto tmOrError = tmBuilderOrError->createTargetMachine();
|
auto tmOrError = tmBuilderOrError->createTargetMachine();
|
||||||
|
@ -145,8 +126,6 @@ int main(int argc, char **argv) {
|
||||||
auto transformer = mlir::makeLLVMPassesTransformer(
|
auto transformer = mlir::makeLLVMPassesTransformer(
|
||||||
passes, 0, /*targetMachine=*/tmOrError->get(), 0);
|
passes, 0, /*targetMachine=*/tmOrError->get(), 0);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
MLIRContext context(registry);
|
MLIRContext context(registry);
|
||||||
context.loadAllAvailableDialects();
|
context.loadAllAvailableDialects();
|
||||||
OwningModuleRef module(parseSourceFile(sourceMgr, &context));
|
OwningModuleRef module(parseSourceFile(sourceMgr, &context));
|
||||||
|
@ -157,7 +136,6 @@ int main(int argc, char **argv) {
|
||||||
return failure();
|
return failure();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
|
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
|
||||||
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
|
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
|
||||||
// RUN: -buffer-deallocation -canonicalize -cse \
|
// RUN: -buffer-deallocation -canonicalize -cse \
|
||||||
|
@ -192,8 +170,6 @@ int main(int argc, char **argv) {
|
||||||
pm.run(*module);
|
pm.run(*module);
|
||||||
module->dump();
|
module->dump();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::cout << "DEBUG load module success" << std::endl;
|
std::cout << "DEBUG load module success" << std::endl;
|
||||||
|
|
||||||
llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default;
|
llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default;
|
||||||
|
@ -206,12 +182,10 @@ int main(int argc, char **argv) {
|
||||||
// Use absolute library path so that gdb can find the symbol table.
|
// Use absolute library path so that gdb can find the symbol table.
|
||||||
|
|
||||||
std::list<std::string> sharedlib;
|
std::list<std::string> sharedlib;
|
||||||
sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
|
sharedlib.push_back(
|
||||||
|
"/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
|
||||||
|
|
||||||
transform(
|
transform(sharedlib, std::back_inserter(libPaths), [](std::string libPath) {
|
||||||
sharedlib,
|
|
||||||
std::back_inserter(libPaths),
|
|
||||||
[](std::string libPath) {
|
|
||||||
SmallString<256> absPath(libPath.begin(), libPath.end());
|
SmallString<256> absPath(libPath.begin(), libPath.end());
|
||||||
cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
|
cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
|
||||||
return absPath;
|
return absPath;
|
||||||
|
@ -226,7 +200,6 @@ int main(int argc, char **argv) {
|
||||||
llvm::StringMap<void *> exportSymbols;
|
llvm::StringMap<void *> exportSymbols;
|
||||||
SmallVector<MlirRunnerDestroyFn> destroyFns;
|
SmallVector<MlirRunnerDestroyFn> destroyFns;
|
||||||
|
|
||||||
|
|
||||||
// Handle libraries that do support mlir-runner init/destroy callbacks.
|
// Handle libraries that do support mlir-runner init/destroy callbacks.
|
||||||
for (auto &libPath : libPaths) {
|
for (auto &libPath : libPaths) {
|
||||||
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
|
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
|
||||||
|
@ -255,17 +228,12 @@ int main(int argc, char **argv) {
|
||||||
return symbolMap;
|
return symbolMap;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto expectedEngine =
|
||||||
|
mlir::ExecutionEngine::create(module.get(), nullptr, transformer,
|
||||||
|
jitCodeGenOptLevel, executionEngineLibs);
|
||||||
auto expectedEngine = mlir::ExecutionEngine::create(
|
|
||||||
module.get(), nullptr, transformer, jitCodeGenOptLevel,
|
|
||||||
executionEngineLibs);
|
|
||||||
// if (!expectedEngine)
|
// if (!expectedEngine)
|
||||||
// return expectedEngine.takeError();
|
// return expectedEngine.takeError();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
auto engine = std::move(*expectedEngine);
|
auto engine = std::move(*expectedEngine);
|
||||||
engine->registerSymbols(runtimeSymbolMap);
|
engine->registerSymbols(runtimeSymbolMap);
|
||||||
|
|
||||||
|
@ -277,7 +245,6 @@ int main(int argc, char **argv) {
|
||||||
// if (options.dumpObjectFile)
|
// if (options.dumpObjectFile)
|
||||||
// engine->dumpToObjectFile("a.o");
|
// engine->dumpToObjectFile("a.o");
|
||||||
|
|
||||||
|
|
||||||
float rawdata[6] = {0, 1, 2, 3, 4, 5};
|
float rawdata[6] = {0, 1, 2, 3, 4, 5};
|
||||||
int64_t dims = 1;
|
int64_t dims = 1;
|
||||||
utils::MemRefDescriptor<float, 1> a{rawdata, rawdata, 0, {6}, {1}};
|
utils::MemRefDescriptor<float, 1> a{rawdata, rawdata, 0, {6}, {1}};
|
||||||
|
|
|
@ -1,10 +0,0 @@
|
||||||
#ifndef BUILDER_ARRAY_
|
|
||||||
#define BUILDER_ARRAY_
|
|
||||||
|
|
||||||
#include "iostream"
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
class Array {}
|
|
||||||
} // namespace builder
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
#include "Attribute.h"
|
||||||
|
|
||||||
|
#include "AttributeImpl.h"
|
||||||
|
#include "Shape.h"
|
||||||
|
#include "Tensor.h"
|
||||||
|
#include "TensorImpl.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "memory.h"
|
||||||
|
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
// #include "mlir/IR/Attributes.h"
|
||||||
|
// #include "mlir/IR/Operation.h"
|
||||||
|
// #include "mlir/IR/StandardTypes.h"
|
||||||
|
// #include "mlir/IR/Types.h"
|
||||||
|
// #include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {}
|
||||||
|
Integer::Integer(int64_t value)
|
||||||
|
: impl_(std::make_shared<Integer::Impl>(value)) {}
|
||||||
|
|
||||||
|
} // namespace builder
|
|
@ -0,0 +1,50 @@
|
||||||
|
#ifndef BUILDER_ATTRIBUTE_H_
|
||||||
|
#define BUILDER_ATTRIBUTE_H_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
class Integer {
|
||||||
|
public:
|
||||||
|
Integer(int value);
|
||||||
|
Integer(int64_t value);
|
||||||
|
class Impl;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Array {
|
||||||
|
public:
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Type {
|
||||||
|
public:
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// template <typename T>
|
||||||
|
// class DenseIntElementsAttr {
|
||||||
|
// public:
|
||||||
|
// explicit DenseIntElementsAttr(Tensor);
|
||||||
|
// class Impl;
|
||||||
|
// std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
// private:
|
||||||
|
// std::shared_ptr<Impl> impl_;
|
||||||
|
// };
|
||||||
|
|
||||||
|
} // namespace builder
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,58 @@
|
||||||
|
#ifndef BUILDER_TYPEIMPL_
|
||||||
|
#define BUILDER_TYPEIMPL_
|
||||||
|
|
||||||
|
#include "Builder.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
class Integer::Impl {
|
||||||
|
public:
|
||||||
|
Impl(){};
|
||||||
|
|
||||||
|
Impl(int value) {
|
||||||
|
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
||||||
|
auto int32_type = mlir::IntegerType::get(context, 32);
|
||||||
|
return mlir::IntegerAttr::get(int32_type, value);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
Impl(int64_t value)
|
||||||
|
: GetAttr([=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
||||||
|
auto int32_type = mlir::IntegerType::get(context, 64);
|
||||||
|
return mlir::IntegerAttr::get(int32_type, value);
|
||||||
|
}){};
|
||||||
|
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
|
||||||
|
|
||||||
|
private:
|
||||||
|
};
|
||||||
|
|
||||||
|
class Array::Impl {
|
||||||
|
public:
|
||||||
|
Impl() = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
};
|
||||||
|
class Type::Impl {
|
||||||
|
public:
|
||||||
|
Impl() = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
};
|
||||||
|
|
||||||
|
// class DenseIntElementsAttr::Impl {
|
||||||
|
// public:
|
||||||
|
// Impl() = default;
|
||||||
|
|
||||||
|
// private:
|
||||||
|
// };
|
||||||
|
|
||||||
|
} // namespace builder
|
||||||
|
|
||||||
|
#endif
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
Builder::Builder() : _impl(std::make_shared<Impl>()) {}
|
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
||||||
void Builder::DumpModule() {}
|
void Builder::DumpModule() {}
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
|
@ -11,10 +11,10 @@ class Builder {
|
||||||
|
|
||||||
Builder();
|
Builder();
|
||||||
void DumpModule();
|
void DumpModule();
|
||||||
std::shared_ptr<Impl> GetImpl() { return _impl; }
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Impl> _impl;
|
std::shared_ptr<Impl> impl_;
|
||||||
};
|
};
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
#ifndef BUILDER_OP_
|
||||||
|
#define BUILDER_OP_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
class Op {
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
}
|
||||||
|
} // namespace builder
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,28 @@
|
||||||
|
#include "PrimitiveType.h"
|
||||||
|
|
||||||
|
#include "Shape.h"
|
||||||
|
#include "Tensor.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
// #include "mlir/IR/Attributes.h"
|
||||||
|
// #include "mlir/IR/Operation.h"
|
||||||
|
// #include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
// #include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
// constexpr mlir::Type GetMlirType(const PrimitiveType primitiveType) {
|
||||||
|
// switch (primitiveType)
|
||||||
|
// {
|
||||||
|
// case PrimitiveType::BF16 :
|
||||||
|
// return
|
||||||
|
// break;
|
||||||
|
|
||||||
|
// default:
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// }
|
||||||
|
|
||||||
|
} // namespace builder
|
|
@ -0,0 +1,29 @@
|
||||||
|
#ifndef BUILDER_PRIMITIVE_TYPE_
|
||||||
|
#define BUILDER_PRIMITIVE_TYPE_
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
enum class PrimitiveType {
|
||||||
|
PRIMITIVE_TYPE_INVALID = 0,
|
||||||
|
PRED = 1,
|
||||||
|
S8 = 2,
|
||||||
|
S16 = 3,
|
||||||
|
S32 = 4,
|
||||||
|
S64 = 5,
|
||||||
|
U8 = 6,
|
||||||
|
U16 = 7,
|
||||||
|
U32 = 8,
|
||||||
|
U64 = 9,
|
||||||
|
F16 = 10,
|
||||||
|
F32 = 11,
|
||||||
|
BF16 = 16,
|
||||||
|
F64 = 12,
|
||||||
|
C64 = 15,
|
||||||
|
C128 = 18,
|
||||||
|
TUPLE = 13,
|
||||||
|
OPAQUE_TYPE = 14,
|
||||||
|
TOKEN = 17
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -0,0 +1,17 @@
|
||||||
|
#ifndef BUILDER_SHAPE_
|
||||||
|
#define BUILDER_SHAPE_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
class Shape {
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
} // namespace builder
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,24 @@
|
||||||
|
#ifndef BUILDER_SHAPEIMPL_
|
||||||
|
#define BUILDER_SHAPEIMPL_
|
||||||
|
|
||||||
|
#include "Builder.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
class Shape::Impl {
|
||||||
|
public:
|
||||||
|
Impl() = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace builder
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,22 @@
|
||||||
|
#include "Tensor.h"
|
||||||
|
|
||||||
|
#include "Shape.h"
|
||||||
|
#include "TensorImpl.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
// #include "mlir/IR/Attributes.h"
|
||||||
|
// #include "mlir/IR/Operation.h"
|
||||||
|
// #include "mlir/IR/StandardTypes.h"
|
||||||
|
// #include "mlir/IR/Types.h"
|
||||||
|
// #include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
// Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
|
||||||
|
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
|
||||||
|
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||||
|
|
||||||
|
Shape Tensor::GetShape() { return impl_->GetShape(); }
|
||||||
|
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
||||||
|
|
||||||
|
} // namespace builder
|
|
@ -0,0 +1,25 @@
|
||||||
|
#ifndef BUILDER_TENSOR_
|
||||||
|
#define BUILDER_TENSOR_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "PrimitiveType.h"
|
||||||
|
#include "Shape.h"
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
class Tensor {
|
||||||
|
public:
|
||||||
|
Tensor(Shape& shape, PrimitiveType& primitiveType);
|
||||||
|
class Impl;
|
||||||
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
Shape GetShape();
|
||||||
|
PrimitiveType GetType();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
} // namespace builder
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,32 @@
|
||||||
|
#ifndef BUILDER_TENSORIMPL_
|
||||||
|
#define BUILDER_TENSORIMPL_
|
||||||
|
|
||||||
|
#include "Builder.h"
|
||||||
|
#include "PrimitiveType.h"
|
||||||
|
#include "Shape.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
class Tensor::Impl {
|
||||||
|
public:
|
||||||
|
Impl() = default;
|
||||||
|
Impl(Shape& shape, PrimitiveType& primitiveType)
|
||||||
|
: shape_(shape), primitiveType_(primitiveType) {}
|
||||||
|
Shape GetShape() { return shape_; }
|
||||||
|
PrimitiveType GetType() { return primitiveType_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Shape shape_;
|
||||||
|
PrimitiveType primitiveType_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace builder
|
||||||
|
|
||||||
|
#endif
|
|
@ -29,7 +29,7 @@
|
||||||
#include "llvm/TableGen/Record.h"
|
#include "llvm/TableGen/Record.h"
|
||||||
#include "llvm/TableGen/TableGenBackend.h"
|
#include "llvm/TableGen/TableGenBackend.h"
|
||||||
|
|
||||||
#include "iostream"
|
#include <iostream>
|
||||||
|
|
||||||
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
|
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
|
||||||
|
|
||||||
|
@ -42,28 +42,117 @@ static const char *const generatedArgName = "odsArg";
|
||||||
static const char *const odsBuilder = "odsBuilder";
|
static const char *const odsBuilder = "odsBuilder";
|
||||||
static const char *const builderOpState = "odsState";
|
static const char *const builderOpState = "odsState";
|
||||||
|
|
||||||
static const std::map<std::string, std::string> typeMapMLIR = {
|
namespace {
|
||||||
{"::mlir::StringAttr", "std::string"},
|
struct mlirTypeWrap {
|
||||||
{"::mlir::IntegerAttr", "int"},
|
std::string Name;
|
||||||
{"::mlir::DenseIntElementsAttr", "std::vector<int>"},
|
std::string (*ConvertToMlir)(std::string &, OpMethodBody &);
|
||||||
{"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
|
|
||||||
{"::mlir::FloatAttr", "float"},
|
|
||||||
{"::mlir::BoolAttr", "bool"},
|
|
||||||
{"::mlir::ElementsAttr", "::builder::Array"},
|
|
||||||
{"::mlir::DenseElementsAttr", "::builder::Tensor"},
|
|
||||||
// current only support string array.
|
|
||||||
{"::mlir::ArrayAttr", "std::vector<std::string>"},
|
|
||||||
{"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
|
|
||||||
{"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
|
|
||||||
{"::mlir::mhlo::GatherDimensionNumbers",
|
|
||||||
"::builder::GatherDimensionNumbers"},
|
|
||||||
{"::mlir::mhlo::ScatterDimensionNumbers",
|
|
||||||
"::builder::ScatterDimensionNumbers"},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
|
||||||
|
{"::mlir::StringAttr",
|
||||||
|
{"std::string",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::StringAttr " << var << "_mlir = mlir::StringAttr::get("
|
||||||
|
<< var << ", ctx);\n";
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::IntegerAttr",
|
||||||
|
{"builder::Integer",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
// body << " mlir::IntegerAttr " << var << "_mlir = mlir::IntegerAttr::get("
|
||||||
|
// << var << ", ctx);\n";
|
||||||
|
body << " mlir::IntegerAttr " << var << "_mlir = " << var
|
||||||
|
<< ".GetAttr(ctx);\n";
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::DenseIntElementsAttr",
|
||||||
|
{"std::vector<int>",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::DenseIntElementsAttr " << var
|
||||||
|
<< "_mlir = mlir::DenseIntElementsAttr::get("
|
||||||
|
<< "mlir::VectorType::get(" << var
|
||||||
|
<< ".size(), opBuilder->getIntegerType(32))," << var << ");\n";
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::mhlo::ChannelHandle",
|
||||||
|
{"ChannelHandle",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::FloatAttr",
|
||||||
|
{"float",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::FloatAttr " << var << "_mlir = mlir::FloatAttr::get("
|
||||||
|
<< var << ", ctx);\n";
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::BoolAttr",
|
||||||
|
{"bool",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
body << " mlir::BoolAttr " << var << "_mlir = mlir::BoolAttr::get("
|
||||||
|
<< var << ", ctx);\n";
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::ElementsAttr",
|
||||||
|
{"::builder::Array",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::DenseElementsAttr",
|
||||||
|
{"::builder::Tensor",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
// current only support string array.
|
||||||
|
{"::mlir::ArrayAttr",
|
||||||
|
{"std::vector<std::string>",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::mhlo::ConvDimensionNumbers",
|
||||||
|
{"::builder::ConvDimensionNumbers",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::mhlo::DotDimensionNumbers",
|
||||||
|
{"::builder::DotDimensionNumbers",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::mhlo::GatherDimensionNumbers",
|
||||||
|
{"::builder::GatherDimensionNumbers",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
{"::mlir::mhlo::ScatterDimensionNumbers",
|
||||||
|
{"::builder::ScatterDimensionNumbers",
|
||||||
|
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||||
|
return var + "_mlir";
|
||||||
|
}}},
|
||||||
|
};
|
||||||
|
|
||||||
|
// static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
|
||||||
|
// {"::mlir::StringAttr", "std::string"},
|
||||||
|
// {"::mlir::IntegerAttr", "int"},
|
||||||
|
// {"::mlir::DenseIntElementsAttr", "std::vector<int>"},
|
||||||
|
// {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
|
||||||
|
// {"::mlir::FloatAttr", "float"},
|
||||||
|
// {"::mlir::BoolAttr", "bool"},
|
||||||
|
// {"::mlir::ElementsAttr", "::builder::Array"},
|
||||||
|
// {"::mlir::DenseElementsAttr", "::builder::Tensor"},
|
||||||
|
// // current only support string array.
|
||||||
|
// {"::mlir::ArrayAttr", "std::vector<std::string>"},
|
||||||
|
// {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
|
||||||
|
// {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
|
||||||
|
// {"::mlir::mhlo::GatherDimensionNumbers",
|
||||||
|
// "::builder::GatherDimensionNumbers"},
|
||||||
|
// {"::mlir::mhlo::ScatterDimensionNumbers",
|
||||||
|
// "::builder::ScatterDimensionNumbers"},
|
||||||
|
// };
|
||||||
|
|
||||||
StringRef typeConvertFromMLIR(StringRef type) {
|
StringRef typeConvertFromMLIR(StringRef type) {
|
||||||
auto re = typeMapMLIR.find(type.str());
|
auto re = typeMapMLIR.find(type.str());
|
||||||
if (re != typeMapMLIR.end()) return StringRef(re->second);
|
if (re != typeMapMLIR.end()) return StringRef(re->second.Name);
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,6 +164,7 @@ StringRef getReturnType(const Attribute &att) {
|
||||||
auto type = att.getStorageType();
|
auto type = att.getStorageType();
|
||||||
return typeConvertFromMLIR(type);
|
return typeConvertFromMLIR(type);
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// The logic to calculate the actual value range for a declared operand/result
|
// The logic to calculate the actual value range for a declared operand/result
|
||||||
// of an op with variadic operands/results. Note that this logic is not for
|
// of an op with variadic operands/results. Note that this logic is not for
|
||||||
|
@ -312,14 +402,6 @@ static std::string getArgumentName(const Operator &op, int index) {
|
||||||
return std::string(formatv("{0}_{1}", generatedArgName, index));
|
return std::string(formatv("{0}_{1}", generatedArgName, index));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if we can use unwrapped value for the given `attr` in builders.
|
|
||||||
static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
|
|
||||||
return getReturnType(attr) != getStorageType(attr) &&
|
|
||||||
// We need to wrap the raw value into an attribute in the builder impl
|
|
||||||
// so we need to make sure that the attribute specifies how to do that.
|
|
||||||
!attr.getConstBuilderTemplate().empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Op emitter
|
// Op emitter
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -427,7 +509,7 @@ private:
|
||||||
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
|
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
|
||||||
|
|
||||||
// Adds op arguments and regions into operation state for build() methods.
|
// Adds op arguments and regions into operation state for build() methods.
|
||||||
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,llvm::SmallVector<OpMethodParameter, 4> paramList,
|
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||||
bool isRawValueAttr = false);
|
bool isRawValueAttr = false);
|
||||||
|
|
||||||
// Generates canonicalizer declaration for the operation.
|
// Generates canonicalizer declaration for the operation.
|
||||||
|
@ -479,6 +561,8 @@ private:
|
||||||
// Generate the type inference interface methods.
|
// Generate the type inference interface methods.
|
||||||
void genTypeInterfaceMethods();
|
void genTypeInterfaceMethods();
|
||||||
|
|
||||||
|
Operator GetOp() { return op; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The TableGen record for this op.
|
// The TableGen record for this op.
|
||||||
// TODO: OpEmitter should not have a Record directly,
|
// TODO: OpEmitter should not have a Record directly,
|
||||||
|
@ -1077,27 +1161,6 @@ void OpEmitter::genNamedSuccessorGetters() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool canGenerateUnwrappedBuilder(Operator &op) {
|
|
||||||
// If this op does not have native attributes at all, return directly to avoid
|
|
||||||
// redefining builders.
|
|
||||||
if (op.getNumNativeAttributes() == 0)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
bool canGenerate = false;
|
|
||||||
// We are generating builders that take raw values for attributes. We need to
|
|
||||||
// make sure the native attributes have a meaningful "unwrapped" value type
|
|
||||||
// different from the wrapped mlir::Attribute type to avoid redefining
|
|
||||||
// builders. This checks for the op has at least one such native attribute.
|
|
||||||
for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
|
|
||||||
NamedAttribute &namedAttr = op.getAttribute(i);
|
|
||||||
if (canUseUnwrappedRawValue(namedAttr.attr)) {
|
|
||||||
canGenerate = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return canGenerate;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool canInferType(Operator &op) {
|
static bool canInferType(Operator &op) {
|
||||||
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
|
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
|
||||||
op.getNumRegions() == 0;
|
op.getNumRegions() == 0;
|
||||||
|
@ -1106,18 +1169,14 @@ static bool canInferType(Operator &op) {
|
||||||
void OpEmitter::genSeparateArgParamBuilder() {
|
void OpEmitter::genSeparateArgParamBuilder() {
|
||||||
SmallVector<AttrParamKind, 2> attrBuilderType;
|
SmallVector<AttrParamKind, 2> attrBuilderType;
|
||||||
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
|
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
|
||||||
// if (canGenerateUnwrappedBuilder(op))
|
|
||||||
// attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
|
|
||||||
|
|
||||||
// Emit with separate builders with or without unwrapped attributes and/or
|
// Emit with separate builders with or without unwrapped attributes and/or
|
||||||
// inferring result type.
|
// inferring result type.
|
||||||
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
|
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
|
||||||
bool inferType) {
|
bool inferType) {
|
||||||
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
||||||
llvm::SmallVector<OpMethodParameter, 4> paramList2;
|
|
||||||
llvm::SmallVector<std::string, 4> resultNames;
|
llvm::SmallVector<std::string, 4> resultNames;
|
||||||
buildParamList(paramList, resultNames, paramKind, attrType);
|
buildParamList(paramList, resultNames, paramKind, attrType);
|
||||||
buildParamList(paramList2, resultNames, paramKind, attrType);
|
|
||||||
|
|
||||||
auto *m = opClass.addMethodAndPrune(
|
auto *m = opClass.addMethodAndPrune(
|
||||||
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
|
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
|
||||||
|
@ -1126,10 +1185,10 @@ void OpEmitter::genSeparateArgParamBuilder() {
|
||||||
return;
|
return;
|
||||||
auto &body = m->body();
|
auto &body = m->body();
|
||||||
genCodeForAddingArgAndRegionForBuilder(
|
genCodeForAddingArgAndRegionForBuilder(
|
||||||
body, paramList2, attrType == AttrParamKind::UnwrappedValue);
|
body, attrType == AttrParamKind::UnwrappedValue);
|
||||||
|
|
||||||
// Push all result types to the operation state
|
// Push all result types to the operation state
|
||||||
//"BBBBBBBBBBBB"
|
|
||||||
// if (inferType) {
|
// if (inferType) {
|
||||||
// // Generate builder that infers type too.
|
// // Generate builder that infers type too.
|
||||||
// // TODO: Subsume this with general checking if type can be
|
// // TODO: Subsume this with general checking if type can be
|
||||||
|
@ -1306,29 +1365,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
||||||
int numAttrs = 0;
|
int numAttrs = 0;
|
||||||
|
|
||||||
int defaultValuedAttrStartIndex = op.getNumArgs();
|
int defaultValuedAttrStartIndex = op.getNumArgs();
|
||||||
if (attrParamKind == AttrParamKind::UnwrappedValue) {
|
|
||||||
// Calculate the start index from which we can attach default values in the
|
|
||||||
// builder declaration.
|
|
||||||
for (int i = op.getNumArgs() - 1; i >= 0; --i) {
|
|
||||||
auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
|
|
||||||
if (!namedAttr || !namedAttr->attr.hasDefaultValue())
|
|
||||||
break;
|
|
||||||
|
|
||||||
if (!canUseUnwrappedRawValue(namedAttr->attr))
|
|
||||||
break;
|
|
||||||
|
|
||||||
// Creating an APInt requires us to provide bitwidth, value, and
|
|
||||||
// signedness, which is complicated compared to others. Similarly
|
|
||||||
// for APFloat.
|
|
||||||
// TODO: Adjust the 'returnType' field of such attributes
|
|
||||||
// to support them.
|
|
||||||
StringRef retType = getReturnType(namedAttr->attr);
|
|
||||||
if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
|
|
||||||
break;
|
|
||||||
|
|
||||||
defaultValuedAttrStartIndex = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||||
auto argument = op.getArg(i);
|
auto argument = op.getArg(i);
|
||||||
|
@ -1357,10 +1393,10 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
||||||
type = getStorageType(attr);
|
type = getStorageType(attr);
|
||||||
break;
|
break;
|
||||||
case AttrParamKind::UnwrappedValue:
|
case AttrParamKind::UnwrappedValue:
|
||||||
if (canUseUnwrappedRawValue(attr))
|
// if (canUseUnwrappedRawValue(attr))
|
||||||
type = getReturnType(attr);
|
// type = getReturnType(attr);
|
||||||
else
|
// else
|
||||||
type = getStorageType(attr);
|
// type = getStorageType(attr);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1394,41 +1430,72 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
||||||
llvm::formatv("{0}Count", region.name).str());
|
llvm::formatv("{0}Count", region.name).str());
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
|
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||||
OpMethodBody &body, llvm::SmallVector<OpMethodParameter, 4> paramList,
|
|
||||||
bool isRawValueAttr) {
|
bool isRawValueAttr) {
|
||||||
|
auto op = GetOp();
|
||||||
|
auto operands = op.getOperands();
|
||||||
|
auto attrs = op.getAttributes();
|
||||||
|
SmallVector<std::string, 4> newAttrs;
|
||||||
|
|
||||||
body << "// AAAAAA \n";
|
// for(auto o : operands){
|
||||||
// for(auto p : paramList)
|
// body << " operands name: "<<o.name<<" getCPPClassName:"<<o.constraint.getCPPClassName()<<"\n";
|
||||||
// {
|
// }
|
||||||
// body <<"==== type:"<< p.getType() << " name:"<< p.getName() << "\n";
|
// for(auto a : attrs){
|
||||||
|
// body << " attrs name: "<<a.name<<" type: "<< a.attr.getStorageType().str()<<"\n";
|
||||||
|
// std::string attrType = a.attr.getStorageType().str();
|
||||||
|
// auto ff = typeMapMLIR.find(attrType);
|
||||||
|
// if(ff != typeMapMLIR.end()){
|
||||||
|
// body << "// BBBBBB \n";
|
||||||
|
// std::string attrName = a.name.str();
|
||||||
|
// ff->second.ConvertToMlir(attrName,body);
|
||||||
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
body << " auto builder = builder.GetImpl();\n";
|
|
||||||
body << " auto loc = builder->GetLoc();\n";
|
// body << "// AAAAAA \n";
|
||||||
body << " auto opBuilder = builder->GetBuilder();\n";
|
// for(auto p : paramList){
|
||||||
|
// body << " AAA type"<<p.getType()<<" name"<<p.getName()<<"\n";
|
||||||
|
// }
|
||||||
|
|
||||||
|
body << " auto b = builder.GetImpl();\n";
|
||||||
|
body << " auto loc = b->GetLoc();\n";
|
||||||
|
body << " auto opBuilder = b->GetBuilder();\n";
|
||||||
|
body << " auto ctx = b->GetContext();\n";
|
||||||
|
for (auto a : attrs) {
|
||||||
|
std::string attrType = a.attr.getStorageType().str();
|
||||||
|
|
||||||
|
if(attrType == "::mlir::DenseIntElementsAttr")
|
||||||
|
{
|
||||||
|
body << " // BBBBBBBB getStorageType:"<< a.attr.getStorageType().str() <<"\n";
|
||||||
|
body << " // BBBBBBBB getReturnType:"<< a.attr.getReturnType().str() <<"\n";
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
auto typePair = typeMapMLIR.find(attrType);
|
||||||
|
if (typePair != typeMapMLIR.end()) {
|
||||||
|
std::string attrName = a.name.str();
|
||||||
|
auto mlirName = typePair->second.ConvertToMlir(attrName, body);
|
||||||
|
newAttrs.emplace_back(mlirName);
|
||||||
|
}
|
||||||
|
}
|
||||||
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
|
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
|
||||||
<< " currentOp =\n";
|
<< " currentOp =\n";
|
||||||
body << " opBuilder.create<mlir::" << op.getDialectName()
|
body << " opBuilder.create<mlir::" << op.getDialectName()
|
||||||
<< "::" << op.getDialectName() << ">(\n";
|
<< "::" << op.getDialectName() << ">(\n";
|
||||||
if (paramList.size() > 1) {
|
body << " loc";
|
||||||
body << " loc,\n";
|
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
|
||||||
std::for_each(paramList.begin() + 1, paramList.end() - 1,
|
body << ",\n " << v.name << ".getResult()";
|
||||||
[&](OpMethodParameter &p) {
|
|
||||||
body << " " << p.getName() << ",\n";
|
|
||||||
});
|
});
|
||||||
body << " " << paramList.back().getName() << "\n";
|
std::for_each(newAttrs.begin(), newAttrs.end(),
|
||||||
} else {
|
[&](std::string &n) { body << ",\n " << n; });
|
||||||
body << " loc\n";
|
body << "\n );\n";
|
||||||
}
|
|
||||||
body << " );\n";
|
|
||||||
body << " builder::" << op.getCppClassName() << " builderOp;\n";
|
body << " builder::" << op.getCppClassName() << " builderOp;\n";
|
||||||
body << " auto opImpl = builderOp.GetImpl();\n";
|
body << " auto opImpl = builderOp.GetImpl();\n";
|
||||||
body << " opImpl.SetOperation(currentOp.getOperation());\n";
|
body << " opImpl.SetOperation(currentOp.getOperation());\n";
|
||||||
body << " return builderOp;\n";
|
body << " return builderOp;\n";
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// // Push all operands to the result.
|
// // Push all operands to the result.
|
||||||
// for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
// for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||||
// std::string argName = getArgumentName(op, i);
|
// std::string argName = getArgumentName(op, i);
|
||||||
|
@ -2079,184 +2146,6 @@ void OpEmitter::genOpAsmInterface() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// OpOperandAdaptor emitter
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// Helper class to emit Op operand adaptors to an output stream. Operand
|
|
||||||
// adaptors are wrappers around ArrayRef<Value> that provide named operand
|
|
||||||
// getters identical to those defined in the Op.
|
|
||||||
class OpOperandAdaptorEmitter {
|
|
||||||
public:
|
|
||||||
static void emitDecl(const Operator &op, raw_ostream &os);
|
|
||||||
static void emitDef(const Operator &op, raw_ostream &os);
|
|
||||||
|
|
||||||
private:
|
|
||||||
explicit OpOperandAdaptorEmitter(const Operator &op);
|
|
||||||
|
|
||||||
// Add verification function. This generates a verify method for the adaptor
|
|
||||||
// which verifies all the op-independent attribute constraints.
|
|
||||||
void addVerification();
|
|
||||||
|
|
||||||
const Operator &op;
|
|
||||||
Class adaptor;
|
|
||||||
};
|
|
||||||
} // end namespace
|
|
||||||
|
|
||||||
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
|
|
||||||
: op(op), adaptor(op.getAdaptorName()) {
|
|
||||||
adaptor.newField("::mlir::ValueRange", "odsOperands");
|
|
||||||
adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
|
|
||||||
adaptor.newField("::mlir::RegionRange", "odsRegions");
|
|
||||||
const auto *attrSizedOperands =
|
|
||||||
op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
|
|
||||||
{
|
|
||||||
SmallVector<OpMethodParameter, 2> paramList;
|
|
||||||
paramList.emplace_back("::mlir::ValueRange", "values");
|
|
||||||
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
|
|
||||||
attrSizedOperands ? "" : "nullptr");
|
|
||||||
paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
|
|
||||||
auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
|
|
||||||
|
|
||||||
constructor->addMemberInitializer("odsOperands", "values");
|
|
||||||
constructor->addMemberInitializer("odsAttrs", "attrs");
|
|
||||||
constructor->addMemberInitializer("odsRegions", "regions");
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto *constructor = adaptor.addConstructorAndPrune(
|
|
||||||
llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
|
|
||||||
constructor->addMemberInitializer("odsOperands", "op->getOperands()");
|
|
||||||
constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
|
|
||||||
constructor->addMemberInitializer("odsRegions", "op->getRegions()");
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands");
|
|
||||||
m->body() << " return odsOperands;";
|
|
||||||
}
|
|
||||||
std::string sizeAttrInit =
|
|
||||||
formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
|
|
||||||
generateNamedOperandGetters(op, adaptor, sizeAttrInit,
|
|
||||||
/*rangeType=*/"::mlir::ValueRange",
|
|
||||||
/*rangeBeginCall=*/"odsOperands.begin()",
|
|
||||||
/*rangeSizeCall=*/"odsOperands.size()",
|
|
||||||
/*getOperandCallPattern=*/"odsOperands[{0}]");
|
|
||||||
|
|
||||||
FmtContext fctx;
|
|
||||||
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
|
|
||||||
|
|
||||||
auto emitAttr = [&](StringRef name, Attribute attr) {
|
|
||||||
auto &body = adaptor.addMethodAndPrune(getStorageType(attr), name)->body();
|
|
||||||
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
|
|
||||||
<< "\n " << getStorageType(attr) << " attr = "
|
|
||||||
<< "odsAttrs.get(\"" << name << "\").";
|
|
||||||
if (attr.hasDefaultValue() || attr.isOptional())
|
|
||||||
body << "dyn_cast_or_null<";
|
|
||||||
else
|
|
||||||
body << "cast<";
|
|
||||||
body << getStorageType(attr) << ">();\n";
|
|
||||||
|
|
||||||
if (attr.hasDefaultValue()) {
|
|
||||||
// Use the default value if attribute is not set.
|
|
||||||
// TODO: this is inefficient, we are recreating the attribute for every
|
|
||||||
// call. This should be set instead.
|
|
||||||
std::string defaultValue = std::string(
|
|
||||||
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
|
|
||||||
body << " if (!attr)\n attr = " << defaultValue << ";\n";
|
|
||||||
}
|
|
||||||
body << " return attr;\n";
|
|
||||||
};
|
|
||||||
|
|
||||||
{
|
|
||||||
auto *m =
|
|
||||||
adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes");
|
|
||||||
m->body() << " return odsAttrs;";
|
|
||||||
}
|
|
||||||
for (auto &namedAttr : op.getAttributes()) {
|
|
||||||
const auto &name = namedAttr.name;
|
|
||||||
const auto &attr = namedAttr.attr;
|
|
||||||
if (!attr.isDerivedAttr())
|
|
||||||
emitAttr(name, attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned numRegions = op.getNumRegions();
|
|
||||||
if (numRegions > 0) {
|
|
||||||
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions");
|
|
||||||
m->body() << " return odsRegions;";
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < numRegions; ++i) {
|
|
||||||
const auto ®ion = op.getRegion(i);
|
|
||||||
if (region.name.empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// Generate the accessors for a variadic region.
|
|
||||||
if (region.isVariadic()) {
|
|
||||||
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name);
|
|
||||||
m->body() << formatv(" return odsRegions.drop_front({0});", i);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name);
|
|
||||||
m->body() << formatv(" return *odsRegions[{0}];", i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add verification function.
|
|
||||||
addVerification();
|
|
||||||
}
|
|
||||||
|
|
||||||
void OpOperandAdaptorEmitter::addVerification() {
|
|
||||||
auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
|
|
||||||
"::mlir::Location", "loc");
|
|
||||||
auto &body = method->body();
|
|
||||||
|
|
||||||
const char *checkAttrSizedValueSegmentsCode = R"(
|
|
||||||
{
|
|
||||||
auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
|
|
||||||
auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
|
|
||||||
if (numElements != {1})
|
|
||||||
return emitError(loc, "'{0}' attribute for specifying {2} segments "
|
|
||||||
"must have {1} elements, but got ") << numElements;
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
|
|
||||||
// Verify a few traits first so that we can use
|
|
||||||
// getODSOperands()/getODSResults() in the rest of the verifier.
|
|
||||||
for (auto &trait : op.getTraits()) {
|
|
||||||
if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) {
|
|
||||||
if (t->getFullyQualifiedTraitName() ==
|
|
||||||
"::mlir::OpTrait::AttrSizedOperandSegments") {
|
|
||||||
body << formatv(checkAttrSizedValueSegmentsCode,
|
|
||||||
"operand_segment_sizes", op.getNumOperands(),
|
|
||||||
"operand");
|
|
||||||
} else if (t->getFullyQualifiedTraitName() ==
|
|
||||||
"::mlir::OpTrait::AttrSizedResultSegments") {
|
|
||||||
body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
|
|
||||||
op.getNumResults(), "result");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FmtContext verifyCtx;
|
|
||||||
populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
|
|
||||||
"<no results should be generated>", verifyCtx);
|
|
||||||
genAttributeVerifier(op, "odsAttrs.get",
|
|
||||||
Twine("emitError(loc, \"'") + op.getOperationName() +
|
|
||||||
"' op \"",
|
|
||||||
/*emitVerificationRequiringOp*/ false, verifyCtx, body);
|
|
||||||
|
|
||||||
body << " return ::mlir::success();";
|
|
||||||
}
|
|
||||||
|
|
||||||
void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
|
|
||||||
OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
|
|
||||||
OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emits the opcode enum and op classes.
|
// Emits the opcode enum and op classes.
|
||||||
static void emitOpClasses(const RecordKeeper &recordKeeper,
|
static void emitOpClasses(const RecordKeeper &recordKeeper,
|
||||||
const std::vector<Record *> &defs, raw_ostream &os,
|
const std::vector<Record *> &defs, raw_ostream &os,
|
||||||
|
@ -2286,11 +2175,9 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
|
||||||
NamespaceEmitter emitter(os, op.getCppNamespace());
|
NamespaceEmitter emitter(os, op.getCppNamespace());
|
||||||
if (emitDecl) {
|
if (emitDecl) {
|
||||||
// os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
|
// os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
|
||||||
// OpOperandAdaptorEmitter::emitDecl(op, os);
|
|
||||||
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
|
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
|
||||||
} else {
|
} else {
|
||||||
// os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
|
// os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
|
||||||
// OpOperandAdaptorEmitter::emitDef(op, os);
|
|
||||||
OpEmitter::emitDef(op, os, staticVerifierEmitter);
|
OpEmitter::emitDef(op, os, staticVerifierEmitter);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
#ifndef BUILDER_OP_
|
|
||||||
#define BUILDER_OP_
|
|
||||||
|
|
||||||
#include "iostream"
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
class Op {
|
|
||||||
class Impl;
|
|
||||||
std::shared_ptr<Impl> GetImpl() { return _impl; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::shared_ptr<Impl> _impl;
|
|
||||||
}
|
|
||||||
} // namespace builder
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,16 +0,0 @@
|
||||||
#ifndef BUILDER_TENSOR_
|
|
||||||
#define BUILDER_TENSOR_
|
|
||||||
|
|
||||||
|
|
||||||
#include "iostream"
|
|
||||||
|
|
||||||
namespace builder{
|
|
||||||
class Tensor{
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,10 +0,0 @@
|
||||||
#ifndef BUILDER_TYPE_
|
|
||||||
#define BUILDER_TYPE_
|
|
||||||
|
|
||||||
#include "iostream"
|
|
||||||
|
|
||||||
namespace builder {
|
|
||||||
class Type {}
|
|
||||||
} // namespace builder
|
|
||||||
|
|
||||||
#endif
|
|
Loading…
Reference in New Issue