modify build cc code,add AddOp in test
This commit is contained in:
parent
4a9b201c4e
commit
975a47f7b2
15
BUILD
15
BUILD
|
@ -179,6 +179,7 @@ cc_library(
|
|||
"@llvm-project//mlir:MlirTableGenMain",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//llvm:config",
|
||||
|
@ -193,9 +194,19 @@ cc_library(
|
|||
# "@llvm-project//mlir:AllPassesAndDialects",
|
||||
# "@llvm-project//mlir:IR",
|
||||
# "@llvm-project//mlir:MlirOptLib",
|
||||
# "@llvm-project//mlir:Pass",
|
||||
# "@llvm-project//mlir:Support",
|
||||
# "@llvm-project//mlir:MlirJitRunner",
|
||||
|
||||
# "@llvm-project//mlir:Analysis",
|
||||
# "@llvm-project//mlir:ControlFlowInterfaces",
|
||||
# "@llvm-project//mlir:InferTypeOpInterface",
|
||||
# "@llvm-project//mlir:MemRefDialect",
|
||||
# "@llvm-project//mlir:Shape",
|
||||
# "@llvm-project//mlir:SideEffects",
|
||||
# "@llvm-project//mlir:StandardOps",
|
||||
# "@llvm-project//mlir:TensorDialect",
|
||||
# "@llvm-project//mlir:TransformUtils",
|
||||
# "@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -215,8 +226,6 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "hlo_ops_base_inc_gen",
|
||||
strip_include_prefix = "include",
|
||||
|
|
|
@ -2,12 +2,18 @@
|
|||
|
||||
int main() {
|
||||
builder::Builder builder;
|
||||
|
||||
builder::Shape shape({100, 100});
|
||||
auto pType = builder::PrimitiveType::F32();
|
||||
builder::Type type(shape, pType);
|
||||
|
||||
builder::Tensor tensor(std::vector<int>(100));
|
||||
auto op = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||
|
||||
auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||
auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||
auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2);
|
||||
|
||||
builder.DumpModule();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -134,6 +134,9 @@ class Type::Impl {
|
|||
: shape_(shape), primitiveType_(primitiveType) {}
|
||||
Shape GetShape() { return shape_; }
|
||||
PrimitiveType GetType() { return primitiveType_; }
|
||||
mlir::Type GetMlirType(mlir::MLIRContext *context) {
|
||||
return mlir::NoneType::get(context);
|
||||
}
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
namespace builder {
|
||||
|
||||
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
||||
void Builder::DumpModule() {}
|
||||
void Builder::DumpModule() { impl_->DumpModule(); }
|
||||
|
||||
} // namespace builder
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "Builder.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
// #include "llvm/Support/InitLLVM.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -12,12 +13,49 @@
|
|||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
// #include "mlir/InitAllDialects.h"
|
||||
// #include "mlir/InitAllPasses.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
class Builder::Impl {
|
||||
public:
|
||||
Impl() : builder_(&context_) {
|
||||
|
||||
// llvm::InitLLVM y(argc, argv);
|
||||
// llvm::InitializeNativeTarget();
|
||||
// llvm::InitializeNativeTargetAsmPrinter();
|
||||
// llvm::InitializeNativeTargetAsmParser();
|
||||
// mlir::initializeLLVMPasses();
|
||||
// Register any command line options.
|
||||
// registerAsmPrinterCLOptions();
|
||||
// registerMLIRContextCLOptions();
|
||||
// registerPassManagerCLOptions();
|
||||
// registerDefaultTimingManagerCLOptions();
|
||||
// DebugCounter::registerCLOptions();
|
||||
|
||||
// mlir::registerAllPasses();
|
||||
mlir::mhlo::registerAllMhloPasses();
|
||||
mlir::lmhlo::registerAllLmhloPasses();
|
||||
mlir::disc_ral::registerAllDiscRalPasses();
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
// mlir::registerAllToLLVMIRTranslations(registry);
|
||||
// mlir::registerAllDialects(registry);
|
||||
registry.insert<mlir::mhlo::MhloDialect>();
|
||||
// registry.insert<mlir::chlo::HloClientDialect>();
|
||||
// registry.insert<mlir::lmhlo::LmhloDialect>();
|
||||
// registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||
// registry.insert<mlir::disc_ral::RalDialect>();
|
||||
context_.appendDialectRegistry(registry);
|
||||
context_.loadAllAvailableDialects();
|
||||
|
||||
|
||||
|
||||
|
||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||
|
@ -34,6 +72,7 @@ class Builder::Impl {
|
|||
mlir::Location GetLoc() { return builder_.getUnknownLoc(); }
|
||||
mlir::OpBuilder GetBuilder() { return builder_; }
|
||||
mlir::MLIRContext* GetContext() { return &context_; }
|
||||
void DumpModule() { module_.dump(); }
|
||||
|
||||
private:
|
||||
mlir::MLIRContext context_;
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
#include "Op.h"
|
||||
|
||||
#include "OpImpl.h"
|
||||
|
||||
namespace builder {
|
||||
|
||||
Op::Op() : impl_(std::make_shared<Impl>()) {}
|
||||
|
||||
} // namespace builder
|
|
@ -7,6 +7,7 @@
|
|||
namespace builder {
|
||||
class Op {
|
||||
public:
|
||||
Op();
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||
|
||||
|
|
|
@ -1447,19 +1447,28 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
|||
auto op = GetOp();
|
||||
auto operands = op.getOperands();
|
||||
auto attrs = op.getAttributes();
|
||||
auto numResults = op.getNumResults();
|
||||
auto numOperands = op.getNumOperands();
|
||||
SmallVector<std::string, 4> newAttrs;
|
||||
|
||||
// if (attrType == "::mlir::DenseIntElementsAttr") {
|
||||
// body << " // BBBBBBBB getStorageType:" << a.attr.getStorageType().str()
|
||||
// body << " // BBBBBBBB getStorageType:" <<
|
||||
// a.attr.getStorageType().str()
|
||||
// << "\n";
|
||||
// body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str()
|
||||
// << "\n";
|
||||
// }
|
||||
|
||||
body << " auto b = builder.GetImpl();\n";
|
||||
body << " auto loc = b->GetLoc();\n";
|
||||
body << " auto opBuilder = b->GetBuilder();\n";
|
||||
body << " auto ctx = b->GetContext();\n";
|
||||
// for(const auto& a : op.getArgs()){
|
||||
// body << " // BBBBBBBB argument.is:"
|
||||
// << a.is<tblgen::NamedTypeConstraint *>() << "\n";
|
||||
// }
|
||||
// // if (argument.is<tblgen::NamedTypeConstraint *>())
|
||||
|
||||
body << " auto bPtr = builder.GetImpl();\n";
|
||||
body << " auto loc = bPtr->GetLoc();\n";
|
||||
body << " auto opBuilder = bPtr->GetBuilder();\n";
|
||||
body << " auto ctx = bPtr->GetContext();\n";
|
||||
for (auto a : attrs) {
|
||||
std::string attrType = a.attr.getStorageType().str();
|
||||
auto typePair = typeMapMLIR.find(attrType);
|
||||
|
@ -1469,11 +1478,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
|||
newAttrs.emplace_back(mlirName);
|
||||
}
|
||||
}
|
||||
int index = 0;
|
||||
for (auto v : operands) {
|
||||
for (int i = 0; i < numResults; ++i) {
|
||||
const auto &result = op.getResult(i);
|
||||
std::string resultName = std::string(result.name);
|
||||
if (resultName.empty())
|
||||
resultName = std::string(formatv("resultType{0}", i));
|
||||
bool isVec = result.isVariadic();
|
||||
if (isVec) {
|
||||
body << " std::vector<mlir::Type> " << resultName << "_v;\n";
|
||||
body << " for(auto r : " << resultName << "){\n " << resultName
|
||||
<< "_v.push_back(r.GetImpl()->GetMlirType(ctx));\n }"
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < numOperands; i++) {
|
||||
const auto &v = op.getOperand(i);
|
||||
std::string name =
|
||||
v.name.empty() ? "odsArg" + std::to_string(index) : v.name.str();
|
||||
index++;
|
||||
v.name.empty() ? "odsArg" + std::to_string(i) : v.name.str();
|
||||
if (v.isVariadic()) {
|
||||
body << " std::vector<mlir::Value> " << name << "_v;\n";
|
||||
body << " for(auto v : " << name << "){\n " << name
|
||||
|
@ -1488,19 +1510,44 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
|||
<< "::" << op.getCppClassName() << ">(\n";
|
||||
body << " loc";
|
||||
|
||||
index = 0;
|
||||
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
|
||||
std::string name =
|
||||
v.name.empty() ? "odsArg_" + std::to_string(index) : v.name.str();
|
||||
index++;
|
||||
if (v.isVariadic()) {
|
||||
body << ",\n " << name << "_v";
|
||||
for (int i = 0; i < numResults; ++i) {
|
||||
const auto &result = op.getResult(i);
|
||||
std::string resultName = std::string(result.name);
|
||||
if (resultName.empty())
|
||||
resultName = std::string(formatv("resultType{0}", i));
|
||||
bool isVec = result.isVariadic();
|
||||
if (isVec) {
|
||||
body << ",\n mlir::TypeRange(" << resultName << "_v)";
|
||||
} else {
|
||||
body << ",\n " << name << ".GetImpl()->GetResult()";
|
||||
body << ",\n " << resultName << ".GetImpl()->GetMlirType(ctx)";
|
||||
}
|
||||
});
|
||||
std::for_each(newAttrs.begin(), newAttrs.end(),
|
||||
[&](std::string &n) { body << ",\n " << n; });
|
||||
}
|
||||
|
||||
int operandIndex = 0;
|
||||
int attributeIndex = 0;
|
||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
auto argument = op.getArg(i);
|
||||
// if true Operands else Attribute
|
||||
if (argument.is<tblgen::NamedTypeConstraint *>()) {
|
||||
const auto &v = op.getOperand(operandIndex);
|
||||
std::string name = v.name.empty()
|
||||
? "odsArg_" + std::to_string(operandIndex)
|
||||
: v.name.str();
|
||||
if (v.isVariadic()) {
|
||||
body << ",\n mlir::ValueRange(" << name << "_v)";
|
||||
} else {
|
||||
body << ",\n " << name << ".GetImpl()->GetResult()";
|
||||
}
|
||||
operandIndex++;
|
||||
} else {
|
||||
body << ",\n " << newAttrs[attributeIndex];
|
||||
attributeIndex++;
|
||||
}
|
||||
}
|
||||
for (const NamedRegion ®ion : op.getRegions())
|
||||
if (region.isVariadic())
|
||||
body << ",\n " << llvm::formatv("{0}Count", region.name).str();
|
||||
|
||||
body << "\n );\n";
|
||||
body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n";
|
||||
body << " auto opImpl = builderOp.GetImpl();\n";
|
||||
|
@ -1520,11 +1567,13 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
|||
// body << " " << builderOpState
|
||||
// << ".addAttribute(\"operand_segment_sizes\", "
|
||||
// "odsBuilder.getI32VectorAttr({";
|
||||
// interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
|
||||
// interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i)
|
||||
// {
|
||||
// if (op.getOperand(i).isOptional())
|
||||
// body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
|
||||
// else if (op.getOperand(i).isVariadic())
|
||||
// body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
|
||||
// body << "static_cast<int32_t>(" << getArgumentName(op, i) <<
|
||||
// ".size())";
|
||||
// else
|
||||
// body << "1";
|
||||
// });
|
||||
|
@ -1553,14 +1602,17 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
|||
// // here given we use function arguments. So we need to strip the
|
||||
// // wrapping quotes.
|
||||
// if (StringRef(builderTemplate).contains("\"$0\""))
|
||||
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
|
||||
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"",
|
||||
// "$0");
|
||||
|
||||
// std::string value =
|
||||
// std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
|
||||
// body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
|
||||
// body << formatv(" {0}.addAttribute(\"{1}\", {2});\n",
|
||||
// builderOpState,
|
||||
// namedAttr.name, value);
|
||||
// } else {
|
||||
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
|
||||
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n",
|
||||
// builderOpState,
|
||||
// namedAttr.name);
|
||||
// }
|
||||
// if (emitNotNullCheck) {
|
||||
|
|
|
@ -314,7 +314,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
|
|||
os << "class " << className;
|
||||
for (const auto &trait : traitsVec)
|
||||
os << ", " << trait;
|
||||
os << " : public Op {\npublic:\n";
|
||||
os << " : public builder::Op {\npublic:\n";
|
||||
// << " using Op::Op;\n"
|
||||
// << " using Op::print;\n"
|
||||
// << " using Adaptor = " << className << "Adaptor;\n";
|
||||
|
|
Loading…
Reference in New Issue