modify code gen of build function

This commit is contained in:
colin.liang 2021-08-04 20:24:07 +08:00
parent 898eb732de
commit 9d5166684c
13 changed files with 437 additions and 430 deletions

2
BUILD
View File

@ -133,7 +133,7 @@ cc_binary(
deps = [ deps = [
"@llvm-project//mlir:MlirTableGenMain", "@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
# "@llvm-project//mlir:TableGen", "@llvm-project//mlir:IR",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen", "@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config", "@llvm-project//llvm:config",

View File

@ -0,0 +1,10 @@
#ifndef BUILDER_ARRAY_
#define BUILDER_ARRAY_
#include "iostream"
namespace builder {
class Array {}
} // namespace builder
#endif

View File

@ -0,0 +1,17 @@
#include "Builder.h"
#include "BuilderImpl.h"
#include "llvm/Support/Casting.h"
// #include "mlir/Dialect/StandardOps/Ops.h"
// #include "mlir/IR/Attributes.h"
// #include "mlir/IR/Operation.h"
// #include "mlir/IR/StandardTypes.h"
// #include "mlir/IR/Types.h"
// #include "mlir/IR/Value.h"
namespace builder {
Builder::Builder() : _impl(std::make_shared<Impl>()) {}
void Builder::DumpModule() {}
} // namespace builder

View File

@ -0,0 +1,21 @@
#ifndef BUILDER_BUILDER_
#define BUILDER_BUILDER_
#include <memory>
namespace builder {
class Builder {
public:
class Impl;
Builder();
void DumpModule();
std::shared_ptr<Impl> GetImpl() { return _impl; }
private:
std::shared_ptr<Impl> _impl;
};
} // namespace builder
#endif

View File

@ -29,6 +29,8 @@
#include "llvm/TableGen/Record.h" #include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h" #include "llvm/TableGen/TableGenBackend.h"
#include "iostream"
#define DEBUG_TYPE "mlir-tblgen-opdefgen" #define DEBUG_TYPE "mlir-tblgen-opdefgen"
using namespace llvm; using namespace llvm;
@ -40,6 +42,40 @@ static const char *const generatedArgName = "odsArg";
static const char *const odsBuilder = "odsBuilder"; static const char *const odsBuilder = "odsBuilder";
static const char *const builderOpState = "odsState"; static const char *const builderOpState = "odsState";
static const std::map<std::string, std::string> typeMapMLIR = {
{"::mlir::StringAttr", "std::string"},
{"::mlir::IntegerAttr", "int"},
{"::mlir::DenseIntElementsAttr", "std::vector<int>"},
{"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
{"::mlir::FloatAttr", "float"},
{"::mlir::BoolAttr", "bool"},
{"::mlir::ElementsAttr", "::builder::Array"},
{"::mlir::DenseElementsAttr", "::builder::Tensor"},
// current only support string array.
{"::mlir::ArrayAttr", "std::vector<std::string>"},
{"::mlir::mhlo::ConvDimensionNumbers", "::builder::ConvDimensionNumbers"},
{"::mlir::mhlo::DotDimensionNumbers", "::builder::DotDimensionNumbers"},
{"::mlir::mhlo::GatherDimensionNumbers",
"::builder::GatherDimensionNumbers"},
{"::mlir::mhlo::ScatterDimensionNumbers",
"::builder::ScatterDimensionNumbers"},
};
StringRef typeConvertFromMLIR(StringRef type) {
auto re = typeMapMLIR.find(type.str());
if (re != typeMapMLIR.end()) return StringRef(re->second);
return type;
}
StringRef getStorageType(const Attribute &att) {
auto type = att.getStorageType();
return typeConvertFromMLIR(type);
}
StringRef getReturnType(const Attribute &att) {
auto type = att.getStorageType();
return typeConvertFromMLIR(type);
}
// The logic to calculate the actual value range for a declared operand/result // The logic to calculate the actual value range for a declared operand/result
// of an op with variadic operands/results. Note that this logic is not for // of an op with variadic operands/results. Note that this logic is not for
// general use; it assumes all variadic operands/results must have the same // general use; it assumes all variadic operands/results must have the same
@ -173,11 +209,11 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
: uniqueOutputLabel(getUniqueName(records)) { : uniqueOutputLabel(getUniqueName(records)) {
llvm::Optional<NamespaceEmitter> namespaceEmitter; llvm::Optional<NamespaceEmitter> namespaceEmitter;
if (!emitDecl) { if (!emitDecl) {
os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); // os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace()); namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
} }
emitTypeConstraintMethods(opDefs, os, emitDecl); // emitTypeConstraintMethods(opDefs, os, emitDecl);
} }
std::string StaticVerifierFunctionEmitter::getUniqueName( std::string StaticVerifierFunctionEmitter::getUniqueName(
@ -278,7 +314,7 @@ static std::string getArgumentName(const Operator &op, int index) {
// Returns true if we can use unwrapped value for the given `attr` in builders. // Returns true if we can use unwrapped value for the given `attr` in builders.
static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
return attr.getReturnType() != attr.getStorageType() && return getReturnType(attr) != getStorageType(attr) &&
// We need to wrap the raw value into an attribute in the builder impl // We need to wrap the raw value into an attribute in the builder impl
// so we need to make sure that the attribute specifies how to do that. // so we need to make sure that the attribute specifies how to do that.
!attr.getConstBuilderTemplate().empty(); !attr.getConstBuilderTemplate().empty();
@ -346,12 +382,12 @@ private:
// Generates the build() method that takes each operand/attribute as a // Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first operand's // stand-alone parameter. The generated build() method uses first operand's
// type as all results' types. // type as all results' types.
void genUseOperandAsResultTypeSeparateParamBuilder(); // void genUseOperandAsResultTypeSeparateParamBuilder();
// Generates the build() method that takes all operands/attributes // Generates the build() method that takes all operands/attributes
// collectively as one parameter. The generated build() method uses first // collectively as one parameter. The generated build() method uses first
// operand's type as all results' types. // operand's type as all results' types.
void genUseOperandAsResultTypeCollectiveParamBuilder(); // void genUseOperandAsResultTypeCollectiveParamBuilder();
// Generates the build() method that takes aggregate operands/attributes // Generates the build() method that takes aggregate operands/attributes
// parameters. This build() method uses inferred types as result types. // parameters. This build() method uses inferred types as result types.
@ -361,11 +397,11 @@ private:
// Generates the build() method that takes each operand/attribute as a // Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first attribute's // stand-alone parameter. The generated build() method uses first attribute's
// type as all result's types. // type as all result's types.
void genUseAttrAsResultTypeBuilder(); // void genUseAttrAsResultTypeBuilder();
// Generates the build() method that takes all result types collectively as // Generates the build() method that takes all result types collectively as
// one parameter. Similarly for operands and attributes. // one parameter. Similarly for operands and attributes.
void genCollectiveParamBuilder(); // void genCollectiveParamBuilder();
// The kind of parameter to generate for result types in builders. // The kind of parameter to generate for result types in builders.
enum class TypeParamKind { enum class TypeParamKind {
@ -391,7 +427,7 @@ private:
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
// Adds op arguments and regions into operation state for build() methods. // Adds op arguments and regions into operation state for build() methods.
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,llvm::SmallVector<OpMethodParameter, 4> paramList,
bool isRawValueAttr = false); bool isRawValueAttr = false);
// Generates canonicalizer declaration for the operation. // Generates canonicalizer declaration for the operation.
@ -588,24 +624,24 @@ OpEmitter::OpEmitter(const Operator &op,
// methods in the generated file. // methods in the generated file.
// genOpAsmInterface(); // genOpAsmInterface();
genOpNameGetter(); genOpNameGetter();
genNamedOperandGetters(); // genNamedOperandGetters();
genNamedOperandSetters(); // genNamedOperandSetters();
genNamedResultGetters(); // genNamedResultGetters();
genNamedRegionGetters(); // genNamedRegionGetters();
genNamedSuccessorGetters(); genNamedSuccessorGetters();
genAttrGetters(); genAttrGetters();
genAttrSetters(); genAttrSetters();
genOptionalAttrRemovers(); // genOptionalAttrRemovers();
genBuilder(); genBuilder();
genParser(); // genParser();
genPrinter(); // genPrinter();
genVerifier(); // genVerifier();
genCanonicalizerDecls(); // genCanonicalizerDecls();
genFolderDecls(); // genFolderDecls();
genTypeInterfaceMethods(); // genTypeInterfaceMethods();
genOpInterfaceMethods(); // genOpInterfaceMethods();
generateOpFormat(op, opClass); // generateOpFormat(op, opClass);
genSideEffectInterfaceMethods(); // genSideEffectInterfaceMethods();
} }
void OpEmitter::emitDecl( void OpEmitter::emitDecl(
@ -631,7 +667,7 @@ void OpEmitter::genAttrGetters() {
Dialect opDialect = op.getDialect(); Dialect opDialect = op.getDialect();
// Emit the derived attribute body. // Emit the derived attribute body.
auto emitDerivedAttr = [&](StringRef name, Attribute attr) { auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); auto *method = opClass.addMethodAndPrune(getReturnType(attr), name);
if (!method) if (!method)
return; return;
auto &body = method->body(); auto &body = method->body();
@ -640,7 +676,7 @@ void OpEmitter::genAttrGetters() {
// Emit with return type specified. // Emit with return type specified.
auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) { auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); auto *method = opClass.addMethodAndPrune(getReturnType(attr), name);
auto &body = method->body(); auto &body = method->body();
body << " auto attr = " << name << "Attr();\n"; body << " auto attr = " << name << "Attr();\n";
if (attr.hasDefaultValue()) { if (attr.hasDefaultValue()) {
@ -664,7 +700,7 @@ void OpEmitter::genAttrGetters() {
// the string interface for better compile time verification. // the string interface for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto *method = auto *method =
opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str()); opClass.addMethodAndPrune(getStorageType(attr), (name + "Attr").str());
if (!method) if (!method)
return; return;
auto &body = method->body(); auto &body = method->body();
@ -673,7 +709,7 @@ void OpEmitter::genAttrGetters() {
body << "dyn_cast_or_null<"; body << "dyn_cast_or_null<";
else else
body << "cast<"; body << "cast<";
body << attr.getStorageType() << ">();"; body << getStorageType(attr) << ">();";
}; };
for (auto &namedAttr : op.getAttributes()) { for (auto &namedAttr : op.getAttributes()) {
@ -759,7 +795,7 @@ void OpEmitter::genAttrSetters() {
// for better compile time verification. // for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(), auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
attr.getStorageType(), "attr"); getStorageType(attr), "attr");
if (!method) if (!method)
return; return;
auto &body = method->body(); auto &body = method->body();
@ -1070,256 +1106,89 @@ static bool canInferType(Operator &op) {
void OpEmitter::genSeparateArgParamBuilder() { void OpEmitter::genSeparateArgParamBuilder() {
SmallVector<AttrParamKind, 2> attrBuilderType; SmallVector<AttrParamKind, 2> attrBuilderType;
attrBuilderType.push_back(AttrParamKind::WrappedAttr); attrBuilderType.push_back(AttrParamKind::WrappedAttr);
if (canGenerateUnwrappedBuilder(op)) // if (canGenerateUnwrappedBuilder(op))
attrBuilderType.push_back(AttrParamKind::UnwrappedValue); // attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
// Emit with separate builders with or without unwrapped attributes and/or // Emit with separate builders with or without unwrapped attributes and/or
// inferring result type. // inferring result type.
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
bool inferType) { bool inferType) {
llvm::SmallVector<OpMethodParameter, 4> paramList; llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<OpMethodParameter, 4> paramList2;
llvm::SmallVector<std::string, 4> resultNames; llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, paramKind, attrType); buildParamList(paramList, resultNames, paramKind, attrType);
buildParamList(paramList2, resultNames, paramKind, attrType);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, auto *m = opClass.addMethodAndPrune(
std::move(paramList)); "::builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
// If the builder is redundant, skip generating the method. // If the builder is redundant, skip generating the method.
if (!m) if (!m)
return; return;
auto &body = m->body(); auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder( genCodeForAddingArgAndRegionForBuilder(
body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); body, paramList2, attrType == AttrParamKind::UnwrappedValue);
// Push all result types to the operation state // Push all result types to the operation state
//"BBBBBBBBBBBB"
// if (inferType) {
// // Generate builder that infers type too.
// // TODO: Subsume this with general checking if type can be
// // inferred automatically.
// // TODO: Expand to handle regions.
// body << formatv(R"(
// ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
// if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
// {1}.location, {1}.operands,
// {1}.attributes.getDictionary({1}.getContext()),
// /*regions=*/{{}, inferredReturnTypes)))
// {1}.addTypes(inferredReturnTypes);
// else
// ::llvm::report_fatal_error("Failed to infer result type(s).");)",
// opClass.getClassName(), builderOpState);
// return;
// }
if (inferType) { // switch (paramKind) {
// Generate builder that infers type too. // case TypeParamKind::None:
// TODO: Subsume this with general checking if type can be // return;
// inferred automatically. // case TypeParamKind::Separate:
// TODO: Expand to handle regions. // for (int i = 0, e = op.getNumResults(); i < e; ++i) {
body << formatv(R"( // if (op.getResult(i).isOptional())
::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; // body << " if (" << resultNames[i] << ")\n ";
if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), // body << " " << builderOpState << ".addTypes(" << resultNames[i]
{1}.location, {1}.operands, // << ");\n";
{1}.attributes.getDictionary({1}.getContext()), // }
/*regions=*/{{}, inferredReturnTypes))) // return;
{1}.addTypes(inferredReturnTypes); // case TypeParamKind::Collective: {
else // int numResults = op.getNumResults();
::llvm::report_fatal_error("Failed to infer result type(s).");)", // int numVariadicResults = op.getNumVariableLengthResults();
opClass.getClassName(), builderOpState); // int numNonVariadicResults = numResults - numVariadicResults;
return; // bool hasVariadicResult = numVariadicResults != 0;
}
switch (paramKind) { // // Avoid emitting "resultTypes.size() >= 0u" which is always true.
case TypeParamKind::None: // if (!(hasVariadicResult && numNonVariadicResults == 0))
return; // body << " "
case TypeParamKind::Separate: // << "assert(resultTypes.size() "
for (int i = 0, e = op.getNumResults(); i < e; ++i) { // << (hasVariadicResult ? ">=" : "==") << " "
if (op.getResult(i).isOptional()) // << numNonVariadicResults
body << " if (" << resultNames[i] << ")\n "; // << "u && \"mismatched number of results\");\n";
body << " " << builderOpState << ".addTypes(" << resultNames[i] // body << " " << builderOpState << ".addTypes(resultTypes);\n";
<< ");\n"; // }
} // return;
return; // }
case TypeParamKind::Collective: { // llvm_unreachable("unhandled TypeParamKind");
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
int numNonVariadicResults = numResults - numVariadicResults;
bool hasVariadicResult = numVariadicResults != 0;
// Avoid emitting "resultTypes.size() >= 0u" which is always true.
if (!(hasVariadicResult && numNonVariadicResults == 0))
body << " "
<< "assert(resultTypes.size() "
<< (hasVariadicResult ? ">=" : "==") << " "
<< numNonVariadicResults
<< "u && \"mismatched number of results\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
}
return;
}
llvm_unreachable("unhandled TypeParamKind");
}; };
// Some of the build methods generated here may be ambiguous, but TableGen's // Some of the build methods generated here may be ambiguous, but TableGen's
// ambiguous function detection will elide those ones. // ambiguous function detection will elide those ones.
for (auto attrType : attrBuilderType) { for (auto attrType : attrBuilderType) {
emit(attrType, TypeParamKind::Separate, /*inferType=*/false); emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
if (canInferType(op)) // if (canInferType(op))
emit(attrType, TypeParamKind::None, /*inferType=*/true); // emit(attrType, TypeParamKind::None, /*inferType=*/true);
emit(attrType, TypeParamKind::Collective, /*inferType=*/false); // emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
} }
} }
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
int numResults = op.getNumResults();
// Signature
llvm::SmallVector<OpMethodParameter, 4> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
// Provide default value for `attributes` when its the last parameter
StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", attributesDefaultValue);
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
// Operands
body << " " << builderOpState << ".addOperands(operands);\n";
// Attributes
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
body << llvm::formatv(
" for (unsigned i = 0; i != {0}; ++i)\n",
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
body << " " << builderOpState << ".addTypes({"
<< llvm::join(resultTypes, ", ") << "});\n\n";
}
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
// TODO: Expand to support regions.
SmallVector<OpMethodParameter, 4> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
int numNonVariadicResults = numResults - numVariadicResults;
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariableLengthOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
// Operands
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
body << " assert(operands.size()"
<< (numVariadicOperands != 0 ? " >= " : " == ")
<< numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n";
body << " " << builderOpState << ".addOperands(operands);\n";
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
body << llvm::formatv(
" for (unsigned i = 0; i != {0}; ++i)\n",
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
body << formatv(R"(
::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
{1}.location, operands,
{1}.attributes.getDictionary({1}.getContext()),
/*regions=*/{{}, inferredReturnTypes))) {{)",
opClass.getClassName(), builderOpState);
if (numVariadicResults == 0 || numNonVariadicResults != 0)
body << " assert(inferredReturnTypes.size()"
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
<< "u && \"mismatched number of return types\");\n";
body << " " << builderOpState << ".addTypes(inferredReturnTypes);";
body << formatv(R"(
} else
::llvm::report_fatal_error("Failed to infer result type(s).");)",
opClass.getClassName(), builderOpState);
}
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, TypeParamKind::None);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder(body);
auto numResults = op.getNumResults();
if (numResults == 0)
return;
// Push all result types to the operation state
const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
std::string resultType =
formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
body << " " << builderOpState << ".addTypes({" << resultType;
for (int i = 1; i != numResults; ++i)
body << ", " << resultType;
body << "});\n\n";
}
void OpEmitter::genUseAttrAsResultTypeBuilder() {
SmallVector<OpMethodParameter, 4> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
// Push all result types to the operation state
std::string resultType;
const auto &namedAttr = op.getAttribute(0);
body << " for (auto attr : attributes) {\n";
body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n";
if (namedAttr.attr.isTypeAttr()) {
resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
} else {
resultType = "attr.second.getType()";
}
// Operands
body << " " << builderOpState << ".addOperands(operands);\n";
// Attributes
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Result types
SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
body << " " << builderOpState << ".addTypes({"
<< llvm::join(resultTypes, ", ") << "});\n";
body << " }\n";
}
/// Returns a signature of the builder. Updates the context `fctx` to enable /// Returns a signature of the builder. Updates the context `fctx` to enable
/// replacement of $_builder and $_state in the body. /// replacement of $_builder and $_state in the body.
static std::string getBuilderSignature(const Builder &builder) { static std::string getBuilderSignature(const Builder &builder) {
@ -1352,21 +1221,21 @@ static std::string getBuilderSignature(const Builder &builder) {
void OpEmitter::genBuilder() { void OpEmitter::genBuilder() {
// Handle custom builders if provided. // Handle custom builders if provided.
for (const Builder &builder : op.getBuilders()) { // for (const Builder &builder : op.getBuilders()) {
std::string paramStr = getBuilderSignature(builder); // std::string paramStr = getBuilderSignature(builder);
Optional<StringRef> body = builder.getBody(); // Optional<StringRef> body = builder.getBody();
OpMethod::Property properties = // OpMethod::Property properties =
body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; // body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
auto *method = // auto *method =
opClass.addMethodAndPrune("void", "build", properties, paramStr); // opClass.addMethodAndPrune("void", "build", properties, paramStr);
FmtContext fctx; // FmtContext fctx;
fctx.withBuilder(odsBuilder); // fctx.withBuilder(odsBuilder);
fctx.addSubst("_state", builderOpState); // fctx.addSubst("_state", builderOpState);
if (body) // if (body)
method->body() << tgfmt(*body, &fctx); // method->body() << tgfmt(*body, &fctx);
} // }
// Generate default builders that requires all result type, operands, and // Generate default builders that requires all result type, operands, and
// attributes as parameters. // attributes as parameters.
@ -1378,78 +1247,18 @@ void OpEmitter::genBuilder() {
genSeparateArgParamBuilder(); genSeparateArgParamBuilder();
// 2. one having an aggregated parameter for all result types / operands / // 2. one having an aggregated parameter for all result types / operands /
// attributes, and // attributes, and
genCollectiveParamBuilder(); // genCollectiveParamBuilder();
// 3. one having a stand-alone parameter for each operand and attribute, // // 3. one having a stand-alone parameter for each operand and attribute,
// use the first operand or attribute's type as all result types // // use the first operand or attribute's type as all result types
// to facilitate different call patterns. // // to facilitate different call patterns.
if (op.getNumVariableLengthResults() == 0) { // if (op.getNumVariableLengthResults() == 0) {
if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { // if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
genUseOperandAsResultTypeSeparateParamBuilder(); // genUseOperandAsResultTypeSeparateParamBuilder();
genUseOperandAsResultTypeCollectiveParamBuilder(); // genUseOperandAsResultTypeCollectiveParamBuilder();
} // }
if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) // if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
genUseAttrAsResultTypeBuilder(); // genUseAttrAsResultTypeBuilder();
} // }
}
void OpEmitter::genCollectiveParamBuilder() {
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
int numNonVariadicResults = numResults - numVariadicResults;
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariableLengthOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
SmallVector<OpMethodParameter, 4> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::TypeRange", "resultTypes");
paramList.emplace_back("::mlir::ValueRange", "operands");
// Provide default value for `attributes` when its the last parameter
StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", attributesDefaultValue);
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
auto &body = m->body();
// Operands
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
body << " assert(operands.size()"
<< (numVariadicOperands != 0 ? " >= " : " == ")
<< numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n";
body << " " << builderOpState << ".addOperands(operands);\n";
// Attributes
body << " " << builderOpState << ".addAttributes(attributes);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
body << llvm::formatv(
" for (unsigned i = 0; i != {0}; ++i)\n",
(op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
body << " (void)" << builderOpState << ".addRegion();\n";
}
// Result types
if (numVariadicResults == 0 || numNonVariadicResults != 0)
body << " assert(resultTypes.size()"
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
<< "u && \"mismatched number of return types\");\n";
body << " " << builderOpState << ".addTypes(resultTypes);\n";
// Generate builder that infers type too.
// TODO: Expand to handle regions and successors.
if (canInferType(op) && op.getNumSuccessors() == 0)
genInferredTypeCollectiveParamBuilder();
} }
void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList, void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
@ -1460,8 +1269,9 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
auto numResults = op.getNumResults(); auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults); resultTypeNames.reserve(numResults);
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); // paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::builder::Builder &", "builder");
// paramList.emplace_back("::mlir::OperationState &", builderOpState);
switch (typeParamKind) { switch (typeParamKind) {
case TypeParamKind::None: case TypeParamKind::None:
@ -1475,7 +1285,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
resultName = std::string(formatv("resultType{0}", i)); resultName = std::string(formatv("resultType{0}", i));
StringRef type = StringRef type =
result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; result.isVariadic() ? "std::vector<::builder::Type>" : "::builder::Type";
OpMethodParameter::Property properties = OpMethodParameter::PP_None; OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (result.isOptional()) if (result.isOptional())
properties = OpMethodParameter::PP_Optional; properties = OpMethodParameter::PP_Optional;
@ -1512,7 +1322,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
// for APFloat. // for APFloat.
// TODO: Adjust the 'returnType' field of such attributes // TODO: Adjust the 'returnType' field of such attributes
// to support them. // to support them.
StringRef retType = namedAttr->attr.getReturnType(); StringRef retType = getReturnType(namedAttr->attr);
if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
break; break;
@ -1525,7 +1335,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
if (argument.is<tblgen::NamedTypeConstraint *>()) { if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands); const auto &operand = op.getOperand(numOperands);
StringRef type = StringRef type =
operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value"; operand.isVariadic() ? "std::vector<::builder::Op>" : "::builder::Op";
OpMethodParameter::Property properties = OpMethodParameter::PP_None; OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (operand.isOptional()) if (operand.isOptional())
properties = OpMethodParameter::PP_Optional; properties = OpMethodParameter::PP_Optional;
@ -1544,13 +1354,13 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
StringRef type; StringRef type;
switch (attrParamKind) { switch (attrParamKind) {
case AttrParamKind::WrappedAttr: case AttrParamKind::WrappedAttr:
type = attr.getStorageType(); type = getStorageType(attr);
break; break;
case AttrParamKind::UnwrappedValue: case AttrParamKind::UnwrappedValue:
if (canUseUnwrappedRawValue(attr)) if (canUseUnwrappedRawValue(attr))
type = attr.getReturnType(); type = getReturnType(attr);
else else
type = attr.getStorageType(); type = getStorageType(attr);
break; break;
} }
@ -1558,7 +1368,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
// Attach default value if requested and possible. // Attach default value if requested and possible.
if (attrParamKind == AttrParamKind::UnwrappedValue && if (attrParamKind == AttrParamKind::UnwrappedValue &&
i >= defaultValuedAttrStartIndex) { i >= defaultValuedAttrStartIndex) {
bool isString = attr.getReturnType() == "::llvm::StringRef"; bool isString = getReturnType(attr) == "::llvm::StringRef";
if (isString) if (isString)
defaultValue.append("\""); defaultValue.append("\"");
defaultValue += attr.getDefaultValue(); defaultValue += attr.getDefaultValue();
@ -1584,84 +1394,117 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
llvm::formatv("{0}Count", region.name).str()); llvm::formatv("{0}Count", region.name).str());
} }
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
bool isRawValueAttr) { OpMethodBody &body, llvm::SmallVector<OpMethodParameter, 4> paramList,
// Push all operands to the result. bool isRawValueAttr) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
std::string argName = getArgumentName(op, i); body << "// AAAAAA \n";
if (op.getOperand(i).isOptional()) // for(auto p : paramList)
body << " if (" << argName << ")\n "; // {
body << " " << builderOpState << ".addOperands(" << argName << ");\n"; // body <<"==== type:"<< p.getType() << " name:"<< p.getName() << "\n";
// }
body << " auto builder = builder.GetImpl();\n";
body << " auto loc = builder->GetLoc();\n";
body << " auto opBuilder = builder->GetBuilder();\n";
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
<< " currentOp =\n";
body << " opBuilder.create<mlir::" << op.getDialectName()
<< "::" << op.getDialectName() << ">(\n";
if (paramList.size() > 1) {
body << " loc,\n";
std::for_each(paramList.begin() + 1, paramList.end() - 1,
[&](OpMethodParameter &p) {
body << " " << p.getName() << ",\n";
});
body << " " << paramList.back().getName() << "\n";
} else {
body << " loc\n";
} }
body << " );\n";
body << " builder::" << op.getCppClassName() << " builderOp;\n";
body << " auto opImpl = builderOp.GetImpl();\n";
body << " opImpl.SetOperation(currentOp.getOperation());\n";
body << " return builderOp;\n";
// If the operation has the operand segment size attribute, add it here.
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
body << " " << builderOpState
<< ".addAttribute(\"operand_segment_sizes\", "
"odsBuilder.getI32VectorAttr({";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
if (op.getOperand(i).isOptional())
body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
else if (op.getOperand(i).isVariadic())
body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
else
body << "1";
});
body << "}));\n";
}
// Push all attributes to the result.
for (const auto &namedAttr : op.getAttributes()) {
auto &attr = namedAttr.attr;
if (!attr.isDerivedAttr()) {
bool emitNotNullCheck = attr.isOptional();
if (emitNotNullCheck) {
body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
}
if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
// If this is a raw value, then we need to wrap it in an Attribute
// instance.
FmtContext fctx;
fctx.withBuilder("odsBuilder");
std::string builderTemplate = // // Push all operands to the result.
std::string(attr.getConstBuilderTemplate()); // for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
// std::string argName = getArgumentName(op, i);
// if (op.getOperand(i).isOptional())
// body << " if (" << argName << ")\n ";
// body << " " << builderOpState << ".addOperands(" << argName << ");\n";
// }
// For StringAttr, its constant builder call will wrap the input in // // If the operation has the operand segment size attribute, add it here.
// quotes, which is correct for normal string literals, but incorrect // if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
// here given we use function arguments. So we need to strip the // body << " " << builderOpState
// wrapping quotes. // << ".addAttribute(\"operand_segment_sizes\", "
if (StringRef(builderTemplate).contains("\"$0\"")) // "odsBuilder.getI32VectorAttr({";
builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); // interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
// if (op.getOperand(i).isOptional())
// body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
// else if (op.getOperand(i).isVariadic())
// body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
// else
// body << "1";
// });
// body << "}));\n";
// }
std::string value = // // Push all attributes to the result.
std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); // for (const auto &namedAttr : op.getAttributes()) {
body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, // auto &attr = namedAttr.attr;
namedAttr.name, value); // if (!attr.isDerivedAttr()) {
} else { // bool emitNotNullCheck = attr.isOptional();
body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, // if (emitNotNullCheck) {
namedAttr.name); // body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
} // }
if (emitNotNullCheck) { // if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
body << " }\n"; // // If this is a raw value, then we need to wrap it in an Attribute
} // // instance.
} // FmtContext fctx;
} // fctx.withBuilder("odsBuilder");
// Create the correct number of regions. // std::string builderTemplate =
for (const NamedRegion &region : op.getRegions()) { // std::string(attr.getConstBuilderTemplate());
if (region.isVariadic())
body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
region.name);
body << " (void)" << builderOpState << ".addRegion();\n"; // // For StringAttr, its constant builder call will wrap the input in
} // // quotes, which is correct for normal string literals, but incorrect
// // here given we use function arguments. So we need to strip the
// // wrapping quotes.
// if (StringRef(builderTemplate).contains("\"$0\""))
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
// Push all successors to the result. // std::string value =
for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { // std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
body << formatv(" {0}.addSuccessors({1});\n", builderOpState, // body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
namedSuccessor.name); // namedAttr.name, value);
} // } else {
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
// namedAttr.name);
// }
// if (emitNotNullCheck) {
// body << " }\n";
// }
// }
// }
// // Create the correct number of regions.
// for (const NamedRegion &region : op.getRegions()) {
// if (region.isVariadic())
// body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
// region.name);
// body << " (void)" << builderOpState << ".addRegion();\n";
// }
// // Push all successors to the result.
// for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
// body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
// namedSuccessor.name);
// }
} }
void OpEmitter::genCanonicalizerDecls() { void OpEmitter::genCanonicalizerDecls() {
@ -2200,7 +2043,7 @@ void OpEmitter::genOpNameGetter() {
auto *method = opClass.addMethodAndPrune( auto *method = opClass.addMethodAndPrune(
"std::string", "getOperationName", "std::string", "getOperationName",
OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr)); OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
method->body() << " return std::string(\"" << op.getOperationName() method->body() << " return std::string(\"" << op.getOperationName()
<< "\");"; << "\");";
} }
@ -2305,15 +2148,15 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
auto emitAttr = [&](StringRef name, Attribute attr) { auto emitAttr = [&](StringRef name, Attribute attr) {
auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body(); auto &body = adaptor.addMethodAndPrune(getStorageType(attr), name)->body();
body << " assert(odsAttrs && \"no attributes when constructing adapter\");" body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
<< "\n " << attr.getStorageType() << " attr = " << "\n " << getStorageType(attr) << " attr = "
<< "odsAttrs.get(\"" << name << "\")."; << "odsAttrs.get(\"" << name << "\").";
if (attr.hasDefaultValue() || attr.isOptional()) if (attr.hasDefaultValue() || attr.isOptional())
body << "dyn_cast_or_null<"; body << "dyn_cast_or_null<";
else else
body << "cast<"; body << "cast<";
body << attr.getStorageType() << ">();\n"; body << getStorageType(attr) << ">();\n";
if (attr.hasDefaultValue()) { if (attr.hasDefaultValue()) {
// Use the default value if attribute is not set. // Use the default value if attribute is not set.
@ -2442,11 +2285,11 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
Operator op(*def); Operator op(*def);
NamespaceEmitter emitter(os, op.getCppNamespace()); NamespaceEmitter emitter(os, op.getCppNamespace());
if (emitDecl) { if (emitDecl) {
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); // os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
// OpOperandAdaptorEmitter::emitDecl(op, os); // OpOperandAdaptorEmitter::emitDecl(op, os);
OpEmitter::emitDecl(op, os, staticVerifierEmitter); OpEmitter::emitDecl(op, os, staticVerifierEmitter);
} else { } else {
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); // os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
// OpOperandAdaptorEmitter::emitDef(op, os); // OpOperandAdaptorEmitter::emitDef(op, os);
OpEmitter::emitDef(op, os, staticVerifierEmitter); OpEmitter::emitDef(op, os, staticVerifierEmitter);
} }
@ -2456,12 +2299,18 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
// Emits a comma-separated list of the ops. // Emits a comma-separated list of the ops.
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) { static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os); IfDefScope scope("GET_OP_LIST", os);
interleave( interleave(
// TODO: We are constructing the Operator wrapper instance just for // TODO: We are constructing the Operator wrapper instance just for
// getting it's qualified class name here. Reduce the overhead by having a // getting it's qualified class name here. Reduce the overhead by having a
// lightweight version of Operator class just for that purpose. // lightweight version of Operator class just for that purpose.
defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, defs,
[&os](Record *def) {
SmallVector<StringRef, 4> namespaces;
std::string className = Operator(def).getQualCppClassName();
llvm::SplitString(StringRef(className), namespaces, StringRef("::"));
if (namespaces.begin() != namespaces.end())
os << "::builder::mhlo::" << namespaces.back().str();
},
[&os]() { os << ",\n"; }); [&os]() { os << ",\n"; });
} }

