modify build cc code,add AddOp in test

This commit is contained in:
colin.liang 2021-08-16 15:35:37 +08:00
parent 4a9b201c4e
commit 975a47f7b2
9 changed files with 150 additions and 31 deletions

15
BUILD
View File

@ -179,6 +179,7 @@ cc_library(
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
@ -193,9 +194,19 @@ cc_library(
# "@llvm-project//mlir:AllPassesAndDialects",
# "@llvm-project//mlir:IR",
# "@llvm-project//mlir:MlirOptLib",
# "@llvm-project//mlir:Pass",
# "@llvm-project//mlir:Support",
# "@llvm-project//mlir:MlirJitRunner",
# "@llvm-project//mlir:Analysis",
# "@llvm-project//mlir:ControlFlowInterfaces",
# "@llvm-project//mlir:InferTypeOpInterface",
# "@llvm-project//mlir:MemRefDialect",
# "@llvm-project//mlir:Shape",
# "@llvm-project//mlir:SideEffects",
# "@llvm-project//mlir:StandardOps",
# "@llvm-project//mlir:TensorDialect",
# "@llvm-project//mlir:TransformUtils",
# "@llvm-project//mlir:Transforms",
],
)
@ -215,8 +226,6 @@ cc_test(
],
)
gentbl_cc_library(
name = "hlo_ops_base_inc_gen",
strip_include_prefix = "include",

View File

@ -2,12 +2,18 @@
int main() {
builder::Builder builder;
builder::Shape shape({100, 100});
auto pType = builder::PrimitiveType::F32();
builder::Type type(shape, pType);
builder::Tensor tensor(std::vector<int>(100));
auto op = builder::mhlo::ConstOp::build(builder, type, tensor);
auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor);
auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor);
auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2);
builder.DumpModule();
return 0;
}

View File

