modify build cc code,add AddOp in test
This commit is contained in:
		
							parent
							
								
									4a9b201c4e
								
							
						
					
					
						commit
						975a47f7b2
					
				
							
								
								
									
										15
									
								
								BUILD
								
								
								
								
							
							
						
						
									
										15
									
								
								BUILD
								
								
								
								
							|  | @ -179,6 +179,7 @@ cc_library( | |||
|         "@llvm-project//mlir:MlirTableGenMain", | ||||
|         "@llvm-project//mlir:Support", | ||||
|         "@llvm-project//mlir:IR", | ||||
|         "@llvm-project//mlir:Pass", | ||||
|         "@llvm-project//llvm:Support", | ||||
|         "@llvm-project//llvm:TableGen", | ||||
|         "@llvm-project//llvm:config", | ||||
|  | @ -193,9 +194,19 @@ cc_library( | |||
|         # "@llvm-project//mlir:AllPassesAndDialects", | ||||
|         # "@llvm-project//mlir:IR", | ||||
|         # "@llvm-project//mlir:MlirOptLib", | ||||
|         # "@llvm-project//mlir:Pass", | ||||
|         # "@llvm-project//mlir:Support", | ||||
|         # "@llvm-project//mlir:MlirJitRunner", | ||||
| 
 | ||||
|         # "@llvm-project//mlir:Analysis", | ||||
|         # "@llvm-project//mlir:ControlFlowInterfaces", | ||||
|         # "@llvm-project//mlir:InferTypeOpInterface", | ||||
|         # "@llvm-project//mlir:MemRefDialect", | ||||
|         # "@llvm-project//mlir:Shape", | ||||
|         # "@llvm-project//mlir:SideEffects", | ||||
|         # "@llvm-project//mlir:StandardOps", | ||||
|         # "@llvm-project//mlir:TensorDialect", | ||||
|         # "@llvm-project//mlir:TransformUtils", | ||||
|         # "@llvm-project//mlir:Transforms", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
|  | @ -215,8 +226,6 @@ cc_test( | |||
|     ], | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| gentbl_cc_library( | ||||
|     name = "hlo_ops_base_inc_gen", | ||||
|     strip_include_prefix = "include", | ||||
|  |  | |||
|  | @ -2,12 +2,18 @@ | |||
| 
 | ||||
| int main() { | ||||
|   builder::Builder builder; | ||||
| 
 | ||||
|   builder::Shape shape({100, 100}); | ||||
|   auto pType = builder::PrimitiveType::F32(); | ||||
|   builder::Type type(shape, pType); | ||||
| 
 | ||||
|   builder::Tensor tensor(std::vector<int>(100)); | ||||
|   auto op = builder::mhlo::ConstOp::build(builder, type, tensor); | ||||
| 
 | ||||
|   auto op1 = builder::mhlo::ConstOp::build(builder, type, tensor); | ||||
|   auto op2 = builder::mhlo::ConstOp::build(builder, type, tensor); | ||||
|   auto op3 = builder::mhlo::AddOp::build(builder, type, op1, op2); | ||||
| 
 | ||||
|   builder.DumpModule(); | ||||
|   return 0; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -134,6 +134,9 @@ class Type::Impl { | |||
|       : shape_(shape), primitiveType_(primitiveType) {} | ||||
|   Shape GetShape() { return shape_; } | ||||
|   PrimitiveType GetType() { return primitiveType_; } | ||||
|   mlir::Type GetMlirType(mlir::MLIRContext *context) { | ||||
|     return mlir::NoneType::get(context); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   Shape shape_; | ||||
|  |  | |||
|  | @ -19,7 +19,7 @@ | |||
| namespace builder { | ||||
| 
 | ||||
| Builder::Builder() : impl_(std::make_shared<Impl>()) {} | ||||
| void Builder::DumpModule() {} | ||||
| void Builder::DumpModule() { impl_->DumpModule(); } | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -3,6 +3,7 @@ | |||
| 
 | ||||
| #include "Builder.h" | ||||
| #include "llvm/Support/Casting.h" | ||||
| // #include "llvm/Support/InitLLVM.h"
 | ||||
| #include "mlir/IR/Attributes.h" | ||||
| #include "mlir/IR/Block.h" | ||||
| #include "mlir/IR/Builders.h" | ||||
|  | @ -12,12 +13,49 @@ | |||
| #include "mlir/IR/Operation.h" | ||||
| #include "mlir/IR/Types.h" | ||||
| #include "mlir/IR/Value.h" | ||||
| // #include "mlir/InitAllDialects.h"
 | ||||
| // #include "mlir/InitAllPasses.h"
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| class Builder::Impl { | ||||
|  public: | ||||
|   Impl() : builder_(&context_) { | ||||
| 
 | ||||
|     // llvm::InitLLVM y(argc, argv);
 | ||||
|     // llvm::InitializeNativeTarget();
 | ||||
|     // llvm::InitializeNativeTargetAsmPrinter();
 | ||||
|     // llvm::InitializeNativeTargetAsmParser();
 | ||||
|     // mlir::initializeLLVMPasses();
 | ||||
|     // Register any command line options.
 | ||||
|     // registerAsmPrinterCLOptions();
 | ||||
|     // registerMLIRContextCLOptions();
 | ||||
|     // registerPassManagerCLOptions();
 | ||||
|     // registerDefaultTimingManagerCLOptions();
 | ||||
|     // DebugCounter::registerCLOptions();
 | ||||
| 
 | ||||
|     // mlir::registerAllPasses();
 | ||||
|     mlir::mhlo::registerAllMhloPasses(); | ||||
|     mlir::lmhlo::registerAllLmhloPasses(); | ||||
|     mlir::disc_ral::registerAllDiscRalPasses(); | ||||
| 
 | ||||
|     mlir::DialectRegistry registry; | ||||
|     // mlir::registerAllToLLVMIRTranslations(registry);
 | ||||
|     // mlir::registerAllDialects(registry);
 | ||||
|     registry.insert<mlir::mhlo::MhloDialect>(); | ||||
|     // registry.insert<mlir::chlo::HloClientDialect>();
 | ||||
|     // registry.insert<mlir::lmhlo::LmhloDialect>();
 | ||||
|     // registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
 | ||||
|     // registry.insert<mlir::disc_ral::RalDialect>();
 | ||||
|     context_.appendDialectRegistry(registry); | ||||
|     context_.loadAllAvailableDialects(); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_)); | ||||
| 
 | ||||
|     llvm::SmallVector<mlir::Type, 4> arg_types; | ||||
|  | @ -34,6 +72,7 @@ class Builder::Impl { | |||
|   mlir::Location GetLoc() { return builder_.getUnknownLoc(); } | ||||
|   mlir::OpBuilder GetBuilder() { return builder_; } | ||||
|   mlir::MLIRContext* GetContext() { return &context_; } | ||||
|   void DumpModule() { module_.dump(); } | ||||
| 
 | ||||
|  private: | ||||
|   mlir::MLIRContext context_; | ||||
|  |  | |||
|  | @ -0,0 +1,9 @@ | |||
| #include "Op.h" | ||||
| 
 | ||||
| #include "OpImpl.h" | ||||
| 
 | ||||
| namespace builder { | ||||
| 
 | ||||
| Op::Op() : impl_(std::make_shared<Impl>()) {} | ||||
| 
 | ||||
| }  // namespace builder
 | ||||
|  | @ -7,6 +7,7 @@ | |||
| namespace builder { | ||||
| class Op { | ||||
|  public: | ||||
|   Op(); | ||||
|   class Impl; | ||||
|   std::shared_ptr<Impl> GetImpl() { return impl_; } | ||||
| 
 | ||||
|  |  | |||
|  | @ -1447,19 +1447,28 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, | |||
|   auto op = GetOp(); | ||||
|   auto operands = op.getOperands(); | ||||
|   auto attrs = op.getAttributes(); | ||||
|   auto numResults = op.getNumResults(); | ||||
|   auto numOperands = op.getNumOperands(); | ||||
|   SmallVector<std::string, 4> newAttrs; | ||||
| 
 | ||||
|   // if (attrType == "::mlir::DenseIntElementsAttr") {
 | ||||
|   //   body << "  //  BBBBBBBB  getStorageType:" << a.attr.getStorageType().str()
 | ||||
|   //   body << "  //  BBBBBBBB  getStorageType:" <<
 | ||||
|   //   a.attr.getStorageType().str()
 | ||||
|   //        << "\n";
 | ||||
|   //   body << "  //  BBBBBBBB  getReturnType:" << a.attr.getReturnType().str()
 | ||||
|   //        << "\n";
 | ||||
|   // }
 | ||||
| 
 | ||||
|   body << "  auto b = builder.GetImpl();\n"; | ||||
|   body << "  auto loc = b->GetLoc();\n"; | ||||
|   body << "  auto opBuilder = b->GetBuilder();\n"; | ||||
|   body << "  auto ctx = b->GetContext();\n"; | ||||
|   // for(const auto& a : op.getArgs()){
 | ||||
|   //   body << "  //  BBBBBBBB argument.is:"
 | ||||
|   //        << a.is<tblgen::NamedTypeConstraint *>() << "\n";
 | ||||
|   // }
 | ||||
|   // // if (argument.is<tblgen::NamedTypeConstraint *>())
 | ||||
| 
 | ||||
|   body << "  auto bPtr = builder.GetImpl();\n"; | ||||
|   body << "  auto loc = bPtr->GetLoc();\n"; | ||||
|   body << "  auto opBuilder = bPtr->GetBuilder();\n"; | ||||
|   body << "  auto ctx = bPtr->GetContext();\n"; | ||||
|   for (auto a : attrs) { | ||||
|     std::string attrType = a.attr.getStorageType().str(); | ||||
|     auto typePair = typeMapMLIR.find(attrType); | ||||
|  | @ -1469,11 +1478,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, | |||
|       newAttrs.emplace_back(mlirName); | ||||
|     } | ||||
|   } | ||||
|   int index = 0; | ||||
|   for (auto v : operands) { | ||||
|   for (int i = 0; i < numResults; ++i) { | ||||
|     const auto &result = op.getResult(i); | ||||
|     std::string resultName = std::string(result.name); | ||||
|     if (resultName.empty()) | ||||
|       resultName = std::string(formatv("resultType{0}", i)); | ||||
|     bool isVec = result.isVariadic(); | ||||
|     if (isVec) { | ||||
|       body << "  std::vector<mlir::Type> " << resultName << "_v;\n"; | ||||
|       body << "  for(auto r : " << resultName << "){\n    " << resultName | ||||
|            << "_v.push_back(r.GetImpl()->GetMlirType(ctx));\n  }" | ||||
|            << "\n"; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   for (int i = 0; i < numOperands; i++) { | ||||
|     const auto &v = op.getOperand(i); | ||||
|     std::string name = | ||||
|         v.name.empty() ? "odsArg" + std::to_string(index) : v.name.str(); | ||||
|     index++; | ||||
|         v.name.empty() ? "odsArg" + std::to_string(i) : v.name.str(); | ||||
|     if (v.isVariadic()) { | ||||
|       body << "  std::vector<mlir::Value> " << name << "_v;\n"; | ||||
|       body << "  for(auto v : " << name << "){\n    " << name | ||||
|  | @ -1488,19 +1510,44 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, | |||
|        << "::" << op.getCppClassName() << ">(\n"; | ||||
|   body << "      loc"; | ||||
| 
 | ||||
|   index = 0; | ||||
|   std::for_each(operands.begin(), operands.end(), [&](NamedTypeConstraint &v) { | ||||
|     std::string name = | ||||
|         v.name.empty() ? "odsArg_" + std::to_string(index) : v.name.str(); | ||||
|     index++; | ||||
|   for (int i = 0; i < numResults; ++i) { | ||||
|     const auto &result = op.getResult(i); | ||||
|     std::string resultName = std::string(result.name); | ||||
|     if (resultName.empty()) | ||||
|       resultName = std::string(formatv("resultType{0}", i)); | ||||
|     bool isVec = result.isVariadic(); | ||||
|     if (isVec) { | ||||
|       body << ",\n      mlir::TypeRange(" << resultName << "_v)"; | ||||
|     } else { | ||||
|       body << ",\n      " << resultName << ".GetImpl()->GetMlirType(ctx)"; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   int operandIndex = 0; | ||||
|   int attributeIndex = 0; | ||||
|   for (int i = 0, e = op.getNumArgs(); i < e; ++i) { | ||||
|     auto argument = op.getArg(i); | ||||
|     // if true Operands else Attribute
 | ||||
|     if (argument.is<tblgen::NamedTypeConstraint *>()) { | ||||
|       const auto &v = op.getOperand(operandIndex); | ||||
|       std::string name = v.name.empty() | ||||
|                              ? "odsArg_" + std::to_string(operandIndex) | ||||
|                              : v.name.str(); | ||||
|       if (v.isVariadic()) { | ||||
|       body << ",\n      " << name << "_v"; | ||||
|         body << ",\n      mlir::ValueRange(" << name << "_v)"; | ||||
|       } else { | ||||
|         body << ",\n      " << name << ".GetImpl()->GetResult()"; | ||||
|       } | ||||
|   }); | ||||
|   std::for_each(newAttrs.begin(), newAttrs.end(), | ||||
|                 [&](std::string &n) { body << ",\n      " << n; }); | ||||
|       operandIndex++; | ||||
|     } else { | ||||
|       body << ",\n      " << newAttrs[attributeIndex]; | ||||
|       attributeIndex++; | ||||
|     } | ||||
|   } | ||||
|   for (const NamedRegion ®ion : op.getRegions()) | ||||
|     if (region.isVariadic()) | ||||
|       body << ",\n      " << llvm::formatv("{0}Count", region.name).str(); | ||||
| 
 | ||||
|   body << "\n    );\n"; | ||||
|   body << "  builder::mhlo::" << op.getCppClassName() << " builderOp;\n"; | ||||
|   body << "  auto opImpl = builderOp.GetImpl();\n"; | ||||
|  | @ -1520,11 +1567,13 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, | |||
|   //   body << "  " << builderOpState
 | ||||
|   //        << ".addAttribute(\"operand_segment_sizes\", "
 | ||||
|   //           "odsBuilder.getI32VectorAttr({";
 | ||||
|   //   interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
 | ||||
|   //   interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i)
 | ||||
|   //   {
 | ||||
|   //     if (op.getOperand(i).isOptional())
 | ||||
|   //       body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
 | ||||
|   //     else if (op.getOperand(i).isVariadic())
 | ||||
|   //       body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
 | ||||
|   //       body << "static_cast<int32_t>(" << getArgumentName(op, i) <<
 | ||||
|   //       ".size())";
 | ||||
|   //     else
 | ||||
|   //       body << "1";
 | ||||
|   //   });
 | ||||
|  | @ -1553,14 +1602,17 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, | |||
|   //       // here given we use function arguments. So we need to strip the
 | ||||
|   //       // wrapping quotes.
 | ||||
|   //       if (StringRef(builderTemplate).contains("\"$0\""))
 | ||||
|   //         builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
 | ||||
|   //         builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"",
 | ||||
|   //         "$0");
 | ||||
| 
 | ||||
|   //       std::string value =
 | ||||
|   //           std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
 | ||||
|   //       body << formatv("  {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
 | ||||
|   //       body << formatv("  {0}.addAttribute(\"{1}\", {2});\n",
 | ||||
|   //       builderOpState,
 | ||||
|   //                       namedAttr.name, value);
 | ||||
|   //     } else {
 | ||||
|   //       body << formatv("  {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
 | ||||
|   //       body << formatv("  {0}.addAttribute(\"{1}\", {1});\n",
 | ||||
|   //       builderOpState,
 | ||||
|   //                       namedAttr.name);
 | ||||
|   //     }
 | ||||
|   //     if (emitNotNullCheck) {
 | ||||
|  |  | |||
|  | @ -314,7 +314,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const { | |||
|   os << "class " << className; | ||||
|   for (const auto &trait : traitsVec) | ||||
|     os << ", " << trait; | ||||
|   os << " : public Op {\npublic:\n"; | ||||
|   os << " : public builder::Op {\npublic:\n"; | ||||
|     //  << "  using Op::Op;\n"
 | ||||
|     //  << "  using Op::print;\n"
 | ||||
|     //  << "  using Adaptor = " << className << "Adaptor;\n";
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue