add integer attribute convert

This commit is contained in:
colin.liang 2021-08-11 10:46:07 +08:00
parent 9d5166684c
commit 95fc37ffa8
22 changed files with 562 additions and 434 deletions

2
BUILD
View File

@ -129,6 +129,8 @@ cc_binary(
"tools/mlir-tblgen-builder/*.cpp", "tools/mlir-tblgen-builder/*.cpp",
"tools/mlir-tblgen-builder/TableGen/*.h", "tools/mlir-tblgen-builder/TableGen/*.h",
"tools/mlir-tblgen-builder/TableGen/*.cpp", "tools/mlir-tblgen-builder/TableGen/*.cpp",
"tools/mlir-tblgen-builder/Builder/*.h",
"tools/mlir-tblgen-builder/Builder/*.cpp",
]), ]),
deps = [ deps = [
"@llvm-project//mlir:MlirTableGenMain", "@llvm-project//mlir:MlirTableGenMain",

View File

@ -14,8 +14,25 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "iostream" #include "iostream"
#include "llvm/Support/CommandLine.h"
#include "mlir/Support/MlirOptMain.h" #include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/AsmState.h" #include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@ -23,62 +40,31 @@ limitations under the License.
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Parser.h" #include "mlir/Parser.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Support/DebugCounter.h" #include "mlir/Support/DebugCounter.h"
#include "mlir/Support/FileUtilities.h" #include "mlir/Support/FileUtilities.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/Support/Timing.h" #include "mlir/Support/Timing.h"
#include "mlir/Support/ToolUtilities.h" #include "mlir/Support/ToolUtilities.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/DynamicLibrary.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/Target/LLVMIR/Dialect/All.h" #include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
// #include "mlir/IR/BuiltinTypes.h" // #include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/Mangling.h" #include "llvm/ExecutionEngine/Orc/Mangling.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h" #include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassNameParser.h" #include "llvm/IR/LegacyPassNameParser.h"
#include "llvm/Support/TargetSelect.h"
using namespace mlir; using namespace mlir;
using namespace llvm; using namespace llvm;
namespace utils{ namespace utils {
template <typename T, int N> template <typename T, int N>
struct MemRefDescriptor { struct MemRefDescriptor {
T *allocated; T *allocated;
@ -87,11 +73,9 @@ struct MemRefDescriptor {
int64_t sizes[N]; int64_t sizes[N];
int64_t strides[N]; int64_t strides[N];
}; };
} } // namespace utils
int main(int argc, char **argv) { int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter(); llvm::InitializeNativeTargetAsmPrinter();
@ -105,7 +89,6 @@ int main(int argc, char **argv) {
// registerDefaultTimingManagerCLOptions(); // registerDefaultTimingManagerCLOptions();
DebugCounter::registerCLOptions(); DebugCounter::registerCLOptions();
mlir::registerAllPasses(); mlir::registerAllPasses();
mlir::mhlo::registerAllMhloPasses(); mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses(); mlir::lmhlo::registerAllLmhloPasses();
@ -120,24 +103,22 @@ int main(int argc, char **argv) {
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>(); registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
registry.insert<mlir::disc_ral::RalDialect>(); registry.insert<mlir::disc_ral::RalDialect>();
// failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", // failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
// registry, // registry,
// /*preloadDialectsInContext=*/false)); // /*preloadDialectsInContext=*/false));
// return 0; // return 0;
std::string errorMessage; std::string errorMessage;
// auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir", &errorMessage); // auto file = mlir::openInputFile("/root/mlir-hlo/bazel-bin/a.mlir",
auto file = mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage); // &errorMessage);
std::cout<<"errorMessage:" <<errorMessage <<std::endl; auto file =
mlir::openInputFile("/root/mlir-hlo/tests/test.mlir", &errorMessage);
std::cout << "errorMessage:" << errorMessage << std::endl;
SourceMgr sourceMgr; SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
SmallVector<const llvm::PassInfo *, 4> passes; SmallVector<const llvm::PassInfo *, 4> passes;
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
auto tmOrError = tmBuilderOrError->createTargetMachine(); auto tmOrError = tmBuilderOrError->createTargetMachine();
@ -145,8 +126,6 @@ int main(int argc, char **argv) {
auto transformer = mlir::makeLLVMPassesTransformer( auto transformer = mlir::makeLLVMPassesTransformer(
passes, 0, /*targetMachine=*/tmOrError->get(), 0); passes, 0, /*targetMachine=*/tmOrError->get(), 0);
MLIRContext context(registry); MLIRContext context(registry);
context.loadAllAvailableDialects(); context.loadAllAvailableDialects();
OwningModuleRef module(parseSourceFile(sourceMgr, &context)); OwningModuleRef module(parseSourceFile(sourceMgr, &context));
@ -157,8 +136,7 @@ int main(int argc, char **argv) {
return failure(); return failure();
}; };
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo \
// RUN: -hlo-legalize-to-lhlo -buffer-hoisting \ // RUN: -hlo-legalize-to-lhlo -buffer-hoisting \
// RUN: -buffer-deallocation -canonicalize -cse \ // RUN: -buffer-deallocation -canonicalize -cse \
// RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ // RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \
@ -192,9 +170,7 @@ int main(int argc, char **argv) {
pm.run(*module); pm.run(*module);
module->dump(); module->dump();
std::cout << "DEBUG load module success" << std::endl;
std::cout<<"DEBUG load module success"<<std::endl;
llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default; llvm::CodeGenOpt::Level jitCodeGenOptLevel = llvm::CodeGenOpt::Default;
@ -206,12 +182,10 @@ int main(int argc, char **argv) {
// Use absolute library path so that gdb can find the symbol table. // Use absolute library path so that gdb can find the symbol table.
std::list<std::string> sharedlib; std::list<std::string> sharedlib;
sharedlib.push_back("/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git"); sharedlib.push_back(
"/root/mlir-hlo/llvm-build/lib/libmlir_runner_utils.so.13git");
transform( transform(sharedlib, std::back_inserter(libPaths), [](std::string libPath) {
sharedlib,
std::back_inserter(libPaths),
[](std::string libPath) {
SmallString<256> absPath(libPath.begin(), libPath.end()); SmallString<256> absPath(libPath.begin(), libPath.end());
cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
return absPath; return absPath;
@ -226,7 +200,6 @@ int main(int argc, char **argv) {
llvm::StringMap<void *> exportSymbols; llvm::StringMap<void *> exportSymbols;
SmallVector<MlirRunnerDestroyFn> destroyFns; SmallVector<MlirRunnerDestroyFn> destroyFns;
// Handle libraries that do support mlir-runner init/destroy callbacks. // Handle libraries that do support mlir-runner init/destroy callbacks.
for (auto &libPath : libPaths) { for (auto &libPath : libPaths) {
auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str()); auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
@ -255,17 +228,12 @@ int main(int argc, char **argv) {
return symbolMap; return symbolMap;
}; };
auto expectedEngine =
mlir::ExecutionEngine::create(module.get(), nullptr, transformer,
jitCodeGenOptLevel, executionEngineLibs);
auto expectedEngine = mlir::ExecutionEngine::create(
module.get(), nullptr, transformer, jitCodeGenOptLevel,
executionEngineLibs);
// if (!expectedEngine) // if (!expectedEngine)
// return expectedEngine.takeError(); // return expectedEngine.takeError();
auto engine = std::move(*expectedEngine); auto engine = std::move(*expectedEngine);
engine->registerSymbols(runtimeSymbolMap); engine->registerSymbols(runtimeSymbolMap);
@ -277,16 +245,15 @@ int main(int argc, char **argv) {
// if (options.dumpObjectFile) // if (options.dumpObjectFile)
// engine->dumpToObjectFile("a.o"); // engine->dumpToObjectFile("a.o");
float rawdata[6] = {0, 1, 2, 3, 4, 5};
float rawdata[6] = {0,1,2,3,4,5};
int64_t dims = 1; int64_t dims = 1;
utils::MemRefDescriptor<float,1> a{rawdata,rawdata,0,{6},{1}}; utils::MemRefDescriptor<float, 1> a{rawdata, rawdata, 0, {6}, {1}};
utils::MemRefDescriptor<float,1> b{rawdata,rawdata,0,{6},{1}}; utils::MemRefDescriptor<float, 1> b{rawdata, rawdata, 0, {6}, {1}};
utils::MemRefDescriptor<float,1> result_memref; utils::MemRefDescriptor<float, 1> result_memref;
struct memref_type{ struct memref_type {
int64_t res_size = 6; int64_t res_size = 6;
utils::MemRefDescriptor<float,1> *memref; utils::MemRefDescriptor<float, 1> *memref;
} result; } result;
result.memref = &result_memref; result.memref = &result_memref;
@ -299,23 +266,23 @@ int main(int argc, char **argv) {
} data; } data;
data.data1_size = &dims; data.data1_size = &dims;
void * a_ptr = &a; void *a_ptr = &a;
data.data1 = &a_ptr; data.data1 = &a_ptr;
data.data2_size = &dims; data.data2_size = &dims;
void * b_ptr = &b; void *b_ptr = &b;
data.data2 = &b_ptr; data.data2 = &b_ptr;
void * result_ptr = &result; void *result_ptr = &result;
data.res = &result; data.res = &result;
void (*fptr)(void **) = *expectedFPtr; void (*fptr)(void **) = *expectedFPtr;
(*fptr)((void **)&data); (*fptr)((void **)&data);
std::cout<<"result: "<<result.memref->allocated[0]<<std::endl; std::cout << "result: " << result.memref->allocated[0] << std::endl;
std::cout<<"result: "<<result.memref->allocated[1]<<std::endl; std::cout << "result: " << result.memref->allocated[1] << std::endl;
std::cout<<"result: "<<result.memref->allocated[2]<<std::endl; std::cout << "result: " << result.memref->allocated[2] << std::endl;
std::cout<<"result: "<<result.memref->allocated[3]<<std::endl; std::cout << "result: " << result.memref->allocated[3] << std::endl;
std::cout<<"result: "<<result.memref->allocated[4]<<std::endl; std::cout << "result: " << result.memref->allocated[4] << std::endl;
std::cout<<"result: "<<result.memref->allocated[5]<<std::endl; std::cout << "result: " << result.memref->allocated[5] << std::endl;
// Run all dynamic library destroy callbacks to prepare for the shutdown. // Run all dynamic library destroy callbacks to prepare for the shutdown.
llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); }); llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });

View File

@ -1,10 +0,0 @@
#ifndef BUILDER_ARRAY_
#define BUILDER_ARRAY_
#include "iostream"
namespace builder {
class Array {}
} // namespace builder
#endif

View File

@ -0,0 +1,22 @@
#include "Attribute.h"
#include "AttributeImpl.h"
#include "Shape.h"
#include "Tensor.h"
#include "TensorImpl.h"
#include "llvm/Support/Casting.h"
#include "memory.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
// #include "mlir/IR/Attributes.h"
// #include "mlir/IR/Operation.h"
// #include "mlir/IR/StandardTypes.h"
// #include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
namespace builder {
Integer::Integer(int value) : impl_(std::make_shared<Integer::Impl>(value)) {}
Integer::Integer(int64_t value)
: impl_(std::make_shared<Integer::Impl>(value)) {}
} // namespace builder

View File

@ -0,0 +1,50 @@
#ifndef BUILDER_ATTRIBUTE_H_
#define BUILDER_ATTRIBUTE_H_
#include <iostream>
#include <memory>
namespace builder {
class Integer {
public:
Integer(int value);
Integer(int64_t value);
class Impl;
private:
std::shared_ptr<Impl> impl_;
};
class Array {
public:
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
class Type {
public:
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
// template <typename T>
// class DenseIntElementsAttr {
// public:
// explicit DenseIntElementsAttr(Tensor);
// class Impl;
// std::shared_ptr<Impl> GetImpl() { return impl_; }
// private:
// std::shared_ptr<Impl> impl_;
// };
} // namespace builder
#endif

View File

@ -0,0 +1,58 @@
#ifndef BUILDER_TYPEIMPL_
#define BUILDER_TYPEIMPL_
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Integer::Impl {
public:
Impl(){};
Impl(int value) {
GetAttr = [=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
auto int32_type = mlir::IntegerType::get(context, 32);
return mlir::IntegerAttr::get(int32_type, value);
};
};
Impl(int64_t value)
: GetAttr([=](mlir::MLIRContext *context) -> mlir::IntegerAttr {
auto int32_type = mlir::IntegerType::get(context, 64);
return mlir::IntegerAttr::get(int32_type, value);
}){};
std::function<mlir::IntegerAttr(mlir::MLIRContext *context)> GetAttr;
private:
};
class Array::Impl {
public:
Impl() = default;
private:
};
class Type::Impl {
public:
Impl() = default;
private:
};
// class DenseIntElementsAttr::Impl {
// public:
// Impl() = default;
// private:
// };
} // namespace builder
#endif

View File

@ -11,7 +11,7 @@
namespace builder { namespace builder {
Builder::Builder() : _impl(std::make_shared<Impl>()) {} Builder::Builder() : impl_(std::make_shared<Impl>()) {}
void Builder::DumpModule() {} void Builder::DumpModule() {}
} // namespace builder } // namespace builder

View File

@ -11,10 +11,10 @@ class Builder {
Builder(); Builder();
void DumpModule(); void DumpModule();
std::shared_ptr<Impl> GetImpl() { return _impl; } std::shared_ptr<Impl> GetImpl() { return impl_; }
private: private:
std::shared_ptr<Impl> _impl; std::shared_ptr<Impl> impl_;
}; };
} // namespace builder } // namespace builder

View File

@ -0,0 +1,17 @@
#ifndef BUILDER_OP_
#define BUILDER_OP_
#include <iostream>
#include <memory>
namespace builder {
class Op {
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
}
} // namespace builder
#endif

View File

@ -0,0 +1,28 @@
#include "PrimitiveType.h"
#include "Shape.h"
#include "Tensor.h"
#include "llvm/Support/Casting.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
// #include "mlir/IR/Attributes.h"
// #include "mlir/IR/Operation.h"
// #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
namespace builder {
// constexpr mlir::Type GetMlirType(const PrimitiveType primitiveType) {
// switch (primitiveType)
// {
// case PrimitiveType::BF16 :
// return
// break;
// default:
// break;
// }
// }
} // namespace builder

View File

@ -0,0 +1,29 @@
#ifndef BUILDER_PRIMITIVE_TYPE_
#define BUILDER_PRIMITIVE_TYPE_
namespace builder {
enum class PrimitiveType {
PRIMITIVE_TYPE_INVALID = 0,
PRED = 1,
S8 = 2,
S16 = 3,
S32 = 4,
S64 = 5,
U8 = 6,
U16 = 7,
U32 = 8,
U64 = 9,
F16 = 10,
F32 = 11,
BF16 = 16,
F64 = 12,
C64 = 15,
C128 = 18,
TUPLE = 13,
OPAQUE_TYPE = 14,
TOKEN = 17
};
}
#endif

View File

@ -0,0 +1,17 @@
#ifndef BUILDER_SHAPE_
#define BUILDER_SHAPE_
#include <iostream>
#include <memory>
namespace builder {
class Shape {
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
private:
std::shared_ptr<Impl> impl_;
};
} // namespace builder
#endif

View File

@ -0,0 +1,24 @@
#ifndef BUILDER_SHAPEIMPL_
#define BUILDER_SHAPEIMPL_
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Shape::Impl {
public:
Impl() = default;
private:
};
} // namespace builder
#endif

View File

@ -0,0 +1,22 @@
#include "Tensor.h"
#include "Shape.h"
#include "TensorImpl.h"
#include "llvm/Support/Casting.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
// #include "mlir/IR/Attributes.h"
// #include "mlir/IR/Operation.h"
// #include "mlir/IR/StandardTypes.h"
// #include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
namespace builder {
// Tensor::Tensor() : impl_(std::make_shared<Impl>()) {}
Tensor::Tensor(Shape& shape, PrimitiveType& primitiveType)
: impl_(std::make_shared<Impl>(shape, primitiveType)) {}
Shape Tensor::GetShape() { return impl_->GetShape(); }
PrimitiveType Tensor::GetType() { return impl_->GetType(); }
} // namespace builder

View File

@ -0,0 +1,25 @@
#ifndef BUILDER_TENSOR_
#define BUILDER_TENSOR_
#include <iostream>
#include <memory>
#include "PrimitiveType.h"
#include "Shape.h"
namespace builder {
class Tensor {
public:
Tensor(Shape& shape, PrimitiveType& primitiveType);
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }
Shape GetShape();
PrimitiveType GetType();
private:
std::shared_ptr<Impl> impl_;
};
} // namespace builder
#endif

View File

@ -0,0 +1,32 @@
#ifndef BUILDER_TENSORIMPL_
#define BUILDER_TENSORIMPL_
#include "Builder.h"
#include "PrimitiveType.h"
#include "Shape.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Tensor::Impl {
public:
Impl() = default;
Impl(Shape& shape, PrimitiveType& primitiveType)
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
private:
Shape shape_;
PrimitiveType primitiveType_;
};
} // namespace builder
#endif

View File

@ -29,7 +29,7 @@
#include "llvm/TableGen/Record.h" #include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h" #include "llvm/TableGen/TableGenBackend.h"
#include "iostream" #include <iostream>
#define DEBUG_TYPE "mlir-tblgen-opdefgen" #define DEBUG_TYPE "mlir-tblgen-opdefgen"
@ -42,28 +42,117 @@ static const char *const generatedArgName = "odsArg";
static const char *const odsBuilder = "odsBuilder"; static const char *const odsBuilder = "odsBuilder";
static const char *const builderOpState = "odsState"; static const char *const builderOpState = "odsState";
static const std::map<std::string, std::string> typeMapMLIR = { namespace {
{"::mlir::StringAttr", "std::string"}, struct mlirTypeWrap {
{"::mlir::IntegerAttr", "int"}, std::string Name;
{"::mlir::DenseIntElementsAttr", "std::vector<int>"}, std::string (*ConvertToMlir)(std::string &, OpMethodBody &);
{"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
{"::mlir::FloatAttr", "float"},
{"::mlir::BoolAttr", "bool"},
{"::mlir::ElementsAttr", "::builder::Array"},
{"::mlir::DenseElementsAttr", "::builder::Tensor"},
// current only support string array.
{"::mlir::ArrayAttr", "std::vector<std::string>"},
{"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
{"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
{"::mlir::mhlo::GatherDimensionNumbers",
"::builder::GatherDimensionNumbers"},
{"::mlir::mhlo::ScatterDimensionNumbers",
"::builder::ScatterDimensionNumbers"},
}; };
static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
{"::mlir::StringAttr",
{"std::string",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::StringAttr " << var << "_mlir = mlir::StringAttr::get("
<< var << ", ctx);\n";
return var + "_mlir";
}}},
{"::mlir::IntegerAttr",
{"builder::Integer",
[](std::string &var, OpMethodBody &body) -> std::string {
// body << " mlir::IntegerAttr " << var << "_mlir = mlir::IntegerAttr::get("
// << var << ", ctx);\n";
body << " mlir::IntegerAttr " << var << "_mlir = " << var
<< ".GetAttr(ctx);\n";
return var + "_mlir";
}}},
{"::mlir::DenseIntElementsAttr",
{"std::vector<int>",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::DenseIntElementsAttr " << var
<< "_mlir = mlir::DenseIntElementsAttr::get("
<< "mlir::VectorType::get(" << var
<< ".size(), opBuilder->getIntegerType(32))," << var << ");\n";
return var + "_mlir";
}}},
{"::mlir::mhlo::ChannelHandle",
{"ChannelHandle",
[](std::string &var, OpMethodBody &body) -> std::string {
return var + "_mlir";
}}},
{"::mlir::FloatAttr",
{"float",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::FloatAttr " << var << "_mlir = mlir::FloatAttr::get("
<< var << ", ctx);\n";
return var + "_mlir";
}}},
{"::mlir::BoolAttr",
{"bool",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::BoolAttr " << var << "_mlir = mlir::BoolAttr::get("
<< var << ", ctx);\n";
return var + "_mlir";
}}},
{"::mlir::ElementsAttr",
{"::builder::Array",
[](std::string &var, OpMethodBody &body) -> std::string {
return var + "_mlir";
}}},
{"::mlir::DenseElementsAttr",
{"::builder::Tensor",
[](std::string &var, OpMethodBody &body) -> std::string {
return var + "_mlir";
}}},
// current only support string array.
{"::mlir::ArrayAttr",
{"std::vector<std::string>",
[](std::string &var, OpMethodBody &body) -> std::string {
return var + "_mlir";
}}},
{"::mlir::mhlo::ConvDimensionNumbers",
{"::builder::ConvDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string {
return var + "_mlir";
}}},
{"::mlir::mhlo::DotDimensionNumbers",
{"::builder::DotDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string {
return var + "_mlir";
}}},
{"::mlir::mhlo::GatherDimensionNumbers",
{"::builder::GatherDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string {
return var + "_mlir";
}}},
{"::mlir::mhlo::ScatterDimensionNumbers",
{"::builder::ScatterDimensionNumbers",
[](std::string &var, OpMethodBody &body) -> std::string {
return var + "_mlir";
}}},
};
// static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
// {"::mlir::StringAttr", "std::string"},
// {"::mlir::IntegerAttr", "int"},
// {"::mlir::DenseIntElementsAttr", "std::vector<int>"},
// {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
// {"::mlir::FloatAttr", "float"},
// {"::mlir::BoolAttr", "bool"},
// {"::mlir::ElementsAttr", "::builder::Array"},
// {"::mlir::DenseElementsAttr", "::builder::Tensor"},
// // current only support string array.
// {"::mlir::ArrayAttr", "std::vector<std::string>"},
// {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
// {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
// {"::mlir::mhlo::GatherDimensionNumbers",
// "::builder::GatherDimensionNumbers"},
// {"::mlir::mhlo::ScatterDimensionNumbers",
// "::builder::ScatterDimensionNumbers"},
// };
StringRef typeConvertFromMLIR(StringRef type) { StringRef typeConvertFromMLIR(StringRef type) {
auto re = typeMapMLIR.find(type.str()); auto re = typeMapMLIR.find(type.str());
if (re != typeMapMLIR.end()) return StringRef(re->second); if (re != typeMapMLIR.end()) return StringRef(re->second.Name);
return type; return type;
} }
@ -75,6 +164,7 @@ StringRef getReturnType(const Attribute &att) {
auto type = att.getStorageType(); auto type = att.getStorageType();
return typeConvertFromMLIR(type); return typeConvertFromMLIR(type);
} }
} // namespace
// The logic to calculate the actual value range for a declared operand/result // The logic to calculate the actual value range for a declared operand/result
// of an op with variadic operands/results. Note that this logic is not for // of an op with variadic operands/results. Note that this logic is not for
@ -312,14 +402,6 @@ static std::string getArgumentName(const Operator &op, int index) {
return std::string(formatv("{0}_{1}", generatedArgName, index)); return std::string(formatv("{0}_{1}", generatedArgName, index));
} }
// Returns true if we can use unwrapped value for the given `attr` in builders.
static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
return getReturnType(attr) != getStorageType(attr) &&
// We need to wrap the raw value into an attribute in the builder impl
// so we need to make sure that the attribute specifies how to do that.
!attr.getConstBuilderTemplate().empty();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Op emitter // Op emitter
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -427,7 +509,7 @@ private:
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
// Adds op arguments and regions into operation state for build() methods. // Adds op arguments and regions into operation state for build() methods.
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,llvm::SmallVector<OpMethodParameter, 4> paramList, void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
bool isRawValueAttr = false); bool isRawValueAttr = false);
// Generates canonicalizer declaration for the operation. // Generates canonicalizer declaration for the operation.
@ -479,7 +561,9 @@ private:
// Generate the type inference interface methods. // Generate the type inference interface methods.
void genTypeInterfaceMethods(); void genTypeInterfaceMethods();
private: Operator GetOp() { return op; }
private:
// The TableGen record for this op. // The TableGen record for this op.
// TODO: OpEmitter should not have a Record directly, // TODO: OpEmitter should not have a Record directly,
// it should rather go through the Operator for better abstraction. // it should rather go through the Operator for better abstraction.
@ -1077,27 +1161,6 @@ void OpEmitter::genNamedSuccessorGetters() {
} }
} }
static bool canGenerateUnwrappedBuilder(Operator &op) {
// If this op does not have native attributes at all, return directly to avoid
// redefining builders.
if (op.getNumNativeAttributes() == 0)
return false;
bool canGenerate = false;
// We are generating builders that take raw values for attributes. We need to
// make sure the native attributes have a meaningful "unwrapped" value type
// different from the wrapped mlir::Attribute type to avoid redefining
// builders. This checks for the op has at least one such native attribute.
for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
NamedAttribute &namedAttr = op.getAttribute(i);
if (canUseUnwrappedRawValue(namedAttr.attr)) {
canGenerate = true;
break;
}
}
return canGenerate;
}
static bool canInferType(Operator &op) { static bool canInferType(Operator &op) {
return op.getTrait("::mlir::InferTypeOpInterface::Trait") && return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
op.getNumRegions() == 0; op.getNumRegions() == 0;
@ -1106,18 +1169,14 @@ static bool canInferType(Operator &op) {
void OpEmitter::genSeparateArgParamBuilder() { void OpEmitter::genSeparateArgParamBuilder() {
SmallVector<AttrParamKind, 2> attrBuilderType; SmallVector<AttrParamKind, 2> attrBuilderType;
attrBuilderType.push_back(AttrParamKind::WrappedAttr); attrBuilderType.push_back(AttrParamKind::WrappedAttr);
// if (canGenerateUnwrappedBuilder(op))
// attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
// Emit with separate builders with or without unwrapped attributes and/or // Emit with separate builders with or without unwrapped attributes and/or
// inferring result type. // inferring result type.
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
bool inferType) { bool inferType) {
llvm::SmallVector<OpMethodParameter, 4> paramList; llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<OpMethodParameter, 4> paramList2;
llvm::SmallVector<std::string, 4> resultNames; llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, paramKind, attrType); buildParamList(paramList, resultNames, paramKind, attrType);
buildParamList(paramList2, resultNames, paramKind, attrType);
auto *m = opClass.addMethodAndPrune( auto *m = opClass.addMethodAndPrune(
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList)); "::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
@ -1126,10 +1185,10 @@ void OpEmitter::genSeparateArgParamBuilder() {
return; return;
auto &body = m->body(); auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder( genCodeForAddingArgAndRegionForBuilder(
body, paramList2, attrType == AttrParamKind::UnwrappedValue); body, attrType == AttrParamKind::UnwrappedValue);
// Push all result types to the operation state // Push all result types to the operation state
//"BBBBBBBBBBBB"
// if (inferType) { // if (inferType) {
// // Generate builder that infers type too. // // Generate builder that infers type too.
// // TODO: Subsume this with general checking if type can be // // TODO: Subsume this with general checking if type can be
@ -1306,29 +1365,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
int numAttrs = 0; int numAttrs = 0;
int defaultValuedAttrStartIndex = op.getNumArgs(); int defaultValuedAttrStartIndex = op.getNumArgs();
if (attrParamKind == AttrParamKind::UnwrappedValue) {
// Calculate the start index from which we can attach default values in the
// builder declaration.
for (int i = op.getNumArgs() - 1; i >= 0; --i) {
auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
if (!namedAttr || !namedAttr->attr.hasDefaultValue())
break;
if (!canUseUnwrappedRawValue(namedAttr->attr))
break;
// Creating an APInt requires us to provide bitwidth, value, and
// signedness, which is complicated compared to others. Similarly
// for APFloat.
// TODO: Adjust the 'returnType' field of such attributes
// to support them.
StringRef retType = getReturnType(namedAttr->attr);
if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
break;
defaultValuedAttrStartIndex = i;
}
}
for (int i = 0, e = op.getNumArgs(); i < e; ++i) { for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i); auto argument = op.getArg(i);
@ -1357,10 +1393,10 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
type = getStorageType(attr); type = getStorageType(attr);
break; break;
case AttrParamKind::UnwrappedValue: case AttrParamKind::UnwrappedValue:
if (canUseUnwrappedRawValue(attr)) // if (canUseUnwrappedRawValue(attr))
type = getReturnType(attr); // type = getReturnType(attr);
else // else
type = getStorageType(attr); // type = getStorageType(attr);
break; break;
} }
@ -1394,41 +1430,72 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
llvm::formatv("{0}Count", region.name).str()); llvm::formatv("{0}Count", region.name).str());
} }
void OpEmitter::genCodeForAddingArgAndRegionForBuilder( void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
OpMethodBody &body, llvm::SmallVector<OpMethodParameter, 4> paramList,
bool isRawValueAttr) { bool isRawValueAttr) {
auto op = GetOp();
auto operands = op.getOperands();
auto attrs = op.getAttributes();
SmallVector<std::string, 4> newAttrs;
body << "// AAAAAA \n"; // for(auto o : operands){
// for(auto p : paramList) // body << " operands name: "<<o.name<<" getCPPClassName:"<<o.constraint.getCPPClassName()<<"\n";
// { // }
// body <<"==== type:"<< p.getType() << " name:"<< p.getName() << "\n"; // for(auto a : attrs){
// body << " attrs name: "<<a.name<<" type: "<< a.attr.getStorageType().str()<<"\n";
// std::string attrType = a.attr.getStorageType().str();
// auto ff = typeMapMLIR.find(attrType);
// if(ff != typeMapMLIR.end()){
// body << "// BBBBBB \n";
// std::string attrName = a.name.str();
// ff->second.ConvertToMlir(attrName,body);
// }
// } // }
body << " auto builder = builder.GetImpl();\n";
body << " auto loc = builder->GetLoc();\n"; // body << "// AAAAAA \n";
body << " auto opBuilder = builder->GetBuilder();\n"; // for(auto p : paramList){
// body << " AAA type"<<p.getType()<<" name"<<p.getName()<<"\n";
// }
body << " auto b = builder.GetImpl();\n";
body << " auto loc = b->GetLoc();\n";
body << " auto opBuilder = b->GetBuilder();\n";
body << " auto ctx = b->GetContext();\n";
for (auto a : attrs) {
std::string attrType = a.attr.getStorageType().str();
if(attrType == "::mlir::DenseIntElementsAttr")
{
body << " // BBBBBBBB getStorageType:"<< a.attr.getStorageType().str() <<"\n";
body << " // BBBBBBBB getReturnType:"<< a.attr.getReturnType().str() <<"\n";
}
auto typePair = typeMapMLIR.find(attrType);
if (typePair != typeMapMLIR.end()) {
std::string attrName = a.name.str();
auto mlirName = typePair->second.ConvertToMlir(attrName, body);
newAttrs.emplace_back(mlirName);
}
}
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName() body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
<< " currentOp =\n"; << " currentOp =\n";
body << " opBuilder.create<mlir::" << op.getDialectName() body << " opBuilder.create<mlir::" << op.getDialectName()
<< "::" << op.getDialectName() << ">(\n"; << "::" << op.getDialectName() << ">(\n";
if (paramList.size() > 1) { body << " loc";
body << " loc,\n"; std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
std::for_each(paramList.begin() + 1, paramList.end() - 1, body << ",\n " << v.name << ".getResult()";
[&](OpMethodParameter &p) {
body << " " << p.getName() << ",\n";
}); });
body << " " << paramList.back().getName() << "\n"; std::for_each(newAttrs.begin(), newAttrs.end(),
} else { [&](std::string &n) { body << ",\n " << n; });
body << " loc\n"; body << "\n );\n";
}
body << " );\n";
body << " builder::" << op.getCppClassName() << " builderOp;\n"; body << " builder::" << op.getCppClassName() << " builderOp;\n";
body << " auto opImpl = builderOp.GetImpl();\n"; body << " auto opImpl = builderOp.GetImpl();\n";
body << " opImpl.SetOperation(currentOp.getOperation());\n"; body << " opImpl.SetOperation(currentOp.getOperation());\n";
body << " return builderOp;\n"; body << " return builderOp;\n";
// // Push all operands to the result. // // Push all operands to the result.
// for (int i = 0, e = op.getNumOperands(); i < e; ++i) { // for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
// std::string argName = getArgumentName(op, i); // std::string argName = getArgumentName(op, i);
@ -2079,184 +2146,6 @@ void OpEmitter::genOpAsmInterface() {
} }
} }
//===----------------------------------------------------------------------===//
// OpOperandAdaptor emitter
//===----------------------------------------------------------------------===//
namespace {
// Helper class to emit Op operand adaptors to an output stream. Operand
// adaptors are wrappers around ArrayRef<Value> that provide named operand
// getters identical to those defined in the Op.
class OpOperandAdaptorEmitter {
public:
static void emitDecl(const Operator &op, raw_ostream &os);
static void emitDef(const Operator &op, raw_ostream &os);
private:
explicit OpOperandAdaptorEmitter(const Operator &op);
// Add verification function. This generates a verify method for the adaptor
// which verifies all the op-independent attribute constraints.
void addVerification();
const Operator &op;
Class adaptor;
};
} // end namespace
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
: op(op), adaptor(op.getAdaptorName()) {
adaptor.newField("::mlir::ValueRange", "odsOperands");
adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
adaptor.newField("::mlir::RegionRange", "odsRegions");
const auto *attrSizedOperands =
op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
{
SmallVector<OpMethodParameter, 2> paramList;
paramList.emplace_back("::mlir::ValueRange", "values");
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
attrSizedOperands ? "" : "nullptr");
paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
constructor->addMemberInitializer("odsOperands", "values");
constructor->addMemberInitializer("odsAttrs", "attrs");
constructor->addMemberInitializer("odsRegions", "regions");
}
{
auto *constructor = adaptor.addConstructorAndPrune(
llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
constructor->addMemberInitializer("odsOperands", "op->getOperands()");
constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
constructor->addMemberInitializer("odsRegions", "op->getRegions()");
}
{
auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands");
m->body() << " return odsOperands;";
}
std::string sizeAttrInit =
formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
generateNamedOperandGetters(op, adaptor, sizeAttrInit,
/*rangeType=*/"::mlir::ValueRange",
/*rangeBeginCall=*/"odsOperands.begin()",
/*rangeSizeCall=*/"odsOperands.size()",
/*getOperandCallPattern=*/"odsOperands[{0}]");
FmtContext fctx;
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
auto emitAttr = [&](StringRef name, Attribute attr) {
auto &body = adaptor.addMethodAndPrune(getStorageType(attr), name)->body();
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
<< "\n " << getStorageType(attr) << " attr = "
<< "odsAttrs.get(\"" << name << "\").";
if (attr.hasDefaultValue() || attr.isOptional())
body << "dyn_cast_or_null<";
else
body << "cast<";
body << getStorageType(attr) << ">();\n";
if (attr.hasDefaultValue()) {
// Use the default value if attribute is not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
std::string defaultValue = std::string(
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
body << " if (!attr)\n attr = " << defaultValue << ";\n";
}
body << " return attr;\n";
};
{
auto *m =
adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes");
m->body() << " return odsAttrs;";
}
for (auto &namedAttr : op.getAttributes()) {
const auto &name = namedAttr.name;
const auto &attr = namedAttr.attr;
if (!attr.isDerivedAttr())
emitAttr(name, attr);
}
unsigned numRegions = op.getNumRegions();
if (numRegions > 0) {
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions");
m->body() << " return odsRegions;";
}
for (unsigned i = 0; i < numRegions; ++i) {
const auto &region = op.getRegion(i);
if (region.name.empty())
continue;
// Generate the accessors for a variadic region.
if (region.isVariadic()) {
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name);
m->body() << formatv(" return odsRegions.drop_front({0});", i);
continue;
}
auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name);
m->body() << formatv(" return *odsRegions[{0}];", i);
}
// Add verification function.
addVerification();
}
void OpOperandAdaptorEmitter::addVerification() {
auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
"::mlir::Location", "loc");
auto &body = method->body();
const char *checkAttrSizedValueSegmentsCode = R"(
{
auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
if (numElements != {1})
return emitError(loc, "'{0}' attribute for specifying {2} segments "
"must have {1} elements, but got ") << numElements;
}
)";
// Verify a few traits first so that we can use
// getODSOperands()/getODSResults() in the rest of the verifier.
for (auto &trait : op.getTraits()) {
if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) {
if (t->getFullyQualifiedTraitName() ==
"::mlir::OpTrait::AttrSizedOperandSegments") {
body << formatv(checkAttrSizedValueSegmentsCode,
"operand_segment_sizes", op.getNumOperands(),
"operand");
} else if (t->getFullyQualifiedTraitName() ==
"::mlir::OpTrait::AttrSizedResultSegments") {
body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
op.getNumResults(), "result");
}
}
}
FmtContext verifyCtx;
populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
"<no results should be generated>", verifyCtx);
genAttributeVerifier(op, "odsAttrs.get",
Twine("emitError(loc, \"'") + op.getOperationName() +
"' op \"",
/*emitVerificationRequiringOp*/ false, verifyCtx, body);
body << " return ::mlir::success();";
}
void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
}
void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
}
// Emits the opcode enum and op classes. // Emits the opcode enum and op classes.
static void emitOpClasses(const RecordKeeper &recordKeeper, static void emitOpClasses(const RecordKeeper &recordKeeper,
const std::vector<Record *> &defs, raw_ostream &os, const std::vector<Record *> &defs, raw_ostream &os,
@ -2286,11 +2175,9 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
NamespaceEmitter emitter(os, op.getCppNamespace()); NamespaceEmitter emitter(os, op.getCppNamespace());
if (emitDecl) { if (emitDecl) {
// os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); // os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
// OpOperandAdaptorEmitter::emitDecl(op, os);
OpEmitter::emitDecl(op, os, staticVerifierEmitter); OpEmitter::emitDecl(op, os, staticVerifierEmitter);
} else { } else {
// os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); // os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
// OpOperandAdaptorEmitter::emitDef(op, os);
OpEmitter::emitDef(op, os, staticVerifierEmitter); OpEmitter::emitDef(op, os, staticVerifierEmitter);
} }
} }

View File

@ -1,16 +0,0 @@
#ifndef BUILDER_OP_
#define BUILDER_OP_
#include "iostream"
namespace builder {
class Op {
class Impl;
std::shared_ptr<Impl> GetImpl() { return _impl; }
private:
std::shared_ptr<Impl> _impl;
}
} // namespace builder
#endif

View File

@ -1,16 +0,0 @@
#ifndef BUILDER_TENSOR_
#define BUILDER_TENSOR_
#include "iostream"
namespace builder{
class Tensor{
}
}
#endif

View File

@ -1,10 +0,0 @@
#ifndef BUILDER_TYPE_
#define BUILDER_TYPE_
#include "iostream"
namespace builder {
class Type {}
} // namespace builder
#endif