@ -134,6 +134,9 @@ class Type::Impl {
: shape_(shape), primitiveType_(primitiveType) {}
Shape GetShape() { return shape_; }
PrimitiveType GetType() { return primitiveType_; }
mlir::Type GetMlirType(mlir::MLIRContext *context) {
return mlir::NoneType::get(context);
}
private:
Shape shape_;

View File

@ -19,7 +19,7 @@
namespace builder {
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
void Builder::DumpModule() {}
void Builder::DumpModule() { impl_->DumpModule(); }
} // namespace builder

View File

@ -3,6 +3,7 @@
#include "Builder.h"
#include "llvm/Support/Casting.h"
// #include "llvm/Support/InitLLVM.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
@ -12,12 +13,49 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
// #include "mlir/InitAllDialects.h"
// #include "mlir/InitAllPasses.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
namespace builder {
class Builder::Impl {
public:
Impl() : builder_(&context_) {
// llvm::InitLLVM y(argc, argv);
// llvm::InitializeNativeTarget();
// llvm::InitializeNativeTargetAsmPrinter();
// llvm::InitializeNativeTargetAsmParser();
// mlir::initializeLLVMPasses();
// Register any command line options.
// registerAsmPrinterCLOptions();
// registerMLIRContextCLOptions();
// registerPassManagerCLOptions();
// registerDefaultTimingManagerCLOptions();
// DebugCounter::registerCLOptions();
// mlir::registerAllPasses();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
mlir::disc_ral::registerAllDiscRalPasses();
mlir::DialectRegistry registry;
// mlir::registerAllToLLVMIRTranslations(registry);
// mlir::registerAllDialects(registry);
registry.insert<mlir::mhlo::MhloDialect>();
// registry.insert<mlir::chlo::HloClientDialect>();
// registry.insert<mlir::lmhlo::LmhloDialect>();
// registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
// registry.insert<mlir::disc_ral::RalDialect>();
context_.appendDialectRegistry(registry);
context_.loadAllAvailableDialects();
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
llvm::SmallVector<mlir::Type, 4> arg_types;
@ -34,6 +72,7 @@ class Builder::Impl {
mlir::Location GetLoc() { return builder_.getUnknownLoc(); }
mlir::OpBuilder GetBuilder() { return builder_; }
mlir::MLIRContext* GetContext() { return &context_; }
void DumpModule() { module_.dump(); }
private:
mlir::MLIRContext context_;

View File

@ -0,0 +1,9 @@
#include "Op.h"
#include "OpImpl.h"
namespace builder {
Op::Op() : impl_(std::make_shared<Impl>()) {}
} // namespace builder

View File

@ -7,6 +7,7 @@
namespace builder {
class Op {
public:
Op();
class Impl;
std::shared_ptr<Impl> GetImpl() { return impl_; }

View File

@ -1447,19 +1447,28 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
auto op = GetOp();
auto operands = op.getOperands();
auto attrs = op.getAttributes();
auto numResults = op.getNumResults();
auto numOperands = op.getNumOperands();
SmallVector<std::string, 4> newAttrs;
// if (attrType == "::mlir::DenseIntElementsAttr") {
// body << " // BBBBBBBB getStorageType:" << a.attr.getStorageType().str()
// body << " // BBBBBBBB getStorageType:" <<
// a.attr.getStorageType().str()
// << "\n";
// body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str()
// << "\n";
// }
body << " auto b = builder.GetImpl();\n";
body << " auto loc = b->GetLoc();\n";
body << " auto opBuilder = b->GetBuilder();\n";
body << " auto ctx = b->GetContext();\n";
// for(const auto& a : op.getArgs()){
// body << " // BBBBBBBB argument.is:"
// << a.is<tblgen::NamedTypeConstraint *>() << "\n";
// }
// // if (argument.is<tblgen::NamedTypeConstraint *>())
body << " auto bPtr = builder.GetImpl();\n";
body << " auto loc = bPtr->GetLoc();\n";
body << " auto opBuilder = bPtr->GetBuilder();\n";
body << " auto ctx = bPtr->GetContext();\n";
for (auto a : attrs) {
std::string attrType = a.attr.getStorageType().str();
auto typePair = typeMapMLIR.find(attrType);
@ -1469,11 +1478,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
newAttrs.emplace_back(mlirName);
}
}
int index = 0;
for (auto v : operands) {
for (int i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
std::string resultName = std::string(result.name);
if (resultName.empty())
resultName = std::string(formatv("resultType{0}", i));
bool isVec = result.isVariadic();
if (isVec) {
body << " std::vector<mlir::Type> " << resultName << "_v;\n";
body << " for(auto r : " << resultName << "){\n " << resultName
<< "_v.push_back(r.GetImpl()->GetMlirType(ctx));\n }"
<< "\n";
}
}
for (int i = 0; i < numOperands; i++) {
const auto &v = op.getOperand(i);
std::string name =
v.name.empty() ? "odsArg" + std::to_string(index) : v.name.str();
index++;
v.name.empty() ? "odsArg" + std::to_string(i) : v.name.str();
if (v.isVariadic()) {
body << " std::vector<mlir::Value> " << name << "_v;\n";
body << " for(auto v : " << name << "){\n " << name
@ -1488,19 +1510,44 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
<< "::" << op.getCppClassName() << ">(\n";
body << " loc";
index = 0;
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
std::string name =
v.name.empty() ? "odsArg_" + std::to_string(index) : v.name.str();
index++;
for (int i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
std::string resultName = std::string(result.name);
if (resultName.empty())
resultName = std::string(formatv("resultType{0}", i));
bool isVec = result.isVariadic();
if (isVec) {
body << ",\n mlir::TypeRange(" << resultName << "_v)";
} else {
body << ",\n " << resultName << ".GetImpl()->GetMlirType(ctx)";
}
}
int operandIndex = 0;
int attributeIndex = 0;
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
// if true Operands else Attribute
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &v = op.getOperand(operandIndex);
std::string name = v.name.empty()
? "odsArg_" + std::to_string(operandIndex)
: v.name.str();
if (v.isVariadic()) {
body << ",\n " << name << "_v";
body << ",\n mlir::ValueRange(" << name << "_v)";
} else {
body << ",\n " << name << ".GetImpl()->GetResult()";
}
});
std::for_each(newAttrs.begin(), newAttrs.end(),
[&](std::string &n) { body << ",\n " << n; });
operandIndex++;
} else {
body << ",\n " << newAttrs[attributeIndex];
attributeIndex++;
}
}
for (const NamedRegion &region : op.getRegions())
if (region.isVariadic())
body << ",\n " << llvm::formatv("{0}Count", region.name).str();
body << "\n );\n";
body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n";
body << " auto opImpl = builderOp.GetImpl();\n";
@ -1520,11 +1567,13 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
// body << " " << builderOpState
// << ".addAttribute(\"operand_segment_sizes\", "
// "odsBuilder.getI32VectorAttr({";
// interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
// interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i)
// {
// if (op.getOperand(i).isOptional())
// body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
// else if (op.getOperand(i).isVariadic())
// body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
// body << "static_cast<int32_t>(" << getArgumentName(op, i) <<
// ".size())";
// else
// body << "1";
// });
@ -1553,14 +1602,17 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
// // here given we use function arguments. So we need to strip the
// // wrapping quotes.
// if (StringRef(builderTemplate).contains("\"$0\""))
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"",
// "$0");
// std::string value =
// std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
// body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
// body << formatv(" {0}.addAttribute(\"{1}\", {2});\n",
// builderOpState,
// namedAttr.name, value);
// } else {
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n",
// builderOpState,
// namedAttr.name);
// }
// if (emitNotNullCheck) {

View File

@ -314,7 +314,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className;
for (const auto &trait : traitsVec)
os << ", " << trait;
os << " : public Op {\npublic:\n";
os << " : public builder::Op {\npublic:\n";
// << " using Op::Op;\n"
// << " using Op::print;\n"
// << " using Adaptor = " << className << "Adaptor;\n";