modify build cc code,add AddOp in test
This commit is contained in:
parent
4a9b201c4e
commit
975a47f7b2
15
BUILD
15
BUILD
|
@ -179,6 +179,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:MlirTableGenMain",
|
"@llvm-project//mlir:MlirTableGenMain",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//llvm:TableGen",
|
"@llvm-project//llvm:TableGen",
|
||||||
"@llvm-project//llvm:config",
|
"@llvm-project//llvm:config",
|
||||||
|
@ -193,9 +194,19 @@ cc_library(
|
||||||
# "@llvm-project//mlir:AllPassesAndDialects",
|
# "@llvm-project//mlir:AllPassesAndDialects",
|
||||||
# "@llvm-project//mlir:IR",
|
# "@llvm-project//mlir:IR",
|
||||||
# "@llvm-project//mlir:MlirOptLib",
|
# "@llvm-project//mlir:MlirOptLib",
|
||||||
# "@llvm-project//mlir:Pass",
|
|
||||||
# "@llvm-project//mlir:Support",
|
# "@llvm-project//mlir:Support",
|
||||||
# "@llvm-project//mlir:MlirJitRunner",
|
# "@llvm-project//mlir:MlirJitRunner",
|
||||||
|
|
||||||
|
# "@llvm-project//mlir:Analysis",
|
||||||
|
# "@llvm-project//mlir:ControlFlowInterfaces",
|
||||||
|
# "@llvm-project//mlir:InferTypeOpInterface",
|
||||||
|
# "@llvm-project//mlir:MemRefDialect",
|
||||||
|
# "@llvm-project//mlir:Shape",
|
||||||
|
# "@llvm-project//mlir:SideEffects",
|
||||||
|
# "@llvm-project//mlir:StandardOps",
|
||||||
|
# "@llvm-project//mlir:TensorDialect",
|
||||||
|
# "@llvm-project//mlir:TransformUtils",
|
||||||
|
# "@llvm-project//mlir:Transforms",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -215,8 +226,6 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
gentbl_cc_library(
|
gentbl_cc_library(
|
||||||
name = "hlo_ops_base_inc_gen",
|
name = "hlo_ops_base_inc_gen",
|
||||||
strip_include_prefix = "include",
|
strip_include_prefix = "include",
|
||||||
|
|
|
@ -2,12 +2,18 @@
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
builder::Builder builder;
|
builder::Builder builder;
|
||||||
|
|
||||||
builder::Shape shape({100, 100});
|
builder::Shape shape({100, 100});
|
||||||
auto pType = builder::PrimitiveType::F32();
|
auto pType = builder::PrimitiveType::F32();
|
||||||
builder::Type type(shape, pType);
|
builder::Type type(shape, pType);
|
||||||
|
|
||||||
builder::Tensor tensor(std::vector<int>(100));
|
builder::Tensor tensor(std::vector<int>(100));
|
||||||
auto op = builder::mhlo::ConstOp::build(builder, type, tensor);
|
|
||||||
|
auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||||
|
auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor);
|
||||||
|
auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2);
|
||||||
|
|
||||||
|
builder.DumpModule();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -134,6 +134,9 @@ class Type::Impl {
|
||||||
: shape_(shape), primitiveType_(primitiveType) {}
|
: shape_(shape), primitiveType_(primitiveType) {}
|
||||||
Shape GetShape() { return shape_; }
|
Shape GetShape() { return shape_; }
|
||||||
PrimitiveType GetType() { return primitiveType_; }
|
PrimitiveType GetType() { return primitiveType_; }
|
||||||
|
mlir::Type GetMlirType(mlir::MLIRContext *context) {
|
||||||
|
return mlir::NoneType::get(context);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape_;
|
Shape shape_;
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
Builder::Builder() : impl_(std::make_shared<Impl>()) {}
|
||||||
void Builder::DumpModule() {}
|
void Builder::DumpModule() { impl_->DumpModule(); }
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
#include "Builder.h"
|
#include "Builder.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
// #include "llvm/Support/InitLLVM.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
@ -12,12 +13,49 @@
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
// #include "mlir/InitAllDialects.h"
|
||||||
|
// #include "mlir/InitAllPasses.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||||
|
|
||||||
namespace builder {
|
namespace builder {
|
||||||
|
|
||||||
class Builder::Impl {
|
class Builder::Impl {
|
||||||
public:
|
public:
|
||||||
Impl() : builder_(&context_) {
|
Impl() : builder_(&context_) {
|
||||||
|
|
||||||
|
// llvm::InitLLVM y(argc, argv);
|
||||||
|
// llvm::InitializeNativeTarget();
|
||||||
|
// llvm::InitializeNativeTargetAsmPrinter();
|
||||||
|
// llvm::InitializeNativeTargetAsmParser();
|
||||||
|
// mlir::initializeLLVMPasses();
|
||||||
|
// Register any command line options.
|
||||||
|
// registerAsmPrinterCLOptions();
|
||||||
|
// registerMLIRContextCLOptions();
|
||||||
|
// registerPassManagerCLOptions();
|
||||||
|
// registerDefaultTimingManagerCLOptions();
|
||||||
|
// DebugCounter::registerCLOptions();
|
||||||
|
|
||||||
|
// mlir::registerAllPasses();
|
||||||
|
mlir::mhlo::registerAllMhloPasses();
|
||||||
|
mlir::lmhlo::registerAllLmhloPasses();
|
||||||
|
mlir::disc_ral::registerAllDiscRalPasses();
|
||||||
|
|
||||||
|
mlir::DialectRegistry registry;
|
||||||
|
// mlir::registerAllToLLVMIRTranslations(registry);
|
||||||
|
// mlir::registerAllDialects(registry);
|
||||||
|
registry.insert<mlir::mhlo::MhloDialect>();
|
||||||
|
// registry.insert<mlir::chlo::HloClientDialect>();
|
||||||
|
// registry.insert<mlir::lmhlo::LmhloDialect>();
|
||||||
|
// registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||||
|
// registry.insert<mlir::disc_ral::RalDialect>();
|
||||||
|
context_.appendDialectRegistry(registry);
|
||||||
|
context_.loadAllAvailableDialects();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
|
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_));
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||||
|
@ -34,6 +72,7 @@ class Builder::Impl {
|
||||||
mlir::Location GetLoc() { return builder_.getUnknownLoc(); }
|
mlir::Location GetLoc() { return builder_.getUnknownLoc(); }
|
||||||
mlir::OpBuilder GetBuilder() { return builder_; }
|
mlir::OpBuilder GetBuilder() { return builder_; }
|
||||||
mlir::MLIRContext* GetContext() { return &context_; }
|
mlir::MLIRContext* GetContext() { return &context_; }
|
||||||
|
void DumpModule() { module_.dump(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mlir::MLIRContext context_;
|
mlir::MLIRContext context_;
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
#include "Op.h"
|
||||||
|
|
||||||
|
#include "OpImpl.h"
|
||||||
|
|
||||||
|
namespace builder {
|
||||||
|
|
||||||
|
Op::Op() : impl_(std::make_shared<Impl>()) {}
|
||||||
|
|
||||||
|
} // namespace builder
|
|
@ -7,6 +7,7 @@
|
||||||
namespace builder {
|
namespace builder {
|
||||||
class Op {
|
class Op {
|
||||||
public:
|
public:
|
||||||
|
Op();
|
||||||
class Impl;
|
class Impl;
|
||||||
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
std::shared_ptr<Impl> GetImpl() { return impl_; }
|
||||||
|
|
||||||
|
|
|
@ -1447,19 +1447,28 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||||
auto op = GetOp();
|
auto op = GetOp();
|
||||||
auto operands = op.getOperands();
|
auto operands = op.getOperands();
|
||||||
auto attrs = op.getAttributes();
|
auto attrs = op.getAttributes();
|
||||||
|
auto numResults = op.getNumResults();
|
||||||
|
auto numOperands = op.getNumOperands();
|
||||||
SmallVector<std::string, 4> newAttrs;
|
SmallVector<std::string, 4> newAttrs;
|
||||||
|
|
||||||
// if (attrType == "::mlir::DenseIntElementsAttr") {
|
// if (attrType == "::mlir::DenseIntElementsAttr") {
|
||||||
// body << " // BBBBBBBB getStorageType:" << a.attr.getStorageType().str()
|
// body << " // BBBBBBBB getStorageType:" <<
|
||||||
|
// a.attr.getStorageType().str()
|
||||||
// << "\n";
|
// << "\n";
|
||||||
// body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str()
|
// body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str()
|
||||||
// << "\n";
|
// << "\n";
|
||||||
// }
|
// }
|
||||||
|
|
||||||
body << " auto b = builder.GetImpl();\n";
|
// for(const auto& a : op.getArgs()){
|
||||||
body << " auto loc = b->GetLoc();\n";
|
// body << " // BBBBBBBB argument.is:"
|
||||||
body << " auto opBuilder = b->GetBuilder();\n";
|
// << a.is<tblgen::NamedTypeConstraint *>() << "\n";
|
||||||
body << " auto ctx = b->GetContext();\n";
|
// }
|
||||||
|
// // if (argument.is<tblgen::NamedTypeConstraint *>())
|
||||||
|
|
||||||
|
body << " auto bPtr = builder.GetImpl();\n";
|
||||||
|
body << " auto loc = bPtr->GetLoc();\n";
|
||||||
|
body << " auto opBuilder = bPtr->GetBuilder();\n";
|
||||||
|
body << " auto ctx = bPtr->GetContext();\n";
|
||||||
for (auto a : attrs) {
|
for (auto a : attrs) {
|
||||||
std::string attrType = a.attr.getStorageType().str();
|
std::string attrType = a.attr.getStorageType().str();
|
||||||
auto typePair = typeMapMLIR.find(attrType);
|
auto typePair = typeMapMLIR.find(attrType);
|
||||||
|
@ -1469,11 +1478,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||||
newAttrs.emplace_back(mlirName);
|
newAttrs.emplace_back(mlirName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int index = 0;
|
for (int i = 0; i < numResults; ++i) {
|
||||||
for (auto v : operands) {
|
const auto &result = op.getResult(i);
|
||||||
|
std::string resultName = std::string(result.name);
|
||||||
|
if (resultName.empty())
|
||||||
|
resultName = std::string(formatv("resultType{0}", i));
|
||||||
|
bool isVec = result.isVariadic();
|
||||||
|
if (isVec) {
|
||||||
|
body << " std::vector<mlir::Type> " << resultName << "_v;\n";
|
||||||
|
body << " for(auto r : " << resultName << "){\n " << resultName
|
||||||
|
<< "_v.push_back(r.GetImpl()->GetMlirType(ctx));\n }"
|
||||||
|
<< "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < numOperands; i++) {
|
||||||
|
const auto &v = op.getOperand(i);
|
||||||
std::string name =
|
std::string name =
|
||||||
v.name.empty() ? "odsArg" + std::to_string(index) : v.name.str();
|
v.name.empty() ? "odsArg" + std::to_string(i) : v.name.str();
|
||||||
index++;
|
|
||||||
if (v.isVariadic()) {
|
if (v.isVariadic()) {
|
||||||
body << " std::vector<mlir::Value> " << name << "_v;\n";
|
body << " std::vector<mlir::Value> " << name << "_v;\n";
|
||||||
body << " for(auto v : " << name << "){\n " << name
|
body << " for(auto v : " << name << "){\n " << name
|
||||||
|
@ -1488,19 +1510,44 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||||
<< "::" << op.getCppClassName() << ">(\n";
|
<< "::" << op.getCppClassName() << ">(\n";
|
||||||
body << " loc";
|
body << " loc";
|
||||||
|
|
||||||
index = 0;
|
for (int i = 0; i < numResults; ++i) {
|
||||||
std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) {
|
const auto &result = op.getResult(i);
|
||||||
std::string name =
|
std::string resultName = std::string(result.name);
|
||||||
v.name.empty() ? "odsArg_" + std::to_string(index) : v.name.str();
|
if (resultName.empty())
|
||||||
index++;
|
resultName = std::string(formatv("resultType{0}", i));
|
||||||
|
bool isVec = result.isVariadic();
|
||||||
|
if (isVec) {
|
||||||
|
body << ",\n mlir::TypeRange(" << resultName << "_v)";
|
||||||
|
} else {
|
||||||
|
body << ",\n " << resultName << ".GetImpl()->GetMlirType(ctx)";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int operandIndex = 0;
|
||||||
|
int attributeIndex = 0;
|
||||||
|
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||||
|
auto argument = op.getArg(i);
|
||||||
|
// if true Operands else Attribute
|
||||||
|
if (argument.is<tblgen::NamedTypeConstraint *>()) {
|
||||||
|
const auto &v = op.getOperand(operandIndex);
|
||||||
|
std::string name = v.name.empty()
|
||||||
|
? "odsArg_" + std::to_string(operandIndex)
|
||||||
|
: v.name.str();
|
||||||
if (v.isVariadic()) {
|
if (v.isVariadic()) {
|
||||||
body << ",\n " << name << "_v";
|
body << ",\n mlir::ValueRange(" << name << "_v)";
|
||||||
} else {
|
} else {
|
||||||
body << ",\n " << name << ".GetImpl()->GetResult()";
|
body << ",\n " << name << ".GetImpl()->GetResult()";
|
||||||
}
|
}
|
||||||
});
|
operandIndex++;
|
||||||
std::for_each(newAttrs.begin(), newAttrs.end(),
|
} else {
|
||||||
[&](std::string &n) { body << ",\n " << n; });
|
body << ",\n " << newAttrs[attributeIndex];
|
||||||
|
attributeIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const NamedRegion ®ion : op.getRegions())
|
||||||
|
if (region.isVariadic())
|
||||||
|
body << ",\n " << llvm::formatv("{0}Count", region.name).str();
|
||||||
|
|
||||||
body << "\n );\n";
|
body << "\n );\n";
|
||||||
body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n";
|
body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n";
|
||||||
body << " auto opImpl = builderOp.GetImpl();\n";
|
body << " auto opImpl = builderOp.GetImpl();\n";
|
||||||
|
@ -1520,11 +1567,13 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||||
// body << " " << builderOpState
|
// body << " " << builderOpState
|
||||||
// << ".addAttribute(\"operand_segment_sizes\", "
|
// << ".addAttribute(\"operand_segment_sizes\", "
|
||||||
// "odsBuilder.getI32VectorAttr({";
|
// "odsBuilder.getI32VectorAttr({";
|
||||||
// interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
|
// interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i)
|
||||||
|
// {
|
||||||
// if (op.getOperand(i).isOptional())
|
// if (op.getOperand(i).isOptional())
|
||||||
// body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
|
// body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
|
||||||
// else if (op.getOperand(i).isVariadic())
|
// else if (op.getOperand(i).isVariadic())
|
||||||
// body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
|
// body << "static_cast<int32_t>(" << getArgumentName(op, i) <<
|
||||||
|
// ".size())";
|
||||||
// else
|
// else
|
||||||
// body << "1";
|
// body << "1";
|
||||||
// });
|
// });
|
||||||
|
@ -1553,14 +1602,17 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||||
// // here given we use function arguments. So we need to strip the
|
// // here given we use function arguments. So we need to strip the
|
||||||
// // wrapping quotes.
|
// // wrapping quotes.
|
||||||
// if (StringRef(builderTemplate).contains("\"$0\""))
|
// if (StringRef(builderTemplate).contains("\"$0\""))
|
||||||
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
|
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"",
|
||||||
|
// "$0");
|
||||||
|
|
||||||
// std::string value =
|
// std::string value =
|
||||||
// std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
|
// std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
|
||||||
// body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
|
// body << formatv(" {0}.addAttribute(\"{1}\", {2});\n",
|
||||||
|
// builderOpState,
|
||||||
// namedAttr.name, value);
|
// namedAttr.name, value);
|
||||||
// } else {
|
// } else {
|
||||||
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
|
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n",
|
||||||
|
// builderOpState,
|
||||||
// namedAttr.name);
|
// namedAttr.name);
|
||||||
// }
|
// }
|
||||||
// if (emitNotNullCheck) {
|
// if (emitNotNullCheck) {
|
||||||
|
|
|
@ -314,7 +314,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
|
||||||
os << "class " << className;
|
os << "class " << className;
|
||||||
for (const auto &trait : traitsVec)
|
for (const auto &trait : traitsVec)
|
||||||
os << ", " << trait;
|
os << ", " << trait;
|
||||||
os << " : public Op {\npublic:\n";
|
os << " : public builder::Op {\npublic:\n";
|
||||||
// << " using Op::Op;\n"
|
// << " using Op::Op;\n"
|
||||||
// << " using Op::print;\n"
|
// << " using Op::print;\n"
|
||||||
// << " using Adaptor = " << className << "Adaptor;\n";
|
// << " using Adaptor = " << className << "Adaptor;\n";
|
||||||
|
|
Loading…
Reference in New Issue