add integer attribute convert
This commit is contained in:
parent
9d5166684c
commit
95fc37ffa8
2
BUILD
2
BUILD
|
@ -129,6 +129,8 @@ cc_binary(
|
|||
"tools/mlir-tblgen-builder/*.cpp",
|
||||
"tools/mlir-tblgen-builder/TableGen/*.h",
|
||||
"tools/mlir-tblgen-builder/TableGen/*.cpp",
|
||||
"tools/mlir-tblgen-builder/Builder/*.h",
|
||||
"tools/mlir-tblgen-builder/Builder/*.cpp",
|
||||
]),
|
||||
deps = [
|
||||
"@llvm-project//mlir:MlirTableGenMain",
|
||||
|
|
|
@ -14,8 +14,25 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
#include "iostream"
|
||||
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/DynamicLibrary.h"
|
||||
#include "llvm/Support/FileUtilities.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/StringSaver.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/JitRunner.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -23,62 +40,31 @@ limitations under the License.
|
|||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/DebugCounter.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
#include "mlir/Support/Timing.h"
|
||||
#include "mlir/Support/ToolUtilities.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FileUtilities.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/StringSaver.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "llvm/Support/DynamicLibrary.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/ExecutionEngine/JitRunner.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/All.h"
|
||||
|
||||
|
||||
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
|
||||
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
|
||||
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
// #include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
|
||||
#include "llvm/ExecutionEngine/Orc/Mangling.h"
|
||||
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/LegacyPassNameParser.h"
|
||||
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace llvm;
|
||||
|
||||
namespace utils{
|
||||
namespace utils {
|
||||
template <typename T, int N>
|
||||
struct MemRefDescriptor {
|
||||
T *allocated;
|
||||
|
@ -87,11 +73,9 @@ struct MemRefDescriptor {
|
|||
int64_t sizes[N];
|
||||
int64_t strides[N];
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
|
@ -105,7 +89,6 @@ int main(int argc, char **argv) {
|
|||
// registerDefaultTimingManagerCLOptions();
|
||||
DebugCounter::registerCLOptions();
|
||||
|
||||
|
||||
mlir::registerAllPasses();
|
||||
mlir::mhlo::registerAllMhloPasses();
|
||||
mlir::lmhlo::registerAllLmhloPasses();
|
||||
|
@ -120,24 +103,22 @@ int main(int argc, char **argv) {
|
|||
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||
registry.insert<mlir::disc_ral::RalDialect>();
|
||||
|
||||
|
||||
// failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
|
||||
// registry,
|
||||
// /*preloadDialectsInContext=*/false));
|
||||
// return 0;
|
||||
|
||||
|
||||
|
||||
std::string errorMessage;
|
||||
|
||||
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage);
|
||||
auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
|
||||
std::cout<<"errorMessage:" <<errorMessage <<std::endl;
|
||||
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir",
|
||||
// &errorMessage);
|
||||
auto file =
|
||||
mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
|
||||
std::cout << "errorMessage:" << errorMessage << std::endl;
|
||||
|
||||
SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
|
||||
|
||||
|
||||
SmallVector<const llvm::PassInfo *, 4> passes;
|
||||
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||
auto tmOrError = tmBuilderOrError->createTargetMachine();
|
||||
|
@ -145,8 +126,6 @@ int main(int argc, char **argv) {
|
|||
auto transformer = mlir::makeLLVMPassesTransformer(
|
||||
passes, 0, /*targetMachine=*/tmOrError->get(), 0);
|
||||
|
||||
|
||||
|
||||
MLIRContext context(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
OwningModuleRef module(parseSourceFile(sourceMgr, &context));
|
||||
|
@ -157,8 +136,7 @@ int main(int argc, char **argv) {
|
|||
return failure();
|
||||
};
|
||||
|
||||
|
||||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
|
||||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
|
||||
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
|
||||
// RUN: -buffer-deallocation -canonicalize -cse \
|
||||
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
|
||||
|
@ -192,9 +170,7 @@ int main(int argc, char **argv) {
|
|||
pm.run(*module);
|
||||
module->dump();
|
||||
|
||||
|
||||
|
||||
std::cout<<"DEBUG load module success"<<std::endl;
|
||||
std::cout << "DEBUG load module success" << std::endl;
|
||||
|
||||
llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default;
|
||||
|
||||
|
@ -206,12 +182,10 @@ int main(int argc, char **argv) {
|
|||
// Use absolute library path so that gdb can find the symbol table.
|
||||
|
||||
std::list<std::string> sharedlib;
|
||||
sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
|
||||
sharedlib.push_back(
|
||||
"/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
|
||||
|
||||
transform(
|
||||
sharedlib,
|
||||
std::back_inserter(libPaths),
|
||||
[](std::string libPath) {
|
||||
transform(sharedlib, std::back_inserter(libPaths), [](std::string libPath) {
|
||||
SmallString<256> absPath(libPath.begin(), libPath.end());
|
||||
cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
|
||||
return absPath;
|
||||
|
@ -226,7 +200,6 @@ int main(int argc, char **argv) {
|
|||
llvm::StringMap<void *> exportSymbols;
|
||||
SmallVector<MlirRunnerDestroyFn> destroyFns;
|
||||
|
||||
|
||||
// Handle libraries that do support mlir-runner init/destroy callbacks.
|
||||
for (auto &libPath : libPaths) {
|
||||
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
|
||||
|
@ -255,17 +228,12 @@ int main(int argc, char **argv) {
|
|||
return symbolMap;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
auto expectedEngine = mlir::ExecutionEngine::create(
|
||||
module.get(), nullptr, transformer, jitCodeGenOptLevel,
|
||||
executionEngineLibs);
|
||||
auto expectedEngine =
|
||||
mlir::ExecutionEngine::create(module.get(), nullptr, transformer,
|
||||
jitCodeGenOptLevel, executionEngineLibs);
|
||||
// if (!expectedEngine)
|
||||
// return expectedEngine.takeError();
|
||||
|
||||
|
||||
|
||||
auto engine = std::move(*expectedEngine);
|
||||
engine->registerSymbols(runtimeSymbolMap);
|
||||
|
||||
|
@ -277,16 +245,15 @@ int main(int argc, char **argv) {
|
|||
// if (options.dumpObjectFile)
|
||||
// engine->dumpToObjectFile("a.o");
|
||||
|
||||
|
||||
float rawdata[6] = {0,1,2,3,4,5};
|
||||
float rawdata[6] = {0, 1, 2, 3, 4, 5};
|
||||
int64_t dims = 1;
|
||||
utils::MemRefDescriptor<float,1> a{rawdata,rawdata,0,{6},{1}};
|
||||
utils::MemRefDescriptor<float,1> b{rawdata,rawdata,0,{6},{1}};
|
||||
utils::MemRefDescriptor<float,1> result_memref;
|
||||
utils::MemRefDescriptor<float, 1> a{rawdata, rawdata, 0, {6}, {1}};
|
||||
utils::MemRefDescriptor<float, 1> b{rawdata, rawdata, 0, {6}, {1}};
|
||||
utils::MemRefDescriptor<float, 1> result_memref;
|
||||
|
||||
struct memref_type{
|
||||
struct memref_type {
|
||||
int64_t res_size = 6;
|
||||
utils::MemRefDescriptor<float,1> *memref;
|
||||
utils::MemRefDescriptor<float, 1> *memref;
|
||||
} result;
|
||||
result.memref = &result_memref;
|
||||
|
||||
|
@ -299,23 +266,23 @@ int main(int argc, char **argv) {
|
|||
} data;
|
||||
|
||||
data.data1_size = &dims;
|
||||
void * a_ptr = &a;
|
||||
void *a_ptr = &a;
|
||||
data.data1 = &a_ptr;
|
||||
data.data2_size = &dims;
|
||||
void * b_ptr = &b;
|
||||
void *b_ptr = &b;
|
||||
data.data2 = &b_ptr;
|
||||
void * result_ptr = &result;
|
||||
void *result_ptr = &result;
|
||||
data.res = &result;
|
||||
|
||||
void (*fptr)(void **) = *expectedFPtr;
|
||||
(*fptr)((void **)&data);
|
||||
|
||||
std::cout<<"result: "<<result.memref->allocated[0]<<std::endl;
|
||||
std::cout<<"result: "<<result.memref->allocated[1]<<std::endl;
|
||||
std::cout<<"result: "<<result.memref->allocated[2]<<std::endl;
|
||||
std::cout<<"result: "<<result.memref->allocated[3]<<std::endl;
|
||||
std::cout<<"result: "<<result.memref->allocated[4]<<std::endl;
|
||||
std::cout<<"result: "<<result.memref->allocated[5]<<std::endl;
|
||||
std::cout << "result: " << result.memref->allocated[0] << std::endl;
|
||||
std::cout << "result: " << result.memref->allocated[1] << std::endl;
|
||||
std::cout << "result: " << result.memref->allocated[2] << std::endl;
|
||||
std::cout << "result: " << result.memref->allocated[3] << std::endl;
|
||||
std::cout << "result: " << result.memref->allocated[4] << std::endl;
|
||||
std::cout << "result: " << result.memref->allocated[5] << std::endl;
|
||||
|
||||
// Run all dynamic library destroy callbacks to prepare for the shutdown.
|
||||
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
#ifndef BUILDER_ARRAY_
|
||||
#define BUILDER_ARRAY_
|
||||
|
||||
#include "iostream"
|
||||
|
||||
namespace builder {
|
||||
class Array {}
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -0,0 +1,22 @@
|
|||
#include "Attribute.h"
|
||||
|
||||
#include "AttributeImpl.h"
|
||||
#include "Shape.h"
|
||||
#include "Tensor.h"
|
||||
#include "TensorImpl.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "memory.h"
|
||||
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||
// #include "mlir/IR/Attributes.h"
|
||||
// #include "mlir/IR/Operation.h"
|
||||
// #include "mlir/IR/StandardTypes.h"
|
||||
// #include "mlir/IR/Types.h"
|
||||
// #include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {}
|
||||
Integer::Integer(int64_t value)
|
||||
: impl_(std::make_shared<Integer::Impl>(value)) {}
|
||||
|
||||
} // namespace builder
|
|
@ -0,0 +1,50 @@
|
|||
#ifndef BUILDER_ATTRIBUTE_H_
|
||||
#define BUILDER_ATTRIBUTE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Integer {
|
||||
public:
|
||||
Integer(int value);
|
||||
Integer(int64_t value);
|
||||
class Impl;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class Array {
|
||||
public:
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class Type {
|
||||
public:
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
// template <typename T>
|
||||
// class DenseIntElementsAttr {
|
||||
// public:
|
||||
// explicit DenseIntElementsAttr(Tensor);
|
||||
// class Impl;
|
||||
// std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
// private:
|
||||
// std::shared_ptr<Impl> impl_;
|
||||
// };
|
||||
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -0,0 +1,58 @@
|
|||
#ifndef BUILDER_TYPEIMPL_
|
||||
#define BUILDER_TYPEIMPL_
|
||||
|
||||
#include "Builder.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Integer::Impl {
|
||||
public:
|
||||
Impl(){};
|
||||
|
||||
Impl(int value) {
|
||||
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
||||
auto int32_type = mlir::IntegerType::get(context, 32);
|
||||
return mlir::IntegerAttr::get(int32_type, value);
|
||||
};
|
||||
};
|
||||
Impl(int64_t value)
|
||||
: GetAttr([=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
|
||||
auto int32_type = mlir::IntegerType::get(context, 64);
|
||||
return mlir::IntegerAttr::get(int32_type, value);
|
||||
}){};
|
||||
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
class Array::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
|
||||
private:
|
||||
};
|
||||
class Type::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
// class DenseIntElementsAttr::Impl {
|
||||
// public:
|
||||
// Impl() = default;
|
||||
|
||||
// private:
|
||||
// };
|
||||
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
namespace builder {
|
||||
|
||||
Builder::Builder() : _impl(std::make_shared<Impl>()) {}
|
||||
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
||||
void Builder::DumpModule() {}
|
||||
|
||||
} // namespace builder
|
|
@ -11,10 +11,10 @@ class Builder {
|
|||
|
||||
Builder();
|
||||
void DumpModule();
|
||||
std::shared_ptr<Impl> GetImpl() { return _impl; }
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> _impl;
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
} // namespace builder
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef BUILDER_OP_
|
||||
#define BUILDER_OP_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
namespace builder {
|
||||
class Op {
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
}
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -0,0 +1,28 @@
|
|||
#include "PrimitiveType.h"
|
||||
|
||||
#include "Shape.h"
|
||||
#include "Tensor.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||
// #include "mlir/IR/Attributes.h"
|
||||
// #include "mlir/IR/Operation.h"
|
||||
// #include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
// #include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
// constexpr mlir::Type GetMlirType(const PrimitiveType primitiveType) {
|
||||
// switch (primitiveType)
|
||||
// {
|
||||
// case PrimitiveType::BF16 :
|
||||
// return
|
||||
// break;
|
||||
|
||||
// default:
|
||||
// break;
|
||||
// }
|
||||
|
||||
// }
|
||||
|
||||
} // namespace builder
|
|
@ -0,0 +1,29 @@
|
|||
#ifndef BUILDER_PRIMITIVE_TYPE_
|
||||
#define BUILDER_PRIMITIVE_TYPE_
|
||||
|
||||
namespace builder {
|
||||
|
||||
enum class PrimitiveType {
|
||||
PRIMITIVE_TYPE_INVALID = 0,
|
||||
PRED = 1,
|
||||
S8 = 2,
|
||||
S16 = 3,
|
||||
S32 = 4,
|
||||
S64 = 5,
|
||||
U8 = 6,
|
||||
U16 = 7,
|
||||
U32 = 8,
|
||||
U64 = 9,
|
||||
F16 = 10,
|
||||
F32 = 11,
|
||||
BF16 = 16,
|
||||
F64 = 12,
|
||||
C64 = 15,
|
||||
C128 = 18,
|
||||
TUPLE = 13,
|
||||
OPAQUE_TYPE = 14,
|
||||
TOKEN = 17
|
||||
};
|
||||
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef BUILDER_SHAPE_
|
||||
#define BUILDER_SHAPE_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
namespace builder {
|
||||
class Shape {
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -0,0 +1,24 @@
|
|||
#ifndef BUILDER_SHAPEIMPL_
|
||||
#define BUILDER_SHAPEIMPL_
|
||||
|
||||
#include "Builder.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Shape::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -0,0 +1,22 @@
|
|||
#include "Tensor.h"
|
||||
|
||||
#include "Shape.h"
|
||||
#include "TensorImpl.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
// #include "mlir/Dialect/StandardOps/Ops.h"
|
||||
// #include "mlir/IR/Attributes.h"
|
||||
// #include "mlir/IR/Operation.h"
|
||||
// #include "mlir/IR/StandardTypes.h"
|
||||
// #include "mlir/IR/Types.h"
|
||||
// #include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
// Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
|
||||
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
|
||||
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
|
||||
|
||||
Shape Tensor::GetShape() { return impl_->GetShape(); }
|
||||
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
|
||||
|
||||
} // namespace builder
|
|
@ -0,0 +1,25 @@
|
|||
#ifndef BUILDER_TENSOR_
|
||||
#define BUILDER_TENSOR_
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "PrimitiveType.h"
|
||||
#include "Shape.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Tensor {
|
||||
public:
|
||||
Tensor(Shape& shape, PrimitiveType& primitiveType);
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
Shape GetShape();
|
||||
PrimitiveType GetType();
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -0,0 +1,32 @@
|
|||
#ifndef BUILDER_TENSORIMPL_
|
||||
#define BUILDER_TENSORIMPL_
|
||||
|
||||
#include "Builder.h"
|
||||
#include "PrimitiveType.h"
|
||||
#include "Shape.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Tensor::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
Impl(Shape& shape, PrimitiveType& primitiveType)
|
||||
: shape_(shape), primitiveType_(primitiveType) {}
|
||||
Shape GetShape() { return shape_; }
|
||||
PrimitiveType GetType() { return primitiveType_; }
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
PrimitiveType primitiveType_;
|
||||
};
|
||||
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -29,7 +29,7 @@
|
|||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
#include "iostream"
|
||||
#include <iostream>
|
||||
|
||||
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
|
||||
|
||||
|
@ -42,28 +42,117 @@ static const char *const generatedArgName = "odsArg";
|
|||
static const char *const odsBuilder = "odsBuilder";
|
||||
static const char *const builderOpState = "odsState";
|
||||
|
||||
static const std::map<std::string, std::string> typeMapMLIR = {
|
||||
{"::mlir::StringAttr", "std::string"},
|
||||
{"::mlir::IntegerAttr", "int"},
|
||||
{"::mlir::DenseIntElementsAttr", "std::vector<int>"},
|
||||
{"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
|
||||
{"::mlir::FloatAttr", "float"},
|
||||
{"::mlir::BoolAttr", "bool"},
|
||||
{"::mlir::ElementsAttr", "::builder::Array"},
|
||||
{"::mlir::DenseElementsAttr", "::builder::Tensor"},
|
||||
// current only support string array.
|
||||
{"::mlir::ArrayAttr", "std::vector<std::string>"},
|
||||
{"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
|
||||
{"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
|
||||
{"::mlir::mhlo::GatherDimensionNumbers",
|
||||
"::builder::GatherDimensionNumbers"},
|
||||
{"::mlir::mhlo::ScatterDimensionNumbers",
|
||||
"::builder::ScatterDimensionNumbers"},
|
||||
namespace {
|
||||
struct mlirTypeWrap {
|
||||
std::string Name;
|
||||
std::string (*ConvertToMlir)(std::string &, OpMethodBody &);
|
||||
};
|
||||
|
||||
static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
|
||||
{"::mlir::StringAttr",
|
||||
{"std::string",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::StringAttr " << var << "_mlir = mlir::StringAttr::get("
|
||||
<< var << ", ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::IntegerAttr",
|
||||
{"builder::Integer",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
// body << " mlir::IntegerAttr " << var << "_mlir = mlir::IntegerAttr::get("
|
||||
// << var << ", ctx);\n";
|
||||
body << " mlir::IntegerAttr " << var << "_mlir = " << var
|
||||
<< ".GetAttr(ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::DenseIntElementsAttr",
|
||||
{"std::vector<int>",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::DenseIntElementsAttr " << var
|
||||
<< "_mlir = mlir::DenseIntElementsAttr::get("
|
||||
<< "mlir::VectorType::get(" << var
|
||||
<< ".size(), opBuilder->getIntegerType(32))," << var << ");\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::ChannelHandle",
|
||||
{"ChannelHandle",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::FloatAttr",
|
||||
{"float",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::FloatAttr " << var << "_mlir = mlir::FloatAttr::get("
|
||||
<< var << ", ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::BoolAttr",
|
||||
{"bool",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
body << " mlir::BoolAttr " << var << "_mlir = mlir::BoolAttr::get("
|
||||
<< var << ", ctx);\n";
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::ElementsAttr",
|
||||
{"::builder::Array",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::DenseElementsAttr",
|
||||
{"::builder::Tensor",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
// current only support string array.
|
||||
{"::mlir::ArrayAttr",
|
||||
{"std::vector<std::string>",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::ConvDimensionNumbers",
|
||||
{"::builder::ConvDimensionNumbers",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::DotDimensionNumbers",
|
||||
{"::builder::DotDimensionNumbers",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::GatherDimensionNumbers",
|
||||
{"::builder::GatherDimensionNumbers",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
{"::mlir::mhlo::ScatterDimensionNumbers",
|
||||
{"::builder::ScatterDimensionNumbers",
|
||||
[](std::string &var, OpMethodBody &body) -> std::string {
|
||||
return var + "_mlir";
|
||||
}}},
|
||||
};
|
||||
|
||||
// static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
|
||||
// {"::mlir::StringAttr", "std::string"},
|
||||
// {"::mlir::IntegerAttr", "int"},
|
||||
// {"::mlir::DenseIntElementsAttr", "std::vector<int>"},
|
||||
// {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
|
||||
// {"::mlir::FloatAttr", "float"},
|
||||
// {"::mlir::BoolAttr", "bool"},
|
||||
// {"::mlir::ElementsAttr", "::builder::Array"},
|
||||
// {"::mlir::DenseElementsAttr", "::builder::Tensor"},
|
||||
// // current only support string array.
|
||||
// {"::mlir::ArrayAttr", "std::vector<std::string>"},
|
||||
// {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
|
||||
// {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
|
||||
// {"::mlir::mhlo::GatherDimensionNumbers",
|
||||
// "::builder::GatherDimensionNumbers"},
|
||||
// {"::mlir::mhlo::ScatterDimensionNumbers",
|
||||
// "::builder::ScatterDimensionNumbers"},
|
||||
// };
|
||||
|
||||
StringRef typeConvertFromMLIR(StringRef type) {
|
||||
auto re = typeMapMLIR.find(type.str());
|
||||
if (re != typeMapMLIR.end()) return StringRef(re->second);
|
||||
if (re != typeMapMLIR.end()) return StringRef(re->second.Name);
|
||||
return type;
|
||||
}
|
||||
|
||||
|
@ -75,6 +164,7 @@ StringRef getReturnType(const Attribute &att) {
|
|||
auto type = att.getStorageType();
|
||||
return typeConvertFromMLIR(type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// The logic to calculate the actual value range for a declared operand/result
|
||||
// of an op with variadic operands/results. Note that this logic is not for
|
||||
|
@ -312,14 +402,6 @@ static std::string getArgumentName(const Operator &op, int index) {
|
|||
return std::string(formatv("{0}_{1}", generatedArgName, index));
|
||||
}
|
||||
|
||||
// Returns true if we can use unwrapped value for the given `attr` in builders.
|
||||
static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
|
||||
return getReturnType(attr) != getStorageType(attr) &&
|
||||
// We need to wrap the raw value into an attribute in the builder impl
|
||||
// so we need to make sure that the attribute specifies how to do that.
|
||||
!attr.getConstBuilderTemplate().empty();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op emitter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -427,7 +509,7 @@ private:
|
|||
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
|
||||
|
||||
// Adds op arguments and regions into operation state for build() methods.
|
||||
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,llvm::SmallVector<OpMethodParameter, 4> paramList,
|
||||
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||
bool isRawValueAttr = false);
|
||||
|
||||
// Generates canonicalizer declaration for the operation.
|
||||
|
@ -479,7 +561,9 @@ private:
|
|||
// Generate the type inference interface methods.
|
||||
void genTypeInterfaceMethods();
|
||||
|
||||
private:
|
||||
Operator GetOp() { return op; }
|
||||
|
||||
private:
|
||||
// The TableGen record for this op.
|
||||
// TODO: OpEmitter should not have a Record directly,
|
||||
// it should rather go through the Operator for better abstraction.
|
||||
|
@ -1077,27 +1161,6 @@ void OpEmitter::genNamedSuccessorGetters() {
|
|||
}
|
||||
}
|
||||
|
||||
static bool canGenerateUnwrappedBuilder(Operator &op) {
|
||||
// If this op does not have native attributes at all, return directly to avoid
|
||||
// redefining builders.
|
||||
if (op.getNumNativeAttributes() == 0)
|
||||
return false;
|
||||
|
||||
bool canGenerate = false;
|
||||
// We are generating builders that take raw values for attributes. We need to
|
||||
// make sure the native attributes have a meaningful "unwrapped" value type
|
||||
// different from the wrapped mlir::Attribute type to avoid redefining
|
||||
// builders. This checks for the op has at least one such native attribute.
|
||||
for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
|
||||
NamedAttribute &namedAttr = op.getAttribute(i);
|
||||
if (canUseUnwrappedRawValue(namedAttr.attr)) {
|
||||
canGenerate = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return canGenerate;
|
||||
}
|
||||
|
||||
static bool canInferType(Operator &op) {
|
||||
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
|
||||
op.getNumRegions() == 0;
|
||||
|
@ -1106,18 +1169,14 @@ static bool canInferType(Operator &op) {
|
|||
void OpEmitter::genSeparateArgParamBuilder() {
|
||||
SmallVector<AttrParamKind, 2> attrBuilderType;
|
||||
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
|
||||
// if (canGenerateUnwrappedBuilder(op))
|
||||
// attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
|
||||
|
||||
// Emit with separate builders with or without unwrapped attributes and/or
|
||||
// inferring result type.
|
||||
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
|
||||
bool inferType) {
|
||||
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
||||
llvm::SmallVector<OpMethodParameter, 4> paramList2;
|
||||
llvm::SmallVector<std::string, 4> resultNames;
|
||||
buildParamList(paramList, resultNames, paramKind, attrType);
|
||||
buildParamList(paramList2, resultNames, paramKind, attrType);
|
||||
|
||||
auto *m = opClass.addMethodAndPrune(
|
||||
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
|
||||
|
@ -1126,10 +1185,10 @@ void OpEmitter::genSeparateArgParamBuilder() {
|
|||
return;
|
||||
auto &body = m->body();
|
||||
genCodeForAddingArgAndRegionForBuilder(
|
||||
body, paramList2, attrType == AttrParamKind::UnwrappedValue);
|
||||
body, attrType == AttrParamKind::UnwrappedValue);
|
||||
|
||||
// Push all result types to the operation state
|
||||
//"BBBBBBBBBBBB"
|
||||
|
||||
// if (inferType) {
|
||||
// // Generate builder that infers type too.
|
||||
// // TODO: Subsume this with general checking if type can be
|
||||
|
@ -1306,29 +1365,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
int numAttrs = 0;
|
||||
|
||||
int defaultValuedAttrStartIndex = op.getNumArgs();
|
||||
if (attrParamKind == AttrParamKind::UnwrappedValue) {
|
||||
// Calculate the start index from which we can attach default values in the
|
||||
// builder declaration.
|
||||
for (int i = op.getNumArgs() - 1; i >= 0; --i) {
|
||||
auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
|
||||
if (!namedAttr || !namedAttr->attr.hasDefaultValue())
|
||||
break;
|
||||
|
||||
if (!canUseUnwrappedRawValue(namedAttr->attr))
|
||||
break;
|
||||
|
||||
// Creating an APInt requires us to provide bitwidth, value, and
|
||||
// signedness, which is complicated compared to others. Similarly
|
||||
// for APFloat.
|
||||
// TODO: Adjust the 'returnType' field of such attributes
|
||||
// to support them.
|
||||
StringRef retType = getReturnType(namedAttr->attr);
|
||||
if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
|
||||
break;
|
||||
|
||||
defaultValuedAttrStartIndex = i;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
auto argument = op.getArg(i);
|
||||
|
@ -1357,10 +1393,10 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
type = getStorageType(attr);
|
||||
break;
|
||||
case AttrParamKind::UnwrappedValue:
|
||||
if (canUseUnwrappedRawValue(attr))
|
||||
type = getReturnType(attr);
|
||||
else
|
||||
type = getStorageType(attr);
|
||||
// if (canUseUnwrappedRawValue(attr))
|
||||
// type = getReturnType(attr);
|
||||
// else
|
||||
// type = getStorageType(attr);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -1394,41 +1430,72 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
llvm::formatv("{0}Count", region.name).str());
|
||||
}
|
||||
|
||||
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
|
||||
OpMethodBody &body, llvm::SmallVector<OpMethodParameter, 4> paramList,
|
||||
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||
bool isRawValueAttr) {
|
||||
auto op = GetOp();
|
||||
auto operands = op.getOperands();
|
||||
auto attrs = op.getAttributes();
|
||||
SmallVector<std::string, 4> newAttrs;
|
||||
|
||||
body << "// AAAAAA \n";
|
||||
// for(auto p : paramList)
|
||||
// {
|
||||
// body <<"==== type:"<< p.getType() << " name:"<< p.getName() << "\n";
|
||||
// for(auto o : operands){
|
||||
// body << " operands name: "<<o.name<<" getCPPClassName:"<<o.constraint.getCPPClassName()<<"\n";
|
||||
// }
|
||||
// for(auto a : attrs){
|
||||
// body << " attrs name: "<<a.name<<" type: "<< a.attr.getStorageType().str()<<"\n";
|
||||
// std::string attrType = a.attr.getStorageType().str();
|
||||
// auto ff = typeMapMLIR.find(attrType);
|
||||
// if(ff != typeMapMLIR.end()){
|
||||
// body << "// BBBBBB \n";
|
||||
// std::string attrName = a.name.str();
|
||||
// ff->second.ConvertToMlir(attrName,body);
|
||||
// }
|
||||
// }
|
||||
|
||||
body << " auto builder = builder.GetImpl();\n";
|
||||
body << " auto loc = builder->GetLoc();\n";
|
||||
body << " auto opBuilder = builder->GetBuilder();\n";
|
||||
|
||||
// body << "// AAAAAA \n";
|
||||
// for(auto p : paramList){
|
||||
// body << " AAA type"<<p.getType()<<" name"<<p.getName()<<"\n";
|
||||
// }
|
||||
|
||||
body << " auto b = builder.GetImpl();\n";
|
||||
body << " auto loc = b->GetLoc();\n";
|
||||
body << " auto opBuilder = b->GetBuilder();\n";
|
||||
body << " auto ctx = b->GetContext();\n";
|
||||
for (auto a : attrs) {
|
||||
std::string attrType = a.attr.getStorageType().str();
|
||||
|
||||
if(attrType == "::mlir::DenseIntElementsAttr")
|
||||
{
|
||||
body << " // BBBBBBBB getStorageType:"<< a.attr.getStorageType().str() <<"\n";
|
||||
body << " // BBBBBBBB getReturnType:"<< a.attr.getReturnType().str() <<"\n";
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
auto typePair = typeMapMLIR.find(attrType);
|
||||
if (typePair != typeMapMLIR.end()) {
|
||||
std::string attrName = a.name.str();
|
||||
auto mlirName = typePair->second.ConvertToMlir(attrName, body);
|
||||
newAttrs.emplace_back(mlirName);
|
||||
}
|
||||
}
|
||||
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
|
||||
<< " currentOp =\n";
|
||||
body << " opBuilder.create<mlir::" << op.getDialectName()
|
||||
<< "::" << op.getDialectName() << ">(\n";
|
||||
if (paramList.size() > 1) {
|
||||
body << " loc,\n";
|
||||
std::for_each(paramList.begin() + 1, paramList.end() - 1,
|
||||
[&](OpMethodParameter &p) {
|
||||
body << " " << p.getName() << ",\n";
|
||||
body << " loc";
|
||||
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
|
||||
body << ",\n " << v.name << ".getResult()";
|
||||
});
|
||||
body << " " << paramList.back().getName() << "\n";
|
||||
} else {
|
||||
body << " loc\n";
|
||||
}
|
||||
body << " );\n";
|
||||
std::for_each(newAttrs.begin(), newAttrs.end(),
|
||||
[&](std::string &n) { body << ",\n " << n; });
|
||||
body << "\n );\n";
|
||||
body << " builder::" << op.getCppClassName() << " builderOp;\n";
|
||||
body << " auto opImpl = builderOp.GetImpl();\n";
|
||||
body << " opImpl.SetOperation(currentOp.getOperation());\n";
|
||||
body << " return builderOp;\n";
|
||||
|
||||
|
||||
|
||||
// // Push all operands to the result.
|
||||
// for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
// std::string argName = getArgumentName(op, i);
|
||||
|
@ -2079,184 +2146,6 @@ void OpEmitter::genOpAsmInterface() {
|
|||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpOperandAdaptor emitter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// Helper class to emit Op operand adaptors to an output stream. Operand
|
||||
// adaptors are wrappers around ArrayRef<Value> that provide named operand
|
||||
// getters identical to those defined in the Op.
|
||||
class OpOperandAdaptorEmitter {
|
||||
public:
|
||||
static void emitDecl(const Operator &op, raw_ostream &os);
|
||||
static void emitDef(const Operator &op, raw_ostream &os);
|
||||
|
||||
private:
|
||||
explicit OpOperandAdaptorEmitter(const Operator &op);
|
||||
|
||||
// Add verification function. This generates a verify method for the adaptor
|
||||
// which verifies all the op-independent attribute constraints.
|
||||
void addVerification();
|
||||
|
||||
const Operator &op;
|
||||
Class adaptor;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
|
||||
: op(op), adaptor(op.getAdaptorName()) {
|
||||
adaptor.newField("::mlir::ValueRange", "odsOperands");
|
||||
adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
|
||||
adaptor.newField("::mlir::RegionRange", "odsRegions");
|
||||
const auto *attrSizedOperands =
|
||||
op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
|
||||
{
|
||||
SmallVector<OpMethodParameter, 2> paramList;
|
||||
paramList.emplace_back("::mlir::ValueRange", "values");
|
||||
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
|
||||
attrSizedOperands ? "" : "nullptr");
|
||||
paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
|
||||
auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
|
||||
|
||||
constructor->addMemberInitializer("odsOperands", "values");
|
||||
constructor->addMemberInitializer("odsAttrs", "attrs");
|
||||
constructor->addMemberInitializer("odsRegions", "regions");
|
||||
}
|
||||
|
||||
{
|
||||
auto *constructor = adaptor.addConstructorAndPrune(
|
||||
llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
|
||||
constructor->addMemberInitializer("odsOperands", "op->getOperands()");
|
||||
constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
|
||||
constructor->addMemberInitializer("odsRegions", "op->getRegions()");
|
||||
}
|
||||
|
||||
{
|
||||
auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands");
|
||||
m->body() << " return odsOperands;";
|
||||
}
|
||||
std::string sizeAttrInit =
|
||||
formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
|
||||
generateNamedOperandGetters(op, adaptor, sizeAttrInit,
|
||||
/*rangeType=*/"::mlir::ValueRange",
|
||||
/*rangeBeginCall=*/"odsOperands.begin()",
|
||||
/*rangeSizeCall=*/"odsOperands.size()",
|
||||
/*getOperandCallPattern=*/"odsOperands[{0}]");
|
||||
|
||||
FmtContext fctx;
|
||||
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
|
||||
|
||||
auto emitAttr = [&](StringRef name, Attribute attr) {
|
||||
auto &body = adaptor.addMethodAndPrune(getStorageType(attr), name)->body();
|
||||
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
|
||||
<< "\n " << getStorageType(attr) << " attr = "
|
||||
<< "odsAttrs.get(\"" << name << "\").";
|
||||
if (attr.hasDefaultValue() || attr.isOptional())
|
||||
body << "dyn_cast_or_null<";
|
||||
else
|
||||
body << "cast<";
|
||||
body << getStorageType(attr) << ">();\n";
|
||||
|
||||
if (attr.hasDefaultValue()) {
|
||||
// Use the default value if attribute is not set.
|
||||
// TODO: this is inefficient, we are recreating the attribute for every
|
||||
// call. This should be set instead.
|
||||
std::string defaultValue = std::string(
|
||||
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
|
||||
body << " if (!attr)\n attr = " << defaultValue << ";\n";
|
||||
}
|
||||
body << " return attr;\n";
|
||||
};
|
||||
|
||||
{
|
||||
auto *m =
|
||||
adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes");
|
||||
m->body() << " return odsAttrs;";
|
||||
}
|
||||
for (auto &namedAttr : op.getAttributes()) {
|
||||
const auto &name = namedAttr.name;
|
||||
const auto &attr = namedAttr.attr;
|
||||
if (!attr.isDerivedAttr())
|
||||
emitAttr(name, attr);
|
||||
}
|
||||
|
||||
unsigned numRegions = op.getNumRegions();
|
||||
if (numRegions > 0) {
|
||||
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions");
|
||||
m->body() << " return odsRegions;";
|
||||
}
|
||||
for (unsigned i = 0; i < numRegions; ++i) {
|
||||
const auto ®ion = op.getRegion(i);
|
||||
if (region.name.empty())
|
||||
continue;
|
||||
|
||||
// Generate the accessors for a variadic region.
|
||||
if (region.isVariadic()) {
|
||||
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name);
|
||||
m->body() << formatv(" return odsRegions.drop_front({0});", i);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name);
|
||||
m->body() << formatv(" return *odsRegions[{0}];", i);
|
||||
}
|
||||
|
||||
// Add verification function.
|
||||
addVerification();
|
||||
}
|
||||
|
||||
void OpOperandAdaptorEmitter::addVerification() {
|
||||
auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
|
||||
"::mlir::Location", "loc");
|
||||
auto &body = method->body();
|
||||
|
||||
const char *checkAttrSizedValueSegmentsCode = R"(
|
||||
{
|
||||
auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
|
||||
auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
|
||||
if (numElements != {1})
|
||||
return emitError(loc, "'{0}' attribute for specifying {2} segments "
|
||||
"must have {1} elements, but got ") << numElements;
|
||||
}
|
||||
)";
|
||||
|
||||
// Verify a few traits first so that we can use
|
||||
// getODSOperands()/getODSResults() in the rest of the verifier.
|
||||
for (auto &trait : op.getTraits()) {
|
||||
if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) {
|
||||
if (t->getFullyQualifiedTraitName() ==
|
||||
"::mlir::OpTrait::AttrSizedOperandSegments") {
|
||||
body << formatv(checkAttrSizedValueSegmentsCode,
|
||||
"operand_segment_sizes", op.getNumOperands(),
|
||||
"operand");
|
||||
} else if (t->getFullyQualifiedTraitName() ==
|
||||
"::mlir::OpTrait::AttrSizedResultSegments") {
|
||||
body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
|
||||
op.getNumResults(), "result");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FmtContext verifyCtx;
|
||||
populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
|
||||
"<no results should be generated>", verifyCtx);
|
||||
genAttributeVerifier(op, "odsAttrs.get",
|
||||
Twine("emitError(loc, \"'") + op.getOperationName() +
|
||||
"' op \"",
|
||||
/*emitVerificationRequiringOp*/ false, verifyCtx, body);
|
||||
|
||||
body << " return ::mlir::success();";
|
||||
}
|
||||
|
||||
void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
|
||||
OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
|
||||
}
|
||||
|
||||
void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
|
||||
OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
|
||||
}
|
||||
|
||||
// Emits the opcode enum and op classes.
|
||||
static void emitOpClasses(const RecordKeeper &recordKeeper,
|
||||
const std::vector<Record *> &defs, raw_ostream &os,
|
||||
|
@ -2286,11 +2175,9 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
|
|||
NamespaceEmitter emitter(os, op.getCppNamespace());
|
||||
if (emitDecl) {
|
||||
// os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
|
||||
// OpOperandAdaptorEmitter::emitDecl(op, os);
|
||||
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
|
||||
} else {
|
||||
// os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
|
||||
// OpOperandAdaptorEmitter::emitDef(op, os);
|
||||
OpEmitter::emitDef(op, os, staticVerifierEmitter);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
#ifndef BUILDER_OP_
|
||||
#define BUILDER_OP_
|
||||
|
||||
#include "iostream"
|
||||
|
||||
namespace builder {
|
||||
class Op {
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return _impl; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> _impl;
|
||||
}
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
|
@ -1,16 +0,0 @@
|
|||
#ifndef BUILDER_TENSOR_
|
||||
#define BUILDER_TENSOR_
|
||||
|
||||
|
||||
#include "iostream"
|
||||
|
||||
namespace builder{
|
||||
class Tensor{
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#endif
|
|
@ -1,10 +0,0 @@
|
|||
#ifndef BUILDER_TYPE_
|
||||
#define BUILDER_TYPE_
|
||||
|
||||
#include "iostream"
|
||||
|
||||
namespace builder {
|
||||
class Type {}
|
||||
} // namespace builder
|
||||
|
||||
#endif
|
Loading…
Reference in New Issue