|
|
|
@ -29,6 +29,8 @@
|
|
|
|
|
#include "llvm/TableGen/Record.h"
|
|
|
|
|
#include "llvm/TableGen/TableGenBackend.h"
|
|
|
|
|
|
|
|
|
|
#include "iostream"
|
|
|
|
|
|
|
|
|
|
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
|
|
|
|
|
|
|
|
|
|
using namespace llvm;
|
|
|
|
@ -40,6 +42,40 @@ static const char *const generatedArgName = "odsArg";
|
|
|
|
|
static const char *const odsBuilder = "odsBuilder";
|
|
|
|
|
static const char *const builderOpState = "odsState";
|
|
|
|
|
|
|
|
|
|
static const std::map<std::string, std::string> typeMapMLIR = {
|
|
|
|
|
{"::mlir::StringAttr", "std::string"},
|
|
|
|
|
{"::mlir::IntegerAttr", "int"},
|
|
|
|
|
{"::mlir::DenseIntElementsAttr", "std::vector<int>"},
|
|
|
|
|
{"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
|
|
|
|
|
{"::mlir::FloatAttr", "float"},
|
|
|
|
|
{"::mlir::BoolAttr", "bool"},
|
|
|
|
|
{"::mlir::ElementsAttr", "::builder::Array"},
|
|
|
|
|
{"::mlir::DenseElementsAttr", "::builder::Tensor"},
|
|
|
|
|
// current only support string array.
|
|
|
|
|
{"::mlir::ArrayAttr", "std::vector<std::string>"},
|
|
|
|
|
{"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
|
|
|
|
|
{"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
|
|
|
|
|
{"::mlir::mhlo::GatherDimensionNumbers",
|
|
|
|
|
"::builder::GatherDimensionNumbers"},
|
|
|
|
|
{"::mlir::mhlo::ScatterDimensionNumbers",
|
|
|
|
|
"::builder::ScatterDimensionNumbers"},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
StringRef typeConvertFromMLIR(StringRef type) {
|
|
|
|
|
auto re = typeMapMLIR.find(type.str());
|
|
|
|
|
if (re != typeMapMLIR.end()) return StringRef(re->second);
|
|
|
|
|
return type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StringRef getStorageType(const Attribute &att) {
|
|
|
|
|
auto type = att.getStorageType();
|
|
|
|
|
return typeConvertFromMLIR(type);
|
|
|
|
|
}
|
|
|
|
|
StringRef getReturnType(const Attribute &att) {
|
|
|
|
|
auto type = att.getStorageType();
|
|
|
|
|
return typeConvertFromMLIR(type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The logic to calculate the actual value range for a declared operand/result
|
|
|
|
|
// of an op with variadic operands/results. Note that this logic is not for
|
|
|
|
|
// general use; it assumes all variadic operands/results must have the same
|
|
|
|
@ -173,11 +209,11 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
|
|
|
|
|
: uniqueOutputLabel(getUniqueName(records)) {
|
|
|
|
|
llvm::Optional<NamespaceEmitter> namespaceEmitter;
|
|
|
|
|
if (!emitDecl) {
|
|
|
|
|
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
|
|
|
|
|
// os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
|
|
|
|
|
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
emitTypeConstraintMethods(opDefs, os, emitDecl);
|
|
|
|
|
// emitTypeConstraintMethods(opDefs, os, emitDecl);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string StaticVerifierFunctionEmitter::getUniqueName(
|
|
|
|
@ -278,7 +314,7 @@ static std::string getArgumentName(const Operator &op, int index) {
|
|
|
|
|
|
|
|
|
|
// Returns true if we can use unwrapped value for the given `attr` in builders.
|
|
|
|
|
static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
|
|
|
|
|
return attr.getReturnType() != attr.getStorageType() &&
|
|
|
|
|
return getReturnType(attr) != getStorageType(attr) &&
|
|
|
|
|
// We need to wrap the raw value into an attribute in the builder impl
|
|
|
|
|
// so we need to make sure that the attribute specifies how to do that.
|
|
|
|
|
!attr.getConstBuilderTemplate().empty();
|
|
|
|
@ -346,12 +382,12 @@ private:
|
|
|
|
|
// Generates the build() method that takes each operand/attribute as a
|
|
|
|
|
// stand-alone parameter. The generated build() method uses first operand's
|
|
|
|
|
// type as all results' types.
|
|
|
|
|
void genUseOperandAsResultTypeSeparateParamBuilder();
|
|
|
|
|
// void genUseOperandAsResultTypeSeparateParamBuilder();
|
|
|
|
|
|
|
|
|
|
// Generates the build() method that takes all operands/attributes
|
|
|
|
|
// collectively as one parameter. The generated build() method uses first
|
|
|
|
|
// operand's type as all results' types.
|
|
|
|
|
void genUseOperandAsResultTypeCollectiveParamBuilder();
|
|
|
|
|
// void genUseOperandAsResultTypeCollectiveParamBuilder();
|
|
|
|
|
|
|
|
|
|
// Generates the build() method that takes aggregate operands/attributes
|
|
|
|
|
// parameters. This build() method uses inferred types as result types.
|
|
|
|
@ -361,11 +397,11 @@ private:
|
|
|
|
|
// Generates the build() method that takes each operand/attribute as a
|
|
|
|
|
// stand-alone parameter. The generated build() method uses first attribute's
|
|
|
|
|
// type as all result's types.
|
|
|
|
|
void genUseAttrAsResultTypeBuilder();
|
|
|
|
|
// void genUseAttrAsResultTypeBuilder();
|
|
|
|
|
|
|
|
|
|
// Generates the build() method that takes all result types collectively as
|
|
|
|
|
// one parameter. Similarly for operands and attributes.
|
|
|
|
|
void genCollectiveParamBuilder();
|
|
|
|
|
// void genCollectiveParamBuilder();
|
|
|
|
|
|
|
|
|
|
// The kind of parameter to generate for result types in builders.
|
|
|
|
|
enum class TypeParamKind {
|
|
|
|
@ -391,7 +427,7 @@ private:
|
|
|
|
|
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
|
|
|
|
|
|
|
|
|
|
// Adds op arguments and regions into operation state for build() methods.
|
|
|
|
|
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
|
|
|
|
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,llvm::SmallVector<OpMethodParameter, 4> paramList,
|
|
|
|
|
bool isRawValueAttr = false);
|
|
|
|
|
|
|
|
|
|
// Generates canonicalizer declaration for the operation.
|
|
|
|
@ -588,24 +624,24 @@ OpEmitter::OpEmitter(const Operator &op,
|
|
|
|
|
// methods in the generated file.
|
|
|
|
|
// genOpAsmInterface();
|
|
|
|
|
genOpNameGetter();
|
|
|
|
|
genNamedOperandGetters();
|
|
|
|
|
genNamedOperandSetters();
|
|
|
|
|
genNamedResultGetters();
|
|
|
|
|
genNamedRegionGetters();
|
|
|
|
|
// genNamedOperandGetters();
|
|
|
|
|
// genNamedOperandSetters();
|
|
|
|
|
// genNamedResultGetters();
|
|
|
|
|
// genNamedRegionGetters();
|
|
|
|
|
genNamedSuccessorGetters();
|
|
|
|
|
genAttrGetters();
|
|
|
|
|
genAttrSetters();
|
|
|
|
|
genOptionalAttrRemovers();
|
|
|
|
|
// genOptionalAttrRemovers();
|
|
|
|
|
genBuilder();
|
|
|
|
|
genParser();
|
|
|
|
|
genPrinter();
|
|
|
|
|
genVerifier();
|
|
|
|
|
genCanonicalizerDecls();
|
|
|
|
|
genFolderDecls();
|
|
|
|
|
genTypeInterfaceMethods();
|
|
|
|
|
genOpInterfaceMethods();
|
|
|
|
|
generateOpFormat(op, opClass);
|
|
|
|
|
genSideEffectInterfaceMethods();
|
|
|
|
|
// genParser();
|
|
|
|
|
// genPrinter();
|
|
|
|
|
// genVerifier();
|
|
|
|
|
// genCanonicalizerDecls();
|
|
|
|
|
// genFolderDecls();
|
|
|
|
|
// genTypeInterfaceMethods();
|
|
|
|
|
// genOpInterfaceMethods();
|
|
|
|
|
// generateOpFormat(op, opClass);
|
|
|
|
|
// genSideEffectInterfaceMethods();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpEmitter::emitDecl(
|
|
|
|
@ -631,7 +667,7 @@ void OpEmitter::genAttrGetters() {
|
|
|
|
|
Dialect opDialect = op.getDialect();
|
|
|
|
|
// Emit the derived attribute body.
|
|
|
|
|
auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
|
|
|
|
|
auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
|
|
|
|
|
auto *method = opClass.addMethodAndPrune(getReturnType(attr), name);
|
|
|
|
|
if (!method)
|
|
|
|
|
return;
|
|
|
|
|
auto &body = method->body();
|
|
|
|
@ -640,7 +676,7 @@ void OpEmitter::genAttrGetters() {
|
|
|
|
|
|
|
|
|
|
// Emit with return type specified.
|
|
|
|
|
auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
|
|
|
|
|
auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
|
|
|
|
|
auto *method = opClass.addMethodAndPrune(getReturnType(attr), name);
|
|
|
|
|
auto &body = method->body();
|
|
|
|
|
body << " auto attr = " << name << "Attr();\n";
|
|
|
|
|
if (attr.hasDefaultValue()) {
|
|
|
|
@ -664,7 +700,7 @@ void OpEmitter::genAttrGetters() {
|
|
|
|
|
// the string interface for better compile time verification.
|
|
|
|
|
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
|
|
|
|
|
auto *method =
|
|
|
|
|
opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
|
|
|
|
|
opClass.addMethodAndPrune(getStorageType(attr), (name + "Attr").str());
|
|
|
|
|
if (!method)
|
|
|
|
|
return;
|
|
|
|
|
auto &body = method->body();
|
|
|
|
@ -673,7 +709,7 @@ void OpEmitter::genAttrGetters() {
|
|
|
|
|
body << "dyn_cast_or_null<";
|
|
|
|
|
else
|
|
|
|
|
body << "cast<";
|
|
|
|
|
body << attr.getStorageType() << ">();";
|
|
|
|
|
body << getStorageType(attr) << ">();";
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (auto &namedAttr : op.getAttributes()) {
|
|
|
|
@ -759,7 +795,7 @@ void OpEmitter::genAttrSetters() {
|
|
|
|
|
// for better compile time verification.
|
|
|
|
|
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
|
|
|
|
|
auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
|
|
|
|
|
attr.getStorageType(), "attr");
|
|
|
|
|
getStorageType(attr), "attr");
|
|
|
|
|
if (!method)
|
|
|
|
|
return;
|
|
|
|
|
auto &body = method->body();
|
|
|
|
@ -1070,256 +1106,89 @@ static bool canInferType(Operator &op) {
|
|
|
|
|
void OpEmitter::genSeparateArgParamBuilder() {
|
|
|
|
|
SmallVector<AttrParamKind, 2> attrBuilderType;
|
|
|
|
|
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
|
|
|
|
|
if (canGenerateUnwrappedBuilder(op))
|
|
|
|
|
attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
|
|
|
|
|
// if (canGenerateUnwrappedBuilder(op))
|
|
|
|
|
// attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
|
|
|
|
|
|
|
|
|
|
// Emit with separate builders with or without unwrapped attributes and/or
|
|
|
|
|
// inferring result type.
|
|
|
|
|
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
|
|
|
|
|
bool inferType) {
|
|
|
|
|
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
|
|
|
|
llvm::SmallVector<OpMethodParameter, 4> paramList2;
|
|
|
|
|
llvm::SmallVector<std::string, 4> resultNames;
|
|
|
|
|
buildParamList(paramList, resultNames, paramKind, attrType);
|
|
|
|
|
buildParamList(paramList2, resultNames, paramKind, attrType);
|
|
|
|
|
|
|
|
|
|
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
|
|
|
|
std::move(paramList));
|
|
|
|
|
auto *m = opClass.addMethodAndPrune(
|
|
|
|
|
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
|
|
|
|
|
// If the builder is redundant, skip generating the method.
|
|
|
|
|
if (!m)
|
|
|
|
|
return;
|
|
|
|
|
auto &body = m->body();
|
|
|
|
|
genCodeForAddingArgAndRegionForBuilder(
|
|
|
|
|
body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
|
|
|
|
|
body, paramList2, attrType == AttrParamKind::UnwrappedValue);
|
|
|
|
|
|
|
|
|
|
// Push all result types to the operation state
|
|
|
|
|
//"BBBBBBBBBBBB"
|
|
|
|
|
// if (inferType) {
|
|
|
|
|
// // Generate builder that infers type too.
|
|
|
|
|
// // TODO: Subsume this with general checking if type can be
|
|
|
|
|
// // inferred automatically.
|
|
|
|
|
// // TODO: Expand to handle regions.
|
|
|
|
|
// body << formatv(R"(
|
|
|
|
|
// ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
|
|
|
|
|
// if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
|
|
|
|
|
// {1}.location, {1}.operands,
|
|
|
|
|
// {1}.attributes.getDictionary({1}.getContext()),
|
|
|
|
|
// /*regions=*/{{}, inferredReturnTypes)))
|
|
|
|
|
// {1}.addTypes(inferredReturnTypes);
|
|
|
|
|
// else
|
|
|
|
|
// ::llvm::report_fatal_error("Failed to infer result type(s).");)",
|
|
|
|
|
// opClass.getClassName(), builderOpState);
|
|
|
|
|
// return;
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
if (inferType) {
|
|
|
|
|
// Generate builder that infers type too.
|
|
|
|
|
// TODO: Subsume this with general checking if type can be
|
|
|
|
|
// inferred automatically.
|
|
|
|
|
// TODO: Expand to handle regions.
|
|
|
|
|
body << formatv(R"(
|
|
|
|
|
::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
|
|
|
|
|
if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
|
|
|
|
|
{1}.location, {1}.operands,
|
|
|
|
|
{1}.attributes.getDictionary({1}.getContext()),
|
|
|
|
|
/*regions=*/{{}, inferredReturnTypes)))
|
|
|
|
|
{1}.addTypes(inferredReturnTypes);
|
|
|
|
|
else
|
|
|
|
|
::llvm::report_fatal_error("Failed to infer result type(s).");)",
|
|
|
|
|
opClass.getClassName(), builderOpState);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// switch (paramKind) {
|
|
|
|
|
// case TypeParamKind::None:
|
|
|
|
|
// return;
|
|
|
|
|
// case TypeParamKind::Separate:
|
|
|
|
|
// for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
|
|
|
|
// if (op.getResult(i).isOptional())
|
|
|
|
|
// body << " if (" << resultNames[i] << ")\n ";
|
|
|
|
|
// body << " " << builderOpState << ".addTypes(" << resultNames[i]
|
|
|
|
|
// << ");\n";
|
|
|
|
|
// }
|
|
|
|
|
// return;
|
|
|
|
|
// case TypeParamKind::Collective: {
|
|
|
|
|
// int numResults = op.getNumResults();
|
|
|
|
|
// int numVariadicResults = op.getNumVariableLengthResults();
|
|
|
|
|
// int numNonVariadicResults = numResults - numVariadicResults;
|
|
|
|
|
// bool hasVariadicResult = numVariadicResults != 0;
|
|
|
|
|
|
|
|
|
|
switch (paramKind) {
|
|
|
|
|
case TypeParamKind::None:
|
|
|
|
|
return;
|
|
|
|
|
case TypeParamKind::Separate:
|
|
|
|
|
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
|
|
|
|
if (op.getResult(i).isOptional())
|
|
|
|
|
body << " if (" << resultNames[i] << ")\n ";
|
|
|
|
|
body << " " << builderOpState << ".addTypes(" << resultNames[i]
|
|
|
|
|
<< ");\n";
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
case TypeParamKind::Collective: {
|
|
|
|
|
int numResults = op.getNumResults();
|
|
|
|
|
int numVariadicResults = op.getNumVariableLengthResults();
|
|
|
|
|
int numNonVariadicResults = numResults - numVariadicResults;
|
|
|
|
|
bool hasVariadicResult = numVariadicResults != 0;
|
|
|
|
|
|
|
|
|
|
// Avoid emitting "resultTypes.size() >= 0u" which is always true.
|
|
|
|
|
if (!(hasVariadicResult && numNonVariadicResults == 0))
|
|
|
|
|
body << " "
|
|
|
|
|
<< "assert(resultTypes.size() "
|
|
|
|
|
<< (hasVariadicResult ? ">=" : "==") << " "
|
|
|
|
|
<< numNonVariadicResults
|
|
|
|
|
<< "u && \"mismatched number of results\");\n";
|
|
|
|
|
body << " " << builderOpState << ".addTypes(resultTypes);\n";
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
llvm_unreachable("unhandled TypeParamKind");
|
|
|
|
|
// // Avoid emitting "resultTypes.size() >= 0u" which is always true.
|
|
|
|
|
// if (!(hasVariadicResult && numNonVariadicResults == 0))
|
|
|
|
|
// body << " "
|
|
|
|
|
// << "assert(resultTypes.size() "
|
|
|
|
|
// << (hasVariadicResult ? ">=" : "==") << " "
|
|
|
|
|
// << numNonVariadicResults
|
|
|
|
|
// << "u && \"mismatched number of results\");\n";
|
|
|
|
|
// body << " " << builderOpState << ".addTypes(resultTypes);\n";
|
|
|
|
|
// }
|
|
|
|
|
// return;
|
|
|
|
|
// }
|
|
|
|
|
// llvm_unreachable("unhandled TypeParamKind");
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Some of the build methods generated here may be ambiguous, but TableGen's
|
|
|
|
|
// ambiguous function detection will elide those ones.
|
|
|
|
|
for (auto attrType : attrBuilderType) {
|
|
|
|
|
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
|
|
|
|
|
if (canInferType(op))
|
|
|
|
|
emit(attrType, TypeParamKind::None, /*inferType=*/true);
|
|
|
|
|
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
|
|
|
|
|
// if (canInferType(op))
|
|
|
|
|
// emit(attrType, TypeParamKind::None, /*inferType=*/true);
|
|
|
|
|
// emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
|
|
|
|
|
int numResults = op.getNumResults();
|
|
|
|
|
|
|
|
|
|
// Signature
|
|
|
|
|
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
|
|
|
|
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
|
|
|
|
paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
|
|
|
|
paramList.emplace_back("::mlir::ValueRange", "operands");
|
|
|
|
|
// Provide default value for `attributes` when its the last parameter
|
|
|
|
|
StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
|
|
|
|
|
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
|
|
|
|
|
"attributes", attributesDefaultValue);
|
|
|
|
|
if (op.getNumVariadicRegions())
|
|
|
|
|
paramList.emplace_back("unsigned", "numRegions");
|
|
|
|
|
|
|
|
|
|
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
|
|
|
|
std::move(paramList));
|
|
|
|
|
// If the builder is redundant, skip generating the method
|
|
|
|
|
if (!m)
|
|
|
|
|
return;
|
|
|
|
|
auto &body = m->body();
|
|
|
|
|
|
|
|
|
|
// Operands
|
|
|
|
|
body << " " << builderOpState << ".addOperands(operands);\n";
|
|
|
|
|
|
|
|
|
|
// Attributes
|
|
|
|
|
body << " " << builderOpState << ".addAttributes(attributes);\n";
|
|
|
|
|
|
|
|
|
|
// Create the correct number of regions
|
|
|
|
|
if (int numRegions = op.getNumRegions()) {
|
|
|
|
|
body << llvm::formatv(
|
|
|
|
|
" for (unsigned i = 0; i != {0}; ++i)\n",
|
|
|
|
|
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
|
|
|
|
|
body << " (void)" << builderOpState << ".addRegion();\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Result types
|
|
|
|
|
SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
|
|
|
|
|
body << " " << builderOpState << ".addTypes({"
|
|
|
|
|
<< llvm::join(resultTypes, ", ") << "});\n\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
|
|
|
|
|
// TODO: Expand to support regions.
|
|
|
|
|
SmallVector<OpMethodParameter, 4> paramList;
|
|
|
|
|
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
|
|
|
|
paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
|
|
|
|
paramList.emplace_back("::mlir::ValueRange", "operands");
|
|
|
|
|
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
|
|
|
|
|
"attributes", "{}");
|
|
|
|
|
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
|
|
|
|
std::move(paramList));
|
|
|
|
|
// If the builder is redundant, skip generating the method
|
|
|
|
|
if (!m)
|
|
|
|
|
return;
|
|
|
|
|
auto &body = m->body();
|
|
|
|
|
|
|
|
|
|
int numResults = op.getNumResults();
|
|
|
|
|
int numVariadicResults = op.getNumVariableLengthResults();
|
|
|
|
|
int numNonVariadicResults = numResults - numVariadicResults;
|
|
|
|
|
|
|
|
|
|
int numOperands = op.getNumOperands();
|
|
|
|
|
int numVariadicOperands = op.getNumVariableLengthOperands();
|
|
|
|
|
int numNonVariadicOperands = numOperands - numVariadicOperands;
|
|
|
|
|
|
|
|
|
|
// Operands
|
|
|
|
|
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
|
|
|
|
|
body << " assert(operands.size()"
|
|
|
|
|
<< (numVariadicOperands != 0 ? " >= " : " == ")
|
|
|
|
|
<< numNonVariadicOperands
|
|
|
|
|
<< "u && \"mismatched number of parameters\");\n";
|
|
|
|
|
body << " " << builderOpState << ".addOperands(operands);\n";
|
|
|
|
|
body << " " << builderOpState << ".addAttributes(attributes);\n";
|
|
|
|
|
|
|
|
|
|
// Create the correct number of regions
|
|
|
|
|
if (int numRegions = op.getNumRegions()) {
|
|
|
|
|
body << llvm::formatv(
|
|
|
|
|
" for (unsigned i = 0; i != {0}; ++i)\n",
|
|
|
|
|
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
|
|
|
|
|
body << " (void)" << builderOpState << ".addRegion();\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Result types
|
|
|
|
|
body << formatv(R"(
|
|
|
|
|
::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
|
|
|
|
|
if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
|
|
|
|
|
{1}.location, operands,
|
|
|
|
|
{1}.attributes.getDictionary({1}.getContext()),
|
|
|
|
|
/*regions=*/{{}, inferredReturnTypes))) {{)",
|
|
|
|
|
opClass.getClassName(), builderOpState);
|
|
|
|
|
if (numVariadicResults == 0 || numNonVariadicResults != 0)
|
|
|
|
|
body << " assert(inferredReturnTypes.size()"
|
|
|
|
|
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
|
|
|
|
|
<< "u && \"mismatched number of return types\");\n";
|
|
|
|
|
body << " " << builderOpState << ".addTypes(inferredReturnTypes);";
|
|
|
|
|
|
|
|
|
|
body << formatv(R"(
|
|
|
|
|
} else
|
|
|
|
|
::llvm::report_fatal_error("Failed to infer result type(s).");)",
|
|
|
|
|
opClass.getClassName(), builderOpState);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
|
|
|
|
|
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
|
|
|
|
llvm::SmallVector<std::string, 4> resultNames;
|
|
|
|
|
buildParamList(paramList, resultNames, TypeParamKind::None);
|
|
|
|
|
|
|
|
|
|
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
|
|
|
|
std::move(paramList));
|
|
|
|
|
// If the builder is redundant, skip generating the method
|
|
|
|
|
if (!m)
|
|
|
|
|
return;
|
|
|
|
|
auto &body = m->body();
|
|
|
|
|
genCodeForAddingArgAndRegionForBuilder(body);
|
|
|
|
|
|
|
|
|
|
auto numResults = op.getNumResults();
|
|
|
|
|
if (numResults == 0)
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
// Push all result types to the operation state
|
|
|
|
|
const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
|
|
|
|
|
std::string resultType =
|
|
|
|
|
formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
|
|
|
|
|
body << " " << builderOpState << ".addTypes({" << resultType;
|
|
|
|
|
for (int i = 1; i != numResults; ++i)
|
|
|
|
|
body << ", " << resultType;
|
|
|
|
|
body << "});\n\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpEmitter::genUseAttrAsResultTypeBuilder() {
|
|
|
|
|
SmallVector<OpMethodParameter, 4> paramList;
|
|
|
|
|
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
|
|
|
|
paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
|
|
|
|
paramList.emplace_back("::mlir::ValueRange", "operands");
|
|
|
|
|
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
|
|
|
|
|
"attributes", "{}");
|
|
|
|
|
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
|
|
|
|
std::move(paramList));
|
|
|
|
|
// If the builder is redundant, skip generating the method
|
|
|
|
|
if (!m)
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
|
|
auto &body = m->body();
|
|
|
|
|
|
|
|
|
|
// Push all result types to the operation state
|
|
|
|
|
std::string resultType;
|
|
|
|
|
const auto &namedAttr = op.getAttribute(0);
|
|
|
|
|
|
|
|
|
|
body << " for (auto attr : attributes) {\n";
|
|
|
|
|
body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n";
|
|
|
|
|
if (namedAttr.attr.isTypeAttr()) {
|
|
|
|
|
resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
|
|
|
|
|
} else {
|
|
|
|
|
resultType = "attr.second.getType()";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Operands
|
|
|
|
|
body << " " << builderOpState << ".addOperands(operands);\n";
|
|
|
|
|
|
|
|
|
|
// Attributes
|
|
|
|
|
body << " " << builderOpState << ".addAttributes(attributes);\n";
|
|
|
|
|
|
|
|
|
|
// Result types
|
|
|
|
|
SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
|
|
|
|
|
body << " " << builderOpState << ".addTypes({"
|
|
|
|
|
<< llvm::join(resultTypes, ", ") << "});\n";
|
|
|
|
|
body << " }\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Returns a signature of the builder. Updates the context `fctx` to enable
|
|
|
|
|
/// replacement of $_builder and $_state in the body.
|
|
|
|
|
static std::string getBuilderSignature(const Builder &builder) {
|
|
|
|
@ -1352,21 +1221,21 @@ static std::string getBuilderSignature(const Builder &builder) {
|
|
|
|
|
|
|
|
|
|
void OpEmitter::genBuilder() {
|
|
|
|
|
// Handle custom builders if provided.
|
|
|
|
|
for (const Builder &builder : op.getBuilders()) {
|
|
|
|
|
std::string paramStr = getBuilderSignature(builder);
|
|
|
|
|
// for (const Builder &builder : op.getBuilders()) {
|
|
|
|
|
// std::string paramStr = getBuilderSignature(builder);
|
|
|
|
|
|
|
|
|
|
Optional<StringRef> body = builder.getBody();
|
|
|
|
|
OpMethod::Property properties =
|
|
|
|
|
body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
|
|
|
|
|
auto *method =
|
|
|
|
|
opClass.addMethodAndPrune("void", "build", properties, paramStr);
|
|
|
|
|
// Optional<StringRef> body = builder.getBody();
|
|
|
|
|
// OpMethod::Property properties =
|
|
|
|
|
// body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
|
|
|
|
|
// auto *method =
|
|
|
|
|
// opClass.addMethodAndPrune("void", "build", properties, paramStr);
|
|
|
|
|
|
|
|
|
|
FmtContext fctx;
|
|
|
|
|
fctx.withBuilder(odsBuilder);
|
|
|
|
|
fctx.addSubst("_state", builderOpState);
|
|
|
|
|
if (body)
|
|
|
|
|
method->body() << tgfmt(*body, &fctx);
|
|
|
|
|
}
|
|
|
|
|
// FmtContext fctx;
|
|
|
|
|
// fctx.withBuilder(odsBuilder);
|
|
|
|
|
// fctx.addSubst("_state", builderOpState);
|
|
|
|
|
// if (body)
|
|
|
|
|
// method->body() << tgfmt(*body, &fctx);
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
// Generate default builders that requires all result type, operands, and
|
|
|
|
|
// attributes as parameters.
|
|
|
|
@ -1378,78 +1247,18 @@ void OpEmitter::genBuilder() {
|
|
|
|
|
genSeparateArgParamBuilder();
|
|
|
|
|
// 2. one having an aggregated parameter for all result types / operands /
|
|
|
|
|
// attributes, and
|
|
|
|
|
genCollectiveParamBuilder();
|
|
|
|
|
// 3. one having a stand-alone parameter for each operand and attribute,
|
|
|
|
|
// use the first operand or attribute's type as all result types
|
|
|
|
|
// to facilitate different call patterns.
|
|
|
|
|
if (op.getNumVariableLengthResults() == 0) {
|
|
|
|
|
if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
|
|
|
|
|
genUseOperandAsResultTypeSeparateParamBuilder();
|
|
|
|
|
genUseOperandAsResultTypeCollectiveParamBuilder();
|
|
|
|
|
}
|
|
|
|
|
if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
|
|
|
|
|
genUseAttrAsResultTypeBuilder();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpEmitter::genCollectiveParamBuilder() {
|
|
|
|
|
int numResults = op.getNumResults();
|
|
|
|
|
int numVariadicResults = op.getNumVariableLengthResults();
|
|
|
|
|
int numNonVariadicResults = numResults - numVariadicResults;
|
|
|
|
|
|
|
|
|
|
int numOperands = op.getNumOperands();
|
|
|
|
|
int numVariadicOperands = op.getNumVariableLengthOperands();
|
|
|
|
|
int numNonVariadicOperands = numOperands - numVariadicOperands;
|
|
|
|
|
|
|
|
|
|
SmallVector<OpMethodParameter, 4> paramList;
|
|
|
|
|
paramList.emplace_back("::mlir::OpBuilder &", "");
|
|
|
|
|
paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
|
|
|
|
paramList.emplace_back("::mlir::TypeRange", "resultTypes");
|
|
|
|
|
paramList.emplace_back("::mlir::ValueRange", "operands");
|
|
|
|
|
// Provide default value for `attributes` when its the last parameter
|
|
|
|
|
StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
|
|
|
|
|
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
|
|
|
|
|
"attributes", attributesDefaultValue);
|
|
|
|
|
if (op.getNumVariadicRegions())
|
|
|
|
|
paramList.emplace_back("unsigned", "numRegions");
|
|
|
|
|
|
|
|
|
|
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
|
|
|
|
std::move(paramList));
|
|
|
|
|
// If the builder is redundant, skip generating the method
|
|
|
|
|
if (!m)
|
|
|
|
|
return;
|
|
|
|
|
auto &body = m->body();
|
|
|
|
|
|
|
|
|
|
// Operands
|
|
|
|
|
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
|
|
|
|
|
body << " assert(operands.size()"
|
|
|
|
|
<< (numVariadicOperands != 0 ? " >= " : " == ")
|
|
|
|
|
<< numNonVariadicOperands
|
|
|
|
|
<< "u && \"mismatched number of parameters\");\n";
|
|
|
|
|
body << " " << builderOpState << ".addOperands(operands);\n";
|
|
|
|
|
|
|
|
|
|
// Attributes
|
|
|
|
|
body << " " << builderOpState << ".addAttributes(attributes);\n";
|
|
|
|
|
|
|
|
|
|
// Create the correct number of regions
|
|
|
|
|
if (int numRegions = op.getNumRegions()) {
|
|
|
|
|
body << llvm::formatv(
|
|
|
|
|
" for (unsigned i = 0; i != {0}; ++i)\n",
|
|
|
|
|
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
|
|
|
|
|
body << " (void)" << builderOpState << ".addRegion();\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Result types
|
|
|
|
|
if (numVariadicResults == 0 || numNonVariadicResults != 0)
|
|
|
|
|
body << " assert(resultTypes.size()"
|
|
|
|
|
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
|
|
|
|
|
<< "u && \"mismatched number of return types\");\n";
|
|
|
|
|
body << " " << builderOpState << ".addTypes(resultTypes);\n";
|
|
|
|
|
|
|
|
|
|
// Generate builder that infers type too.
|
|
|
|
|
// TODO: Expand to handle regions and successors.
|
|
|
|
|
if (canInferType(op) && op.getNumSuccessors() == 0)
|
|
|
|
|
genInferredTypeCollectiveParamBuilder();
|
|
|
|
|
// genCollectiveParamBuilder();
|
|
|
|
|
// // 3. one having a stand-alone parameter for each operand and attribute,
|
|
|
|
|
// // use the first operand or attribute's type as all result types
|
|
|
|
|
// // to facilitate different call patterns.
|
|
|
|
|
// if (op.getNumVariableLengthResults() == 0) {
|
|
|
|
|
// if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
|
|
|
|
|
// genUseOperandAsResultTypeSeparateParamBuilder();
|
|
|
|
|
// genUseOperandAsResultTypeCollectiveParamBuilder();
|
|
|
|
|
// }
|
|
|
|
|
// if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
|
|
|
|
|
// genUseAttrAsResultTypeBuilder();
|
|
|
|
|
// }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|
|
|
@ -1460,8 +1269,9 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|
|
|
|
auto numResults = op.getNumResults();
|
|
|
|
|
resultTypeNames.reserve(numResults);
|
|
|
|
|
|
|
|
|
|
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
|
|
|
|
paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
|
|
|
|
// paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
|
|
|
|
paramList.emplace_back("::builder::Builder &", "builder");
|
|
|
|
|
// paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
|
|
|
|
|
|
|
|
|
switch (typeParamKind) {
|
|
|
|
|
case TypeParamKind::None:
|
|
|
|
@ -1475,7 +1285,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|
|
|
|
resultName = std::string(formatv("resultType{0}", i));
|
|
|
|
|
|
|
|
|
|
StringRef type =
|
|
|
|
|
result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
|
|
|
|
|
result.isVariadic() ? "std::vector<::builder::Type>" : "::builder::Type";
|
|
|
|
|
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
|
|
|
|
if (result.isOptional())
|
|
|
|
|
properties = OpMethodParameter::PP_Optional;
|
|
|
|
@ -1512,7 +1322,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|
|
|
|
// for APFloat.
|
|
|
|
|
// TODO: Adjust the 'returnType' field of such attributes
|
|
|
|
|
// to support them.
|
|
|
|
|
StringRef retType = namedAttr->attr.getReturnType();
|
|
|
|
|
StringRef retType = getReturnType(namedAttr->attr);
|
|
|
|
|
if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
|
@ -1525,7 +1335,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|
|
|
|
if (argument.is<tblgen::NamedTypeConstraint *>()) {
|
|
|
|
|
const auto &operand = op.getOperand(numOperands);
|
|
|
|
|
StringRef type =
|
|
|
|
|
operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
|
|
|
|
|
operand.isVariadic() ? "std::vector<::builder::Op>" : "::builder::Op";
|
|
|
|
|
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
|
|
|
|
if (operand.isOptional())
|
|
|
|
|
properties = OpMethodParameter::PP_Optional;
|
|
|
|
@ -1544,13 +1354,13 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|
|
|
|
StringRef type;
|
|
|
|
|
switch (attrParamKind) {
|
|
|
|
|
case AttrParamKind::WrappedAttr:
|
|
|
|
|
type = attr.getStorageType();
|
|
|
|
|
type = getStorageType(attr);
|
|
|
|
|
break;
|
|
|
|
|
case AttrParamKind::UnwrappedValue:
|
|
|
|
|
if (canUseUnwrappedRawValue(attr))
|
|
|
|
|
type = attr.getReturnType();
|
|
|
|
|
type = getReturnType(attr);
|
|
|
|
|
else
|
|
|
|
|
type = attr.getStorageType();
|
|
|
|
|
type = getStorageType(attr);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1558,7 +1368,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|
|
|
|
// Attach default value if requested and possible.
|
|
|
|
|
if (attrParamKind == AttrParamKind::UnwrappedValue &&
|
|
|
|
|
i >= defaultValuedAttrStartIndex) {
|
|
|
|
|
bool isString = attr.getReturnType() == "::llvm::StringRef";
|
|
|
|
|
bool isString = getReturnType(attr) == "::llvm::StringRef";
|
|
|
|
|
if (isString)
|
|
|
|
|
defaultValue.append("\"");
|
|
|
|
|
defaultValue += attr.getDefaultValue();
|
|
|
|
@ -1584,84 +1394,117 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|
|
|
|
llvm::formatv("{0}Count", region.name).str());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
|
|
|
|
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
|
|
|
|
|
OpMethodBody &body, llvm::SmallVector<OpMethodParameter, 4> paramList,
|
|
|
|
|
bool isRawValueAttr) {
|
|
|
|
|
// Push all operands to the result.
|
|
|
|
|
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
|
|
|
|
std::string argName = getArgumentName(op, i);
|
|
|
|
|
if (op.getOperand(i).isOptional())
|
|
|
|
|
body << " if (" << argName << ")\n ";
|
|
|
|
|
body << " " << builderOpState << ".addOperands(" << argName << ");\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If the operation has the operand segment size attribute, add it here.
|
|
|
|
|
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
|
|
|
|
|
body << " " << builderOpState
|
|
|
|
|
<< ".addAttribute(\"operand_segment_sizes\", "
|
|
|
|
|
"odsBuilder.getI32VectorAttr({";
|
|
|
|
|
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
|
|
|
|
|
if (op.getOperand(i).isOptional())
|
|
|
|
|
body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
|
|
|
|
|
else if (op.getOperand(i).isVariadic())
|
|
|
|
|
body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
|
|
|
|
|
else
|
|
|
|
|
body << "1";
|
|
|
|
|
body << "// AAAAAA \n";
|
|
|
|
|
// for(auto p : paramList)
|
|
|
|
|
// {
|
|
|
|
|
// body <<"==== type:"<< p.getType() << " name:"<< p.getName() << "\n";
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
body << " auto builder = builder.GetImpl();\n";
|
|
|
|
|
body << " auto loc = builder->GetLoc();\n";
|
|
|
|
|
body << " auto opBuilder = builder->GetBuilder();\n";
|
|
|
|
|
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
|
|
|
|
|
<< " currentOp =\n";
|
|
|
|
|
body << " opBuilder.create<mlir::" << op.getDialectName()
|
|
|
|
|
<< "::" << op.getDialectName() << ">(\n";
|
|
|
|
|
if (paramList.size() > 1) {
|
|
|
|
|
body << " loc,\n";
|
|
|
|
|
std::for_each(paramList.begin() + 1, paramList.end() - 1,
|
|
|
|
|
[&](OpMethodParameter &p) {
|
|
|
|
|
body << " " << p.getName() << ",\n";
|
|
|
|
|
});
|
|
|
|
|
body << "}));\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Push all attributes to the result.
|
|
|
|
|
for (const auto &namedAttr : op.getAttributes()) {
|
|
|
|
|
auto &attr = namedAttr.attr;
|
|
|
|
|
if (!attr.isDerivedAttr()) {
|
|
|
|
|
bool emitNotNullCheck = attr.isOptional();
|
|
|
|
|
if (emitNotNullCheck) {
|
|
|
|
|
body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
|
|
|
|
|
}
|
|
|
|
|
if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
|
|
|
|
|
// If this is a raw value, then we need to wrap it in an Attribute
|
|
|
|
|
// instance.
|
|
|
|
|
FmtContext fctx;
|
|
|
|
|
fctx.withBuilder("odsBuilder");
|
|
|
|
|
|
|
|
|
|
std::string builderTemplate =
|
|
|
|
|
std::string(attr.getConstBuilderTemplate());
|
|
|
|
|
|
|
|
|
|
// For StringAttr, its constant builder call will wrap the input in
|
|
|
|
|
// quotes, which is correct for normal string literals, but incorrect
|
|
|
|
|
// here given we use function arguments. So we need to strip the
|
|
|
|
|
// wrapping quotes.
|
|
|
|
|
if (StringRef(builderTemplate).contains("\"$0\""))
|
|
|
|
|
builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
|
|
|
|
|
|
|
|
|
|
std::string value =
|
|
|
|
|
std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
|
|
|
|
|
body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
|
|
|
|
|
namedAttr.name, value);
|
|
|
|
|
body << " " << paramList.back().getName() << "\n";
|
|
|
|
|
} else {
|
|
|
|
|
body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
|
|
|
|
|
namedAttr.name);
|
|
|
|
|
}
|
|
|
|
|
if (emitNotNullCheck) {
|
|
|
|
|
body << " }\n";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
body << " loc\n";
|
|
|
|
|
}
|
|
|
|
|
body << " );\n";
|
|
|
|
|
body << " builder::" << op.getCppClassName() << " builderOp;\n";
|
|
|
|
|
body << " auto opImpl = builderOp.GetImpl();\n";
|
|
|
|
|
body << " opImpl.SetOperation(currentOp.getOperation());\n";
|
|
|
|
|
body << " return builderOp;\n";
|
|
|
|
|
|
|
|
|
|
// Create the correct number of regions.
|
|
|
|
|
for (const NamedRegion ®ion : op.getRegions()) {
|
|
|
|
|
if (region.isVariadic())
|
|
|
|
|
body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
|
|
|
|
|
region.name);
|
|
|
|
|
|
|
|
|
|
body << " (void)" << builderOpState << ".addRegion();\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Push all successors to the result.
|
|
|
|
|
for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
|
|
|
|
|
body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
|
|
|
|
|
namedSuccessor.name);
|
|
|
|
|
}
|
|
|
|
|
// // Push all operands to the result.
|
|
|
|
|
// for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
|
|
|
|
// std::string argName = getArgumentName(op, i);
|
|
|
|
|
// if (op.getOperand(i).isOptional())
|
|
|
|
|
// body << " if (" << argName << ")\n ";
|
|
|
|
|
// body << " " << builderOpState << ".addOperands(" << argName << ");\n";
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
// // If the operation has the operand segment size attribute, add it here.
|
|
|
|
|
// if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
|
|
|
|
|
// body << " " << builderOpState
|
|
|
|
|
// << ".addAttribute(\"operand_segment_sizes\", "
|
|
|
|
|
// "odsBuilder.getI32VectorAttr({";
|
|
|
|
|
// interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
|
|
|
|
|
// if (op.getOperand(i).isOptional())
|
|
|
|
|
// body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
|
|
|
|
|
// else if (op.getOperand(i).isVariadic())
|
|
|
|
|
// body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
|
|
|
|
|
// else
|
|
|
|
|
// body << "1";
|
|
|
|
|
// });
|
|
|
|
|
// body << "}));\n";
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
// // Push all attributes to the result.
|
|
|
|
|
// for (const auto &namedAttr : op.getAttributes()) {
|
|
|
|
|
// auto &attr = namedAttr.attr;
|
|
|
|
|
// if (!attr.isDerivedAttr()) {
|
|
|
|
|
// bool emitNotNullCheck = attr.isOptional();
|
|
|
|
|
// if (emitNotNullCheck) {
|
|
|
|
|
// body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
|
|
|
|
|
// }
|
|
|
|
|
// if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
|
|
|
|
|
// // If this is a raw value, then we need to wrap it in an Attribute
|
|
|
|
|
// // instance.
|
|
|
|
|
// FmtContext fctx;
|
|
|
|
|
// fctx.withBuilder("odsBuilder");
|
|
|
|
|
|
|
|
|
|
// std::string builderTemplate =
|
|
|
|
|
// std::string(attr.getConstBuilderTemplate());
|
|
|
|
|
|
|
|
|
|
// // For StringAttr, its constant builder call will wrap the input in
|
|
|
|
|
// // quotes, which is correct for normal string literals, but incorrect
|
|
|
|
|
// // here given we use function arguments. So we need to strip the
|
|
|
|
|
// // wrapping quotes.
|
|
|
|
|
// if (StringRef(builderTemplate).contains("\"$0\""))
|
|
|
|
|
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
|
|
|
|
|
|
|
|
|
|
// std::string value =
|
|
|
|
|
// std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
|
|
|
|
|
// body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
|
|
|
|
|
// namedAttr.name, value);
|
|
|
|
|
// } else {
|
|
|
|
|
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
|
|
|
|
|
// namedAttr.name);
|
|
|
|
|
// }
|
|
|
|
|
// if (emitNotNullCheck) {
|
|
|
|
|
// body << " }\n";
|
|
|
|
|
// }
|
|
|
|
|
// }
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
// // Create the correct number of regions.
|
|
|
|
|
// for (const NamedRegion ®ion : op.getRegions()) {
|
|
|
|
|
// if (region.isVariadic())
|
|
|
|
|
// body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
|
|
|
|
|
// region.name);
|
|
|
|
|
|
|
|
|
|
// body << " (void)" << builderOpState << ".addRegion();\n";
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
// // Push all successors to the result.
|
|
|
|
|
// for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
|
|
|
|
|
// body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
|
|
|
|
|
// namedSuccessor.name);
|
|
|
|
|
// }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpEmitter::genCanonicalizerDecls() {
|
|
|
|
@ -2305,15 +2148,15 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
|
|
|
|
|
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
|
|
|
|
|
|
|
|
|
|
auto emitAttr = [&](StringRef name, Attribute attr) {
|
|
|
|
|
auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
|
|
|
|
|
auto &body = adaptor.addMethodAndPrune(getStorageType(attr), name)->body();
|
|
|
|
|
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
|
|
|
|
|
<< "\n " << attr.getStorageType() << " attr = "
|
|
|
|
|
<< "\n " << getStorageType(attr) << " attr = "
|
|
|
|
|
<< "odsAttrs.get(\"" << name << "\").";
|
|
|
|
|
if (attr.hasDefaultValue() || attr.isOptional())
|
|
|
|
|
body << "dyn_cast_or_null<";
|
|
|
|
|
else
|
|
|
|
|
body << "cast<";
|
|
|
|
|
body << attr.getStorageType() << ">();\n";
|
|
|
|
|
body << getStorageType(attr) << ">();\n";
|
|
|
|
|
|
|
|
|
|
if (attr.hasDefaultValue()) {
|
|
|
|
|
// Use the default value if attribute is not set.
|
|
|
|
@ -2442,11 +2285,11 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
|
|
|
|
|
Operator op(*def);
|
|
|
|
|
NamespaceEmitter emitter(os, op.getCppNamespace());
|
|
|
|
|
if (emitDecl) {
|
|
|
|
|
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
|
|
|
|
|
// os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
|
|
|
|
|
// OpOperandAdaptorEmitter::emitDecl(op, os);
|
|
|
|
|
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
|
|
|
|
|
} else {
|
|
|
|
|
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
|
|
|
|
|
// os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
|
|
|
|
|
// OpOperandAdaptorEmitter::emitDef(op, os);
|
|
|
|
|
OpEmitter::emitDef(op, os, staticVerifierEmitter);
|
|
|
|
|
}
|
|
|
|
@ -2456,12 +2299,18 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
|
|
|
|
|
// Emits a comma-separated list of the ops.
|
|
|
|
|
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
|
|
|
|
|
IfDefScope scope("GET_OP_LIST", os);
|
|
|
|
|
|
|
|
|
|
interleave(
|
|
|
|
|
// TODO: We are constructing the Operator wrapper instance just for
|
|
|
|
|
// getting it's qualified class name here. Reduce the overhead by having a
|
|
|
|
|
// lightweight version of Operator class just for that purpose.
|
|
|
|
|
defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
|
|
|
|
|
defs,
|
|
|
|
|
[&os](Record *def) {
|
|
|
|
|
SmallVector<StringRef, 4> namespaces;
|
|
|
|
|
std::string className = Operator(def).getQualCppClassName();
|
|
|
|
|
llvm::SplitString(StringRef(className), namespaces, StringRef("::"));
|
|
|
|
|
if (namespaces.begin() != namespaces.end())
|
|
|
|
|
os << "::builder::mhlo::" << namespaces.back().str();
|
|
|
|
|
},
|
|
|
|
|
[&os]() { os << ",\n"; });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|