diff --git a/BUILD b/BUILD index 9a45be0..1059dbb 100644 --- a/BUILD +++ b/BUILD @@ -133,7 +133,7 @@ cc_binary( deps = [ "@llvm-project//mlir:MlirTableGenMain", "@llvm-project//mlir:Support", - # "@llvm-project//mlir:TableGen", + "@llvm-project//mlir:IR", "@llvm-project//llvm:Support", "@llvm-project//llvm:TableGen", "@llvm-project//llvm:config", diff --git a/tools/mlir-tblgen-builder/Array.h b/tools/mlir-tblgen-builder/Array.h new file mode 100644 index 0000000..fec5c69 --- /dev/null +++ b/tools/mlir-tblgen-builder/Array.h @@ -0,0 +1,10 @@ +#ifndef BUILDER_ARRAY_ +#define BUILDER_ARRAY_ + +#include "iostream" + +namespace builder { +class Array {} +} // namespace builder + +#endif \ No newline at end of file diff --git a/tools/mlir-tblgen-builder/Builder.cpp b/tools/mlir-tblgen-builder/Builder.cpp new file mode 100644 index 0000000..534b452 --- /dev/null +++ b/tools/mlir-tblgen-builder/Builder.cpp @@ -0,0 +1,17 @@ +#include "Builder.h" + +#include "BuilderImpl.h" +#include "llvm/Support/Casting.h" +// #include "mlir/Dialect/StandardOps/Ops.h" +// #include "mlir/IR/Attributes.h" +// #include "mlir/IR/Operation.h" +// #include "mlir/IR/StandardTypes.h" +// #include "mlir/IR/Types.h" +// #include "mlir/IR/Value.h" + +namespace builder { + +Builder::Builder() : _impl(std::make_shared()) {} +void Builder::DumpModule() {} + +} // namespace builder \ No newline at end of file diff --git a/tools/mlir-tblgen-builder/Builder.h b/tools/mlir-tblgen-builder/Builder.h new file mode 100644 index 0000000..c649a6e --- /dev/null +++ b/tools/mlir-tblgen-builder/Builder.h @@ -0,0 +1,21 @@ +#ifndef BUILDER_BUILDER_ +#define BUILDER_BUILDER_ + +#include + +namespace builder { + +class Builder { + public: + class Impl; + + Builder(); + void DumpModule(); + std::shared_ptr GetImpl() { return _impl; } + + private: + std::shared_ptr _impl; +}; +} // namespace builder + +#endif \ No newline at end of file diff --git a/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp b/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp index c2cc5bf..79de5d7 100644 --- a/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp +++ b/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp @@ -29,6 +29,8 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#include "iostream" + #define DEBUG_TYPE "mlir-tblgen-opdefgen" using namespace llvm; @@ -40,6 +42,40 @@ static const char *const generatedArgName = "odsArg"; static const char *const odsBuilder = "odsBuilder"; static const char *const builderOpState = "odsState"; +static const std::map typeMapMLIR = { + {"::mlir::StringAttr", "std::string"}, + {"::mlir::IntegerAttr", "int"}, + {"::mlir::DenseIntElementsAttr", "std::vector"}, + {"::mlir::mhlo::ChannelHandle", "ChannelHandle"}, + {"::mlir::FloatAttr", "float"}, + {"::mlir::BoolAttr", "bool"}, + {"::mlir::ElementsAttr", "::builder::Array"}, + {"::mlir::DenseElementsAttr", "::builder::Tensor"}, + // current only support string array. + {"::mlir::ArrayAttr", "std::vector"}, + {"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"}, + {"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"}, + {"::mlir::mhlo::GatherDimensionNumbers", + "::builder::GatherDimensionNumbers"}, + {"::mlir::mhlo::ScatterDimensionNumbers", + "::builder::ScatterDimensionNumbers"}, +}; + +StringRef typeConvertFromMLIR(StringRef type) { + auto re = typeMapMLIR.find(type.str()); + if (re != typeMapMLIR.end()) return StringRef(re->second); + return type; +} + +StringRef getStorageType(const Attribute &att) { + auto type = att.getStorageType(); + return typeConvertFromMLIR(type); +} +StringRef getReturnType(const Attribute &att) { + auto type = att.getStorageType(); + return typeConvertFromMLIR(type); +} + // The logic to calculate the actual value range for a declared operand/result // of an op with variadic operands/results. Note that this logic is not for // general use; it assumes all variadic operands/results must have the same @@ -173,11 +209,11 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( : uniqueOutputLabel(getUniqueName(records)) { llvm::Optional namespaceEmitter; if (!emitDecl) { - os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); + // os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace()); } - emitTypeConstraintMethods(opDefs, os, emitDecl); + // emitTypeConstraintMethods(opDefs, os, emitDecl); } std::string StaticVerifierFunctionEmitter::getUniqueName( @@ -278,7 +314,7 @@ static std::string getArgumentName(const Operator &op, int index) { // Returns true if we can use unwrapped value for the given `attr` in builders. static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { - return attr.getReturnType() != attr.getStorageType() && + return getReturnType(attr) != getStorageType(attr) && // We need to wrap the raw value into an attribute in the builder impl // so we need to make sure that the attribute specifies how to do that. !attr.getConstBuilderTemplate().empty(); @@ -346,12 +382,12 @@ private: // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first operand's // type as all results' types. - void genUseOperandAsResultTypeSeparateParamBuilder(); + // void genUseOperandAsResultTypeSeparateParamBuilder(); // Generates the build() method that takes all operands/attributes // collectively as one parameter. The generated build() method uses first // operand's type as all results' types. - void genUseOperandAsResultTypeCollectiveParamBuilder(); + // void genUseOperandAsResultTypeCollectiveParamBuilder(); // Generates the build() method that takes aggregate operands/attributes // parameters. This build() method uses inferred types as result types. @@ -361,11 +397,11 @@ private: // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first attribute's // type as all result's types. - void genUseAttrAsResultTypeBuilder(); + // void genUseAttrAsResultTypeBuilder(); // Generates the build() method that takes all result types collectively as // one parameter. Similarly for operands and attributes. - void genCollectiveParamBuilder(); + // void genCollectiveParamBuilder(); // The kind of parameter to generate for result types in builders. enum class TypeParamKind { @@ -391,7 +427,7 @@ private: AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); // Adds op arguments and regions into operation state for build() methods. - void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, + void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,llvm::SmallVector paramList, bool isRawValueAttr = false); // Generates canonicalizer declaration for the operation. @@ -588,24 +624,24 @@ OpEmitter::OpEmitter(const Operator &op, // methods in the generated file. // genOpAsmInterface(); genOpNameGetter(); - genNamedOperandGetters(); - genNamedOperandSetters(); - genNamedResultGetters(); - genNamedRegionGetters(); + // genNamedOperandGetters(); + // genNamedOperandSetters(); + // genNamedResultGetters(); + // genNamedRegionGetters(); genNamedSuccessorGetters(); genAttrGetters(); genAttrSetters(); - genOptionalAttrRemovers(); + // genOptionalAttrRemovers(); genBuilder(); - genParser(); - genPrinter(); - genVerifier(); - genCanonicalizerDecls(); - genFolderDecls(); - genTypeInterfaceMethods(); - genOpInterfaceMethods(); - generateOpFormat(op, opClass); - genSideEffectInterfaceMethods(); + // genParser(); + // genPrinter(); + // genVerifier(); + // genCanonicalizerDecls(); + // genFolderDecls(); + // genTypeInterfaceMethods(); + // genOpInterfaceMethods(); + // generateOpFormat(op, opClass); + // genSideEffectInterfaceMethods(); } void OpEmitter::emitDecl( @@ -631,7 +667,7 @@ void OpEmitter::genAttrGetters() { Dialect opDialect = op.getDialect(); // Emit the derived attribute body. auto emitDerivedAttr = [&](StringRef name, Attribute attr) { - auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); + auto *method = opClass.addMethodAndPrune(getReturnType(attr), name); if (!method) return; auto &body = method->body(); @@ -640,7 +676,7 @@ void OpEmitter::genAttrGetters() { // Emit with return type specified. auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) { - auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); + auto *method = opClass.addMethodAndPrune(getReturnType(attr), name); auto &body = method->body(); body << " auto attr = " << name << "Attr();\n"; if (attr.hasDefaultValue()) { @@ -664,7 +700,7 @@ void OpEmitter::genAttrGetters() { // the string interface for better compile time verification. auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { auto *method = - opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str()); + opClass.addMethodAndPrune(getStorageType(attr), (name + "Attr").str()); if (!method) return; auto &body = method->body(); @@ -673,7 +709,7 @@ void OpEmitter::genAttrGetters() { body << "dyn_cast_or_null<"; else body << "cast<"; - body << attr.getStorageType() << ">();"; + body << getStorageType(attr) << ">();"; }; for (auto &namedAttr : op.getAttributes()) { @@ -759,7 +795,7 @@ void OpEmitter::genAttrSetters() { // for better compile time verification. auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(), - attr.getStorageType(), "attr"); + getStorageType(attr), "attr"); if (!method) return; auto &body = method->body(); @@ -1070,256 +1106,89 @@ static bool canInferType(Operator &op) { void OpEmitter::genSeparateArgParamBuilder() { SmallVector attrBuilderType; attrBuilderType.push_back(AttrParamKind::WrappedAttr); - if (canGenerateUnwrappedBuilder(op)) - attrBuilderType.push_back(AttrParamKind::UnwrappedValue); + // if (canGenerateUnwrappedBuilder(op)) + // attrBuilderType.push_back(AttrParamKind::UnwrappedValue); // Emit with separate builders with or without unwrapped attributes and/or // inferring result type. auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, bool inferType) { llvm::SmallVector paramList; + llvm::SmallVector paramList2; llvm::SmallVector resultNames; buildParamList(paramList, resultNames, paramKind, attrType); + buildParamList(paramList2, resultNames, paramKind, attrType); - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); + auto *m = opClass.addMethodAndPrune( + "::builder::Op", "build", OpMethod::MP_Static, std::move(paramList)); // If the builder is redundant, skip generating the method. if (!m) return; auto &body = m->body(); genCodeForAddingArgAndRegionForBuilder( - body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); + body, paramList2, attrType == AttrParamKind::UnwrappedValue); // Push all result types to the operation state +//"BBBBBBBBBBBB" + // if (inferType) { + // // Generate builder that infers type too. + // // TODO: Subsume this with general checking if type can be + // // inferred automatically. + // // TODO: Expand to handle regions. + // body << formatv(R"( + // ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; + // if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), + // {1}.location, {1}.operands, + // {1}.attributes.getDictionary({1}.getContext()), + // /*regions=*/{{}, inferredReturnTypes))) + // {1}.addTypes(inferredReturnTypes); + // else + // ::llvm::report_fatal_error("Failed to infer result type(s).");)", + // opClass.getClassName(), builderOpState); + // return; + // } - if (inferType) { - // Generate builder that infers type too. - // TODO: Subsume this with general checking if type can be - // inferred automatically. - // TODO: Expand to handle regions. - body << formatv(R"( - ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; - if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), - {1}.location, {1}.operands, - {1}.attributes.getDictionary({1}.getContext()), - /*regions=*/{{}, inferredReturnTypes))) - {1}.addTypes(inferredReturnTypes); - else - ::llvm::report_fatal_error("Failed to infer result type(s).");)", - opClass.getClassName(), builderOpState); - return; - } + // switch (paramKind) { + // case TypeParamKind::None: + // return; + // case TypeParamKind::Separate: + // for (int i = 0, e = op.getNumResults(); i < e; ++i) { + // if (op.getResult(i).isOptional()) + // body << " if (" << resultNames[i] << ")\n "; + // body << " " << builderOpState << ".addTypes(" << resultNames[i] + // << ");\n"; + // } + // return; + // case TypeParamKind::Collective: { + // int numResults = op.getNumResults(); + // int numVariadicResults = op.getNumVariableLengthResults(); + // int numNonVariadicResults = numResults - numVariadicResults; + // bool hasVariadicResult = numVariadicResults != 0; - switch (paramKind) { - case TypeParamKind::None: - return; - case TypeParamKind::Separate: - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - if (op.getResult(i).isOptional()) - body << " if (" << resultNames[i] << ")\n "; - body << " " << builderOpState << ".addTypes(" << resultNames[i] - << ");\n"; - } - return; - case TypeParamKind::Collective: { - int numResults = op.getNumResults(); - int numVariadicResults = op.getNumVariableLengthResults(); - int numNonVariadicResults = numResults - numVariadicResults; - bool hasVariadicResult = numVariadicResults != 0; - - // Avoid emitting "resultTypes.size() >= 0u" which is always true. - if (!(hasVariadicResult && numNonVariadicResults == 0)) - body << " " - << "assert(resultTypes.size() " - << (hasVariadicResult ? ">=" : "==") << " " - << numNonVariadicResults - << "u && \"mismatched number of results\");\n"; - body << " " << builderOpState << ".addTypes(resultTypes);\n"; - } - return; - } - llvm_unreachable("unhandled TypeParamKind"); + // // Avoid emitting "resultTypes.size() >= 0u" which is always true. + // if (!(hasVariadicResult && numNonVariadicResults == 0)) + // body << " " + // << "assert(resultTypes.size() " + // << (hasVariadicResult ? ">=" : "==") << " " + // << numNonVariadicResults + // << "u && \"mismatched number of results\");\n"; + // body << " " << builderOpState << ".addTypes(resultTypes);\n"; + // } + // return; + // } + // llvm_unreachable("unhandled TypeParamKind"); }; // Some of the build methods generated here may be ambiguous, but TableGen's // ambiguous function detection will elide those ones. for (auto attrType : attrBuilderType) { emit(attrType, TypeParamKind::Separate, /*inferType=*/false); - if (canInferType(op)) - emit(attrType, TypeParamKind::None, /*inferType=*/true); - emit(attrType, TypeParamKind::Collective, /*inferType=*/false); + // if (canInferType(op)) + // emit(attrType, TypeParamKind::None, /*inferType=*/true); + // emit(attrType, TypeParamKind::Collective, /*inferType=*/false); } } -void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { - int numResults = op.getNumResults(); - - // Signature - llvm::SmallVector paramList; - paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); - paramList.emplace_back("::mlir::OperationState &", builderOpState); - paramList.emplace_back("::mlir::ValueRange", "operands"); - // Provide default value for `attributes` when its the last parameter - StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; - paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", attributesDefaultValue); - if (op.getNumVariadicRegions()) - paramList.emplace_back("unsigned", "numRegions"); - - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); - // If the builder is redundant, skip generating the method - if (!m) - return; - auto &body = m->body(); - - // Operands - body << " " << builderOpState << ".addOperands(operands);\n"; - - // Attributes - body << " " << builderOpState << ".addAttributes(attributes);\n"; - - // Create the correct number of regions - if (int numRegions = op.getNumRegions()) { - body << llvm::formatv( - " for (unsigned i = 0; i != {0}; ++i)\n", - (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); - body << " (void)" << builderOpState << ".addRegion();\n"; - } - - // Result types - SmallVector resultTypes(numResults, "operands[0].getType()"); - body << " " << builderOpState << ".addTypes({" - << llvm::join(resultTypes, ", ") << "});\n\n"; -} - -void OpEmitter::genInferredTypeCollectiveParamBuilder() { - // TODO: Expand to support regions. - SmallVector paramList; - paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); - paramList.emplace_back("::mlir::OperationState &", builderOpState); - paramList.emplace_back("::mlir::ValueRange", "operands"); - paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", "{}"); - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); - // If the builder is redundant, skip generating the method - if (!m) - return; - auto &body = m->body(); - - int numResults = op.getNumResults(); - int numVariadicResults = op.getNumVariableLengthResults(); - int numNonVariadicResults = numResults - numVariadicResults; - - int numOperands = op.getNumOperands(); - int numVariadicOperands = op.getNumVariableLengthOperands(); - int numNonVariadicOperands = numOperands - numVariadicOperands; - - // Operands - if (numVariadicOperands == 0 || numNonVariadicOperands != 0) - body << " assert(operands.size()" - << (numVariadicOperands != 0 ? " >= " : " == ") - << numNonVariadicOperands - << "u && \"mismatched number of parameters\");\n"; - body << " " << builderOpState << ".addOperands(operands);\n"; - body << " " << builderOpState << ".addAttributes(attributes);\n"; - - // Create the correct number of regions - if (int numRegions = op.getNumRegions()) { - body << llvm::formatv( - " for (unsigned i = 0; i != {0}; ++i)\n", - (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); - body << " (void)" << builderOpState << ".addRegion();\n"; - } - - // Result types - body << formatv(R"( - ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes; - if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), - {1}.location, operands, - {1}.attributes.getDictionary({1}.getContext()), - /*regions=*/{{}, inferredReturnTypes))) {{)", - opClass.getClassName(), builderOpState); - if (numVariadicResults == 0 || numNonVariadicResults != 0) - body << " assert(inferredReturnTypes.size()" - << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults - << "u && \"mismatched number of return types\");\n"; - body << " " << builderOpState << ".addTypes(inferredReturnTypes);"; - - body << formatv(R"( - } else - ::llvm::report_fatal_error("Failed to infer result type(s).");)", - opClass.getClassName(), builderOpState); -} - -void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { - llvm::SmallVector paramList; - llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, TypeParamKind::None); - - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); - // If the builder is redundant, skip generating the method - if (!m) - return; - auto &body = m->body(); - genCodeForAddingArgAndRegionForBuilder(body); - - auto numResults = op.getNumResults(); - if (numResults == 0) - return; - - // Push all result types to the operation state - const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; - std::string resultType = - formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); - body << " " << builderOpState << ".addTypes({" << resultType; - for (int i = 1; i != numResults; ++i) - body << ", " << resultType; - body << "});\n\n"; -} - -void OpEmitter::genUseAttrAsResultTypeBuilder() { - SmallVector paramList; - paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); - paramList.emplace_back("::mlir::OperationState &", builderOpState); - paramList.emplace_back("::mlir::ValueRange", "operands"); - paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", "{}"); - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); - // If the builder is redundant, skip generating the method - if (!m) - return; - - auto &body = m->body(); - - // Push all result types to the operation state - std::string resultType; - const auto &namedAttr = op.getAttribute(0); - - body << " for (auto attr : attributes) {\n"; - body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n"; - if (namedAttr.attr.isTypeAttr()) { - resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()"; - } else { - resultType = "attr.second.getType()"; - } - - // Operands - body << " " << builderOpState << ".addOperands(operands);\n"; - - // Attributes - body << " " << builderOpState << ".addAttributes(attributes);\n"; - - // Result types - SmallVector resultTypes(op.getNumResults(), resultType); - body << " " << builderOpState << ".addTypes({" - << llvm::join(resultTypes, ", ") << "});\n"; - body << " }\n"; -} - /// Returns a signature of the builder. Updates the context `fctx` to enable /// replacement of $_builder and $_state in the body. static std::string getBuilderSignature(const Builder &builder) { @@ -1352,21 +1221,21 @@ static std::string getBuilderSignature(const Builder &builder) { void OpEmitter::genBuilder() { // Handle custom builders if provided. - for (const Builder &builder : op.getBuilders()) { - std::string paramStr = getBuilderSignature(builder); + // for (const Builder &builder : op.getBuilders()) { + // std::string paramStr = getBuilderSignature(builder); - Optional body = builder.getBody(); - OpMethod::Property properties = - body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; - auto *method = - opClass.addMethodAndPrune("void", "build", properties, paramStr); + // Optional body = builder.getBody(); + // OpMethod::Property properties = + // body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; + // auto *method = + // opClass.addMethodAndPrune("void", "build", properties, paramStr); - FmtContext fctx; - fctx.withBuilder(odsBuilder); - fctx.addSubst("_state", builderOpState); - if (body) - method->body() << tgfmt(*body, &fctx); - } + // FmtContext fctx; + // fctx.withBuilder(odsBuilder); + // fctx.addSubst("_state", builderOpState); + // if (body) + // method->body() << tgfmt(*body, &fctx); + // } // Generate default builders that requires all result type, operands, and // attributes as parameters. @@ -1378,78 +1247,18 @@ void OpEmitter::genBuilder() { genSeparateArgParamBuilder(); // 2. one having an aggregated parameter for all result types / operands / // attributes, and - genCollectiveParamBuilder(); - // 3. one having a stand-alone parameter for each operand and attribute, - // use the first operand or attribute's type as all result types - // to facilitate different call patterns. - if (op.getNumVariableLengthResults() == 0) { - if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { - genUseOperandAsResultTypeSeparateParamBuilder(); - genUseOperandAsResultTypeCollectiveParamBuilder(); - } - if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) - genUseAttrAsResultTypeBuilder(); - } -} - -void OpEmitter::genCollectiveParamBuilder() { - int numResults = op.getNumResults(); - int numVariadicResults = op.getNumVariableLengthResults(); - int numNonVariadicResults = numResults - numVariadicResults; - - int numOperands = op.getNumOperands(); - int numVariadicOperands = op.getNumVariableLengthOperands(); - int numNonVariadicOperands = numOperands - numVariadicOperands; - - SmallVector paramList; - paramList.emplace_back("::mlir::OpBuilder &", ""); - paramList.emplace_back("::mlir::OperationState &", builderOpState); - paramList.emplace_back("::mlir::TypeRange", "resultTypes"); - paramList.emplace_back("::mlir::ValueRange", "operands"); - // Provide default value for `attributes` when its the last parameter - StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; - paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", attributesDefaultValue); - if (op.getNumVariadicRegions()) - paramList.emplace_back("unsigned", "numRegions"); - - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); - // If the builder is redundant, skip generating the method - if (!m) - return; - auto &body = m->body(); - - // Operands - if (numVariadicOperands == 0 || numNonVariadicOperands != 0) - body << " assert(operands.size()" - << (numVariadicOperands != 0 ? " >= " : " == ") - << numNonVariadicOperands - << "u && \"mismatched number of parameters\");\n"; - body << " " << builderOpState << ".addOperands(operands);\n"; - - // Attributes - body << " " << builderOpState << ".addAttributes(attributes);\n"; - - // Create the correct number of regions - if (int numRegions = op.getNumRegions()) { - body << llvm::formatv( - " for (unsigned i = 0; i != {0}; ++i)\n", - (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); - body << " (void)" << builderOpState << ".addRegion();\n"; - } - - // Result types - if (numVariadicResults == 0 || numNonVariadicResults != 0) - body << " assert(resultTypes.size()" - << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults - << "u && \"mismatched number of return types\");\n"; - body << " " << builderOpState << ".addTypes(resultTypes);\n"; - - // Generate builder that infers type too. - // TODO: Expand to handle regions and successors. - if (canInferType(op) && op.getNumSuccessors() == 0) - genInferredTypeCollectiveParamBuilder(); + // genCollectiveParamBuilder(); + // // 3. one having a stand-alone parameter for each operand and attribute, + // // use the first operand or attribute's type as all result types + // // to facilitate different call patterns. + // if (op.getNumVariableLengthResults() == 0) { + // if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { + // genUseOperandAsResultTypeSeparateParamBuilder(); + // genUseOperandAsResultTypeCollectiveParamBuilder(); + // } + // if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) + // genUseAttrAsResultTypeBuilder(); + // } } void OpEmitter::buildParamList(SmallVectorImpl ¶mList, @@ -1460,8 +1269,9 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, auto numResults = op.getNumResults(); resultTypeNames.reserve(numResults); - paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); - paramList.emplace_back("::mlir::OperationState &", builderOpState); + // paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::builder::Builder &", "builder"); + // paramList.emplace_back("::mlir::OperationState &", builderOpState); switch (typeParamKind) { case TypeParamKind::None: @@ -1475,7 +1285,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, resultName = std::string(formatv("resultType{0}", i)); StringRef type = - result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; + result.isVariadic() ? "std::vector<::builder::Type>" : "::builder::Type"; OpMethodParameter::Property properties = OpMethodParameter::PP_None; if (result.isOptional()) properties = OpMethodParameter::PP_Optional; @@ -1512,7 +1322,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, // for APFloat. // TODO: Adjust the 'returnType' field of such attributes // to support them. - StringRef retType = namedAttr->attr.getReturnType(); + StringRef retType = getReturnType(namedAttr->attr); if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") break; @@ -1525,7 +1335,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, if (argument.is()) { const auto &operand = op.getOperand(numOperands); StringRef type = - operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value"; + operand.isVariadic() ? "std::vector<::builder::Op>" : "::builder::Op"; OpMethodParameter::Property properties = OpMethodParameter::PP_None; if (operand.isOptional()) properties = OpMethodParameter::PP_Optional; @@ -1544,13 +1354,13 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, StringRef type; switch (attrParamKind) { case AttrParamKind::WrappedAttr: - type = attr.getStorageType(); + type = getStorageType(attr); break; case AttrParamKind::UnwrappedValue: if (canUseUnwrappedRawValue(attr)) - type = attr.getReturnType(); + type = getReturnType(attr); else - type = attr.getStorageType(); + type = getStorageType(attr); break; } @@ -1558,7 +1368,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, // Attach default value if requested and possible. if (attrParamKind == AttrParamKind::UnwrappedValue && i >= defaultValuedAttrStartIndex) { - bool isString = attr.getReturnType() == "::llvm::StringRef"; + bool isString = getReturnType(attr) == "::llvm::StringRef"; if (isString) defaultValue.append("\""); defaultValue += attr.getDefaultValue(); @@ -1584,84 +1394,117 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, llvm::formatv("{0}Count", region.name).str()); } -void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, - bool isRawValueAttr) { - // Push all operands to the result. - for (int i = 0, e = op.getNumOperands(); i < e; ++i) { - std::string argName = getArgumentName(op, i); - if (op.getOperand(i).isOptional()) - body << " if (" << argName << ")\n "; - body << " " << builderOpState << ".addOperands(" << argName << ");\n"; +void OpEmitter::genCodeForAddingArgAndRegionForBuilder( + OpMethodBody &body, llvm::SmallVector paramList, + bool isRawValueAttr) { + + body << "// AAAAAA \n"; + // for(auto p : paramList) + // { + // body <<"==== type:"<< p.getType() << " name:"<< p.getName() << "\n"; + // } + + body << " auto builder = builder.GetImpl();\n"; + body << " auto loc = builder->GetLoc();\n"; + body << " auto opBuilder = builder->GetBuilder();\n"; + body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName() + << " currentOp =\n"; + body << " opBuilder.create(\n"; + if (paramList.size() > 1) { + body << " loc,\n"; + std::for_each(paramList.begin() + 1, paramList.end() - 1, + [&](OpMethodParameter &p) { + body << " " << p.getName() << ",\n"; + }); + body << " " << paramList.back().getName() << "\n"; + } else { + body << " loc\n"; } + body << " );\n"; + body << " builder::" << op.getCppClassName() << " builderOp;\n"; + body << " auto opImpl = builderOp.GetImpl();\n"; + body << " opImpl.SetOperation(currentOp.getOperation());\n"; + body << " return builderOp;\n"; - // If the operation has the operand segment size attribute, add it here. - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - body << " " << builderOpState - << ".addAttribute(\"operand_segment_sizes\", " - "odsBuilder.getI32VectorAttr({"; - interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { - if (op.getOperand(i).isOptional()) - body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; - else if (op.getOperand(i).isVariadic()) - body << "static_cast(" << getArgumentName(op, i) << ".size())"; - else - body << "1"; - }); - body << "}));\n"; - } + - // Push all attributes to the result. - for (const auto &namedAttr : op.getAttributes()) { - auto &attr = namedAttr.attr; - if (!attr.isDerivedAttr()) { - bool emitNotNullCheck = attr.isOptional(); - if (emitNotNullCheck) { - body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; - } - if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { - // If this is a raw value, then we need to wrap it in an Attribute - // instance. - FmtContext fctx; - fctx.withBuilder("odsBuilder"); + // // Push all operands to the result. + // for (int i = 0, e = op.getNumOperands(); i < e; ++i) { + // std::string argName = getArgumentName(op, i); + // if (op.getOperand(i).isOptional()) + // body << " if (" << argName << ")\n "; + // body << " " << builderOpState << ".addOperands(" << argName << ");\n"; + // } - std::string builderTemplate = - std::string(attr.getConstBuilderTemplate()); + // // If the operation has the operand segment size attribute, add it here. + // if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + // body << " " << builderOpState + // << ".addAttribute(\"operand_segment_sizes\", " + // "odsBuilder.getI32VectorAttr({"; + // interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { + // if (op.getOperand(i).isOptional()) + // body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; + // else if (op.getOperand(i).isVariadic()) + // body << "static_cast(" << getArgumentName(op, i) << ".size())"; + // else + // body << "1"; + // }); + // body << "}));\n"; + // } - // For StringAttr, its constant builder call will wrap the input in - // quotes, which is correct for normal string literals, but incorrect - // here given we use function arguments. So we need to strip the - // wrapping quotes. - if (StringRef(builderTemplate).contains("\"$0\"")) - builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); + // // Push all attributes to the result. + // for (const auto &namedAttr : op.getAttributes()) { + // auto &attr = namedAttr.attr; + // if (!attr.isDerivedAttr()) { + // bool emitNotNullCheck = attr.isOptional(); + // if (emitNotNullCheck) { + // body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; + // } + // if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { + // // If this is a raw value, then we need to wrap it in an Attribute + // // instance. + // FmtContext fctx; + // fctx.withBuilder("odsBuilder"); - std::string value = - std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); - body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, - namedAttr.name, value); - } else { - body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, - namedAttr.name); - } - if (emitNotNullCheck) { - body << " }\n"; - } - } - } + // std::string builderTemplate = + // std::string(attr.getConstBuilderTemplate()); - // Create the correct number of regions. - for (const NamedRegion ®ion : op.getRegions()) { - if (region.isVariadic()) - body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", - region.name); + // // For StringAttr, its constant builder call will wrap the input in + // // quotes, which is correct for normal string literals, but incorrect + // // here given we use function arguments. So we need to strip the + // // wrapping quotes. + // if (StringRef(builderTemplate).contains("\"$0\"")) + // builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); - body << " (void)" << builderOpState << ".addRegion();\n"; - } + // std::string value = + // std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); + // body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, + // namedAttr.name, value); + // } else { + // body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, + // namedAttr.name); + // } + // if (emitNotNullCheck) { + // body << " }\n"; + // } + // } + // } - // Push all successors to the result. - for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { - body << formatv(" {0}.addSuccessors({1});\n", builderOpState, - namedSuccessor.name); - } + // // Create the correct number of regions. + // for (const NamedRegion ®ion : op.getRegions()) { + // if (region.isVariadic()) + // body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", + // region.name); + + // body << " (void)" << builderOpState << ".addRegion();\n"; + // } + + // // Push all successors to the result. + // for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { + // body << formatv(" {0}.addSuccessors({1});\n", builderOpState, + // namedSuccessor.name); + // } } void OpEmitter::genCanonicalizerDecls() { @@ -2200,7 +2043,7 @@ void OpEmitter::genOpNameGetter() { auto *method = opClass.addMethodAndPrune( "std::string", "getOperationName", OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr)); - method->body() << " return std::string(\"" << op.getOperationName() + method->body() << " return std::string(\"" << op.getOperationName() << "\");"; } @@ -2305,15 +2148,15 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); auto emitAttr = [&](StringRef name, Attribute attr) { - auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body(); + auto &body = adaptor.addMethodAndPrune(getStorageType(attr), name)->body(); body << " assert(odsAttrs && \"no attributes when constructing adapter\");" - << "\n " << attr.getStorageType() << " attr = " + << "\n " << getStorageType(attr) << " attr = " << "odsAttrs.get(\"" << name << "\")."; if (attr.hasDefaultValue() || attr.isOptional()) body << "dyn_cast_or_null<"; else body << "cast<"; - body << attr.getStorageType() << ">();\n"; + body << getStorageType(attr) << ">();\n"; if (attr.hasDefaultValue()) { // Use the default value if attribute is not set. @@ -2442,11 +2285,11 @@ static void emitOpClasses(const RecordKeeper &recordKeeper, Operator op(*def); NamespaceEmitter emitter(os, op.getCppNamespace()); if (emitDecl) { - os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); + // os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); // OpOperandAdaptorEmitter::emitDecl(op, os); OpEmitter::emitDecl(op, os, staticVerifierEmitter); } else { - os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); + // os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); // OpOperandAdaptorEmitter::emitDef(op, os); OpEmitter::emitDef(op, os, staticVerifierEmitter); } @@ -2456,12 +2299,18 @@ static void emitOpClasses(const RecordKeeper &recordKeeper, // Emits a comma-separated list of the ops. static void emitOpList(const std::vector &defs, raw_ostream &os) { IfDefScope scope("GET_OP_LIST", os); - interleave( // TODO: We are constructing the Operator wrapper instance just for // getting it's qualified class name here. Reduce the overhead by having a // lightweight version of Operator class just for that purpose. - defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, + defs, + [&os](Record *def) { + SmallVector namespaces; + std::string className = Operator(def).getQualCppClassName(); + llvm::SplitString(StringRef(className), namespaces, StringRef("::")); + if (namespaces.begin() != namespaces.end()) + os << "::builder::mhlo::" << namespaces.back().str(); + }, [&os]() { os << ",\n"; }); } diff --git a/tools/mlir-tblgen-builder/BuilderImpl.h b/tools/mlir-tblgen-builder/BuilderImpl.h new file mode 100644 index 0000000..943b36f --- /dev/null +++ b/tools/mlir-tblgen-builder/BuilderImpl.h @@ -0,0 +1,30 @@ +#ifndef BUILDER_BUILDERIMPL_ +#define BUILDER_BUILDERIMPL_ + +#include "Builder.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +namespace builder { + +class Builder::Impl { + public: + Impl() {} + // mlir::Location GetLoc() { return mlir_loc_; } + // mlir::OpBuilder GetBuilder() { return mlir_builder_; } + mlir::MLIRContext *GetContext() { return &mlir_context_; } + + private: + // mlir::Location mlir_loc_; + // mlir::OpBuilder mlir_builder_; + mlir::MLIRContext mlir_context_; +}; + +} // namespace builder + +#endif \ No newline at end of file diff --git a/tools/mlir-tblgen-builder/Op.h b/tools/mlir-tblgen-builder/Op.h new file mode 100644 index 0000000..85daa61 --- /dev/null +++ b/tools/mlir-tblgen-builder/Op.h @@ -0,0 +1,16 @@ +#ifndef BUILDER_OP_ +#define BUILDER_OP_ + +#include "iostream" + +namespace builder { +class Op { + class Impl; + std::shared_ptr GetImpl() { return _impl; } + + private: + std::shared_ptr _impl; +} +} // namespace builder + +#endif \ No newline at end of file diff --git a/tools/mlir-tblgen-builder/OpImpl.h b/tools/mlir-tblgen-builder/OpImpl.h new file mode 100644 index 0000000..93af7aa --- /dev/null +++ b/tools/mlir-tblgen-builder/OpImpl.h @@ -0,0 +1,26 @@ +#ifndef BUILDER_OPIMPL_ +#define BUILDER_OPIMPL_ + +#include "Builder.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +namespace builder { + +class Op::Impl { + public: + Impl() = default; + void SetOperation(Operation *Op) { op_ = Op; } + + private: + Operation *op_; +}; + +} // namespace builder + +#endif \ No newline at end of file diff --git a/tools/mlir-tblgen-builder/TableGen/CodeGenHelpers.h b/tools/mlir-tblgen-builder/TableGen/CodeGenHelpers.h index 3da4758..a28bc9e 100644 --- a/tools/mlir-tblgen-builder/TableGen/CodeGenHelpers.h +++ b/tools/mlir-tblgen-builder/TableGen/CodeGenHelpers.h @@ -49,14 +49,24 @@ public: ~NamespaceEmitter() { for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; + if (ns.equals("mlir")) + os << "} // namespace " + << "builder" + << "\n\n"; + else + os << "} // namespace " << ns << "\n\n"; } private: void emitNamespaceStarts(raw_ostream &os, StringRef cppNamespace) { llvm::SplitString(cppNamespace, namespaces, "::"); for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; + if (ns.equals("mlir")) + os << "namespace " + << "builder" + << " {\n"; + else + os << "namespace " << ns << " {\n"; } raw_ostream &os; SmallVector namespaces; diff --git a/tools/mlir-tblgen-builder/TableGen/OpClass.cpp b/tools/mlir-tblgen-builder/TableGen/OpClass.cpp index 2fffe5a..2579219 100644 --- a/tools/mlir-tblgen-builder/TableGen/OpClass.cpp +++ b/tools/mlir-tblgen-builder/TableGen/OpClass.cpp @@ -209,7 +209,7 @@ void OpMethod::writeDeclTo(raw_ostream &os) const { else { os << " {\n"; methodBody.writeTo(os); - os << "}"; + os << " }"; } } @@ -261,7 +261,7 @@ void Class::newField(StringRef type, StringRef name, StringRef defaultValue) { } void Class::writeDeclTo(raw_ostream &os) const { bool hasPrivateMethod = false; - os << "class " << className << " {\n"; + os << "class " << className << "\n"; os << "public:\n"; forAllMethods([&](const OpMethod &method) { @@ -311,10 +311,10 @@ void OpClass::addTrait(Twine trait) { } void OpClass::writeDeclTo(raw_ostream &os) const { - os << "class " << className << " : public ::mlir::Op<" << className; + os << "class " << className; for (const auto &trait : traitsVec) os << ", " << trait; - os << "> {\npublic:\n"; + os << " : public Op {\npublic:\n"; // << " using Op::Op;\n" // << " using Op::print;\n" // << " using Adaptor = " << className << "Adaptor;\n"; @@ -330,8 +330,8 @@ void OpClass::writeDeclTo(raw_ostream &os) const { }); // TODO: Add line control markers to make errors easier to debug. - if (!extraClassDeclaration.empty()) - os << extraClassDeclaration << "\n"; + // if (!extraClassDeclaration.empty()) + // os << extraClassDeclaration << "\n"; if (hasPrivateMethod) { os << "\nprivate:\n"; diff --git a/tools/mlir-tblgen-builder/TableGen/OpClass.h b/tools/mlir-tblgen-builder/TableGen/OpClass.h index 243e7fa..fa66e26 100644 --- a/tools/mlir-tblgen-builder/TableGen/OpClass.h +++ b/tools/mlir-tblgen-builder/TableGen/OpClass.h @@ -61,6 +61,8 @@ public: void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); } const std::string &getType() const { return type; } + const std::string &getName() const { return name; } + bool hasDefaultValue() const { return !defaultValue.empty(); } private: diff --git a/tools/mlir-tblgen-builder/Tensor.h b/tools/mlir-tblgen-builder/Tensor.h new file mode 100644 index 0000000..1661b30 --- /dev/null +++ b/tools/mlir-tblgen-builder/Tensor.h @@ -0,0 +1,16 @@ +#ifndef BUILDER_TENSOR_ +#define BUILDER_TENSOR_ + + +#include "iostream" + +namespace builder{ + class Tensor{ + } +} + + + + + +#endif \ No newline at end of file diff --git a/tools/mlir-tblgen-builder/Type.h b/tools/mlir-tblgen-builder/Type.h new file mode 100644 index 0000000..3b24e54 --- /dev/null +++ b/tools/mlir-tblgen-builder/Type.h @@ -0,0 +1,10 @@ +#ifndef BUILDER_TYPE_ +#define BUILDER_TYPE_ + +#include "iostream" + +namespace builder { +class Type {} +} // namespace builder + +#endif \ No newline at end of file