View File

@ -0,0 +1,30 @@
#ifndef BUILDER_BUILDERIMPL_
#define BUILDER_BUILDERIMPL_
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Builder::Impl {
public:
Impl() {}
// mlir::Location GetLoc() { return mlir_loc_; }
// mlir::OpBuilder GetBuilder() { return mlir_builder_; }
mlir::MLIRContext *GetContext() { return &mlir_context_; }
private:
// mlir::Location mlir_loc_;
// mlir::OpBuilder mlir_builder_;
mlir::MLIRContext mlir_context_;
};
} // namespace builder
#endif

View File

@ -0,0 +1,16 @@
#ifndef BUILDER_OP_
#define BUILDER_OP_
#include "iostream"
namespace builder {
class Op {
class Impl;
std::shared_ptr<Impl> GetImpl() { return _impl; }
private:
std::shared_ptr<Impl> _impl;
}
} // namespace builder
#endif

View File

@ -0,0 +1,26 @@
#ifndef BUILDER_OPIMPL_
#define BUILDER_OPIMPL_
#include "Builder.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
namespace builder {
class Op::Impl {
public:
Impl() = default;
void SetOperation(Operation *Op) { op_ = Op; }
private:
Operation *op_;
};
} // namespace builder
#endif

View File

@ -49,14 +49,24 @@ public:
~NamespaceEmitter() { ~NamespaceEmitter() {
for (StringRef ns : llvm::reverse(namespaces)) for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n"; if (ns.equals("mlir"))
os << "} // namespace "
<< "builder"
<< "\n\n";
else
os << "} // namespace " << ns << "\n\n";
} }
private: private:
void emitNamespaceStarts(raw_ostream &os, StringRef cppNamespace) { void emitNamespaceStarts(raw_ostream &os, StringRef cppNamespace) {
llvm::SplitString(cppNamespace, namespaces, "::"); llvm::SplitString(cppNamespace, namespaces, "::");
for (StringRef ns : namespaces) for (StringRef ns : namespaces)
os << "namespace " << ns << " {\n"; if (ns.equals("mlir"))
os << "namespace "
<< "builder"
<< " {\n";
else
os << "namespace " << ns << " {\n";
} }
raw_ostream &os; raw_ostream &os;
SmallVector<StringRef, 2> namespaces; SmallVector<StringRef, 2> namespaces;

