modify code gen of build function

This commit is contained in:
colin.liang 2021-08-04 20:24:07 +08:00
parent 898eb732de
commit 9d5166684c
13 changed files with 437 additions and 430 deletions

2
BUILD
View File

@ -133,7 +133,7 @@ cc_binary(
deps = [
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
# "@llvm-project//mlir:TableGen",
"@llvm-project//mlir:IR",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",

View File

@ -0,0 +1,10 @@
#ifndef BUILDER_ARRAY_
#define BUILDER_ARRAY_
#include "iostream"
namespace builder {
class Array {}
} // namespace builder
#endif

View File

@ -0,0 +1,17 @@
#include "Builder.h"
#include "BuilderImpl.h"
#include "llvm/Support/Casting.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
// #include "mlir/IR/Attributes.h"
// #include "mlir/IR/Operation.h"
// #include "mlir/IR/StandardTypes.h"
// #include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
namespace builder {
Builder::Builder() : _impl(std::make_shared<Impl>()) {}
void Builder::DumpModule() {}
} // namespace builder

View File

@ -0,0 +1,21 @@
#ifndef BUILDER_BUILDER_
#define BUILDER_BUILDER_
#include <memory>
namespace builder {
class Builder {
public:
class Impl;
Builder();
void DumpModule();
std::shared_ptr<Impl> GetImpl() { return _impl; }
private:
std::shared_ptr<Impl> _impl;
};
} // namespace builder
#endif

View File

@ -29,6 +29,8 @@
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "iostream"
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
using namespace llvm;
@ -40,6 +42,40 @@ static const char *const generatedArgName = "odsArg";
static const char *const odsBuilder = "odsBuilder";
static const char *const builderOpState = "odsState";
static const std::map<std::string, std::string> typeMapMLIR = {
{"::mlir::StringAttr", "std::string"},
{"::mlir::IntegerAttr", "int"},
{"::mlir::DenseIntElementsAttr", "std::vector<int>"},
{"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
{"::mlir::FloatAttr", "float"},
{"::mlir::BoolAttr", "bool"},
{"::mlir::ElementsAttr", "::builder::Array"},
{"::mlir::DenseElementsAttr", "::builder::Tensor"},
// current only support string array.
{"::mlir::ArrayAttr", "std::vector<std::string>"},
{"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
{"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
{"::mlir::mhlo::GatherDimensionNumbers",
"::builder::GatherDimensionNumbers"},
{"::mlir::mhlo::ScatterDimensionNumbers",
"::builder::ScatterDimensionNumbers"},
};
StringRef typeConvertFromMLIR(StringRef type) {
auto re = typeMapMLIR.find(type.str());
if (re != typeMapMLIR.end()) return StringRef(re->second);
return type;
}
StringRef getStorageType(const Attribute &att) {
auto type = att.getStorageType();
return typeConvertFromMLIR(type);
}
StringRef getReturnType(const Attribute &att) {
auto type = att.getStorageType();
return typeConvertFromMLIR(type);
}
// The logic to calculate the actual value range for a declared operand/result
// of an op with variadic operands/results. Note that this logic is not for
// general use; it assumes all variadic operands/results must have the same
@ -173,11 +209,11 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
: uniqueOutputLabel(getUniqueName(records)) {
llvm::Optional<NamespaceEmitter> namespaceEmitter;
if (!emitDecl) {
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
// os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
}
emitTypeConstraintMethods(opDefs, os, emitDecl);
// emitTypeConstraintMethods(opDefs, os, emitDecl);
}
std::string StaticVerifierFunctionEmitter::getUniqueName(
@ -278,7 +314,7 @@ static std::string getArgumentName(const Operator &op, int index) {
// Returns true if we can use unwrapped value for the given `attr` in builders.
static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
return attr.getReturnType() != attr.getStorageType() &&
return getReturnType(attr) != getStorageType(attr) &&
// We need to wrap the raw value into an attribute in the builder impl
// so we need to make sure that the attribute specifies how to do that.
!attr.getConstBuilderTemplate().empty();
@ -346,12 +382,12 @@ private:
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first operand's
// type as all results' types.
void genUseOperandAsResultTypeSeparateParamBuilder();
// void genUseOperandAsResultTypeSeparateParamBuilder();
// Generates the build() method that takes all operands/attributes
// collectively as one parameter. The generated build() method uses first
// operand's type as all results' types.
void genUseOperandAsResultTypeCollectiveParamBuilder();
// void genUseOperandAsResultTypeCollectiveParamBuilder();
// Generates the build() method that takes aggregate operands/attributes
// parameters. This build() method uses inferred types as result types.
@ -361,11 +397,11 @@ private:
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first attribute's
// type as all result's types.
void genUseAttrAsResultTypeBuilder();
// void genUseAttrAsResultTypeBuilder();
// Generates the build() method that takes all result types collectively as
// one parameter. Similarly for operands and attributes.
void genCollectiveParamBuilder();
// void genCollectiveParamBuilder();
// The kind of parameter to generate for result types in builders.
enum class TypeParamKind {
@ -391,7 +427,7 @@ private:
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
// Adds op arguments and regions into operation state for build() methods.
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,llvm::SmallVector<OpMethodParameter, 4> paramList,
bool isRawValueAttr = false);
// Generates canonicalizer declaration for the operation.
@ -588,24 +624,24 @@ OpEmitter::OpEmitter(const Operator &op,
// methods in the generated file.
// genOpAsmInterface();
genOpNameGetter();
genNamedOperandGetters();
genNamedOperandSetters();
genNamedResultGetters();
genNamedRegionGetters();
// genNamedOperandGetters();
// genNamedOperandSetters();
// genNamedResultGetters();
// genNamedRegionGetters();
genNamedSuccessorGetters();
genAttrGetters();
genAttrSetters();
genOptionalAttrRemovers();
// genOptionalAttrRemovers();
genBuilder();
genParser();
genPrinter();
genVerifier();
genCanonicalizerDecls();
genFolderDecls();
genTypeInterfaceMethods();
genOpInterfaceMethods();
generateOpFormat(op, opClass);
genSideEffectInterfaceMethods();
// genParser();
// genPrinter();
// genVerifier();
// genCanonicalizerDecls();
// genFolderDecls();
// genTypeInterfaceMethods();
// genOpInterfaceMethods();
// generateOpFormat(op, opClass);
// genSideEffectInterfaceMethods();
}
void OpEmitter::emitDecl(
@ -631,7 +667,7 @@ void OpEmitter::genAttrGetters() {
Dialect opDialect = op.getDialect();
// Emit the derived attribute body.
auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
auto *method = opClass.addMethodAndPrune(getReturnType(attr), name);
if (!method)
return;
auto &body = method->body();
@ -640,7 +676,7 @@ void OpEmitter::genAttrGetters() {
// Emit with return type specified.
auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
auto *method = opClass.addMethodAndPrune(getReturnType(attr), name);
auto &body = method->body();
body << " auto attr = " << name << "Attr();\n";
if (attr.hasDefaultValue()) {
@ -664,7 +700,7 @@ void OpEmitter::genAttrGetters() {
// the string interface for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto *method =
opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
opClass.addMethodAndPrune(getStorageType(attr), (name + "Attr").str());
if (!method)
return;
auto &body = method->body();
@ -673,7 +709,7 @@ void OpEmitter::genAttrGetters() {
body << "dyn_cast_or_null<";
else
body << "cast<";
body << attr.getStorageType() << ">();";
body << getStorageType(attr) << ">();";
};
for (auto &namedAttr : op.getAttributes()) {
@ -759,7 +795,7 @@ void OpEmitter::genAttrSetters() {
// for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
attr.getStorageType(), "attr");
getStorageType(attr), "attr");
if (!method)
return;
auto &body = method->body();
@ -1070,256 +1106,89 @@ static bool canInferType(Operator &op) {
void OpEmitter::genSeparateArgParamBuilder() {
SmallVector<AttrParamKind, 2> attrBuilderType;
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
if (canGenerateUnwrappedBuilder(op))
attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
// if (canGenerateUnwrappedBuilder(op))
// attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
// Emit with separate builders with or without unwrapped attributes and/or
// inferring result type.
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
bool inferType) {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<OpMethodParameter, 4> paramList2;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, paramKind, attrType);
buildParamList(paramList2, resultNames, paramKind, attrType);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
auto *m = opClass.addMethodAndPrune(
"::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
// If the builder is redundant, skip generating the method.
if (!m)
return;
auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder(
body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
body, paramList2, attrType == AttrParamKind::UnwrappedValue);
// Push all result types to the operation state
//"BBBBBBBBBBBB"
// if (inferType) {
// // Generate builder that infers type too.
// // TODO: Subsume this with general checking if type can be
// // inferred automatically.
// // TODO: Expand to handle regions.
// body << formatv(R"(
// ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
// if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
// {1}.location, {1}.operands,
// {1}.attributes.getDictionary({1}.getContext()),
// /*regions=*/{{}, inferredReturnTypes)))
// {1}.addTypes(inferredReturnTypes);
// else
// ::llvm::report_fatal_error("Failed to infer result type(s).");)",
// opClass.getClassName(), builderOpState);
// return;
// }
if (inferType) {
// Generate builder that infers type too.
// TODO: Subsume this with general checking if type can be
// inferred automatically.
// TODO: Expand to handle regions.
body << formatv(R"(
::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
{1}.location, {1}.operands,
{1}.attributes.getDictionary({1}.getContext()),
/*regions=*/{{}, inferredReturnTypes)))
{1}.addTypes(inferredReturnTypes);
else
::llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
return;
}
// switch (paramKind) {
// case TypeParamKind::None:
// return;
// case TypeParamKind::Separate:
// for (int i = 0, e = op.getNumResults(); i < e; ++i) {
// if (op.getResult(i).isOptional())
// body << " if (" << resultNames[i] << ")\n ";
// body << " " << builderOpState << ".addTypes(" << resultNames[i]
// << ");\n";
// }
// return;
// case TypeParamKind::Collective: {
// int numResults = op.getNumResults();
// int numVariadicResults = op.getNumVariableLengthResults();
// int numNonVariadicResults = numResults - numVariadicResults;
// bool hasVariadicResult = numVariadicResults != 0;
switch (paramKind) {
case TypeParamKind::None:
return;
case TypeParamKind::Separate:
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
if (op.getResult(i).isOptional())
body << " if (" << resultNames[i] << ")\n ";
body << " " << builderOpState << ".addTypes(" << resultNames[i]
<< ");\n";
}
return;
case TypeParamKind::Collective: {
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
int numNonVariadicResults = numResults - numVariadicResults;
bool hasVariadicResult = numVariadicResults != 0;
// Avoid emitting "resultTypes.size() >= 0u" which is always true.
if (!(hasVariadicResult && numNonVariadicResults == 0))
body << " "
<< "assert(resultTypes.size() "
<< (hasVariadicResult ? ">=" : "==") << " "
<< numNonVariadicResults
<< "u && \"mismatched number of results\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
}
return;
}
llvm_unreachable("unhandled TypeParamKind");
// // Avoid emitting "resultTypes.size() >= 0u" which is always true.
// if (!(hasVariadicResult && numNonVariadicResults == 0))
// body << " "
// << "assert(resultTypes.size() "
// << (hasVariadicResult ? ">=" : "==") << " "
// << numNonVariadicResults
// << "u && \"mismatched number of results\");\n";
// body << " " << builderOpState << ".addTypes(resultTypes);\n";
// }
// return;
// }
// llvm_unreachable("unhandled TypeParamKind");
};
// Some of the build methods generated here may be ambiguous, but TableGen's
// ambiguous function detection will elide those ones.
for (auto attrType : attrBuilderType) {
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
if (canInferType(op))
emit(attrType, TypeParamKind::None, /*inferType=*/true);
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
// if (canInferType(op))
// emit(attrType, TypeParamKind::None, /*inferType=*/true);
// emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
}
}
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
int numResults = op.getNumResults();
// Signature
llvm::SmallVector<OpMethodParameter, 4> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
// Provide default value for `attributes` when its the last parameter
StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", attributesDefaultValue);
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
// Operands
body << " " << builderOpState << ".addOperands(operands);\n";
// Attributes
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
body << llvm::formatv(
" for (unsigned i = 0; i != {0}; ++i)\n",
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
body << " " << builderOpState << ".addTypes({"
<< llvm::join(resultTypes, ", ") << "});\n\n";
}
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
// TODO: Expand to support regions.
SmallVector<OpMethodParameter, 4> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
int numNonVariadicResults = numResults - numVariadicResults;
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariableLengthOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
// Operands
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
body << " assert(operands.size()"
<< (numVariadicOperands != 0 ? " >= " : " == ")
<< numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n";
body << " " << builderOpState << ".addOperands(operands);\n";
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
body << llvm::formatv(
" for (unsigned i = 0; i != {0}; ++i)\n",
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
body << formatv(R"(
::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
{1}.location, operands,
{1}.attributes.getDictionary({1}.getContext()),
/*regions=*/{{}, inferredReturnTypes))) {{)",
opClass.getClassName(), builderOpState);
if (numVariadicResults == 0 || numNonVariadicResults != 0)
body << " assert(inferredReturnTypes.size()"
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
<< "u && \"mismatched number of return types\");\n";
body << " " << builderOpState << ".addTypes(inferredReturnTypes);";
body << formatv(R"(
} else
::llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
}
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, TypeParamKind::None);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder(body);
auto numResults = op.getNumResults();
if (numResults == 0)
return;
// Push all result types to the operation state
const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
std::string resultType =
formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
body << " " << builderOpState << ".addTypes({" << resultType;
for (int i = 1; i != numResults; ++i)
body << ", " << resultType;
body << "});\n\n";
}
void OpEmitter::genUseAttrAsResultTypeBuilder() {
SmallVector<OpMethodParameter, 4> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
// Push all result types to the operation state
std::string resultType;
const auto &namedAttr = op.getAttribute(0);
body << " for (auto attr : attributes) {\n";
body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n";
if (namedAttr.attr.isTypeAttr()) {
resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
} else {
resultType = "attr.second.getType()";
}
// Operands
body << " " << builderOpState << ".addOperands(operands);\n";
// Attributes
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Result types
SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
body << " " << builderOpState << ".addTypes({"
<< llvm::join(resultTypes, ", ") << "});\n";
body << " }\n";
}
/// Returns a signature of the builder. Updates the context `fctx` to enable
/// replacement of $_builder and $_state in the body.
static std::string getBuilderSignature(const Builder &builder) {
@ -1352,21 +1221,21 @@ static std::string getBuilderSignature(const Builder &builder) {
void OpEmitter::genBuilder() {
// Handle custom builders if provided.
for (const Builder &builder : op.getBuilders()) {
std::string paramStr = getBuilderSignature(builder);
// for (const Builder &builder : op.getBuilders()) {
// std::string paramStr = getBuilderSignature(builder);
Optional<StringRef> body = builder.getBody();
OpMethod::Property properties =
body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
auto *method =
opClass.addMethodAndPrune("void", "build", properties, paramStr);
// Optional<StringRef> body = builder.getBody();
// OpMethod::Property properties =
// body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
// auto *method =
// opClass.addMethodAndPrune("void", "build", properties, paramStr);
FmtContext fctx;
fctx.withBuilder(odsBuilder);
fctx.addSubst("_state", builderOpState);
if (body)
method->body() << tgfmt(*body, &fctx);
}
// FmtContext fctx;
// fctx.withBuilder(odsBuilder);
// fctx.addSubst("_state", builderOpState);
// if (body)
// method->body() << tgfmt(*body, &fctx);
// }
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
@ -1378,78 +1247,18 @@ void OpEmitter::genBuilder() {
genSeparateArgParamBuilder();
// 2. one having an aggregated parameter for all result types / operands /
// attributes, and
genCollectiveParamBuilder();
// 3. one having a stand-alone parameter for each operand and attribute,
// use the first operand or attribute's type as all result types
// to facilitate different call patterns.
if (op.getNumVariableLengthResults() == 0) {
if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
genUseOperandAsResultTypeSeparateParamBuilder();
genUseOperandAsResultTypeCollectiveParamBuilder();
}
if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
genUseAttrAsResultTypeBuilder();
}
}
void OpEmitter::genCollectiveParamBuilder() {
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
int numNonVariadicResults = numResults - numVariadicResults;
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariableLengthOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
SmallVector<OpMethodParameter, 4> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::TypeRange", "resultTypes");
paramList.emplace_back("::mlir::ValueRange", "operands");
// Provide default value for `attributes` when its the last parameter
StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", attributesDefaultValue);
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
// Operands
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
body << " assert(operands.size()"
<< (numVariadicOperands != 0 ? " >= " : " == ")
<< numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n";
body << " " << builderOpState << ".addOperands(operands);\n";
// Attributes
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
body << llvm::formatv(
" for (unsigned i = 0; i != {0}; ++i)\n",
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
if (numVariadicResults == 0 || numNonVariadicResults != 0)
body << " assert(resultTypes.size()"
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
<< "u && \"mismatched number of return types\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
// Generate builder that infers type too.
// TODO: Expand to handle regions and successors.
if (canInferType(op) && op.getNumSuccessors() == 0)
genInferredTypeCollectiveParamBuilder();
// genCollectiveParamBuilder();
// // 3. one having a stand-alone parameter for each operand and attribute,
// // use the first operand or attribute's type as all result types
// // to facilitate different call patterns.
// if (op.getNumVariableLengthResults() == 0) {
// if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// genUseOperandAsResultTypeSeparateParamBuilder();
// genUseOperandAsResultTypeCollectiveParamBuilder();
// }
// if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
// genUseAttrAsResultTypeBuilder();
// }
}
void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
@ -1460,8 +1269,9 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
// paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::builder::Builder &", "builder");
// paramList.emplace_back("::mlir::OperationState &", builderOpState);
switch (typeParamKind) {
case TypeParamKind::None:
@ -1475,7 +1285,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
resultName = std::string(formatv("resultType{0}", i));
StringRef type =
result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
result.isVariadic() ? "std::vector<::builder::Type>" : "::builder::Type";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (result.isOptional())
properties = OpMethodParameter::PP_Optional;
@ -1512,7 +1322,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
// for APFloat.
// TODO: Adjust the 'returnType' field of such attributes
// to support them.
StringRef retType = namedAttr->attr.getReturnType();
StringRef retType = getReturnType(namedAttr->attr);
if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
break;
@ -1525,7 +1335,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
StringRef type =
operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
operand.isVariadic() ? "std::vector<::builder::Op>" : "::builder::Op";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (operand.isOptional())
properties = OpMethodParameter::PP_Optional;
@ -1544,13 +1354,13 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
StringRef type;
switch (attrParamKind) {
case AttrParamKind::WrappedAttr:
type = attr.getStorageType();
type = getStorageType(attr);
break;
case AttrParamKind::UnwrappedValue:
if (canUseUnwrappedRawValue(attr))
type = attr.getReturnType();
type = getReturnType(attr);
else
type = attr.getStorageType();
type = getStorageType(attr);
break;
}
@ -1558,7 +1368,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
// Attach default value if requested and possible.
if (attrParamKind == AttrParamKind::UnwrappedValue &&
i >= defaultValuedAttrStartIndex) {
bool isString = attr.getReturnType() == "::llvm::StringRef";
bool isString = getReturnType(attr) == "::llvm::StringRef";
if (isString)
defaultValue.append("\"");
defaultValue += attr.getDefaultValue();
@ -1584,84 +1394,117 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
llvm::formatv("{0}Count", region.name).str());
}
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
OpMethodBody &body, llvm::SmallVector<OpMethodParameter, 4> paramList,
bool isRawValueAttr) {
// Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
std::string argName = getArgumentName(op, i);
if (op.getOperand(i).isOptional())
body << " if (" << argName << ")\n ";
body << " " << builderOpState << ".addOperands(" << argName << ");\n";
}
// If the operation has the operand segment size attribute, add it here.
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
body << " " << builderOpState
<< ".addAttribute(\"operand_segment_sizes\", "
"odsBuilder.getI32VectorAttr({";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
if (op.getOperand(i).isOptional())
body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
else if (op.getOperand(i).isVariadic())
body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
else
body << "1";
body << "// AAAAAA \n";
// for(auto p : paramList)
// {
// body <<"==== type:"<< p.getType() << " name:"<< p.getName() << "\n";
// }
body << " auto builder = builder.GetImpl();\n";
body << " auto loc = builder->GetLoc();\n";
body << " auto opBuilder = builder->GetBuilder();\n";
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
<< " currentOp =\n";
body << " opBuilder.create<mlir::" << op.getDialectName()
<< "::" << op.getDialectName() << ">(\n";
if (paramList.size() > 1) {
body << " loc,\n";
std::for_each(paramList.begin() + 1, paramList.end() - 1,
[&](OpMethodParameter &p) {
body << " " << p.getName() << ",\n";
});
body << "}));\n";
}
// Push all attributes to the result.
for (const auto &namedAttr : op.getAttributes()) {
auto &attr = namedAttr.attr;
if (!attr.isDerivedAttr()) {
bool emitNotNullCheck = attr.isOptional();
if (emitNotNullCheck) {
body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
}
if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
// If this is a raw value, then we need to wrap it in an Attribute
// instance.
FmtContext fctx;
fctx.withBuilder("odsBuilder");
std::string builderTemplate =
std::string(attr.getConstBuilderTemplate());
// For StringAttr, its constant builder call will wrap the input in
// quotes, which is correct for normal string literals, but incorrect
// here given we use function arguments. So we need to strip the
// wrapping quotes.
if (StringRef(builderTemplate).contains("\"$0\""))
builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
std::string value =
std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
namedAttr.name, value);
body << " " << paramList.back().getName() << "\n";
} else {
body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
namedAttr.name);
}
if (emitNotNullCheck) {
body << " }\n";
}
}
body << " loc\n";
}
body << " );\n";
body << " builder::" << op.getCppClassName() << " builderOp;\n";
body << " auto opImpl = builderOp.GetImpl();\n";
body << " opImpl.SetOperation(currentOp.getOperation());\n";
body << " return builderOp;\n";
// Create the correct number of regions.
for (const NamedRegion &region : op.getRegions()) {
if (region.isVariadic())
body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
region.name);
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Push all successors to the result.
for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
namedSuccessor.name);
}
// // Push all operands to the result.
// for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
// std::string argName = getArgumentName(op, i);
// if (op.getOperand(i).isOptional())
// body << " if (" << argName << ")\n ";
// body << " " << builderOpState << ".addOperands(" << argName << ");\n";
// }
// // If the operation has the operand segment size attribute, add it here.
// if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
// body << " " << builderOpState
// << ".addAttribute(\"operand_segment_sizes\", "
// "odsBuilder.getI32VectorAttr({";
// interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
// if (op.getOperand(i).isOptional())
// body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
// else if (op.getOperand(i).isVariadic())
// body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
// else
// body << "1";
// });
// body << "}));\n";
// }
// // Push all attributes to the result.
// for (const auto &namedAttr : op.getAttributes()) {
// auto &attr = namedAttr.attr;
// if (!attr.isDerivedAttr()) {
// bool emitNotNullCheck = attr.isOptional();
// if (emitNotNullCheck) {
// body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
// }
// if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
// // If this is a raw value, then we need to wrap it in an Attribute
// // instance.
// FmtContext fctx;
// fctx.withBuilder("odsBuilder");
// std::string builderTemplate =
// std::string(attr.getConstBuilderTemplate());
// // For StringAttr, its constant builder call will wrap the input in
// // quotes, which is correct for normal string literals, but incorrect
// // here given we use function arguments. So we need to strip the
// // wrapping quotes.
// if (StringRef(builderTemplate).contains("\"$0\""))
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
// std::string value =
// std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
// body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
// namedAttr.name, value);
// } else {
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
// namedAttr.name);
// }
// if (emitNotNullCheck) {
// body << " }\n";
// }
// }
// }
// // Create the correct number of regions.
// for (const NamedRegion &region : op.getRegions()) {
// if (region.isVariadic())
// body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
// region.name);
// body << " (void)" << builderOpState << ".addRegion();\n";
// }
// // Push all successors to the result.
// for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
// body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
// namedSuccessor.name);
// }
}
void OpEmitter::genCanonicalizerDecls() {
@ -2305,15 +2148,15 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
auto emitAttr = [&](StringRef name, Attribute attr) {
auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
auto &body = adaptor.addMethodAndPrune(getStorageType(attr), name)->body();
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
<< "\n " << attr.getStorageType() << " attr = "
<< "\n " << getStorageType(attr) << " attr = "
<< "odsAttrs.get(\"" << name << "\").";
if (attr.hasDefaultValue() || attr.isOptional())
body << "dyn_cast_or_null<";
else
body << "cast<";
body << attr.getStorageType() << ">();\n";
body << getStorageType(attr) << ">();\n";
if (attr.hasDefaultValue()) {
// Use the default value if attribute is not set.
@ -2442,11 +2285,11 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
Operator op(*def);
NamespaceEmitter emitter(os, op.getCppNamespace());
if (emitDecl) {
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
// os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
// OpOperandAdaptorEmitter::emitDecl(op, os);
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
} else {
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
// os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
// OpOperandAdaptorEmitter::emitDef(op, os);
OpEmitter::emitDef(op, os, staticVerifierEmitter);
}
@ -2456,12 +2299,18 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
// Emits a comma-separated list of the ops.
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os);
interleave(
// TODO: We are constructing the Operator wrapper instance just for
// getting it's qualified class name here. Reduce the overhead by having a
// lightweight version of Operator class just for that purpose.
defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
defs,
[&os](Record *def) {
SmallVector<StringRef, 4> namespaces;
std::string className = Operator(def).getQualCppClassName();
llvm::SplitString(StringRef(className), namespaces, StringRef("::"));
if (namespaces.begin() != namespaces.end())
os << "::builder::mhlo::" << namespaces.back().str();
},
[&os]() { os << ",\n"; });
}

View File

@ -0,0 +1,30 @@
#ifndef BUILDER_BUILDERIMPL_
#define BUILDER_BUILDERIMPL_
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Builder::Impl {
public:
Impl() {}
// mlir::Location GetLoc() { return mlir_loc_; }
// mlir::OpBuilder GetBuilder() { return mlir_builder_; }
mlir::MLIRContext *GetContext() { return &mlir_context_; }
private:
// mlir::Location mlir_loc_;
// mlir::OpBuilder mlir_builder_;
mlir::MLIRContext mlir_context_;
};
} // namespace builder
#endif

View File

@ -0,0 +1,16 @@
#ifndef BUILDER_OP_
#define BUILDER_OP_
#include "iostream"
namespace builder {
class Op {
class Impl;
std::shared_ptr<Impl> GetImpl() { return _impl; }
private:
std::shared_ptr<Impl> _impl;
}
} // namespace builder
#endif

View File

@ -0,0 +1,26 @@
#ifndef BUILDER_OPIMPL_
#define BUILDER_OPIMPL_
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Op::Impl {
public:
Impl() = default;
void SetOperation(Operation *Op) { op_ = Op; }
private:
Operation *op_;
};
} // namespace builder
#endif

View File

@ -49,13 +49,23 @@ public:
~NamespaceEmitter() {
for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
if (ns.equals("mlir"))
os << "} // namespace "
<< "builder"
<< "\n\n";
else
os << "} // namespace " << ns << "\n\n";
}
private:
void emitNamespaceStarts(raw_ostream &os, StringRef cppNamespace) {
llvm::SplitString(cppNamespace, namespaces, "::");
for (StringRef ns : namespaces)
if (ns.equals("mlir"))
os << "namespace "
<< "builder"
<< " {\n";
else
os << "namespace " << ns << " {\n";
}
raw_ostream &os;

View File

@ -261,7 +261,7 @@ void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
}
void Class::writeDeclTo(raw_ostream &os) const {
bool hasPrivateMethod = false;
os << "class " << className << " {\n";
os << "class " << className << "\n";
os << "public:\n";
forAllMethods([&](const OpMethod &method) {
@ -311,10 +311,10 @@ void OpClass::addTrait(Twine trait) {
}
void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public ::mlir::Op<" << className;
os << "class " << className;
for (const auto &trait : traitsVec)
os << ", " << trait;
os << "> {\npublic:\n";
os << " : public Op {\npublic:\n";
// << " using Op::Op;\n"
// << " using Op::print;\n"
// << " using Adaptor = " << className << "Adaptor;\n";
@ -330,8 +330,8 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
});
// TODO: Add line control markers to make errors easier to debug.
if (!extraClassDeclaration.empty())
os << extraClassDeclaration << "\n";
// if (!extraClassDeclaration.empty())
// os << extraClassDeclaration << "\n";
if (hasPrivateMethod) {
os << "\nprivate:\n";

View File

@ -61,6 +61,8 @@ public:
void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
const std::string &getType() const { return type; }
const std::string &getName() const { return name; }
bool hasDefaultValue() const { return !defaultValue.empty(); }
private:

View File

@ -0,0 +1,16 @@
#ifndef BUILDER_TENSOR_
#define BUILDER_TENSOR_
#include "iostream"
namespace builder{
class Tensor{
}
}
#endif

View File

@ -0,0 +1,10 @@
#ifndef BUILDER_TYPE_
#define BUILDER_TYPE_
#include "iostream"
namespace builder {
class Type {}
} // namespace builder
#endif