diff --git a/BUILD b/BUILD index 1ab7320..4459d8b 100644 --- a/BUILD +++ b/BUILD @@ -179,6 +179,7 @@ cc_library( "@llvm-project//mlir:MlirTableGenMain", "@llvm-project//mlir:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//llvm:Support", "@llvm-project//llvm:TableGen", "@llvm-project//llvm:config", @@ -193,9 +194,19 @@ cc_library( # "@llvm-project//mlir:AllPassesAndDialects", # "@llvm-project//mlir:IR", # "@llvm-project//mlir:MlirOptLib", - # "@llvm-project//mlir:Pass", # "@llvm-project//mlir:Support", # "@llvm-project//mlir:MlirJitRunner", + + # "@llvm-project//mlir:Analysis", + # "@llvm-project//mlir:ControlFlowInterfaces", + # "@llvm-project//mlir:InferTypeOpInterface", + # "@llvm-project//mlir:MemRefDialect", + # "@llvm-project//mlir:Shape", + # "@llvm-project//mlir:SideEffects", + # "@llvm-project//mlir:StandardOps", + # "@llvm-project//mlir:TensorDialect", + # "@llvm-project//mlir:TransformUtils", + # "@llvm-project//mlir:Transforms", ], ) @@ -215,8 +226,6 @@ cc_test( ], ) - - gentbl_cc_library( name = "hlo_ops_base_inc_gen", strip_include_prefix = "include", diff --git a/tests/mlir-tblgen-builder/test_basic.cpp b/tests/mlir-tblgen-builder/test_basic.cpp index f87ac3d..2604215 100644 --- a/tests/mlir-tblgen-builder/test_basic.cpp +++ b/tests/mlir-tblgen-builder/test_basic.cpp @@ -2,12 +2,18 @@ int main() { builder::Builder builder; + builder::Shape shape({100, 100}); auto pType = builder::PrimitiveType::F32(); builder::Type type(shape, pType); builder::Tensor tensor(std::vector(100)); - auto op = builder::mhlo::ConstOp::build(builder, type, tensor); + + auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor); + auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor); + auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2); + + builder.DumpModule(); return 0; } diff --git a/tools/mlir-tblgen-builder/Builder/AttributeImpl.h b/tools/mlir-tblgen-builder/Builder/AttributeImpl.h index 34a5560..928dd9d 100644 --- a/tools/mlir-tblgen-builder/Builder/AttributeImpl.h +++ b/tools/mlir-tblgen-builder/Builder/AttributeImpl.h @@ -134,6 +134,9 @@ class Type::Impl { : shape_(shape), primitiveType_(primitiveType) {} Shape GetShape() { return shape_; } PrimitiveType GetType() { return primitiveType_; } + mlir::Type GetMlirType(mlir::MLIRContext *context) { + return mlir::NoneType::get(context); + } private: Shape shape_; diff --git a/tools/mlir-tblgen-builder/Builder/Builder.cpp b/tools/mlir-tblgen-builder/Builder/Builder.cpp index 139357d..3fbff45 100644 --- a/tools/mlir-tblgen-builder/Builder/Builder.cpp +++ b/tools/mlir-tblgen-builder/Builder/Builder.cpp @@ -19,7 +19,7 @@ namespace builder { Builder::Builder() : impl_(std::make_shared()) {} -void Builder::DumpModule() {} +void Builder::DumpModule() { impl_->DumpModule(); } } // namespace builder diff --git a/tools/mlir-tblgen-builder/Builder/BuilderImpl.h b/tools/mlir-tblgen-builder/Builder/BuilderImpl.h index e0ec3cc..4c1b8c0 100644 --- a/tools/mlir-tblgen-builder/Builder/BuilderImpl.h +++ b/tools/mlir-tblgen-builder/Builder/BuilderImpl.h @@ -3,6 +3,7 @@ #include "Builder.h" #include "llvm/Support/Casting.h" +// #include "llvm/Support/InitLLVM.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" @@ -12,12 +13,49 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +// #include "mlir/InitAllDialects.h" +// #include "mlir/InitAllPasses.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" namespace builder { class Builder::Impl { public: Impl() : builder_(&context_) { + + // llvm::InitLLVM y(argc, argv); + // llvm::InitializeNativeTarget(); + // llvm::InitializeNativeTargetAsmPrinter(); + // llvm::InitializeNativeTargetAsmParser(); + // mlir::initializeLLVMPasses(); + // Register any command line options. + // registerAsmPrinterCLOptions(); + // registerMLIRContextCLOptions(); + // registerPassManagerCLOptions(); + // registerDefaultTimingManagerCLOptions(); + // DebugCounter::registerCLOptions(); + + // mlir::registerAllPasses(); + mlir::mhlo::registerAllMhloPasses(); + mlir::lmhlo::registerAllLmhloPasses(); + mlir::disc_ral::registerAllDiscRalPasses(); + + mlir::DialectRegistry registry; + // mlir::registerAllToLLVMIRTranslations(registry); + // mlir::registerAllDialects(registry); + registry.insert(); + // registry.insert(); + // registry.insert(); + // registry.insert(); + // registry.insert(); + context_.appendDialectRegistry(registry); + context_.loadAllAvailableDialects(); + + + + module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_)); llvm::SmallVector arg_types; @@ -34,6 +72,7 @@ class Builder::Impl { mlir::Location GetLoc() { return builder_.getUnknownLoc(); } mlir::OpBuilder GetBuilder() { return builder_; } mlir::MLIRContext* GetContext() { return &context_; } + void DumpModule() { module_.dump(); } private: mlir::MLIRContext context_; diff --git a/tools/mlir-tblgen-builder/Builder/Op.cpp b/tools/mlir-tblgen-builder/Builder/Op.cpp new file mode 100644 index 0000000..e351ecc --- /dev/null +++ b/tools/mlir-tblgen-builder/Builder/Op.cpp @@ -0,0 +1,9 @@ +#include "Op.h" + +#include "OpImpl.h" + +namespace builder { + +Op::Op() : impl_(std::make_shared()) {} + +} // namespace builder \ No newline at end of file diff --git a/tools/mlir-tblgen-builder/Builder/Op.h b/tools/mlir-tblgen-builder/Builder/Op.h index 3bf8293..4ce2fdf 100644 --- a/tools/mlir-tblgen-builder/Builder/Op.h +++ b/tools/mlir-tblgen-builder/Builder/Op.h @@ -7,6 +7,7 @@ namespace builder { class Op { public: + Op(); class Impl; std::shared_ptr GetImpl() { return impl_; } diff --git a/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp b/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp index 41e85d3..eab212e 100644 --- a/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp +++ b/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp @@ -1447,19 +1447,28 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, auto op = GetOp(); auto operands = op.getOperands(); auto attrs = op.getAttributes(); + auto numResults = op.getNumResults(); + auto numOperands = op.getNumOperands(); SmallVector newAttrs; // if (attrType == "::mlir::DenseIntElementsAttr") { - // body << " // BBBBBBBB getStorageType:" << a.attr.getStorageType().str() + // body << " // BBBBBBBB getStorageType:" << + // a.attr.getStorageType().str() // << "\n"; // body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str() // << "\n"; // } - body << " auto b = builder.GetImpl();\n"; - body << " auto loc = b->GetLoc();\n"; - body << " auto opBuilder = b->GetBuilder();\n"; - body << " auto ctx = b->GetContext();\n"; + // for(const auto& a : op.getArgs()){ + // body << " // BBBBBBBB argument.is:" + // << a.is() << "\n"; + // } + // // if (argument.is()) + + body << " auto bPtr = builder.GetImpl();\n"; + body << " auto loc = bPtr->GetLoc();\n"; + body << " auto opBuilder = bPtr->GetBuilder();\n"; + body << " auto ctx = bPtr->GetContext();\n"; for (auto a : attrs) { std::string attrType = a.attr.getStorageType().str(); auto typePair = typeMapMLIR.find(attrType); @@ -1469,11 +1478,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, newAttrs.emplace_back(mlirName); } } - int index = 0; - for (auto v : operands) { + for (int i = 0; i < numResults; ++i) { + const auto &result = op.getResult(i); + std::string resultName = std::string(result.name); + if (resultName.empty()) + resultName = std::string(formatv("resultType{0}", i)); + bool isVec = result.isVariadic(); + if (isVec) { + body << " std::vector " << resultName << "_v;\n"; + body << " for(auto r : " << resultName << "){\n " << resultName + << "_v.push_back(r.GetImpl()->GetMlirType(ctx));\n }" + << "\n"; + } + } + + for (int i = 0; i < numOperands; i++) { + const auto &v = op.getOperand(i); std::string name = - v.name.empty() ? "odsArg" + std::to_string(index) : v.name.str(); - index++; + v.name.empty() ? "odsArg" + std::to_string(i) : v.name.str(); if (v.isVariadic()) { body << " std::vector " << name << "_v;\n"; body << " for(auto v : " << name << "){\n " << name @@ -1488,19 +1510,44 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, << "::" << op.getCppClassName() << ">(\n"; body << " loc"; - index = 0; - std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) { - std::string name = - v.name.empty() ? "odsArg_" + std::to_string(index) : v.name.str(); - index++; - if (v.isVariadic()) { - body << ",\n " << name << "_v"; + for (int i = 0; i < numResults; ++i) { + const auto &result = op.getResult(i); + std::string resultName = std::string(result.name); + if (resultName.empty()) + resultName = std::string(formatv("resultType{0}", i)); + bool isVec = result.isVariadic(); + if (isVec) { + body << ",\n mlir::TypeRange(" << resultName << "_v)"; } else { - body << ",\n " << name << ".GetImpl()->GetResult()"; + body << ",\n " << resultName << ".GetImpl()->GetMlirType(ctx)"; } - }); - std::for_each(newAttrs.begin(), newAttrs.end(), - [&](std::string &n) { body << ",\n " << n; }); + } + + int operandIndex = 0; + int attributeIndex = 0; + for (int i = 0, e = op.getNumArgs(); i < e; ++i) { + auto argument = op.getArg(i); + // if true Operands else Attribute + if (argument.is()) { + const auto &v = op.getOperand(operandIndex); + std::string name = v.name.empty() + ? "odsArg_" + std::to_string(operandIndex) + : v.name.str(); + if (v.isVariadic()) { + body << ",\n mlir::ValueRange(" << name << "_v)"; + } else { + body << ",\n " << name << ".GetImpl()->GetResult()"; + } + operandIndex++; + } else { + body << ",\n " << newAttrs[attributeIndex]; + attributeIndex++; + } + } + for (const NamedRegion ®ion : op.getRegions()) + if (region.isVariadic()) + body << ",\n " << llvm::formatv("{0}Count", region.name).str(); + body << "\n );\n"; body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n"; body << " auto opImpl = builderOp.GetImpl();\n"; @@ -1520,11 +1567,13 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, // body << " " << builderOpState // << ".addAttribute(\"operand_segment_sizes\", " // "odsBuilder.getI32VectorAttr({"; - // interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { + // interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) + // { // if (op.getOperand(i).isOptional()) // body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; // else if (op.getOperand(i).isVariadic()) - // body << "static_cast(" << getArgumentName(op, i) << ".size())"; + // body << "static_cast(" << getArgumentName(op, i) << + // ".size())"; // else // body << "1"; // }); @@ -1553,14 +1602,17 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, // // here given we use function arguments. So we need to strip the // // wrapping quotes. // if (StringRef(builderTemplate).contains("\"$0\"")) - // builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); + // builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", + // "$0"); // std::string value = // std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); - // body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, + // body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", + // builderOpState, // namedAttr.name, value); // } else { - // body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, + // body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", + // builderOpState, // namedAttr.name); // } // if (emitNotNullCheck) { diff --git a/tools/mlir-tblgen-builder/TableGen/OpClass.cpp b/tools/mlir-tblgen-builder/TableGen/OpClass.cpp index 2579219..ee7aa5a 100644 --- a/tools/mlir-tblgen-builder/TableGen/OpClass.cpp +++ b/tools/mlir-tblgen-builder/TableGen/OpClass.cpp @@ -314,7 +314,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const { os << "class " << className; for (const auto &trait : traitsVec) os << ", " << trait; - os << " : public Op {\npublic:\n"; + os << " : public builder::Op {\npublic:\n"; // << " using Op::Op;\n" // << " using Op::print;\n" // << " using Adaptor = " << className << "Adaptor;\n";