View File

@ -209,7 +209,7 @@ void OpMethod::writeDeclTo(raw_ostream &os) const {
else { else {
os << " {\n"; os << " {\n";
methodBody.writeTo(os); methodBody.writeTo(os);
os << "}"; os << " }";
} }
} }
@ -261,7 +261,7 @@ void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
} }
void Class::writeDeclTo(raw_ostream &os) const { void Class::writeDeclTo(raw_ostream &os) const {
bool hasPrivateMethod = false; bool hasPrivateMethod = false;
os << "class " << className << " {\n"; os << "class " << className << "\n";
os << "public:\n"; os << "public:\n";
forAllMethods([&](const OpMethod &method) { forAllMethods([&](const OpMethod &method) {
@ -311,10 +311,10 @@ void OpClass::addTrait(Twine trait) {
} }
void OpClass::writeDeclTo(raw_ostream &os) const { void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public ::mlir::Op<" << className; os << "class " << className;
for (const auto &trait : traitsVec) for (const auto &trait : traitsVec)
os << ", " << trait; os << ", " << trait;
os << "> {\npublic:\n"; os << " : public Op {\npublic:\n";
// << " using Op::Op;\n" // << " using Op::Op;\n"
// << " using Op::print;\n" // << " using Op::print;\n"
// << " using Adaptor = " << className << "Adaptor;\n"; // << " using Adaptor = " << className << "Adaptor;\n";
@ -330,8 +330,8 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
}); });
// TODO: Add line control markers to make errors easier to debug. // TODO: Add line control markers to make errors easier to debug.
if (!extraClassDeclaration.empty()) // if (!extraClassDeclaration.empty())
os << extraClassDeclaration << "\n"; // os << extraClassDeclaration << "\n";
if (hasPrivateMethod) { if (hasPrivateMethod) {
os << "\nprivate:\n"; os << "\nprivate:\n";

View File

@ -61,6 +61,8 @@ public:
void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); } void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
const std::string &getType() const { return type; } const std::string &getType() const { return type; }
const std::string &getName() const { return name; }
bool hasDefaultValue() const { return !defaultValue.empty(); } bool hasDefaultValue() const { return !defaultValue.empty(); }
private: private:

View File

@ -0,0 +1,16 @@
#ifndef BUILDER_TENSOR_
#define BUILDER_TENSOR_
#include "iostream"
namespace builder{
class Tensor{
}
}
#endif

View File

@ -0,0 +1,10 @@
#ifndef BUILDER_TYPE_
#define BUILDER_TYPE_
#include "iostream"
namespace builder {
class Type {}
} // namespace builder
#endif