diff --git a/.gitignore b/.gitignore index 53e8335..a2d5f5a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ build llvm-project llvm-build bazel-* +bazel-bin +.vscode diff --git a/BUILD b/BUILD index adb73c7..9a45be0 100644 --- a/BUILD +++ b/BUILD @@ -122,6 +122,46 @@ gentbl_cc_library( deps = [":hlo_ops_td_files"], ) +cc_binary( + name = "mlir-tblgen-builder", + srcs = glob([ + "tools/mlir-tblgen-builder/*.h", + "tools/mlir-tblgen-builder/*.cpp", + "tools/mlir-tblgen-builder/TableGen/*.h", + "tools/mlir-tblgen-builder/TableGen/*.cpp", + ]), + deps = [ + "@llvm-project//mlir:MlirTableGenMain", + "@llvm-project//mlir:Support", + # "@llvm-project//mlir:TableGen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TableGen", + "@llvm-project//llvm:config", + ], +) + +gentbl_cc_library( + name = "hlo_ops_builder_gen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-builder-decls"], + "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc", + ), + ( + ["-gen-builder-defs"], + "include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.cc.inc", + ), + ], + tblgen = ":mlir-tblgen-builder", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", + td_includes = [ + "external/mlir-hlo/include", + "include", + ], + deps = [":hlo_ops_td_files"], +) + gentbl_cc_library( name = "hlo_ops_base_inc_gen", strip_include_prefix = "include", @@ -519,6 +559,7 @@ cc_library( ":hlo_ops_base_structs", ":hlo_ops_common", ":hlo_ops_inc_gen", + ":hlo_ops_builder_gen", ":hlo_ops_pattern_gen", ":infer_fusibility_op_interface", "@llvm-project//llvm:Support", diff --git a/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp b/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp new file mode 100644 index 0000000..c2cc5bf --- /dev/null +++ b/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp @@ -0,0 +1,2497 @@ +//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// OpDefinitionsGen uses the description of operations to generate C++ +// definitions for ops. +// +//===----------------------------------------------------------------------===// + +#include "OpFormatGen.h" +#include "OpGenHelpers.h" +#include "TableGen/CodeGenHelpers.h" +#include "TableGen/Format.h" +#include "TableGen/GenInfo.h" +#include "TableGen/Interfaces.h" +#include "TableGen/OpClass.h" +#include "TableGen/Operator.h" +#include "TableGen/SideEffects.h" +#include "TableGen/Trait.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#define DEBUG_TYPE "mlir-tblgen-opdefgen" + +using namespace llvm; +using namespace mlir; +using namespace mlir::tblgen; + +static const char *const tblgenNamePrefix = "tblgen_"; +static const char *const generatedArgName = "odsArg"; +static const char *const odsBuilder = "odsBuilder"; +static const char *const builderOpState = "odsState"; + +// The logic to calculate the actual value range for a declared operand/result +// of an op with variadic operands/results. Note that this logic is not for +// general use; it assumes all variadic operands/results must have the same +// number of values. +// +// {0}: The list of whether each declared operand/result is variadic. +// {1}: The total number of non-variadic operands/results. +// {2}: The total number of variadic operands/results. +// {3}: The total number of actual values. +// {4}: "operand" or "result". +const char *sameVariadicSizeValueRangeCalcCode = R"( + bool isVariadic[] = {{{0}}; + int prevVariadicCount = 0; + for (unsigned i = 0; i < index; ++i) + if (isVariadic[i]) ++prevVariadicCount; + + // Calculate how many dynamic values a static variadic {4} corresponds to. + // This assumes all static variadic {4}s have the same dynamic value count. + int variadicSize = ({3} - {1}) / {2}; + // `index` passed in as the parameter is the static index which counts each + // {4} (variadic or not) as size 1. So here for each previous static variadic + // {4}, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static {4} starts. + int start = index + (variadicSize - 1) * prevVariadicCount; + int size = isVariadic[index] ? variadicSize : 1; + return {{start, size}; +)"; + +// The logic to calculate the actual value range for a declared operand/result +// of an op with variadic operands/results. Note that this logic is assumes +// the op has an attribute specifying the size of each operand/result segment +// (variadic or not). +// +// {0}: The name of the attribute specifying the segment sizes. +const char *adapterSegmentSizeAttrInitCode = R"( + assert(odsAttrs && "missing segment size attribute for op"); + auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); +)"; +const char *opSegmentSizeAttrInitCode = R"( + auto sizeAttr = (*this)->getAttr("{0}").cast<::mlir::DenseIntElementsAttr>(); +)"; +const char *attrSizedSegmentValueRangeCalcCode = R"( + auto sizeAttrValues = sizeAttr.getValues(); + unsigned start = 0; + for (unsigned i = 0; i < index; ++i) + start += *(sizeAttrValues.begin() + i); + unsigned size = *(sizeAttrValues.begin() + index); + return {start, size}; +)"; + +// The logic to build a range of either operand or result values. +// +// {0}: The begin iterator of the actual values. +// {1}: The call to generate the start and length of the value range. +const char *valueRangeReturnCode = R"( + auto valueRange = {1}; + return {{std::next({0}, valueRange.first), + std::next({0}, valueRange.first + valueRange.second)}; +)"; + +static const char *const opCommentHeader = R"( +//===----------------------------------------------------------------------===// +// {0} {1} +//===----------------------------------------------------------------------===// + +)"; + +//===----------------------------------------------------------------------===// +// StaticVerifierFunctionEmitter +//===----------------------------------------------------------------------===// + +namespace { +/// This class deduplicates shared operation verification code by emitting +/// static functions alongside the op definitions. These methods are local to +/// the definition file, and are invoked within the operation verify methods. +/// An example is shown below: +/// +/// static LogicalResult localVerify(...) +/// +/// LogicalResult OpA::verify(...) { +/// if (failed(localVerify(...))) +/// return failure(); +/// ... +/// } +/// +/// LogicalResult OpB::verify(...) { +/// if (failed(localVerify(...))) +/// return failure(); +/// ... +/// } +/// +class StaticVerifierFunctionEmitter { +public: + StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records, + ArrayRef opDefs, + raw_ostream &os, bool emitDecl); + + /// Get the name of the local function used for the given type constraint. + /// These functions are used for operand and result constraints and have the + /// form: + /// LogicalResult(Operation *op, Type type, StringRef valueKind, + /// unsigned valueGroupStartIndex); + StringRef getTypeConstraintFn(const Constraint &constraint) const { + auto it = localTypeConstraints.find(constraint.getAsOpaquePointer()); + assert(it != localTypeConstraints.end() && "expected valid constraint fn"); + return it->second; + } + +private: + /// Returns a unique name to use when generating local methods. + static std::string getUniqueName(const llvm::RecordKeeper &records); + + /// Emit local methods for the type constraints used within the provided op + /// definitions. + void emitTypeConstraintMethods(ArrayRef opDefs, + raw_ostream &os, bool emitDecl); + + /// A unique label for the file currently being generated. This is used to + /// ensure that the local functions have a unique name. + std::string uniqueOutputLabel; + + /// A set of functions implementing type constraints, used for operand and + /// result verification. + llvm::DenseMap localTypeConstraints; +}; +} // namespace + +StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( + const llvm::RecordKeeper &records, ArrayRef opDefs, + raw_ostream &os, bool emitDecl) + : uniqueOutputLabel(getUniqueName(records)) { + llvm::Optional namespaceEmitter; + if (!emitDecl) { + os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); + namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace()); + } + + emitTypeConstraintMethods(opDefs, os, emitDecl); +} + +std::string StaticVerifierFunctionEmitter::getUniqueName( + const llvm::RecordKeeper &records) { + // Use the input file name when generating a unique name. + std::string inputFilename = records.getInputFilename(); + + // Drop all but the base filename. + StringRef nameRef = llvm::sys::path::filename(inputFilename); + nameRef.consume_back(".td"); + + // Sanitize any invalid characters. + std::string uniqueName; + for (char c : nameRef) { + if (llvm::isAlnum(c) || c == '_') + uniqueName.push_back(c); + else + uniqueName.append(llvm::utohexstr((unsigned char)c)); + } + return uniqueName; +} + +void StaticVerifierFunctionEmitter::emitTypeConstraintMethods( + ArrayRef opDefs, raw_ostream &os, bool emitDecl) { + // Collect a set of all of the used type constraints within the operation + // definitions. + llvm::SetVector typeConstraints; + for (Record *def : opDefs) { + Operator op(*def); + for (NamedTypeConstraint &operand : op.getOperands()) + if (operand.hasPredicate()) + typeConstraints.insert(operand.constraint.getAsOpaquePointer()); + for (NamedTypeConstraint &result : op.getResults()) + if (result.hasPredicate()) + typeConstraints.insert(result.constraint.getAsOpaquePointer()); + } + + FmtContext fctx; + for (auto it : llvm::enumerate(typeConstraints)) { + // Generate an obscure and unique name for this type constraint. + std::string name = (Twine("__mlir_ods_local_type_constraint_") + + uniqueOutputLabel + Twine(it.index())) + .str(); + localTypeConstraints.try_emplace(it.value(), name); + + // Only generate the methods if we are generating definitions. + if (emitDecl) + continue; + + Constraint constraint = Constraint::getFromOpaquePointer(it.value()); + os << "static ::mlir::LogicalResult " << name + << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef " + "valueKind, unsigned valueGroupStartIndex) {\n"; + + os << " if (!(" + << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type")) + << ")) {\n" + << formatv( + " return op->emitOpError(valueKind) << \" #\" << " + "valueGroupStartIndex << \" must be {0}, but got \" << type;\n", + constraint.getSummary()) + << " }\n" + << " return ::mlir::success();\n" + << "}\n\n"; + } +} + +//===----------------------------------------------------------------------===// +// Utility structs and functions +//===----------------------------------------------------------------------===// + +// Replaces all occurrences of `match` in `str` with `substitute`. +static std::string replaceAllSubstrs(std::string str, const std::string &match, + const std::string &substitute) { + std::string::size_type scanLoc = 0, matchLoc = std::string::npos; + while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) { + str = str.replace(matchLoc, match.size(), substitute); + scanLoc = matchLoc + substitute.size(); + } + return str; +} + +// Returns whether the record has a value of the given name that can be returned +// via getValueAsString. +static inline bool hasStringAttribute(const Record &record, + StringRef fieldName) { + auto valueInit = record.getValueInit(fieldName); + return isa(valueInit); +} + +static std::string getArgumentName(const Operator &op, int index) { + const auto &operand = op.getOperand(index); + if (!operand.name.empty()) + return std::string(operand.name); + else + return std::string(formatv("{0}_{1}", generatedArgName, index)); +} + +// Returns true if we can use unwrapped value for the given `attr` in builders. +static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { + return attr.getReturnType() != attr.getStorageType() && + // We need to wrap the raw value into an attribute in the builder impl + // so we need to make sure that the attribute specifies how to do that. + !attr.getConstBuilderTemplate().empty(); +} + +//===----------------------------------------------------------------------===// +// Op emitter +//===----------------------------------------------------------------------===// + +namespace { +// Helper class to emit a record into the given output stream. +class OpEmitter { +public: + static void + emitDecl(const Operator &op, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter); + static void + emitDef(const Operator &op, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter); + +private: + OpEmitter(const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter); + + void emitDecl(raw_ostream &os); + void emitDef(raw_ostream &os); + + // Generates the OpAsmOpInterface for this operation if possible. + void genOpAsmInterface(); + + // Generates the `getOperationName` method for this op. + void genOpNameGetter(); + + // Generates getters for the attributes. + void genAttrGetters(); + + // Generates setter for the attributes. + void genAttrSetters(); + + // Generates removers for optional attributes. + void genOptionalAttrRemovers(); + + // Generates getters for named operands. + void genNamedOperandGetters(); + + // Generates setters for named operands. + void genNamedOperandSetters(); + + // Generates getters for named results. + void genNamedResultGetters(); + + // Generates getters for named regions. + void genNamedRegionGetters(); + + // Generates getters for named successors. + void genNamedSuccessorGetters(); + + // Generates builder methods for the operation. + void genBuilder(); + + // Generates the build() method that takes each operand/attribute + // as a stand-alone parameter. + void genSeparateArgParamBuilder(); + + // Generates the build() method that takes each operand/attribute as a + // stand-alone parameter. The generated build() method uses first operand's + // type as all results' types. + void genUseOperandAsResultTypeSeparateParamBuilder(); + + // Generates the build() method that takes all operands/attributes + // collectively as one parameter. The generated build() method uses first + // operand's type as all results' types. + void genUseOperandAsResultTypeCollectiveParamBuilder(); + + // Generates the build() method that takes aggregate operands/attributes + // parameters. This build() method uses inferred types as result types. + // Requires: The type needs to be inferable via InferTypeOpInterface. + void genInferredTypeCollectiveParamBuilder(); + + // Generates the build() method that takes each operand/attribute as a + // stand-alone parameter. The generated build() method uses first attribute's + // type as all result's types. + void genUseAttrAsResultTypeBuilder(); + + // Generates the build() method that takes all result types collectively as + // one parameter. Similarly for operands and attributes. + void genCollectiveParamBuilder(); + + // The kind of parameter to generate for result types in builders. + enum class TypeParamKind { + None, // No result type in parameter list. + Separate, // A separate parameter for each result type. + Collective, // An ArrayRef for all result types. + }; + + // The kind of parameter to generate for attributes in builders. + enum class AttrParamKind { + WrappedAttr, // A wrapped MLIR Attribute instance. + UnwrappedValue, // A raw value without MLIR Attribute wrapper. + }; + + // Builds the parameter list for build() method of this op. This method writes + // to `paramList` the comma-separated parameter list and updates + // `resultTypeNames` with the names for parameters for specifying result + // types. The given `typeParamKind` and `attrParamKind` controls how result + // types and attributes are placed in the parameter list. + void buildParamList(llvm::SmallVectorImpl ¶mList, + SmallVectorImpl &resultTypeNames, + TypeParamKind typeParamKind, + AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); + + // Adds op arguments and regions into operation state for build() methods. + void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, + bool isRawValueAttr = false); + + // Generates canonicalizer declaration for the operation. + void genCanonicalizerDecls(); + + // Generates the folder declaration for the operation. + void genFolderDecls(); + + // Generates the parser for the operation. + void genParser(); + + // Generates the printer for the operation. + void genPrinter(); + + // Generates verify method for the operation. + void genVerifier(); + + // Generates verify statements for operands and results in the operation. + // The generated code will be attached to `body`. + void genOperandResultVerifier(OpMethodBody &body, + Operator::value_range values, + StringRef valueKind); + + // Generates verify statements for regions in the operation. + // The generated code will be attached to `body`. + void genRegionVerifier(OpMethodBody &body); + + // Generates verify statements for successors in the operation. + // The generated code will be attached to `body`. + void genSuccessorVerifier(OpMethodBody &body); + + // Generates the traits used by the object. + void genTraits(); + + // Generate the OpInterface methods for all interfaces. + void genOpInterfaceMethods(); + + // Generate op interface methods for the given interface. + void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait); + + // Generate op interface method for the given interface method. If + // 'declaration' is true, generates a declaration, else a definition. + OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, + bool declaration = true); + + // Generate the side effect interface methods. + void genSideEffectInterfaceMethods(); + + // Generate the type inference interface methods. + void genTypeInterfaceMethods(); + +private: + // The TableGen record for this op. + // TODO: OpEmitter should not have a Record directly, + // it should rather go through the Operator for better abstraction. + const Record &def; + + // The wrapper operator class for querying information from this op. + Operator op; + + // The C++ code builder for this op + OpClass opClass; + + // The format context for verification code generation. + FmtContext verifyCtx; + + // The emitter containing all of the locally emitted verification functions. + const StaticVerifierFunctionEmitter &staticVerifierEmitter; +}; +} // end anonymous namespace + +// Populate the format context `ctx` with substitutions of attributes, operands +// and results. +// - attrGet corresponds to the name of the function to call to get value of +// attribute (the generated function call returns an Attribute); +// - operandGet corresponds to the name of the function with which to retrieve +// an operand (the generated function call returns an OperandRange); +// - resultGet corresponds to the name of the function to get an result (the +// generated function call returns a ValueRange); +static void populateSubstitutions(const Operator &op, const char *attrGet, + const char *operandGet, const char *resultGet, + FmtContext &ctx) { + // Populate substitutions for attributes and named operands. + for (const auto &namedAttr : op.getAttributes()) + ctx.addSubst(namedAttr.name, + formatv("{0}(\"{1}\")", attrGet, namedAttr.name)); + for (int i = 0, e = op.getNumOperands(); i < e; ++i) { + auto &value = op.getOperand(i); + if (value.name.empty()) + continue; + + if (value.isVariadic()) + ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i)); + else + ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i)); + } + + // Populate substitutions for results. + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + auto &value = op.getResult(i); + if (value.name.empty()) + continue; + + if (value.isVariadic()) + ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i)); + else + ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i)); + } +} + +// Generate attribute verification. If emitVerificationRequiringOp is set then +// only verification for attributes whose value depend on op being known are +// emitted, else only verification that doesn't depend on the op being known are +// generated. +// - emitErrorPrefix is the prefix for the error emitting call which consists +// of the entire function call up to start of error message fragment; +// - emitVerificationRequiringOp specifies whether verification should be +// emitted for verification that require the op to exist; +static void genAttributeVerifier(const Operator &op, const char *attrGet, + const Twine &emitErrorPrefix, + bool emitVerificationRequiringOp, + FmtContext &ctx, OpMethodBody &body) { + for (const auto &namedAttr : op.getAttributes()) { + const auto &attr = namedAttr.attr; + if (attr.isDerivedAttr()) + continue; + + auto attrName = namedAttr.name; + bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); + auto attrPred = attr.getPredicate(); + auto condition = attrPred.isNull() ? "" : attrPred.getCondition(); + // There is a condition to emit only if the use of $_op and whether to + // emit verifications for op matches. + bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^ + emitVerificationRequiringOp); + + // Prefix with `tblgen_` to avoid hiding the attribute accessor. + auto varName = tblgenNamePrefix + attrName; + + // If the attribute is + // 1. Required (not allowed missing) and not in op verification, or + // 2. Has a condition that will get verified + // then the variable will be used. + // + // Therefore, for optional attributes whose verification requires that an + // op already exists for verification/emitVerificationRequiringOp is set + // has nothing that can be verified here. + if ((allowMissingAttr || emitVerificationRequiringOp) && + !hasConditionToEmit) + continue; + + body << formatv(" {\n auto {0} = {1}(\"{2}\");\n", varName, attrGet, + attrName); + + if (!emitVerificationRequiringOp && !allowMissingAttr) { + body << " if (!" << varName << ") return " << emitErrorPrefix + << "\"requires attribute '" << attrName << "'\");\n"; + } + + if (!hasConditionToEmit) { + body << " }\n"; + continue; + } + + if (allowMissingAttr) { + // If the attribute has a default value, then only verify the predicate if + // set. This does effectively assume that the default value is valid. + // TODO: verify the debug value is valid (perhaps in debug mode only). + body << " if (" << varName << ") {\n"; + } + + body << tgfmt(" if (!($0)) return $1\"attribute '$2' " + "failed to satisfy constraint: $3\");\n", + /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)), + emitErrorPrefix, attrName, attr.getSummary()); + if (allowMissingAttr) + body << " }\n"; + body << " }\n"; + } +} + +OpEmitter::OpEmitter(const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) + : def(op.getDef()), op(op), + opClass(op.getCppClassName(), op.getExtraClassDeclaration()), + staticVerifierEmitter(staticVerifierEmitter) { + verifyCtx.withOp("(*this->getOperation())"); + verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); + + // Dot not need traits in buider + // genTraits(); + + // Generate C++ code for various op methods. The order here determines the + // methods in the generated file. + // genOpAsmInterface(); + genOpNameGetter(); + genNamedOperandGetters(); + genNamedOperandSetters(); + genNamedResultGetters(); + genNamedRegionGetters(); + genNamedSuccessorGetters(); + genAttrGetters(); + genAttrSetters(); + genOptionalAttrRemovers(); + genBuilder(); + genParser(); + genPrinter(); + genVerifier(); + genCanonicalizerDecls(); + genFolderDecls(); + genTypeInterfaceMethods(); + genOpInterfaceMethods(); + generateOpFormat(op, opClass); + genSideEffectInterfaceMethods(); +} + +void OpEmitter::emitDecl( + const Operator &op, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) { + OpEmitter(op, staticVerifierEmitter).emitDecl(os); +} + +void OpEmitter::emitDef( + const Operator &op, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) { + OpEmitter(op, staticVerifierEmitter).emitDef(os); +} + +void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); } + +void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } + +void OpEmitter::genAttrGetters() { + FmtContext fctx; + fctx.withBuilder("::mlir::Builder((*this)->getContext())"); + + Dialect opDialect = op.getDialect(); + // Emit the derived attribute body. + auto emitDerivedAttr = [&](StringRef name, Attribute attr) { + auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); + if (!method) + return; + auto &body = method->body(); + body << " " << attr.getDerivedCodeBody() << "\n"; + }; + + // Emit with return type specified. + auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) { + auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); + auto &body = method->body(); + body << " auto attr = " << name << "Attr();\n"; + if (attr.hasDefaultValue()) { + // Returns the default value if not set. + // TODO: this is inefficient, we are recreating the attribute for every + // call. This should be set instead. + std::string defaultValue = std::string( + tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + body << " if (!attr)\n return " + << tgfmt(attr.getConvertFromStorageCall(), + &fctx.withSelf(defaultValue)) + << ";\n"; + } + body << " return " + << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) + << ";\n"; + }; + + // Generate raw named accessor type. This is a wrapper class that allows + // referring to the attributes via accessors instead of having to use + // the string interface for better compile time verification. + auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { + auto *method = + opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str()); + if (!method) + return; + auto &body = method->body(); + body << " return (*this)->getAttr(\"" << name << "\").template "; + if (attr.isOptional() || attr.hasDefaultValue()) + body << "dyn_cast_or_null<"; + else + body << "cast<"; + body << attr.getStorageType() << ">();"; + }; + + for (auto &namedAttr : op.getAttributes()) { + const auto &name = namedAttr.name; + const auto &attr = namedAttr.attr; + if (attr.isDerivedAttr()) { + emitDerivedAttr(name, attr); + } else { + emitAttrWithStorageType(name, attr); + emitAttrWithReturnType(name, attr); + } + } + + auto derivedAttrs = make_filter_range(op.getAttributes(), + [](const NamedAttribute &namedAttr) { + return namedAttr.attr.isDerivedAttr(); + }); + if (!derivedAttrs.empty()) { + opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait"); + // Generate helper method to query whether a named attribute is a derived + // attribute. This enables, for example, avoiding adding an attribute that + // overlaps with a derived attribute. + { + auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute", + OpMethod::MP_Static, + "::llvm::StringRef", "name"); + auto &body = method->body(); + for (auto namedAttr : derivedAttrs) + body << " if (name == \"" << namedAttr.name << "\") return true;\n"; + body << " return false;"; + } + // Generate method to materialize derived attributes as a DictionaryAttr. + { + auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr", + "materializeDerivedAttributes"); + auto &body = method->body(); + + auto nonMaterializable = + make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) { + return namedAttr.attr.getConvertFromStorageCall().empty(); + }); + if (!nonMaterializable.empty()) { + std::string attrs; + llvm::raw_string_ostream os(attrs); + interleaveComma(nonMaterializable, os, + [&](const NamedAttribute &attr) { os << attr.name; }); + PrintWarning( + op.getLoc(), + formatv( + "op has non-materializable derived attributes '{0}', skipping", + os.str())); + body << formatv(" emitOpError(\"op has non-materializable derived " + "attributes '{0}'\");\n", + attrs); + body << " return nullptr;"; + return; + } + + body << " ::mlir::MLIRContext* ctx = getContext();\n"; + body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; + body << " return ::mlir::DictionaryAttr::get("; + body << " ctx, {\n"; + interleave( + derivedAttrs, body, + [&](const NamedAttribute &namedAttr) { + auto tmpl = namedAttr.attr.getConvertFromStorageCall(); + body << " {::mlir::Identifier::get(\"" << namedAttr.name + << "\", ctx),\n" + << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()") + .withBuilder("odsBuilder") + .addSubst("_ctx", "ctx")) + << "}"; + }, + ",\n"); + body << "});"; + } + } +} + +void OpEmitter::genAttrSetters() { + // Generate raw named setter type. This is a wrapper class that allows setting + // to the attributes via setters instead of having to use the string interface + // for better compile time verification. + auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { + auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(), + attr.getStorageType(), "attr"); + if (!method) + return; + auto &body = method->body(); + body << " (*this)->setAttr(\"" << name << "\", attr);"; + }; + + for (auto &namedAttr : op.getAttributes()) { + const auto &name = namedAttr.name; + const auto &attr = namedAttr.attr; + if (!attr.isDerivedAttr()) + emitAttrWithStorageType(name, attr); + } +} + +void OpEmitter::genOptionalAttrRemovers() { + // Generate methods for removing optional attributes, instead of having to + // use the string interface. Enables better compile time verification. + auto emitRemoveAttr = [&](StringRef name) { + auto upperInitial = name.take_front().upper(); + auto suffix = name.drop_front(); + auto *method = opClass.addMethodAndPrune( + "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str()); + if (!method) + return; + auto &body = method->body(); + body << " return (*this)->removeAttr(\"" << name << "\");"; + }; + + for (const auto &namedAttr : op.getAttributes()) { + const auto &name = namedAttr.name; + const auto &attr = namedAttr.attr; + if (attr.isOptional()) + emitRemoveAttr(name); + } +} + +// Generates the code to compute the start and end index of an operand or result +// range. +template +static void +generateValueRangeStartAndEnd(Class &opClass, StringRef methodName, + int numVariadic, int numNonVariadic, + StringRef rangeSizeCall, bool hasAttrSegmentSize, + StringRef sizeAttrInit, RangeT &&odsValues) { + auto *method = opClass.addMethodAndPrune("std::pair", + methodName, "unsigned", "index"); + if (!method) + return; + auto &body = method->body(); + if (numVariadic == 0) { + body << " return {index, 1};\n"; + } else if (hasAttrSegmentSize) { + body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode; + } else { + // Because the op can have arbitrarily interleaved variadic and non-variadic + // operands, we need to embed a list in the "sink" getter method for + // calculation at run-time. + llvm::SmallVector isVariadic; + isVariadic.reserve(llvm::size(odsValues)); + for (auto &it : odsValues) + isVariadic.push_back(it.isVariableLength() ? "true" : "false"); + std::string isVariadicList = llvm::join(isVariadic, ", "); + body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, + numNonVariadic, numVariadic, rangeSizeCall, "operand"); + } +} + +// Generates the named operand getter methods for the given Operator `op` and +// puts them in `opClass`. Uses `rangeType` as the return type of getters that +// return a range of operands (individual operands are `Value ` and each +// element in the range must also be `Value `); use `rangeBeginCall` to get +// an iterator to the beginning of the operand range; use `rangeSizeCall` to +// obtain the number of operands. `getOperandCallPattern` contains the code +// necessary to obtain a single operand whose position will be substituted +// instead of +// "{0}" marker in the pattern. Note that the pattern should work for any kind +// of ops, in particular for one-operand ops that may not have the +// `getOperand(unsigned)` method. +static void generateNamedOperandGetters(const Operator &op, Class &opClass, + StringRef sizeAttrInit, + StringRef rangeType, + StringRef rangeBeginCall, + StringRef rangeSizeCall, + StringRef getOperandCallPattern) { + const int numOperands = op.getNumOperands(); + const int numVariadicOperands = op.getNumVariableLengthOperands(); + const int numNormalOperands = numOperands - numVariadicOperands; + + const auto *sameVariadicSize = + op.getTrait("::mlir::OpTrait::SameVariadicOperandSize"); + const auto *attrSizedOperands = + op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); + + if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) { + PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " + "specification over their sizes"); + } + + if (numVariadicOperands < 2 && attrSizedOperands) { + PrintFatalError(op.getLoc(), "op must have at least two variadic operands " + "to use 'AttrSizedOperandSegments' trait"); + } + + if (attrSizedOperands && sameVariadicSize) { + PrintFatalError(op.getLoc(), + "op cannot have both 'AttrSizedOperandSegments' and " + "'SameVariadicOperandSize' traits"); + } + + // First emit a few "sink" getter methods upon which we layer all nicer named + // getter methods. + generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength", + numVariadicOperands, numNormalOperands, + rangeSizeCall, attrSizedOperands, sizeAttrInit, + const_cast(op).getOperands()); + + auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned", + "index"); + auto &body = m->body(); + body << formatv(valueRangeReturnCode, rangeBeginCall, + "getODSOperandIndexAndLength(index)"); + + // Then we emit nicer named getter methods by redirecting to the "sink" getter + // method. + for (int i = 0; i != numOperands; ++i) { + const auto &operand = op.getOperand(i); + if (operand.name.empty()) + continue; + + if (operand.isOptional()) { + m = opClass.addMethodAndPrune("::mlir::Value", operand.name); + m->body() + << " auto operands = getODSOperands(" << i << ");\n" + << " return operands.empty() ? ::mlir::Value() : *operands.begin();"; + } else if (operand.isVariadic()) { + m = opClass.addMethodAndPrune(rangeType, operand.name); + m->body() << " return getODSOperands(" << i << ");"; + } else { + m = opClass.addMethodAndPrune("::mlir::Value", operand.name); + m->body() << " return *getODSOperands(" << i << ").begin();"; + } + } +} + +void OpEmitter::genNamedOperandGetters() { + generateNamedOperandGetters( + op, opClass, + /*sizeAttrInit=*/ + formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(), + /*rangeType=*/"::mlir::Operation::operand_range", + /*rangeBeginCall=*/"getOperation()->operand_begin()", + /*rangeSizeCall=*/"getOperation()->getNumOperands()", + /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); +} + +void OpEmitter::genNamedOperandSetters() { + auto *attrSizedOperands = + op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); + for (int i = 0, e = op.getNumOperands(); i != e; ++i) { + const auto &operand = op.getOperand(i); + if (operand.name.empty()) + continue; + auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange", + (operand.name + "Mutable").str()); + auto &body = m->body(); + body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" + << " return ::mlir::MutableOperandRange(getOperation(), " + "range.first, range.second"; + if (attrSizedOperands) + body << ", ::mlir::MutableOperandRange::OperandSegment(" << i + << "u, *getOperation()->getAttrDictionary().getNamed(" + "\"operand_segment_sizes\"))"; + body << ");\n"; + } +} + +void OpEmitter::genNamedResultGetters() { + const int numResults = op.getNumResults(); + const int numVariadicResults = op.getNumVariableLengthResults(); + const int numNormalResults = numResults - numVariadicResults; + + // If we have more than one variadic results, we need more complicated logic + // to calculate the value range for each result. + + const auto *sameVariadicSize = + op.getTrait("::mlir::OpTrait::SameVariadicResultSize"); + const auto *attrSizedResults = + op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"); + + if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) { + PrintFatalError(op.getLoc(), "op has multiple variadic results but no " + "specification over their sizes"); + } + + if (numVariadicResults < 2 && attrSizedResults) { + PrintFatalError(op.getLoc(), "op must have at least two variadic results " + "to use 'AttrSizedResultSegments' trait"); + } + + if (attrSizedResults && sameVariadicSize) { + PrintFatalError(op.getLoc(), + "op cannot have both 'AttrSizedResultSegments' and " + "'SameVariadicResultSize' traits"); + } + + generateValueRangeStartAndEnd( + opClass, "getODSResultIndexAndLength", numVariadicResults, + numNormalResults, "getOperation()->getNumResults()", attrSizedResults, + formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(), + op.getResults()); + + auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range", + "getODSResults", "unsigned", "index"); + m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", + "getODSResultIndexAndLength(index)"); + + for (int i = 0; i != numResults; ++i) { + const auto &result = op.getResult(i); + if (result.name.empty()) + continue; + + if (result.isOptional()) { + m = opClass.addMethodAndPrune("::mlir::Value", result.name); + m->body() + << " auto results = getODSResults(" << i << ");\n" + << " return results.empty() ? ::mlir::Value() : *results.begin();"; + } else if (result.isVariadic()) { + m = opClass.addMethodAndPrune("::mlir::Operation::result_range", + result.name); + m->body() << " return getODSResults(" << i << ");"; + } else { + m = opClass.addMethodAndPrune("::mlir::Value", result.name); + m->body() << " return *getODSResults(" << i << ").begin();"; + } + } +} + +void OpEmitter::genNamedRegionGetters() { + unsigned numRegions = op.getNumRegions(); + for (unsigned i = 0; i < numRegions; ++i) { + const auto ®ion = op.getRegion(i); + if (region.name.empty()) + continue; + + // Generate the accessors for a variadic region. + if (region.isVariadic()) { + auto *m = opClass.addMethodAndPrune( + "::mlir::MutableArrayRef<::mlir::Region>", region.name); + m->body() << formatv(" return (*this)->getRegions().drop_front({0});", + i); + continue; + } + + auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name); + m->body() << formatv(" return (*this)->getRegion({0});", i); + } +} + +void OpEmitter::genNamedSuccessorGetters() { + unsigned numSuccessors = op.getNumSuccessors(); + for (unsigned i = 0; i < numSuccessors; ++i) { + const NamedSuccessor &successor = op.getSuccessor(i); + if (successor.name.empty()) + continue; + + // Generate the accessors for a variadic successor list. + if (successor.isVariadic()) { + auto *m = + opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name); + m->body() << formatv( + " return {std::next((*this)->successor_begin(), {0}), " + "(*this)->successor_end()};", + i); + continue; + } + + auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name); + m->body() << formatv(" return (*this)->getSuccessor({0});", i); + } +} + +static bool canGenerateUnwrappedBuilder(Operator &op) { + // If this op does not have native attributes at all, return directly to avoid + // redefining builders. + if (op.getNumNativeAttributes() == 0) + return false; + + bool canGenerate = false; + // We are generating builders that take raw values for attributes. We need to + // make sure the native attributes have a meaningful "unwrapped" value type + // different from the wrapped mlir::Attribute type to avoid redefining + // builders. This checks for the op has at least one such native attribute. + for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { + NamedAttribute &namedAttr = op.getAttribute(i); + if (canUseUnwrappedRawValue(namedAttr.attr)) { + canGenerate = true; + break; + } + } + return canGenerate; +} + +static bool canInferType(Operator &op) { + return op.getTrait("::mlir::InferTypeOpInterface::Trait") && + op.getNumRegions() == 0; +} + +void OpEmitter::genSeparateArgParamBuilder() { + SmallVector attrBuilderType; + attrBuilderType.push_back(AttrParamKind::WrappedAttr); + if (canGenerateUnwrappedBuilder(op)) + attrBuilderType.push_back(AttrParamKind::UnwrappedValue); + + // Emit with separate builders with or without unwrapped attributes and/or + // inferring result type. + auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, + bool inferType) { + llvm::SmallVector paramList; + llvm::SmallVector resultNames; + buildParamList(paramList, resultNames, paramKind, attrType); + + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method. + if (!m) + return; + auto &body = m->body(); + genCodeForAddingArgAndRegionForBuilder( + body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); + + // Push all result types to the operation state + + if (inferType) { + // Generate builder that infers type too. + // TODO: Subsume this with general checking if type can be + // inferred automatically. + // TODO: Expand to handle regions. + body << formatv(R"( + ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; + if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), + {1}.location, {1}.operands, + {1}.attributes.getDictionary({1}.getContext()), + /*regions=*/{{}, inferredReturnTypes))) + {1}.addTypes(inferredReturnTypes); + else + ::llvm::report_fatal_error("Failed to infer result type(s).");)", + opClass.getClassName(), builderOpState); + return; + } + + switch (paramKind) { + case TypeParamKind::None: + return; + case TypeParamKind::Separate: + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + if (op.getResult(i).isOptional()) + body << " if (" << resultNames[i] << ")\n "; + body << " " << builderOpState << ".addTypes(" << resultNames[i] + << ");\n"; + } + return; + case TypeParamKind::Collective: { + int numResults = op.getNumResults(); + int numVariadicResults = op.getNumVariableLengthResults(); + int numNonVariadicResults = numResults - numVariadicResults; + bool hasVariadicResult = numVariadicResults != 0; + + // Avoid emitting "resultTypes.size() >= 0u" which is always true. + if (!(hasVariadicResult && numNonVariadicResults == 0)) + body << " " + << "assert(resultTypes.size() " + << (hasVariadicResult ? ">=" : "==") << " " + << numNonVariadicResults + << "u && \"mismatched number of results\");\n"; + body << " " << builderOpState << ".addTypes(resultTypes);\n"; + } + return; + } + llvm_unreachable("unhandled TypeParamKind"); + }; + + // Some of the build methods generated here may be ambiguous, but TableGen's + // ambiguous function detection will elide those ones. + for (auto attrType : attrBuilderType) { + emit(attrType, TypeParamKind::Separate, /*inferType=*/false); + if (canInferType(op)) + emit(attrType, TypeParamKind::None, /*inferType=*/true); + emit(attrType, TypeParamKind::Collective, /*inferType=*/false); + } +} + +void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { + int numResults = op.getNumResults(); + + // Signature + llvm::SmallVector paramList; + paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OperationState &", builderOpState); + paramList.emplace_back("::mlir::ValueRange", "operands"); + // Provide default value for `attributes` when its the last parameter + StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; + paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", + "attributes", attributesDefaultValue); + if (op.getNumVariadicRegions()) + paramList.emplace_back("unsigned", "numRegions"); + + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + auto &body = m->body(); + + // Operands + body << " " << builderOpState << ".addOperands(operands);\n"; + + // Attributes + body << " " << builderOpState << ".addAttributes(attributes);\n"; + + // Create the correct number of regions + if (int numRegions = op.getNumRegions()) { + body << llvm::formatv( + " for (unsigned i = 0; i != {0}; ++i)\n", + (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); + body << " (void)" << builderOpState << ".addRegion();\n"; + } + + // Result types + SmallVector resultTypes(numResults, "operands[0].getType()"); + body << " " << builderOpState << ".addTypes({" + << llvm::join(resultTypes, ", ") << "});\n\n"; +} + +void OpEmitter::genInferredTypeCollectiveParamBuilder() { + // TODO: Expand to support regions. + SmallVector paramList; + paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OperationState &", builderOpState); + paramList.emplace_back("::mlir::ValueRange", "operands"); + paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", + "attributes", "{}"); + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + auto &body = m->body(); + + int numResults = op.getNumResults(); + int numVariadicResults = op.getNumVariableLengthResults(); + int numNonVariadicResults = numResults - numVariadicResults; + + int numOperands = op.getNumOperands(); + int numVariadicOperands = op.getNumVariableLengthOperands(); + int numNonVariadicOperands = numOperands - numVariadicOperands; + + // Operands + if (numVariadicOperands == 0 || numNonVariadicOperands != 0) + body << " assert(operands.size()" + << (numVariadicOperands != 0 ? " >= " : " == ") + << numNonVariadicOperands + << "u && \"mismatched number of parameters\");\n"; + body << " " << builderOpState << ".addOperands(operands);\n"; + body << " " << builderOpState << ".addAttributes(attributes);\n"; + + // Create the correct number of regions + if (int numRegions = op.getNumRegions()) { + body << llvm::formatv( + " for (unsigned i = 0; i != {0}; ++i)\n", + (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); + body << " (void)" << builderOpState << ".addRegion();\n"; + } + + // Result types + body << formatv(R"( + ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes; + if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), + {1}.location, operands, + {1}.attributes.getDictionary({1}.getContext()), + /*regions=*/{{}, inferredReturnTypes))) {{)", + opClass.getClassName(), builderOpState); + if (numVariadicResults == 0 || numNonVariadicResults != 0) + body << " assert(inferredReturnTypes.size()" + << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults + << "u && \"mismatched number of return types\");\n"; + body << " " << builderOpState << ".addTypes(inferredReturnTypes);"; + + body << formatv(R"( + } else + ::llvm::report_fatal_error("Failed to infer result type(s).");)", + opClass.getClassName(), builderOpState); +} + +void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { + llvm::SmallVector paramList; + llvm::SmallVector resultNames; + buildParamList(paramList, resultNames, TypeParamKind::None); + + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + auto &body = m->body(); + genCodeForAddingArgAndRegionForBuilder(body); + + auto numResults = op.getNumResults(); + if (numResults == 0) + return; + + // Push all result types to the operation state + const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; + std::string resultType = + formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); + body << " " << builderOpState << ".addTypes({" << resultType; + for (int i = 1; i != numResults; ++i) + body << ", " << resultType; + body << "});\n\n"; +} + +void OpEmitter::genUseAttrAsResultTypeBuilder() { + SmallVector paramList; + paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OperationState &", builderOpState); + paramList.emplace_back("::mlir::ValueRange", "operands"); + paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", + "attributes", "{}"); + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + + auto &body = m->body(); + + // Push all result types to the operation state + std::string resultType; + const auto &namedAttr = op.getAttribute(0); + + body << " for (auto attr : attributes) {\n"; + body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n"; + if (namedAttr.attr.isTypeAttr()) { + resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()"; + } else { + resultType = "attr.second.getType()"; + } + + // Operands + body << " " << builderOpState << ".addOperands(operands);\n"; + + // Attributes + body << " " << builderOpState << ".addAttributes(attributes);\n"; + + // Result types + SmallVector resultTypes(op.getNumResults(), resultType); + body << " " << builderOpState << ".addTypes({" + << llvm::join(resultTypes, ", ") << "});\n"; + body << " }\n"; +} + +/// Returns a signature of the builder. Updates the context `fctx` to enable +/// replacement of $_builder and $_state in the body. +static std::string getBuilderSignature(const Builder &builder) { + ArrayRef params(builder.getParameters()); + + // Inject builder and state arguments. + llvm::SmallVector arguments; + arguments.reserve(params.size() + 2); + arguments.push_back( + llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str()); + arguments.push_back( + llvm::formatv("::mlir::OperationState &{0}", builderOpState).str()); + + for (unsigned i = 0, e = params.size(); i < e; ++i) { + // If no name is provided, generate one. + Optional paramName = params[i].getName(); + std::string name = + paramName ? paramName->str() : "odsArg" + std::to_string(i); + + std::string defaultValue; + if (Optional defaultParamValue = params[i].getDefaultValue()) + defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str(); + arguments.push_back( + llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue) + .str()); + } + + return llvm::join(arguments, ", "); +} + +void OpEmitter::genBuilder() { + // Handle custom builders if provided. + for (const Builder &builder : op.getBuilders()) { + std::string paramStr = getBuilderSignature(builder); + + Optional body = builder.getBody(); + OpMethod::Property properties = + body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; + auto *method = + opClass.addMethodAndPrune("void", "build", properties, paramStr); + + FmtContext fctx; + fctx.withBuilder(odsBuilder); + fctx.addSubst("_state", builderOpState); + if (body) + method->body() << tgfmt(*body, &fctx); + } + + // Generate default builders that requires all result type, operands, and + // attributes as parameters. + if (op.skipDefaultBuilders()) + return; + + // We generate three classes of builders here: + // 1. one having a stand-alone parameter for each operand / attribute, and + genSeparateArgParamBuilder(); + // 2. one having an aggregated parameter for all result types / operands / + // attributes, and + genCollectiveParamBuilder(); + // 3. one having a stand-alone parameter for each operand and attribute, + // use the first operand or attribute's type as all result types + // to facilitate different call patterns. + if (op.getNumVariableLengthResults() == 0) { + if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { + genUseOperandAsResultTypeSeparateParamBuilder(); + genUseOperandAsResultTypeCollectiveParamBuilder(); + } + if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) + genUseAttrAsResultTypeBuilder(); + } +} + +void OpEmitter::genCollectiveParamBuilder() { + int numResults = op.getNumResults(); + int numVariadicResults = op.getNumVariableLengthResults(); + int numNonVariadicResults = numResults - numVariadicResults; + + int numOperands = op.getNumOperands(); + int numVariadicOperands = op.getNumVariableLengthOperands(); + int numNonVariadicOperands = numOperands - numVariadicOperands; + + SmallVector paramList; + paramList.emplace_back("::mlir::OpBuilder &", ""); + paramList.emplace_back("::mlir::OperationState &", builderOpState); + paramList.emplace_back("::mlir::TypeRange", "resultTypes"); + paramList.emplace_back("::mlir::ValueRange", "operands"); + // Provide default value for `attributes` when its the last parameter + StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; + paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", + "attributes", attributesDefaultValue); + if (op.getNumVariadicRegions()) + paramList.emplace_back("unsigned", "numRegions"); + + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + auto &body = m->body(); + + // Operands + if (numVariadicOperands == 0 || numNonVariadicOperands != 0) + body << " assert(operands.size()" + << (numVariadicOperands != 0 ? " >= " : " == ") + << numNonVariadicOperands + << "u && \"mismatched number of parameters\");\n"; + body << " " << builderOpState << ".addOperands(operands);\n"; + + // Attributes + body << " " << builderOpState << ".addAttributes(attributes);\n"; + + // Create the correct number of regions + if (int numRegions = op.getNumRegions()) { + body << llvm::formatv( + " for (unsigned i = 0; i != {0}; ++i)\n", + (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); + body << " (void)" << builderOpState << ".addRegion();\n"; + } + + // Result types + if (numVariadicResults == 0 || numNonVariadicResults != 0) + body << " assert(resultTypes.size()" + << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults + << "u && \"mismatched number of return types\");\n"; + body << " " << builderOpState << ".addTypes(resultTypes);\n"; + + // Generate builder that infers type too. + // TODO: Expand to handle regions and successors. + if (canInferType(op) && op.getNumSuccessors() == 0) + genInferredTypeCollectiveParamBuilder(); +} + +void OpEmitter::buildParamList(SmallVectorImpl ¶mList, + SmallVectorImpl &resultTypeNames, + TypeParamKind typeParamKind, + AttrParamKind attrParamKind) { + resultTypeNames.clear(); + auto numResults = op.getNumResults(); + resultTypeNames.reserve(numResults); + + paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OperationState &", builderOpState); + + switch (typeParamKind) { + case TypeParamKind::None: + break; + case TypeParamKind::Separate: { + // Add parameters for all return types + for (int i = 0; i < numResults; ++i) { + const auto &result = op.getResult(i); + std::string resultName = std::string(result.name); + if (resultName.empty()) + resultName = std::string(formatv("resultType{0}", i)); + + StringRef type = + result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; + OpMethodParameter::Property properties = OpMethodParameter::PP_None; + if (result.isOptional()) + properties = OpMethodParameter::PP_Optional; + + paramList.emplace_back(type, resultName, properties); + resultTypeNames.emplace_back(std::move(resultName)); + } + } break; + case TypeParamKind::Collective: { + paramList.emplace_back("::mlir::TypeRange", "resultTypes"); + resultTypeNames.push_back("resultTypes"); + } break; + } + + // Add parameters for all arguments (operands and attributes). + + int numOperands = 0; + int numAttrs = 0; + + int defaultValuedAttrStartIndex = op.getNumArgs(); + if (attrParamKind == AttrParamKind::UnwrappedValue) { + // Calculate the start index from which we can attach default values in the + // builder declaration. + for (int i = op.getNumArgs() - 1; i >= 0; --i) { + auto *namedAttr = op.getArg(i).dyn_cast(); + if (!namedAttr || !namedAttr->attr.hasDefaultValue()) + break; + + if (!canUseUnwrappedRawValue(namedAttr->attr)) + break; + + // Creating an APInt requires us to provide bitwidth, value, and + // signedness, which is complicated compared to others. Similarly + // for APFloat. + // TODO: Adjust the 'returnType' field of such attributes + // to support them. + StringRef retType = namedAttr->attr.getReturnType(); + if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") + break; + + defaultValuedAttrStartIndex = i; + } + } + + for (int i = 0, e = op.getNumArgs(); i < e; ++i) { + auto argument = op.getArg(i); + if (argument.is()) { + const auto &operand = op.getOperand(numOperands); + StringRef type = + operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value"; + OpMethodParameter::Property properties = OpMethodParameter::PP_None; + if (operand.isOptional()) + properties = OpMethodParameter::PP_Optional; + + paramList.emplace_back(type, getArgumentName(op, numOperands), + properties); + ++numOperands; + } else { + const auto &namedAttr = op.getAttribute(numAttrs); + const auto &attr = namedAttr.attr; + + OpMethodParameter::Property properties = OpMethodParameter::PP_None; + if (attr.isOptional()) + properties = OpMethodParameter::PP_Optional; + + StringRef type; + switch (attrParamKind) { + case AttrParamKind::WrappedAttr: + type = attr.getStorageType(); + break; + case AttrParamKind::UnwrappedValue: + if (canUseUnwrappedRawValue(attr)) + type = attr.getReturnType(); + else + type = attr.getStorageType(); + break; + } + + std::string defaultValue; + // Attach default value if requested and possible. + if (attrParamKind == AttrParamKind::UnwrappedValue && + i >= defaultValuedAttrStartIndex) { + bool isString = attr.getReturnType() == "::llvm::StringRef"; + if (isString) + defaultValue.append("\""); + defaultValue += attr.getDefaultValue(); + if (isString) + defaultValue.append("\""); + } + paramList.emplace_back(type, namedAttr.name, defaultValue, properties); + ++numAttrs; + } + } + + /// Insert parameters for each successor. + for (const NamedSuccessor &succ : op.getSuccessors()) { + StringRef type = + succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *"; + paramList.emplace_back(type, succ.name); + } + + /// Insert parameters for variadic regions. + for (const NamedRegion ®ion : op.getRegions()) + if (region.isVariadic()) + paramList.emplace_back("unsigned", + llvm::formatv("{0}Count", region.name).str()); +} + +void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, + bool isRawValueAttr) { + // Push all operands to the result. + for (int i = 0, e = op.getNumOperands(); i < e; ++i) { + std::string argName = getArgumentName(op, i); + if (op.getOperand(i).isOptional()) + body << " if (" << argName << ")\n "; + body << " " << builderOpState << ".addOperands(" << argName << ");\n"; + } + + // If the operation has the operand segment size attribute, add it here. + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + body << " " << builderOpState + << ".addAttribute(\"operand_segment_sizes\", " + "odsBuilder.getI32VectorAttr({"; + interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { + if (op.getOperand(i).isOptional()) + body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; + else if (op.getOperand(i).isVariadic()) + body << "static_cast(" << getArgumentName(op, i) << ".size())"; + else + body << "1"; + }); + body << "}));\n"; + } + + // Push all attributes to the result. + for (const auto &namedAttr : op.getAttributes()) { + auto &attr = namedAttr.attr; + if (!attr.isDerivedAttr()) { + bool emitNotNullCheck = attr.isOptional(); + if (emitNotNullCheck) { + body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; + } + if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { + // If this is a raw value, then we need to wrap it in an Attribute + // instance. + FmtContext fctx; + fctx.withBuilder("odsBuilder"); + + std::string builderTemplate = + std::string(attr.getConstBuilderTemplate()); + + // For StringAttr, its constant builder call will wrap the input in + // quotes, which is correct for normal string literals, but incorrect + // here given we use function arguments. So we need to strip the + // wrapping quotes. + if (StringRef(builderTemplate).contains("\"$0\"")) + builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); + + std::string value = + std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); + body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, + namedAttr.name, value); + } else { + body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, + namedAttr.name); + } + if (emitNotNullCheck) { + body << " }\n"; + } + } + } + + // Create the correct number of regions. + for (const NamedRegion ®ion : op.getRegions()) { + if (region.isVariadic()) + body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", + region.name); + + body << " (void)" << builderOpState << ".addRegion();\n"; + } + + // Push all successors to the result. + for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { + body << formatv(" {0}.addSuccessors({1});\n", builderOpState, + namedSuccessor.name); + } +} + +void OpEmitter::genCanonicalizerDecls() { + bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod"); + if (hasCanonicalizeMethod) { + // static LogicResult FooOp:: + // canonicalize(FooOp op, PatternRewriter &rewriter); + SmallVector paramList; + paramList.emplace_back(op.getCppClassName(), "op"); + paramList.emplace_back("::mlir::PatternRewriter &", "rewriter"); + opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize", + OpMethod::MP_StaticDeclaration, + std::move(paramList)); + } + + // We get a prototype for 'getCanonicalizationPatterns' if requested directly + // or if using a 'canonicalize' method. + bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer"); + if (!hasCanonicalizeMethod && !hasCanonicalizer) + return; + + // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize' + // method, but not implementing 'getCanonicalizationPatterns' manually. + bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer; + + // Add a signature for getCanonicalizationPatterns if implemented by the + // dialect or if synthesized to call 'canonicalize'. + SmallVector paramList; + paramList.emplace_back("::mlir::RewritePatternSet &", "results"); + paramList.emplace_back("::mlir::MLIRContext *", "context"); + auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; + auto *method = opClass.addMethodAndPrune( + "void", "getCanonicalizationPatterns", kind, std::move(paramList)); + + // If synthesizing the method, fill it it. + if (hasBody) + method->body() << " results.add(canonicalize);\n"; +} + +void OpEmitter::genFolderDecls() { + bool hasSingleResult = + op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0; + + if (def.getValueAsBit("hasFolder")) { + if (hasSingleResult) { + opClass.addMethodAndPrune( + "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration, + "::llvm::ArrayRef<::mlir::Attribute>", "operands"); + } else { + SmallVector paramList; + paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); + paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", + "results"); + opClass.addMethodAndPrune("::mlir::LogicalResult", "fold", + OpMethod::MP_Declaration, std::move(paramList)); + } + } +} + +void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) { + Interface interface = opTrait->getInterface(); + + // Get the set of methods that should always be declared. + auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods(); + llvm::StringSet<> alwaysDeclaredMethods; + alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), + alwaysDeclaredMethodsVec.end()); + + for (const InterfaceMethod &method : interface.getMethods()) { + // Don't declare if the method has a body. + if (method.getBody()) + continue; + // Don't declare if the method has a default implementation and the op + // didn't request that it always be declared. + if (method.getDefaultImplementation() && + !alwaysDeclaredMethods.count(method.getName())) + continue; + genOpInterfaceMethod(method); + } +} + +OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, + bool declaration) { + SmallVector paramList; + for (const InterfaceMethod::Argument &arg : method.getArguments()) + paramList.emplace_back(arg.type, arg.name); + + auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None; + if (declaration) + properties = + static_cast(properties | OpMethod::MP_Declaration); + return opClass.addMethodAndPrune(method.getReturnType(), method.getName(), + properties, std::move(paramList)); +} + +void OpEmitter::genOpInterfaceMethods() { + for (const auto &trait : op.getTraits()) { + if (const auto *opTrait = dyn_cast(&trait)) + if (opTrait->shouldDeclareMethods()) + genOpInterfaceMethods(opTrait); + } +} + +void OpEmitter::genSideEffectInterfaceMethods() { + enum EffectKind { Operand, Result, Symbol, Static }; + struct EffectLocation { + /// The effect applied. + SideEffect effect; + + /// The index if the kind is not static. + unsigned index : 30; + + /// The kind of the location. + unsigned kind : 2; + }; + + StringMap> interfaceEffects; + auto resolveDecorators = [&](Operator::var_decorator_range decorators, + unsigned index, unsigned kind) { + for (auto decorator : decorators) + if (SideEffect *effect = dyn_cast(&decorator)) { + opClass.addTrait(effect->getInterfaceTrait()); + interfaceEffects[effect->getBaseEffectName()].push_back( + EffectLocation{*effect, index, kind}); + } + }; + + // Collect effects that were specified via: + /// Traits. + for (const auto &trait : op.getTraits()) { + const auto *opTrait = dyn_cast(&trait); + if (!opTrait) + continue; + auto &effects = interfaceEffects[opTrait->getBaseEffectName()]; + for (auto decorator : opTrait->getEffects()) + effects.push_back(EffectLocation{cast(decorator), + /*index=*/0, EffectKind::Static}); + } + /// Attributes and Operands. + for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) { + Argument arg = op.getArg(i); + if (arg.is()) { + resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand); + ++operandIt; + continue; + } + const NamedAttribute *attr = arg.get(); + if (attr->attr.getBaseAttr().isSymbolRefAttr()) + resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol); + } + /// Results. + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) + resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result); + + // The code used to add an effect instance. + // {0}: The effect class. + // {1}: Optional value or symbol reference. + // {1}: The resource class. + const char *addEffectCode = + " effects.emplace_back({0}::get(), {1}{2}::get());\n"; + + for (auto &it : interfaceEffects) { + // Generate the 'getEffects' method. + std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::" + "SideEffects::EffectInstance<{0}>> &", + it.first()) + .str(); + auto *getEffects = + opClass.addMethodAndPrune("void", "getEffects", type, "effects"); + auto &body = getEffects->body(); + + // Add effect instances for each of the locations marked on the operation. + for (auto &location : it.second) { + StringRef effect = location.effect.getName(); + StringRef resource = location.effect.getResource(); + if (location.kind == EffectKind::Static) { + // A static instance has no attached value. + body << llvm::formatv(addEffectCode, effect, "", resource).str(); + } else if (location.kind == EffectKind::Symbol) { + // A symbol reference requires adding the proper attribute. + const auto *attr = op.getArg(location.index).get(); + if (attr->attr.isOptional()) { + body << " if (auto symbolRef = " << attr->name << "Attr())\n " + << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource) + .str(); + } else { + body << llvm::formatv(addEffectCode, effect, attr->name + "(), ", + resource) + .str(); + } + } else { + // Otherwise this is an operand/result, so we need to attach the Value. + body << " for (::mlir::Value value : getODS" + << (location.kind == EffectKind::Operand ? "Operands" : "Results") + << "(" << location.index << "))\n " + << llvm::formatv(addEffectCode, effect, "value, ", resource).str(); + } + } + } +} + +void OpEmitter::genTypeInterfaceMethods() { + if (!op.allResultTypesKnown()) + return; + // Generate 'inferReturnTypes' method declaration using the interface method + // declared in 'InferTypeOpInterface' op interface. + const auto *trait = dyn_cast( + op.getTrait("::mlir::InferTypeOpInterface::Trait")); + Interface interface = trait->getInterface(); + OpMethod *method = [&]() -> OpMethod * { + for (const InterfaceMethod &interfaceMethod : interface.getMethods()) { + if (interfaceMethod.getName() == "inferReturnTypes") { + return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false); + } + } + assert(0 && "unable to find inferReturnTypes interface method"); + return nullptr; + }(); + auto &body = method->body(); + body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; + + FmtContext fctx; + fctx.withBuilder("odsBuilder"); + body << " ::mlir::Builder odsBuilder(context);\n"; + + auto emitType = + [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & { + if (type.isArg()) { + auto argIndex = type.getArg(); + assert(!op.getArg(argIndex).is()); + auto arg = op.getArgToOperandOrAttribute(argIndex); + if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) + return body << "operands[" << arg.operandOrAttributeIndex() + << "].getType()"; + return body << "attributes[" << arg.operandOrAttributeIndex() + << "].getType()"; + } else { + return body << tgfmt(*type.getType().getBuilderCall(), &fctx); + } + }; + + for (int i = 0, e = op.getNumResults(); i != e; ++i) { + body << " inferredReturnTypes[" << i << "] = "; + auto types = op.getSameTypeAsResult(i); + emitType(types[0]) << ";\n"; + if (types.size() == 1) + continue; + // TODO: We could verify equality here, but skipping that for verification. + } + body << " return ::mlir::success();"; +} + +void OpEmitter::genParser() { + if (!hasStringAttribute(def, "parser") || + hasStringAttribute(def, "assemblyFormat")) + return; + + SmallVector paramList; + paramList.emplace_back("::mlir::OpAsmParser &", "parser"); + paramList.emplace_back("::mlir::OperationState &", "result"); + auto *method = + opClass.addMethodAndPrune("::mlir::ParseResult", "parse", + OpMethod::MP_Static, std::move(paramList)); + + FmtContext fctx; + fctx.addSubst("cppClass", opClass.getClassName()); + auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r"); + method->body() << " " << tgfmt(parser, &fctx); +} + +void OpEmitter::genPrinter() { + if (hasStringAttribute(def, "assemblyFormat")) + return; + + auto valueInit = def.getValueInit("printer"); + StringInit *stringInit = dyn_cast(valueInit); + if (!stringInit) + return; + + auto *method = + opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p"); + FmtContext fctx; + fctx.addSubst("cppClass", opClass.getClassName()); + auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r"); + method->body() << " " << tgfmt(printer, &fctx); +} + +void OpEmitter::genVerifier() { + auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify"); + auto &body = method->body(); + // body << " if (failed(" << op.getAdaptorName() + // << "(*this).verify((*this)->getLoc()))) " + // << "return ::mlir::failure();\n"; + + auto *valueInit = def.getValueInit("verifier"); + StringInit *stringInit = dyn_cast(valueInit); + bool hasCustomVerify = stringInit && !stringInit->getValue().empty(); + populateSubstitutions(op, "(*this)->getAttr", "this->getODSOperands", + "this->getODSResults", verifyCtx); + + genAttributeVerifier(op, "(*this)->getAttr", "emitOpError(", + /*emitVerificationRequiringOp=*/true, verifyCtx, body); + genOperandResultVerifier(body, op.getOperands(), "operand"); + genOperandResultVerifier(body, op.getResults(), "result"); + + for (auto &trait : op.getTraits()) { + if (auto *t = dyn_cast(&trait)) { + body << tgfmt(" if (!($0))\n " + "return emitOpError(\"failed to verify that $1\");\n", + &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), + t->getSummary()); + } + } + + genRegionVerifier(body); + genSuccessorVerifier(body); + + if (hasCustomVerify) { + FmtContext fctx; + fctx.addSubst("cppClass", opClass.getClassName()); + auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r"); + body << " " << tgfmt(printer, &fctx); + } else { + body << " return ::mlir::success();\n"; + } +} + +void OpEmitter::genOperandResultVerifier(OpMethodBody &body, + Operator::value_range values, + StringRef valueKind) { + FmtContext fctx; + + body << " {\n"; + body << " unsigned index = 0; (void)index;\n"; + + for (auto staticValue : llvm::enumerate(values)) { + bool hasPredicate = staticValue.value().hasPredicate(); + bool isOptional = staticValue.value().isOptional(); + if (!hasPredicate && !isOptional) + continue; + body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n", + // Capitalize the first letter to match the function name + valueKind.substr(0, 1).upper(), valueKind.substr(1), + staticValue.index()); + + // If the constraint is optional check that the value group has at most 1 + // value. + if (isOptional) { + body << formatv(" if (valueGroup{0}.size() > 1)\n" + " return emitOpError(\"{1} group starting at #\") " + "<< index << \" requires 0 or 1 element, but found \" << " + "valueGroup{0}.size();\n", + staticValue.index(), valueKind); + } + + // Otherwise, if there is no predicate there is nothing left to do. + if (!hasPredicate) + continue; + // Emit a loop to check all the dynamic values in the pack. + StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn( + staticValue.value().constraint); + body << " for (::mlir::Value v : valueGroup" << staticValue.index() + << ") {\n" + << " if (::mlir::failed(" << constraintFn + << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n" + << " return ::mlir::failure();\n" + << " ++index;\n" + << " }\n"; + } + + body << " }\n"; +} + +void OpEmitter::genRegionVerifier(OpMethodBody &body) { + // If we have no regions, there is nothing more to do. + unsigned numRegions = op.getNumRegions(); + if (numRegions == 0) + return; + + body << "{\n"; + body << " unsigned index = 0; (void)index;\n"; + + for (unsigned i = 0; i < numRegions; ++i) { + const auto ®ion = op.getRegion(i); + if (region.constraint.getPredicate().isNull()) + continue; + + body << " for (::mlir::Region ®ion : "; + body << formatv(region.isVariadic() + ? "{0}()" + : "::mlir::MutableArrayRef<::mlir::Region>((*this)" + "->getRegion({1}))", + region.name, i); + body << ") {\n"; + auto constraint = tgfmt(region.constraint.getConditionTemplate(), + &verifyCtx.withSelf("region")) + .str(); + + body << formatv(" (void)region;\n" + " if (!({0})) {\n " + "return emitOpError(\"region #\") << index << \" {1}" + "failed to " + "verify constraint: {2}\";\n }\n", + constraint, + region.name.empty() ? "" : "('" + region.name + "') ", + region.constraint.getSummary()) + << " ++index;\n" + << " }\n"; + } + body << " }\n"; +} + +void OpEmitter::genSuccessorVerifier(OpMethodBody &body) { + // If we have no successors, there is nothing more to do. + unsigned numSuccessors = op.getNumSuccessors(); + if (numSuccessors == 0) + return; + + body << "{\n"; + body << " unsigned index = 0; (void)index;\n"; + + for (unsigned i = 0; i < numSuccessors; ++i) { + const auto &successor = op.getSuccessor(i); + if (successor.constraint.getPredicate().isNull()) + continue; + + if (successor.isVariadic()) { + body << formatv(" for (::mlir::Block *successor : {0}()) {\n", + successor.name); + } else { + body << " {\n"; + body << formatv(" ::mlir::Block *successor = {0}();\n", + successor.name); + } + auto constraint = tgfmt(successor.constraint.getConditionTemplate(), + &verifyCtx.withSelf("successor")) + .str(); + + body << formatv(" (void)successor;\n" + " if (!({0})) {\n " + "return emitOpError(\"successor #\") << index << \"('{1}') " + "failed to " + "verify constraint: {2}\";\n }\n", + constraint, successor.name, + successor.constraint.getSummary()) + << " ++index;\n" + << " }\n"; + } + body << " }\n"; +} + +/// Add a size count trait to the given operation class. +static void addSizeCountTrait(OpClass &opClass, StringRef traitKind, + int numTotal, int numVariadic) { + if (numVariadic != 0) { + if (numTotal == numVariadic) + opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s"); + else + opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" + + Twine(numTotal - numVariadic) + ">::Impl"); + return; + } + switch (numTotal) { + case 0: + opClass.addTrait("::mlir::OpTrait::Zero" + traitKind); + break; + case 1: + opClass.addTrait("::mlir::OpTrait::One" + traitKind); + break; + default: + opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) + + ">::Impl"); + break; + } +} + +void OpEmitter::genTraits() { + // Add region size trait. + unsigned numRegions = op.getNumRegions(); + unsigned numVariadicRegions = op.getNumVariadicRegions(); + addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions); + + // Add result size traits. + int numResults = op.getNumResults(); + int numVariadicResults = op.getNumVariableLengthResults(); + addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); + + // For single result ops with a known specific type, generate a OneTypedResult + // trait. + if (numResults == 1 && numVariadicResults == 0) { + auto cppName = op.getResults().begin()->constraint.getCPPClassName(); + opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl"); + } + + // Add successor size trait. + unsigned numSuccessors = op.getNumSuccessors(); + unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); + addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors); + + // Add variadic size trait and normal op traits. + int numOperands = op.getNumOperands(); + int numVariadicOperands = op.getNumVariableLengthOperands(); + + // Add operand size trait. + if (numVariadicOperands != 0) { + if (numOperands == numVariadicOperands) + opClass.addTrait("::mlir::OpTrait::VariadicOperands"); + else + opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" + + Twine(numOperands - numVariadicOperands) + ">::Impl"); + } else { + switch (numOperands) { + case 0: + opClass.addTrait("::mlir::OpTrait::ZeroOperands"); + break; + case 1: + opClass.addTrait("::mlir::OpTrait::OneOperand"); + break; + default: + opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) + + ">::Impl"); + break; + } + } + + // Add the native and interface traits. + for (const auto &trait : op.getTraits()) { + if (auto opTrait = dyn_cast(&trait)) + opClass.addTrait(opTrait->getFullyQualifiedTraitName()); + else if (auto opTrait = dyn_cast(&trait)) + opClass.addTrait(opTrait->getFullyQualifiedTraitName()); + } +} + +void OpEmitter::genOpNameGetter() { + auto *method = opClass.addMethodAndPrune( + "std::string", "getOperationName", + OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr)); + method->body() << " return std::string(\"" << op.getOperationName() + << "\");"; +} + +void OpEmitter::genOpAsmInterface() { + // If the user only has one results or specifically added the Asm trait, + // then don't generate it for them. We specifically only handle multi result + // operations, because the name of a single result in the common case is not + // interesting(generally 'result'/'output'/etc.). + // TODO: We could also add a flag to allow operations to opt in to this + // generation, even if they only have a single operation. + int numResults = op.getNumResults(); + if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait")) + return; + + SmallVector resultNames(numResults); + for (int i = 0; i != numResults; ++i) + resultNames[i] = op.getResultName(i); + + // Don't add the trait if none of the results have a valid name. + if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); })) + return; + opClass.addTrait("::mlir::OpAsmOpInterface::Trait"); + + // Generate the right accessor for the number of results. + auto *method = opClass.addMethodAndPrune( + "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn"); + auto &body = method->body(); + for (int i = 0; i != numResults; ++i) { + body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" + << " if (!llvm::empty(resultGroup" << i << "))\n" + << " setNameFn(*resultGroup" << i << ".begin(), \"" + << resultNames[i] << "\");\n"; + } +} + +//===----------------------------------------------------------------------===// +// OpOperandAdaptor emitter +//===----------------------------------------------------------------------===// + +namespace { +// Helper class to emit Op operand adaptors to an output stream. Operand +// adaptors are wrappers around ArrayRef that provide named operand +// getters identical to those defined in the Op. +class OpOperandAdaptorEmitter { +public: + static void emitDecl(const Operator &op, raw_ostream &os); + static void emitDef(const Operator &op, raw_ostream &os); + +private: + explicit OpOperandAdaptorEmitter(const Operator &op); + + // Add verification function. This generates a verify method for the adaptor + // which verifies all the op-independent attribute constraints. + void addVerification(); + + const Operator &op; + Class adaptor; +}; +} // end namespace + +OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) + : op(op), adaptor(op.getAdaptorName()) { + adaptor.newField("::mlir::ValueRange", "odsOperands"); + adaptor.newField("::mlir::DictionaryAttr", "odsAttrs"); + adaptor.newField("::mlir::RegionRange", "odsRegions"); + const auto *attrSizedOperands = + op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); + { + SmallVector paramList; + paramList.emplace_back("::mlir::ValueRange", "values"); + paramList.emplace_back("::mlir::DictionaryAttr", "attrs", + attrSizedOperands ? "" : "nullptr"); + paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); + auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList)); + + constructor->addMemberInitializer("odsOperands", "values"); + constructor->addMemberInitializer("odsAttrs", "attrs"); + constructor->addMemberInitializer("odsRegions", "regions"); + } + + { + auto *constructor = adaptor.addConstructorAndPrune( + llvm::formatv("{0}&", op.getCppClassName()).str(), "op"); + constructor->addMemberInitializer("odsOperands", "op->getOperands()"); + constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); + constructor->addMemberInitializer("odsRegions", "op->getRegions()"); + } + + { + auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands"); + m->body() << " return odsOperands;"; + } + std::string sizeAttrInit = + formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); + generateNamedOperandGetters(op, adaptor, sizeAttrInit, + /*rangeType=*/"::mlir::ValueRange", + /*rangeBeginCall=*/"odsOperands.begin()", + /*rangeSizeCall=*/"odsOperands.size()", + /*getOperandCallPattern=*/"odsOperands[{0}]"); + + FmtContext fctx; + fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); + + auto emitAttr = [&](StringRef name, Attribute attr) { + auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body(); + body << " assert(odsAttrs && \"no attributes when constructing adapter\");" + << "\n " << attr.getStorageType() << " attr = " + << "odsAttrs.get(\"" << name << "\")."; + if (attr.hasDefaultValue() || attr.isOptional()) + body << "dyn_cast_or_null<"; + else + body << "cast<"; + body << attr.getStorageType() << ">();\n"; + + if (attr.hasDefaultValue()) { + // Use the default value if attribute is not set. + // TODO: this is inefficient, we are recreating the attribute for every + // call. This should be set instead. + std::string defaultValue = std::string( + tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + body << " if (!attr)\n attr = " << defaultValue << ";\n"; + } + body << " return attr;\n"; + }; + + { + auto *m = + adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes"); + m->body() << " return odsAttrs;"; + } + for (auto &namedAttr : op.getAttributes()) { + const auto &name = namedAttr.name; + const auto &attr = namedAttr.attr; + if (!attr.isDerivedAttr()) + emitAttr(name, attr); + } + + unsigned numRegions = op.getNumRegions(); + if (numRegions > 0) { + auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions"); + m->body() << " return odsRegions;"; + } + for (unsigned i = 0; i < numRegions; ++i) { + const auto ®ion = op.getRegion(i); + if (region.name.empty()) + continue; + + // Generate the accessors for a variadic region. + if (region.isVariadic()) { + auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name); + m->body() << formatv(" return odsRegions.drop_front({0});", i); + continue; + } + + auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name); + m->body() << formatv(" return *odsRegions[{0}];", i); + } + + // Add verification function. + addVerification(); +} + +void OpOperandAdaptorEmitter::addVerification() { + auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify", + "::mlir::Location", "loc"); + auto &body = method->body(); + + const char *checkAttrSizedValueSegmentsCode = R"( + { + auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); + auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); + if (numElements != {1}) + return emitError(loc, "'{0}' attribute for specifying {2} segments " + "must have {1} elements, but got ") << numElements; + } + )"; + + // Verify a few traits first so that we can use + // getODSOperands()/getODSResults() in the rest of the verifier. + for (auto &trait : op.getTraits()) { + if (auto *t = dyn_cast(&trait)) { + if (t->getFullyQualifiedTraitName() == + "::mlir::OpTrait::AttrSizedOperandSegments") { + body << formatv(checkAttrSizedValueSegmentsCode, + "operand_segment_sizes", op.getNumOperands(), + "operand"); + } else if (t->getFullyQualifiedTraitName() == + "::mlir::OpTrait::AttrSizedResultSegments") { + body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes", + op.getNumResults(), "result"); + } + } + } + + FmtContext verifyCtx; + populateSubstitutions(op, "odsAttrs.get", "getODSOperands", + "", verifyCtx); + genAttributeVerifier(op, "odsAttrs.get", + Twine("emitError(loc, \"'") + op.getOperationName() + + "' op \"", + /*emitVerificationRequiringOp*/ false, verifyCtx, body); + + body << " return ::mlir::success();"; +} + +void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { + OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os); +} + +void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { + OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os); +} + +// Emits the opcode enum and op classes. +static void emitOpClasses(const RecordKeeper &recordKeeper, + const std::vector &defs, raw_ostream &os, + bool emitDecl) { + // First emit forward declaration for each class, this allows them to refer + // to each others in traits for example. + if (emitDecl) { + os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n"; + os << "#undef GET_OP_FWD_DEFINES\n"; + for (auto *def : defs) { + Operator op(*def); + NamespaceEmitter emitter(os, op.getCppNamespace()); + os << "class " << op.getCppClassName() << ";\n"; + } + os << "#endif\n\n"; + } + + IfDefScope scope("GET_OP_CLASSES", os); + if (defs.empty()) + return; + + // Generate all of the locally instantiated methods first. + StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os, + emitDecl); + for (auto *def : defs) { + Operator op(*def); + NamespaceEmitter emitter(os, op.getCppNamespace()); + if (emitDecl) { + os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); + // OpOperandAdaptorEmitter::emitDecl(op, os); + OpEmitter::emitDecl(op, os, staticVerifierEmitter); + } else { + os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); + // OpOperandAdaptorEmitter::emitDef(op, os); + OpEmitter::emitDef(op, os, staticVerifierEmitter); + } + } +} + +// Emits a comma-separated list of the ops. +static void emitOpList(const std::vector &defs, raw_ostream &os) { + IfDefScope scope("GET_OP_LIST", os); + + interleave( + // TODO: We are constructing the Operator wrapper instance just for + // getting it's qualified class name here. Reduce the overhead by having a + // lightweight version of Operator class just for that purpose. + defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, + [&os]() { os << ",\n"; }); +} + +static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { + emitSourceFileHeader("Op Declarations", os); + + std::vector defs = getRequestedOpDefinitions(recordKeeper); + emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); + + return false; +} + +static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { + emitSourceFileHeader("Op Definitions", os); + + std::vector defs = getRequestedOpDefinitions(recordKeeper); + emitOpList(defs, os); + emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); + + return false; +} + +static mlir::GenRegistration + genOpDecls("gen-builder-decls", "Generate op declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitOpDecls(records, os); + }); + +static mlir::GenRegistration genOpDefs("gen-builder-defs", "Generate op definitions", + [](const RecordKeeper &records, + raw_ostream &os) { + return emitOpDefs(records, os); + }); diff --git a/tools/mlir-tblgen-builder/OpFormatGen.cpp b/tools/mlir-tblgen-builder/OpFormatGen.cpp new file mode 100644 index 0000000..ca7db45 --- /dev/null +++ b/tools/mlir-tblgen-builder/OpFormatGen.cpp @@ -0,0 +1,3294 @@ +//===- OpFormatGen.cpp - MLIR operation asm format generator --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OpFormatGen.h" +#include "mlir/Support/LogicalResult.h" +#include "TableGen/Format.h" +#include "TableGen/GenInfo.h" +#include "TableGen/Interfaces.h" +#include "TableGen/OpClass.h" +#include "TableGen/Operator.h" +#include "TableGen/Trait.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +#define DEBUG_TYPE "mlir-tblgen-opformatgen" + +using namespace mlir; +using namespace mlir::tblgen; + +static llvm::cl::opt formatErrorIsFatal( + "asmformat-error-is-fatal", + llvm::cl::desc("Emit a fatal error if format parsing fails"), + llvm::cl::init(true)); + +/// Returns true if the given string can be formatted as a keyword. +static bool canFormatStringAsKeyword(StringRef value) { + if (!isalpha(value.front()) && value.front() != '_') + return false; + return llvm::all_of(value.drop_front(), [](char c) { + return isalnum(c) || c == '_' || c == '$' || c == '.'; + }); +} + +//===----------------------------------------------------------------------===// +// Element +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents a single format element. +class Element { +public: + enum class Kind { + /// This element is a directive. + AttrDictDirective, + CustomDirective, + FunctionalTypeDirective, + OperandsDirective, + RefDirective, + RegionsDirective, + ResultsDirective, + SuccessorsDirective, + TypeDirective, + + /// This element is a literal. + Literal, + + /// This element is a whitespace. + Newline, + Space, + + /// This element is an variable value. + AttributeVariable, + OperandVariable, + RegionVariable, + ResultVariable, + SuccessorVariable, + + /// This element is an optional element. + Optional, + }; + Element(Kind kind) : kind(kind) {} + virtual ~Element() = default; + + /// Return the kind of this element. + Kind getKind() const { return kind; } + +private: + /// The kind of this element. + Kind kind; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// VariableElement + +namespace { +/// This class represents an instance of an variable element. A variable refers +/// to something registered on the operation itself, e.g. an argument, result, +/// etc. +template +class VariableElement : public Element { +public: + VariableElement(const VarT *var) : Element(kindVal), var(var) {} + static bool classof(const Element *element) { + return element->getKind() == kindVal; + } + const VarT *getVar() { return var; } + +protected: + const VarT *var; +}; + +/// This class represents a variable that refers to an attribute argument. +struct AttributeVariable + : public VariableElement { + using VariableElement::VariableElement; + + /// Return the constant builder call for the type of this attribute, or None + /// if it doesn't have one. + Optional getTypeBuilder() const { + Optional attrType = var->attr.getValueType(); + return attrType ? attrType->getBuilderCall() : llvm::None; + } + + /// Return if this attribute refers to a UnitAttr. + bool isUnitAttr() const { + return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; + } +}; + +/// This class represents a variable that refers to an operand argument. +using OperandVariable = + VariableElement; + +/// This class represents a variable that refers to a region. +using RegionVariable = + VariableElement; + +/// This class represents a variable that refers to a result. +using ResultVariable = + VariableElement; + +/// This class represents a variable that refers to a successor. +using SuccessorVariable = + VariableElement; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// DirectiveElement + +namespace { +/// This class implements single kind directives. +template +class DirectiveElement : public Element { +public: + DirectiveElement() : Element(type){}; + static bool classof(const Element *ele) { return ele->getKind() == type; } +}; +/// This class represents the `operands` directive. This directive represents +/// all of the operands of an operation. +using OperandsDirective = DirectiveElement; + +/// This class represents the `regions` directive. This directive represents +/// all of the regions of an operation. +using RegionsDirective = DirectiveElement; + +/// This class represents the `results` directive. This directive represents +/// all of the results of an operation. +using ResultsDirective = DirectiveElement; + +/// This class represents the `successors` directive. This directive represents +/// all of the successors of an operation. +using SuccessorsDirective = + DirectiveElement; + +/// This class represents the `attr-dict` directive. This directive represents +/// the attribute dictionary of the operation. +class AttrDictDirective + : public DirectiveElement { +public: + explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {} + bool isWithKeyword() const { return withKeyword; } + +private: + /// If the dictionary should be printed with the 'attributes' keyword. + bool withKeyword; +}; + +/// This class represents a custom format directive that is implemented by the +/// user in C++. +class CustomDirective : public Element { +public: + CustomDirective(StringRef name, + std::vector> &&arguments) + : Element{Kind::CustomDirective}, name(name), + arguments(std::move(arguments)) {} + + static bool classof(const Element *element) { + return element->getKind() == Kind::CustomDirective; + } + + /// Return the name of this optional element. + StringRef getName() const { return name; } + + /// Return the arguments to the custom directive. + auto getArguments() const { return llvm::make_pointee_range(arguments); } + +private: + /// The user provided name of the directive. + StringRef name; + + /// The arguments to the custom directive. + std::vector> arguments; +}; + +/// This class represents the `functional-type` directive. This directive takes +/// two arguments and formats them, respectively, as the inputs and results of a +/// FunctionType. +class FunctionalTypeDirective + : public DirectiveElement { +public: + FunctionalTypeDirective(std::unique_ptr inputs, + std::unique_ptr results) + : inputs(std::move(inputs)), results(std::move(results)) {} + Element *getInputs() const { return inputs.get(); } + Element *getResults() const { return results.get(); } + +private: + /// The input and result arguments. + std::unique_ptr inputs, results; +}; + +/// This class represents the `ref` directive. +class RefDirective : public DirectiveElement { +public: + RefDirective(std::unique_ptr arg) : operand(std::move(arg)) {} + Element *getOperand() const { return operand.get(); } + +private: + /// The operand that is used to format the directive. + std::unique_ptr operand; +}; + +/// This class represents the `type` directive. +class TypeDirective : public DirectiveElement { +public: + TypeDirective(std::unique_ptr arg) : operand(std::move(arg)) {} + Element *getOperand() const { return operand.get(); } + +private: + /// The operand that is used to format the directive. + std::unique_ptr operand; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// LiteralElement + +namespace { +/// This class represents an instance of a literal element. +class LiteralElement : public Element { +public: + LiteralElement(StringRef literal) + : Element{Kind::Literal}, literal(literal) {} + static bool classof(const Element *element) { + return element->getKind() == Kind::Literal; + } + + /// Return the literal for this element. + StringRef getLiteral() const { return literal; } + + /// Returns true if the given string is a valid literal. + static bool isValidLiteral(StringRef value); + +private: + /// The spelling of the literal for this element. + StringRef literal; +}; +} // end anonymous namespace + +bool LiteralElement::isValidLiteral(StringRef value) { + if (value.empty()) + return false; + char front = value.front(); + + // If there is only one character, this must either be punctuation or a + // single character bare identifier. + if (value.size() == 1) + return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front); + + // Check the punctuation that are larger than a single character. + if (value == "->") + return true; + + // Otherwise, this must be an identifier. + return canFormatStringAsKeyword(value); +} + +//===----------------------------------------------------------------------===// +// WhitespaceElement + +namespace { +/// This class represents a whitespace element, e.g. newline or space. It's a +/// literal that is printed but never parsed. +class WhitespaceElement : public Element { +public: + WhitespaceElement(Kind kind) : Element{kind} {} + static bool classof(const Element *element) { + Kind kind = element->getKind(); + return kind == Kind::Newline || kind == Kind::Space; + } +}; + +/// This class represents an instance of a newline element. It's a literal that +/// prints a newline. It is ignored by the parser. +class NewlineElement : public WhitespaceElement { +public: + NewlineElement() : WhitespaceElement(Kind::Newline) {} + static bool classof(const Element *element) { + return element->getKind() == Kind::Newline; + } +}; + +/// This class represents an instance of a space element. It's a literal that +/// prints or omits printing a space. It is ignored by the parser. +class SpaceElement : public WhitespaceElement { +public: + SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {} + static bool classof(const Element *element) { + return element->getKind() == Kind::Space; + } + + /// Returns true if this element should print as a space. Otherwise, the + /// element should omit printing a space between the surrounding elements. + bool getValue() const { return value; } + +private: + bool value; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// OptionalElement + +namespace { +/// This class represents a group of elements that are optionally emitted based +/// upon an optional variable of the operation, and a group of elements that are +/// emotted when the anchor element is not present. +class OptionalElement : public Element { +public: + OptionalElement(std::vector> &&thenElements, + std::vector> &&elseElements, + unsigned anchor, unsigned parseStart) + : Element{Kind::Optional}, thenElements(std::move(thenElements)), + elseElements(std::move(elseElements)), anchor(anchor), + parseStart(parseStart) {} + static bool classof(const Element *element) { + return element->getKind() == Kind::Optional; + } + + /// Return the `then` elements of this grouping. + auto getThenElements() const { + return llvm::make_pointee_range(thenElements); + } + + /// Return the `else` elements of this grouping. + auto getElseElements() const { + return llvm::make_pointee_range(elseElements); + } + + /// Return the anchor of this optional group. + Element *getAnchor() const { return thenElements[anchor].get(); } + + /// Return the index of the first element that needs to be parsed. + unsigned getParseStart() const { return parseStart; } + +private: + /// The child elements of `then` branch of this optional. + std::vector> thenElements; + /// The child elements of `else` branch of this optional. + std::vector> elseElements; + /// The index of the element that acts as the anchor for the optional group. + unsigned anchor; + /// The index of the first element that is parsed (is not a + /// WhitespaceElement). + unsigned parseStart; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// OperationFormat +//===----------------------------------------------------------------------===// + +namespace { + +using ConstArgument = + llvm::PointerUnion; + +struct OperationFormat { + /// This class represents a specific resolver for an operand or result type. + class TypeResolution { + public: + TypeResolution() = default; + + /// Get the index into the buildable types for this type, or None. + Optional getBuilderIdx() const { return builderIdx; } + void setBuilderIdx(int idx) { builderIdx = idx; } + + /// Get the variable this type is resolved to, or nullptr. + const NamedTypeConstraint *getVariable() const { + return resolver.dyn_cast(); + } + /// Get the attribute this type is resolved to, or nullptr. + const NamedAttribute *getAttribute() const { + return resolver.dyn_cast(); + } + /// Get the transformer for the type of the variable, or None. + Optional getVarTransformer() const { + return variableTransformer; + } + void setResolver(ConstArgument arg, Optional transformer) { + resolver = arg; + variableTransformer = transformer; + assert(getVariable() || getAttribute()); + } + + private: + /// If the type is resolved with a buildable type, this is the index into + /// 'buildableTypes' in the parent format. + Optional builderIdx; + /// If the type is resolved based upon another operand or result, this is + /// the variable or the attribute that this type is resolved to. + ConstArgument resolver; + /// If the type is resolved based upon another operand or result, this is + /// a transformer to apply to the variable when resolving. + Optional variableTransformer; + }; + + OperationFormat(const Operator &op) + : allOperands(false), allOperandTypes(false), allResultTypes(false) { + operandTypes.resize(op.getNumOperands(), TypeResolution()); + resultTypes.resize(op.getNumResults(), TypeResolution()); + + hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) { + return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator"); + }); + + hasSingleBlockTrait = + hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock"); + } + + /// Generate the operation parser from this format. + void genParser(Operator &op, OpClass &opClass); + /// Generate the parser code for a specific format element. + void genElementParser(Element *element, OpMethodBody &body, + FmtContext &attrTypeCtx); + /// Generate the c++ to resolve the types of operands and results during + /// parsing. + void genParserTypeResolution(Operator &op, OpMethodBody &body); + /// Generate the c++ to resolve regions during parsing. + void genParserRegionResolution(Operator &op, OpMethodBody &body); + /// Generate the c++ to resolve successors during parsing. + void genParserSuccessorResolution(Operator &op, OpMethodBody &body); + /// Generate the c++ to handling variadic segment size traits. + void genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body); + + /// Generate the operation printer from this format. + void genPrinter(Operator &op, OpClass &opClass); + + /// Generate the printer code for a specific format element. + void genElementPrinter(Element *element, OpMethodBody &body, Operator &op, + bool &shouldEmitSpace, bool &lastWasPunctuation); + + /// The various elements in this format. + std::vector> elements; + + /// A flag indicating if all operand/result types were seen. If the format + /// contains these, it can not contain individual type resolvers. + bool allOperands, allOperandTypes, allResultTypes; + + /// A flag indicating if this operation has the SingleBlockImplicitTerminator + /// trait. + bool hasImplicitTermTrait; + + /// A flag indicating if this operation has the SingleBlock trait. + bool hasSingleBlockTrait; + + /// A map of buildable types to indices. + llvm::MapVector> buildableTypes; + + /// The index of the buildable type, if valid, for every operand and result. + std::vector operandTypes, resultTypes; + + /// The set of attributes explicitly used within the format. + SmallVector usedAttributes; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Parser Gen + +/// Returns true if we can format the given attribute as an EnumAttr in the +/// parser format. +static bool canFormatEnumAttr(const NamedAttribute *attr) { + Attribute baseAttr = attr->attr.getBaseAttr(); + const EnumAttr *enumAttr = dyn_cast(&baseAttr); + if (!enumAttr) + return false; + + // The attribute must have a valid underlying type and a constant builder. + return !enumAttr->getUnderlyingType().empty() && + !enumAttr->getConstBuilderTemplate().empty(); +} + +/// Returns if we should format the given attribute as an SymbolNameAttr. +static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) { + return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr"; +} + +/// The code snippet used to generate a parser call for an attribute. +/// +/// {0}: The name of the attribute. +/// {1}: The type for the attribute. +const char *const attrParserCode = R"( + if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes)) + return ::mlir::failure(); +)"; +const char *const optionalAttrParserCode = R"( + { + ::mlir::OptionalParseResult parseResult = + parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes); + if (parseResult.hasValue() && failed(*parseResult)) + return ::mlir::failure(); + } +)"; + +/// The code snippet used to generate a parser call for a symbol name attribute. +/// +/// {0}: The name of the attribute. +const char *const symbolNameAttrParserCode = R"( + if (parser.parseSymbolName({0}Attr, "{0}", result.attributes)) + return ::mlir::failure(); +)"; +const char *const optionalSymbolNameAttrParserCode = R"( + // Parsing an optional symbol name doesn't fail, so no need to check the + // result. + (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes); +)"; + +/// The code snippet used to generate a parser call for an enum attribute. +/// +/// {0}: The name of the attribute. +/// {1}: The c++ namespace for the enum symbolize functions. +/// {2}: The function to symbolize a string of the enum. +/// {3}: The constant builder call to create an attribute of the enum type. +/// {4}: The set of allowed enum keywords. +/// {5}: The error message on failure when the enum isn't present. +const char *const enumAttrParserCode = R"( + { + ::llvm::StringRef attrStr; + ::mlir::NamedAttrList attrStorage; + auto loc = parser.getCurrentLocation(); + if (parser.parseOptionalKeyword(&attrStr, {4})) { + ::mlir::StringAttr attrVal; + ::mlir::OptionalParseResult parseResult = + parser.parseOptionalAttribute(attrVal, + parser.getBuilder().getNoneType(), + "{0}", attrStorage); + if (parseResult.hasValue()) {{ + if (failed(*parseResult)) + return ::mlir::failure(); + attrStr = attrVal.getValue(); + } else { + {5} + } + } + if (!attrStr.empty()) { + auto attrOptional = {1}::{2}(attrStr); + if (!attrOptional) + return parser.emitError(loc, "invalid ") + << "{0} attribute specification: \"" << attrStr << '"';; + + {0}Attr = {3}; + result.addAttribute("{0}", {0}Attr); + } + } +)"; + +/// The code snippet used to generate a parser call for an operand. +/// +/// {0}: The name of the operand. +const char *const variadicOperandParserCode = R"( + {0}OperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList({0}Operands)) + return ::mlir::failure(); +)"; +const char *const optionalOperandParserCode = R"( + { + {0}OperandsLoc = parser.getCurrentLocation(); + ::mlir::OpAsmParser::OperandType operand; + ::mlir::OptionalParseResult parseResult = + parser.parseOptionalOperand(operand); + if (parseResult.hasValue()) { + if (failed(*parseResult)) + return ::mlir::failure(); + {0}Operands.push_back(operand); + } + } +)"; +const char *const operandParserCode = R"( + {0}OperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand({0}RawOperands[0])) + return ::mlir::failure(); +)"; + +/// The code snippet used to generate a parser call for a type list. +/// +/// {0}: The name for the type list. +const char *const variadicTypeParserCode = R"( + if (parser.parseTypeList({0}Types)) + return ::mlir::failure(); +)"; +const char *const optionalTypeParserCode = R"( + { + ::mlir::Type optionalType; + ::mlir::OptionalParseResult parseResult = + parser.parseOptionalType(optionalType); + if (parseResult.hasValue()) { + if (failed(*parseResult)) + return ::mlir::failure(); + {0}Types.push_back(optionalType); + } + } +)"; +const char *const typeParserCode = R"( + if (parser.parseType({0}RawTypes[0])) + return ::mlir::failure(); +)"; + +/// The code snippet used to generate a parser call for a functional type. +/// +/// {0}: The name for the input type list. +/// {1}: The name for the result type list. +const char *const functionalTypeParserCode = R"( + ::mlir::FunctionType {0}__{1}_functionType; + if (parser.parseType({0}__{1}_functionType)) + return ::mlir::failure(); + {0}Types = {0}__{1}_functionType.getInputs(); + {1}Types = {0}__{1}_functionType.getResults(); +)"; + +/// The code snippet used to generate a parser call for a region list. +/// +/// {0}: The name for the region list. +const char *regionListParserCode = R"( + { + std::unique_ptr<::mlir::Region> region; + auto firstRegionResult = parser.parseOptionalRegion(region); + if (firstRegionResult.hasValue()) { + if (failed(*firstRegionResult)) + return ::mlir::failure(); + {0}Regions.emplace_back(std::move(region)); + + // Parse any trailing regions. + while (succeeded(parser.parseOptionalComma())) { + region = std::make_unique<::mlir::Region>(); + if (parser.parseRegion(*region)) + return ::mlir::failure(); + {0}Regions.emplace_back(std::move(region)); + } + } + } +)"; + +/// The code snippet used to ensure a list of regions have terminators. +/// +/// {0}: The name of the region list. +const char *regionListEnsureTerminatorParserCode = R"( + for (auto ®ion : {0}Regions) + ensureTerminator(*region, parser.getBuilder(), result.location); +)"; + +/// The code snippet used to ensure a list of regions have a block. +/// +/// {0}: The name of the region list. +const char *regionListEnsureSingleBlockParserCode = R"( + for (auto ®ion : {0}Regions) + if (region->empty()) region->emplaceBlock(); +)"; + +/// The code snippet used to generate a parser call for an optional region. +/// +/// {0}: The name of the region. +const char *optionalRegionParserCode = R"( + { + auto parseResult = parser.parseOptionalRegion(*{0}Region); + if (parseResult.hasValue() && failed(*parseResult)) + return ::mlir::failure(); + } +)"; + +/// The code snippet used to generate a parser call for a region. +/// +/// {0}: The name of the region. +const char *regionParserCode = R"( + if (parser.parseRegion(*{0}Region)) + return ::mlir::failure(); +)"; + +/// The code snippet used to ensure a region has a terminator. +/// +/// {0}: The name of the region. +const char *regionEnsureTerminatorParserCode = R"( + ensureTerminator(*{0}Region, parser.getBuilder(), result.location); +)"; + +/// The code snippet used to ensure a region has a block. +/// +/// {0}: The name of the region. +const char *regionEnsureSingleBlockParserCode = R"( + if ({0}Region->empty()) {0}Region->emplaceBlock(); +)"; + +/// The code snippet used to generate a parser call for a successor list. +/// +/// {0}: The name for the successor list. +const char *successorListParserCode = R"( + { + ::mlir::Block *succ; + auto firstSucc = parser.parseOptionalSuccessor(succ); + if (firstSucc.hasValue()) { + if (failed(*firstSucc)) + return ::mlir::failure(); + {0}Successors.emplace_back(succ); + + // Parse any trailing successors. + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseSuccessor(succ)) + return ::mlir::failure(); + {0}Successors.emplace_back(succ); + } + } + } +)"; + +/// The code snippet used to generate a parser call for a successor. +/// +/// {0}: The name of the successor. +const char *successorParserCode = R"( + if (parser.parseSuccessor({0}Successor)) + return ::mlir::failure(); +)"; + +namespace { +/// The type of length for a given parse argument. +enum class ArgumentLengthKind { + /// The argument is variadic, and may contain 0->N elements. + Variadic, + /// The argument is optional, and may contain 0 or 1 elements. + Optional, + /// The argument is a single element, i.e. always represents 1 element. + Single +}; +} // end anonymous namespace + +/// Get the length kind for the given constraint. +static ArgumentLengthKind +getArgumentLengthKind(const NamedTypeConstraint *var) { + if (var->isOptional()) + return ArgumentLengthKind::Optional; + if (var->isVariadic()) + return ArgumentLengthKind::Variadic; + return ArgumentLengthKind::Single; +} + +/// Get the name used for the type list for the given type directive operand. +/// 'lengthKind' to the corresponding kind for the given argument. +static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) { + if (auto *operand = dyn_cast(arg)) { + lengthKind = getArgumentLengthKind(operand->getVar()); + return operand->getVar()->name; + } + if (auto *result = dyn_cast(arg)) { + lengthKind = getArgumentLengthKind(result->getVar()); + return result->getVar()->name; + } + lengthKind = ArgumentLengthKind::Variadic; + if (isa(arg)) + return "allOperand"; + if (isa(arg)) + return "allResult"; + llvm_unreachable("unknown 'type' directive argument"); +} + +/// Generate the parser for a literal value. +static void genLiteralParser(StringRef value, OpMethodBody &body) { + // Handle the case of a keyword/identifier. + if (value.front() == '_' || isalpha(value.front())) { + body << "Keyword(\"" << value << "\")"; + return; + } + body << (StringRef)StringSwitch(value) + .Case("->", "Arrow()") + .Case(":", "Colon()") + .Case(",", "Comma()") + .Case("=", "Equal()") + .Case("<", "Less()") + .Case(">", "Greater()") + .Case("{", "LBrace()") + .Case("}", "RBrace()") + .Case("(", "LParen()") + .Case(")", "RParen()") + .Case("[", "LSquare()") + .Case("]", "RSquare()") + .Case("?", "Question()") + .Case("+", "Plus()") + .Case("*", "Star()"); +} + +/// Generate the storage code required for parsing the given element. +static void genElementParserStorage(Element *element, OpMethodBody &body) { + if (auto *optional = dyn_cast(element)) { + auto elements = optional->getThenElements(); + + // If the anchor is a unit attribute, it won't be parsed directly so elide + // it. + auto *anchor = dyn_cast(optional->getAnchor()); + Element *elidedAnchorElement = nullptr; + if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr()) + elidedAnchorElement = anchor; + for (auto &childElement : elements) + if (&childElement != elidedAnchorElement) + genElementParserStorage(&childElement, body); + for (auto &childElement : optional->getElseElements()) + genElementParserStorage(&childElement, body); + + } else if (auto *custom = dyn_cast(element)) { + for (auto ¶mElement : custom->getArguments()) + genElementParserStorage(¶mElement, body); + + } else if (isa(element)) { + body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " + "allOperands;\n"; + + } else if (isa(element)) { + body << " ::llvm::SmallVector, 2> " + "fullRegions;\n"; + + } else if (isa(element)) { + body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n"; + + } else if (auto *attr = dyn_cast(element)) { + const NamedAttribute *var = attr->getVar(); + body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), + var->name); + + } else if (auto *operand = dyn_cast(element)) { + StringRef name = operand->getVar()->name; + if (operand->getVar()->isVariableLength()) { + body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " + << name << "Operands;\n"; + } else { + body << " ::mlir::OpAsmParser::OperandType " << name + << "RawOperands[1];\n" + << " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name + << "Operands(" << name << "RawOperands);"; + } + body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n" + " (void){0}OperandsLoc;\n", + name); + + } else if (auto *region = dyn_cast(element)) { + StringRef name = region->getVar()->name; + if (region->getVar()->isVariadic()) { + body << llvm::formatv( + " ::llvm::SmallVector, 2> " + "{0}Regions;\n", + name); + } else { + body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = " + "std::make_unique<::mlir::Region>();\n", + name); + } + + } else if (auto *successor = dyn_cast(element)) { + StringRef name = successor->getVar()->name; + if (successor->getVar()->isVariadic()) { + body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> " + "{0}Successors;\n", + name); + } else { + body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name); + } + + } else if (auto *dir = dyn_cast(element)) { + ArgumentLengthKind lengthKind; + StringRef name = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind != ArgumentLengthKind::Single) + body << " ::mlir::SmallVector<::mlir::Type, 1> " << name << "Types;\n"; + else + body << llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", name) + << llvm::formatv( + " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n", + name); + } else if (auto *dir = dyn_cast(element)) { + ArgumentLengthKind ignored; + body << " ::llvm::ArrayRef<::mlir::Type> " + << getTypeListName(dir->getInputs(), ignored) << "Types;\n"; + body << " ::llvm::ArrayRef<::mlir::Type> " + << getTypeListName(dir->getResults(), ignored) << "Types;\n"; + } +} + +/// Generate the parser for a parameter to a custom directive. +static void genCustomParameterParser(Element ¶m, OpMethodBody &body) { + if (auto *attr = dyn_cast(¶m)) { + body << attr->getVar()->name << "Attr"; + } else if (isa(¶m)) { + body << "result.attributes"; + } else if (auto *operand = dyn_cast(¶m)) { + StringRef name = operand->getVar()->name; + ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv("{0}Operands", name); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv("{0}Operand", name); + else + body << formatv("{0}RawOperands[0]", name); + + } else if (auto *region = dyn_cast(¶m)) { + StringRef name = region->getVar()->name; + if (region->getVar()->isVariadic()) + body << llvm::formatv("{0}Regions", name); + else + body << llvm::formatv("*{0}Region", name); + + } else if (auto *successor = dyn_cast(¶m)) { + StringRef name = successor->getVar()->name; + if (successor->getVar()->isVariadic()) + body << llvm::formatv("{0}Successors", name); + else + body << llvm::formatv("{0}Successor", name); + + } else if (auto *dir = dyn_cast(¶m)) { + genCustomParameterParser(*dir->getOperand(), body); + + } else if (auto *dir = dyn_cast(¶m)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv("{0}Types", listName); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv("{0}Type", listName); + else + body << formatv("{0}RawTypes[0]", listName); + } else { + llvm_unreachable("unknown custom directive parameter"); + } +} + +/// Generate the parser for a custom directive. +static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) { + body << " {\n"; + + // Preprocess the directive variables. + // * Add a local variable for optional operands and types. This provides a + // better API to the user defined parser methods. + // * Set the location of operand variables. + for (Element ¶m : dir->getArguments()) { + if (auto *operand = dyn_cast(¶m)) { + body << " " << operand->getVar()->name + << "OperandsLoc = parser.getCurrentLocation();\n"; + if (operand->getVar()->isOptional()) { + body << llvm::formatv( + " llvm::Optional<::mlir::OpAsmParser::OperandType> " + "{0}Operand;\n", + operand->getVar()->name); + } + } else if (auto *dir = dyn_cast(¶m)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName); + } else if (auto *dir = dyn_cast(¶m)) { + Element *input = dir->getOperand(); + if (auto *operand = dyn_cast(input)) { + if (!operand->getVar()->isOptional()) + continue; + body << llvm::formatv( + " {0} {1}Operand = {1}Operands.empty() ? {0}() : " + "{1}Operands[0];\n", + "llvm::Optional<::mlir::OpAsmParser::OperandType>", + operand->getVar()->name); + + } else if (auto *type = dyn_cast(input)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(type->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Optional) { + body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? " + "::mlir::Type() : {0}Types[0];\n", + listName); + } + } + } + } + + body << " if (parse" << dir->getName() << "(parser"; + for (Element ¶m : dir->getArguments()) { + body << ", "; + genCustomParameterParser(param, body); + } + + body << "))\n" + << " return ::mlir::failure();\n"; + + // After parsing, add handling for any of the optional constructs. + for (Element ¶m : dir->getArguments()) { + if (auto *attr = dyn_cast(¶m)) { + const NamedAttribute *var = attr->getVar(); + if (var->attr.isOptional()) + body << llvm::formatv(" if ({0}Attr)\n ", var->name); + + body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", + var->name); + } else if (auto *operand = dyn_cast(¶m)) { + const NamedTypeConstraint *var = operand->getVar(); + if (!var->isOptional()) + continue; + body << llvm::formatv(" if ({0}Operand.hasValue())\n" + " {0}Operands.push_back(*{0}Operand);\n", + var->name); + } else if (auto *dir = dyn_cast(¶m)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Optional) { + body << llvm::formatv(" if ({0}Type)\n" + " {0}Types.push_back({0}Type);\n", + listName); + } + } + } + + body << " }\n"; +} + +/// Generate the parser for a enum attribute. +static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body, + FmtContext &attrTypeCtx) { + Attribute baseAttr = var->attr.getBaseAttr(); + const EnumAttr &enumAttr = cast(baseAttr); + std::vector cases = enumAttr.getAllCases(); + + // Generate the code for building an attribute for this enum. + std::string attrBuilderStr; + { + llvm::raw_string_ostream os(attrBuilderStr); + os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx, + "attrOptional.getValue()"); + } + + // Build a string containing the cases that can be formatted as a keyword. + std::string validCaseKeywordsStr = "{"; + llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr); + for (const EnumAttrCase &attrCase : cases) + if (canFormatStringAsKeyword(attrCase.getStr())) + validCaseKeywordsOS << '"' << attrCase.getStr() << "\","; + validCaseKeywordsOS.str().back() = '}'; + + // If the attribute is not optional, build an error message for the missing + // attribute. + std::string errorMessage; + if (!var->attr.isOptional()) { + llvm::raw_string_ostream errorMessageOS(errorMessage); + errorMessageOS + << "return parser.emitError(loc, \"expected string or " + "keyword containing one of the following enum values for attribute '" + << var->name << "' ["; + llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) { + errorMessageOS << attrCase.getStr(); + }); + errorMessageOS << "]\");"; + } + + body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(), + enumAttr.getStringToSymbolFnName(), attrBuilderStr, + validCaseKeywordsStr, errorMessage); +} + +void OperationFormat::genParser(Operator &op, OpClass &opClass) { + llvm::SmallVector paramList; + paramList.emplace_back("::mlir::OpAsmParser &", "parser"); + paramList.emplace_back("::mlir::OperationState &", "result"); + + auto *method = + opClass.addMethodAndPrune("::mlir::ParseResult", "parse", + OpMethod::MP_Static, std::move(paramList)); + auto &body = method->body(); + + // Generate variables to store the operands and type within the format. This + // allows for referencing these variables in the presence of optional + // groupings. + for (auto &element : elements) + genElementParserStorage(&*element, body); + + // A format context used when parsing attributes with buildable types. + FmtContext attrTypeCtx; + attrTypeCtx.withBuilder("parser.getBuilder()"); + + // Generate parsers for each of the elements. + for (auto &element : elements) + genElementParser(element.get(), body, attrTypeCtx); + + // Generate the code to resolve the operand/result types and successors now + // that they have been parsed. + genParserTypeResolution(op, body); + genParserRegionResolution(op, body); + genParserSuccessorResolution(op, body); + genParserVariadicSegmentResolution(op, body); + + body << " return ::mlir::success();\n"; +} + +void OperationFormat::genElementParser(Element *element, OpMethodBody &body, + FmtContext &attrTypeCtx) { + /// Optional Group. + if (auto *optional = dyn_cast(element)) { + auto elements = llvm::drop_begin(optional->getThenElements(), + optional->getParseStart()); + + // Generate a special optional parser for the first element to gate the + // parsing of the rest of the elements. + Element *firstElement = &*elements.begin(); + if (auto *attrVar = dyn_cast(firstElement)) { + genElementParser(attrVar, body, attrTypeCtx); + body << " if (" << attrVar->getVar()->name << "Attr) {\n"; + } else if (auto *literal = dyn_cast(firstElement)) { + body << " if (succeeded(parser.parseOptional"; + genLiteralParser(literal->getLiteral(), body); + body << ")) {\n"; + } else if (auto *opVar = dyn_cast(firstElement)) { + genElementParser(opVar, body, attrTypeCtx); + body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n"; + } else if (auto *regionVar = dyn_cast(firstElement)) { + const NamedRegion *region = regionVar->getVar(); + if (region->isVariadic()) { + genElementParser(regionVar, body, attrTypeCtx); + body << " if (!" << region->name << "Regions.empty()) {\n"; + } else { + body << llvm::formatv(optionalRegionParserCode, region->name); + body << " if (!" << region->name << "Region->empty()) {\n "; + if (hasImplicitTermTrait) + body << llvm::formatv(regionEnsureTerminatorParserCode, region->name); + else if (hasSingleBlockTrait) + body << llvm::formatv(regionEnsureSingleBlockParserCode, + region->name); + } + } + + // If the anchor is a unit attribute, we don't need to print it. When + // parsing, we will add this attribute if this group is present. + Element *elidedAnchorElement = nullptr; + auto *anchorAttr = dyn_cast(optional->getAnchor()); + if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) { + elidedAnchorElement = anchorAttr; + + // Add the anchor unit attribute to the operation state. + body << " result.addAttribute(\"" << anchorAttr->getVar()->name + << "\", parser.getBuilder().getUnitAttr());\n"; + } + + // Generate the rest of the elements normally. + for (Element &childElement : llvm::drop_begin(elements, 1)) { + if (&childElement != elidedAnchorElement) + genElementParser(&childElement, body, attrTypeCtx); + } + body << " }"; + + // Generate the else elements. + auto elseElements = optional->getElseElements(); + if (!elseElements.empty()) { + body << " else {\n"; + for (Element &childElement : elseElements) + genElementParser(&childElement, body, attrTypeCtx); + body << " }"; + } + body << "\n"; + + /// Literals. + } else if (LiteralElement *literal = dyn_cast(element)) { + body << " if (parser.parse"; + genLiteralParser(literal->getLiteral(), body); + body << ")\n return ::mlir::failure();\n"; + + /// Whitespaces. + } else if (isa(element)) { + // Nothing to parse. + + /// Arguments. + } else if (auto *attr = dyn_cast(element)) { + const NamedAttribute *var = attr->getVar(); + + // Check to see if we can parse this as an enum attribute. + if (canFormatEnumAttr(var)) + return genEnumAttrParser(var, body, attrTypeCtx); + + // Check to see if we should parse this as a symbol name attribute. + if (shouldFormatSymbolNameAttr(var)) { + body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode + : symbolNameAttrParserCode, + var->name); + return; + } + + // If this attribute has a buildable type, use that when parsing the + // attribute. + std::string attrTypeStr; + if (Optional typeBuilder = attr->getTypeBuilder()) { + llvm::raw_string_ostream os(attrTypeStr); + os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); + } + + body << formatv(var->attr.isOptional() ? optionalAttrParserCode + : attrParserCode, + var->name, attrTypeStr); + } else if (auto *operand = dyn_cast(element)) { + ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); + StringRef name = operand->getVar()->name; + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv(variadicOperandParserCode, name); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv(optionalOperandParserCode, name); + else + body << formatv(operandParserCode, name); + + } else if (auto *region = dyn_cast(element)) { + bool isVariadic = region->getVar()->isVariadic(); + body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode, + region->getVar()->name); + if (hasImplicitTermTrait) + body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode + : regionEnsureTerminatorParserCode, + region->getVar()->name); + else if (hasSingleBlockTrait) + body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode + : regionEnsureSingleBlockParserCode, + region->getVar()->name); + + } else if (auto *successor = dyn_cast(element)) { + bool isVariadic = successor->getVar()->isVariadic(); + body << formatv(isVariadic ? successorListParserCode : successorParserCode, + successor->getVar()->name); + + /// Directives. + } else if (auto *attrDict = dyn_cast(element)) { + body << " if (parser.parseOptionalAttrDict" + << (attrDict->isWithKeyword() ? "WithKeyword" : "") + << "(result.attributes))\n" + << " return ::mlir::failure();\n"; + } else if (auto *customDir = dyn_cast(element)) { + genCustomDirectiveParser(customDir, body); + + } else if (isa(element)) { + body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n" + << " if (parser.parseOperandList(allOperands))\n" + << " return ::mlir::failure();\n"; + + } else if (isa(element)) { + body << llvm::formatv(regionListParserCode, "full"); + if (hasImplicitTermTrait) + body << llvm::formatv(regionListEnsureTerminatorParserCode, "full"); + else if (hasSingleBlockTrait) + body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full"); + + } else if (isa(element)) { + body << llvm::formatv(successorListParserCode, "full"); + + } else if (auto *dir = dyn_cast(element)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv(variadicTypeParserCode, listName); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv(optionalTypeParserCode, listName); + else + body << formatv(typeParserCode, listName); + } else if (auto *dir = dyn_cast(element)) { + ArgumentLengthKind ignored; + body << formatv(functionalTypeParserCode, + getTypeListName(dir->getInputs(), ignored), + getTypeListName(dir->getResults(), ignored)); + } else { + llvm_unreachable("unknown format element"); + } +} + +void OperationFormat::genParserTypeResolution(Operator &op, + OpMethodBody &body) { + // If any of type resolutions use transformed variables, make sure that the + // types of those variables are resolved. + SmallPtrSet verifiedVariables; + FmtContext verifierFCtx; + for (TypeResolution &resolver : + llvm::concat(resultTypes, operandTypes)) { + Optional transformer = resolver.getVarTransformer(); + if (!transformer) + continue; + // Ensure that we don't verify the same variables twice. + const NamedTypeConstraint *variable = resolver.getVariable(); + if (!variable || !verifiedVariables.insert(variable).second) + continue; + + auto constraint = variable->constraint; + body << " for (::mlir::Type type : " << variable->name << "Types) {\n" + << " (void)type;\n" + << " if (!(" + << tgfmt(constraint.getConditionTemplate(), + &verifierFCtx.withSelf("type")) + << ")) {\n" + << formatv(" return parser.emitError(parser.getNameLoc()) << " + "\"'{0}' must be {1}, but got \" << type;\n", + variable->name, constraint.getSummary()) + << " }\n" + << " }\n"; + } + + // Initialize the set of buildable types. + if (!buildableTypes.empty()) { + FmtContext typeBuilderCtx; + typeBuilderCtx.withBuilder("parser.getBuilder()"); + for (auto &it : buildableTypes) + body << " ::mlir::Type odsBuildableType" << it.second << " = " + << tgfmt(it.first, &typeBuilderCtx) << ";\n"; + } + + // Emit the code necessary for a type resolver. + auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) { + if (Optional val = resolver.getBuilderIdx()) { + body << "odsBuildableType" << *val; + } else if (const NamedTypeConstraint *var = resolver.getVariable()) { + if (Optional tform = resolver.getVarTransformer()) { + FmtContext fmtContext; + fmtContext.addSubst("_ctxt", "parser.getBuilder().getContext()"); + if (var->isVariadic()) + fmtContext.withSelf(var->name + "Types"); + else + fmtContext.withSelf(var->name + "Types[0]"); + body << tgfmt(*tform, &fmtContext); + } else { + body << var->name << "Types"; + } + } else if (const NamedAttribute *attr = resolver.getAttribute()) { + if (Optional tform = resolver.getVarTransformer()) + body << tgfmt(*tform, + &FmtContext().withSelf(attr->name + "Attr.getType()")); + else + body << attr->name << "Attr.getType()"; + } else { + body << curVar << "Types"; + } + }; + + // Resolve each of the result types. + if (allResultTypes) { + body << " result.addTypes(allResultTypes);\n"; + } else { + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { + body << " result.addTypes("; + emitTypeResolver(resultTypes[i], op.getResultName(i)); + body << ");\n"; + } + } + + // Early exit if there are no operands. + if (op.getNumOperands() == 0) + return; + + // Handle the case where all operand types are in one group. + if (allOperandTypes) { + // If we have all operands together, use the full operand list directly. + if (allOperands) { + body << " if (parser.resolveOperands(allOperands, allOperandTypes, " + "allOperandLoc, result.operands))\n" + " return ::mlir::failure();\n"; + return; + } + + // Otherwise, use llvm::concat to merge the disjoint operand lists together. + // llvm::concat does not allow the case of a single range, so guard it here. + body << " if (parser.resolveOperands("; + if (op.getNumOperands() > 1) { + body << "::llvm::concat("; + llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) { + body << operand.name << "Operands"; + }); + body << ")"; + } else { + body << op.operand_begin()->name << "Operands"; + } + body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n" + << " return ::mlir::failure();\n"; + return; + } + // Handle the case where all of the operands were grouped together. + if (allOperands) { + body << " if (parser.resolveOperands(allOperands, "; + + // Group all of the operand types together to perform the resolution all at + // once. Use llvm::concat to perform the merge. llvm::concat does not allow + // the case of a single range, so guard it here. + if (op.getNumOperands() > 1) { + body << "::llvm::concat("; + llvm::interleaveComma( + llvm::seq(0, op.getNumOperands()), body, [&](int i) { + body << "::llvm::ArrayRef<::mlir::Type>("; + emitTypeResolver(operandTypes[i], op.getOperand(i).name); + body << ")"; + }); + body << ")"; + } else { + emitTypeResolver(operandTypes.front(), op.getOperand(0).name); + } + + body << ", allOperandLoc, result.operands))\n" + << " return ::mlir::failure();\n"; + return; + } + + // The final case is the one where each of the operands types are resolved + // separately. + for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { + NamedTypeConstraint &operand = op.getOperand(i); + body << " if (parser.resolveOperands(" << operand.name << "Operands, "; + + // Resolve the type of this operand. + TypeResolution &operandType = operandTypes[i]; + emitTypeResolver(operandType, operand.name); + + // If the type is resolved by a non-variadic variable, index into the + // resolved type list. This allows for resolving the types of a variadic + // operand list from a non-variadic variable. + bool verifyOperandAndTypeSize = true; + if (auto *resolverVar = operandType.getVariable()) { + if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) { + body << "[0]"; + verifyOperandAndTypeSize = false; + } + } else { + verifyOperandAndTypeSize = !operandType.getBuilderIdx(); + } + + // Check to see if the sizes between the types and operands must match. If + // they do, provide the operand location to select the proper resolution + // overload. + if (verifyOperandAndTypeSize) + body << ", " << operand.name << "OperandsLoc"; + body << ", result.operands))\n return ::mlir::failure();\n"; + } +} + +void OperationFormat::genParserRegionResolution(Operator &op, + OpMethodBody &body) { + // Check for the case where all regions were parsed. + bool hasAllRegions = llvm::any_of( + elements, [](auto &elt) { return isa(elt.get()); }); + if (hasAllRegions) { + body << " result.addRegions(fullRegions);\n"; + return; + } + + // Otherwise, handle each region individually. + for (const NamedRegion ®ion : op.getRegions()) { + if (region.isVariadic()) + body << " result.addRegions(" << region.name << "Regions);\n"; + else + body << " result.addRegion(std::move(" << region.name << "Region));\n"; + } +} + +void OperationFormat::genParserSuccessorResolution(Operator &op, + OpMethodBody &body) { + // Check for the case where all successors were parsed. + bool hasAllSuccessors = llvm::any_of( + elements, [](auto &elt) { return isa(elt.get()); }); + if (hasAllSuccessors) { + body << " result.addSuccessors(fullSuccessors);\n"; + return; + } + + // Otherwise, handle each successor individually. + for (const NamedSuccessor &successor : op.getSuccessors()) { + if (successor.isVariadic()) + body << " result.addSuccessors(" << successor.name << "Successors);\n"; + else + body << " result.addSuccessors(" << successor.name << "Successor);\n"; + } +} + +void OperationFormat::genParserVariadicSegmentResolution(Operator &op, + OpMethodBody &body) { + if (!allOperands && + op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + body << " result.addAttribute(\"operand_segment_sizes\", " + << "parser.getBuilder().getI32VectorAttr({"; + auto interleaveFn = [&](const NamedTypeConstraint &operand) { + // If the operand is variadic emit the parsed size. + if (operand.isVariableLength()) + body << "static_cast(" << operand.name << "Operands.size())"; + else + body << "1"; + }; + llvm::interleaveComma(op.getOperands(), body, interleaveFn); + body << "}));\n"; + } + + if (!allResultTypes && + op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { + body << " result.addAttribute(\"result_segment_sizes\", " + << "parser.getBuilder().getI32VectorAttr({"; + auto interleaveFn = [&](const NamedTypeConstraint &result) { + // If the result is variadic emit the parsed size. + if (result.isVariableLength()) + body << "static_cast(" << result.name << "Types.size())"; + else + body << "1"; + }; + llvm::interleaveComma(op.getResults(), body, interleaveFn); + body << "}));\n"; + } +} + +//===----------------------------------------------------------------------===// +// PrinterGen + +/// The code snippet used to generate a printer call for a region of an +// operation that has the SingleBlockImplicitTerminator trait. +/// +/// {0}: The name of the region. +const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( + { + bool printTerminator = true; + if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{ + printTerminator = !term->getAttrDictionary().empty() || + term->getNumOperands() != 0 || + term->getNumResults() != 0; + } + p.printRegion({0}, /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/printTerminator); + } +)"; + +/// The code snippet used to generate a printer call for an enum that has cases +/// that can't be represented with a keyword. +/// +/// {0}: The name of the enum attribute. +/// {1}: The name of the enum attributes symbolToString function. +const char *enumAttrBeginPrinterCode = R"( + { + auto caseValue = {0}(); + auto caseValueStr = {1}(caseValue); +)"; + +/// Generate the printer for the 'attr-dict' directive. +static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, + OpMethodBody &body, bool withKeyword) { + body << " p.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "") + << "((*this)->getAttrs(), /*elidedAttrs=*/{"; + // Elide the variadic segment size attributes if necessary. + if (!fmt.allOperands && + op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) + body << "\"operand_segment_sizes\", "; + if (!fmt.allResultTypes && + op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) + body << "\"result_segment_sizes\", "; + llvm::interleaveComma( + fmt.usedAttributes, body, + [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; }); + body << "});\n"; +} + +/// Generate the printer for a literal value. `shouldEmitSpace` is true if a +/// space should be emitted before this element. `lastWasPunctuation` is true if +/// the previous element was a punctuation literal. +static void genLiteralPrinter(StringRef value, OpMethodBody &body, + bool &shouldEmitSpace, bool &lastWasPunctuation) { + body << " p"; + + // Don't insert a space for certain punctuation. + auto shouldPrintSpaceBeforeLiteral = [&] { + if (value.size() != 1 && value != "->") + return true; + if (lastWasPunctuation) + return !StringRef(">)}],").contains(value.front()); + return !StringRef("<>(){}[],").contains(value.front()); + }; + if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral()) + body << " << ' '"; + body << " << \"" << value << "\";\n"; + + // Insert a space after certain literals. + shouldEmitSpace = + value.size() != 1 || !StringRef("<({[").contains(value.front()); + lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); +} + +/// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation` +/// are set to false. +static void genSpacePrinter(bool value, OpMethodBody &body, + bool &shouldEmitSpace, bool &lastWasPunctuation) { + if (value) { + body << " p << ' ';\n"; + lastWasPunctuation = false; + } else { + lastWasPunctuation = true; + } + shouldEmitSpace = false; +} + +/// Generate the printer for a custom directive parameter. +static void genCustomDirectiveParameterPrinter(Element *element, + OpMethodBody &body) { + if (auto *attr = dyn_cast(element)) { + body << attr->getVar()->name << "Attr()"; + + } else if (isa(element)) { + body << "getOperation()->getAttrDictionary()"; + + } else if (auto *operand = dyn_cast(element)) { + body << operand->getVar()->name << "()"; + + } else if (auto *region = dyn_cast(element)) { + body << region->getVar()->name << "()"; + + } else if (auto *successor = dyn_cast(element)) { + body << successor->getVar()->name << "()"; + + } else if (auto *dir = dyn_cast(element)) { + genCustomDirectiveParameterPrinter(dir->getOperand(), body); + + } else if (auto *dir = dyn_cast(element)) { + auto *typeOperand = dir->getOperand(); + auto *operand = dyn_cast(typeOperand); + auto *var = operand ? operand->getVar() + : cast(typeOperand)->getVar(); + if (var->isVariadic()) + body << var->name << "().getTypes()"; + else if (var->isOptional()) + body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name); + else + body << var->name << "().getType()"; + } else { + llvm_unreachable("unknown custom directive parameter"); + } +} + +/// Generate the printer for a custom directive. +static void genCustomDirectivePrinter(CustomDirective *customDir, + OpMethodBody &body) { + body << " print" << customDir->getName() << "(p, *this"; + for (Element ¶m : customDir->getArguments()) { + body << ", "; + genCustomDirectiveParameterPrinter(¶m, body); + } + body << ");\n"; +} + +/// Generate the printer for a region with the given variable name. +static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body, + bool hasImplicitTermTrait) { + if (hasImplicitTermTrait) + body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode, + regionName); + else + body << " p.printRegion(" << regionName << ");\n"; +} +static void genVariadicRegionPrinter(const Twine ®ionListName, + OpMethodBody &body, + bool hasImplicitTermTrait) { + body << " llvm::interleaveComma(" << regionListName + << ", p, [&](::mlir::Region ®ion) {\n "; + genRegionPrinter("region", body, hasImplicitTermTrait); + body << " });\n"; +} + +/// Generate the C++ for an operand to a (*-)type directive. +static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) { + if (isa(arg)) + return body << "getOperation()->getOperandTypes()"; + if (isa(arg)) + return body << "getOperation()->getResultTypes()"; + auto *operand = dyn_cast(arg); + auto *var = operand ? operand->getVar() : cast(arg)->getVar(); + if (var->isVariadic()) + return body << var->name << "().getTypes()"; + if (var->isOptional()) + return body << llvm::formatv( + "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " + "::llvm::ArrayRef<::mlir::Type>())", + var->name); + return body << "::llvm::ArrayRef<::mlir::Type>(" << var->name + << "().getType())"; +} + +/// Generate the printer for an enum attribute. +static void genEnumAttrPrinter(const NamedAttribute *var, OpMethodBody &body) { + Attribute baseAttr = var->attr.getBaseAttr(); + const EnumAttr &enumAttr = cast(baseAttr); + std::vector cases = enumAttr.getAllCases(); + + body << llvm::formatv(enumAttrBeginPrinterCode, + (var->attr.isOptional() ? "*" : "") + var->name, + enumAttr.getSymbolToStringFnName()); + + // Get a string containing all of the cases that can't be represented with a + // keyword. + llvm::BitVector nonKeywordCases(cases.size()); + bool hasStrCase = false; + for (auto it : llvm::enumerate(cases)) { + hasStrCase = it.value().isStrCase(); + if (!canFormatStringAsKeyword(it.value().getStr())) + nonKeywordCases.set(it.index()); + } + + // If this is a string enum, use the case string to determine which cases + // need to use the string form. + if (hasStrCase) { + if (nonKeywordCases.any()) { + body << " if (llvm::is_contained(llvm::ArrayRef("; + llvm::interleaveComma(nonKeywordCases.set_bits(), body, [&](unsigned it) { + body << '"' << cases[it].getStr() << '"'; + }); + body << ")))\n" + " p << '\"' << caseValueStr << '\"';\n" + " else\n "; + } + body << " p << caseValueStr;\n" + " }\n"; + return; + } + + // Otherwise if this is a bit enum attribute, don't allow cases that may + // overlap with other cases. For simplicity sake, only allow cases with a + // single bit value. + if (enumAttr.isBitEnum()) { + for (auto it : llvm::enumerate(cases)) { + int64_t value = it.value().getValue(); + if (value < 0 || !llvm::isPowerOf2_64(value)) + nonKeywordCases.set(it.index()); + } + } + + // If there are any cases that can't be used with a keyword, switch on the + // case value to determine when to print in the string form. + if (nonKeywordCases.any()) { + body << " switch (caseValue) {\n"; + StringRef cppNamespace = enumAttr.getCppNamespace(); + StringRef enumName = enumAttr.getEnumClassName(); + for (auto it : llvm::enumerate(cases)) { + if (nonKeywordCases.test(it.index())) + continue; + StringRef symbol = it.value().getSymbol(); + body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName, + llvm::isDigit(symbol.front()) ? ("_" + symbol) + : symbol); + } + body << " p << caseValueStr;\n" + " break;\n" + " default:\n" + " p << '\"' << caseValueStr << '\"';\n" + " break;\n" + " }\n" + " }\n"; + return; + } + + body << " p << caseValueStr;\n" + " }\n"; +} + +/// Generate the check for the anchor of an optional group. +static void genOptionalGroupPrinterAnchor(Element *anchor, OpMethodBody &body) { + TypeSwitch(anchor) + .Case([&](auto *element) { + const NamedTypeConstraint *var = element->getVar(); + if (var->isOptional()) + body << " if (" << var->name << "()) {\n"; + else if (var->isVariadic()) + body << " if (!" << var->name << "().empty()) {\n"; + }) + .Case([&](RegionVariable *element) { + const NamedRegion *var = element->getVar(); + // TODO: Add a check for optional regions here when ODS supports it. + body << " if (!" << var->name << "().empty()) {\n"; + }) + .Case([&](TypeDirective *element) { + genOptionalGroupPrinterAnchor(element->getOperand(), body); + }) + .Case([&](FunctionalTypeDirective *element) { + genOptionalGroupPrinterAnchor(element->getInputs(), body); + }) + .Case([&](AttributeVariable *attr) { + body << " if ((*this)->getAttr(\"" << attr->getVar()->name + << "\")) {\n"; + }); +} + +void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body, + Operator &op, bool &shouldEmitSpace, + bool &lastWasPunctuation) { + if (LiteralElement *literal = dyn_cast(element)) + return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, + lastWasPunctuation); + + // Emit a whitespace element. + if (isa(element)) { + body << " p.printNewline();\n"; + return; + } + if (SpaceElement *space = dyn_cast(element)) + return genSpacePrinter(space->getValue(), body, shouldEmitSpace, + lastWasPunctuation); + + // Emit an optional group. + if (OptionalElement *optional = dyn_cast(element)) { + // Emit the check for the presence of the anchor element. + Element *anchor = optional->getAnchor(); + genOptionalGroupPrinterAnchor(anchor, body); + + // If the anchor is a unit attribute, we don't need to print it. When + // parsing, we will add this attribute if this group is present. + auto elements = optional->getThenElements(); + Element *elidedAnchorElement = nullptr; + auto *anchorAttr = dyn_cast(anchor); + if (anchorAttr && anchorAttr != &*elements.begin() && + anchorAttr->isUnitAttr()) { + elidedAnchorElement = anchorAttr; + } + + // Emit each of the elements. + for (Element &childElement : elements) { + if (&childElement != elidedAnchorElement) { + genElementPrinter(&childElement, body, op, shouldEmitSpace, + lastWasPunctuation); + } + } + body << " }"; + + // Emit each of the else elements. + auto elseElements = optional->getElseElements(); + if (!elseElements.empty()) { + body << " else {\n"; + for (Element &childElement : elseElements) { + genElementPrinter(&childElement, body, op, shouldEmitSpace, + lastWasPunctuation); + } + body << " }"; + } + + body << "\n"; + return; + } + + // Emit the attribute dictionary. + if (auto *attrDict = dyn_cast(element)) { + genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword()); + lastWasPunctuation = false; + return; + } + + // Optionally insert a space before the next element. The AttrDict printer + // already adds a space as necessary. + if (shouldEmitSpace || !lastWasPunctuation) + body << " p << ' ';\n"; + lastWasPunctuation = false; + shouldEmitSpace = true; + + if (auto *attr = dyn_cast(element)) { + const NamedAttribute *var = attr->getVar(); + + // If we are formatting as an enum, symbolize the attribute as a string. + if (canFormatEnumAttr(var)) + return genEnumAttrPrinter(var, body); + + // If we are formatting as a symbol name, handle it as a symbol name. + if (shouldFormatSymbolNameAttr(var)) { + body << " p.printSymbolName(" << var->name << "Attr().getValue());\n"; + return; + } + + // Elide the attribute type if it is buildable. + if (attr->getTypeBuilder()) + body << " p.printAttributeWithoutType(" << var->name << "Attr());\n"; + else + body << " p.printAttribute(" << var->name << "Attr());\n"; + } else if (auto *operand = dyn_cast(element)) { + if (operand->getVar()->isOptional()) { + body << " if (::mlir::Value value = " << operand->getVar()->name + << "())\n" + << " p << value;\n"; + } else { + body << " p << " << operand->getVar()->name << "();\n"; + } + } else if (auto *region = dyn_cast(element)) { + const NamedRegion *var = region->getVar(); + if (var->isVariadic()) { + genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait); + } else { + genRegionPrinter(var->name + "()", body, hasImplicitTermTrait); + } + } else if (auto *successor = dyn_cast(element)) { + const NamedSuccessor *var = successor->getVar(); + if (var->isVariadic()) + body << " ::llvm::interleaveComma(" << var->name << "(), p);\n"; + else + body << " p << " << var->name << "();\n"; + } else if (auto *dir = dyn_cast(element)) { + genCustomDirectivePrinter(dir, body); + } else if (isa(element)) { + body << " p << getOperation()->getOperands();\n"; + } else if (isa(element)) { + genVariadicRegionPrinter("getOperation()->getRegions()", body, + hasImplicitTermTrait); + } else if (isa(element)) { + body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n"; + } else if (auto *dir = dyn_cast(element)) { + body << " p << "; + genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; + } else if (auto *dir = dyn_cast(element)) { + body << " p.printFunctionalType("; + genTypeOperandPrinter(dir->getInputs(), body) << ", "; + genTypeOperandPrinter(dir->getResults(), body) << ");\n"; + } else { + llvm_unreachable("unknown format element"); + } +} + +void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { + auto *method = + opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &p"); + auto &body = method->body(); + + // Emit the operation name, trimming the prefix if this is the standard + // dialect. + body << " p << \""; + std::string opName = op.getOperationName(); + if (op.getDialectName() == "std") + body << StringRef(opName).drop_front(4); + else + body << opName; + body << "\";\n"; + + // Flags for if we should emit a space, and if the last element was + // punctuation. + bool shouldEmitSpace = true, lastWasPunctuation = false; + for (auto &element : elements) + genElementPrinter(element.get(), body, op, shouldEmitSpace, + lastWasPunctuation); +} + +//===----------------------------------------------------------------------===// +// FormatLexer +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents a specific token in the input format. +class Token { +public: + enum Kind { + // Markers. + eof, + error, + + // Tokens with no info. + l_paren, + r_paren, + caret, + colon, + comma, + equal, + less, + greater, + question, + + // Keywords. + keyword_start, + kw_attr_dict, + kw_attr_dict_w_keyword, + kw_custom, + kw_functional_type, + kw_operands, + kw_ref, + kw_regions, + kw_results, + kw_successors, + kw_type, + keyword_end, + + // String valued tokens. + identifier, + literal, + variable, + }; + Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} + + /// Return the bytes that make up this token. + StringRef getSpelling() const { return spelling; } + + /// Return the kind of this token. + Kind getKind() const { return kind; } + + /// Return a location for this token. + llvm::SMLoc getLoc() const { + return llvm::SMLoc::getFromPointer(spelling.data()); + } + + /// Return if this token is a keyword. + bool isKeyword() const { return kind > keyword_start && kind < keyword_end; } + +private: + /// Discriminator that indicates the kind of token this is. + Kind kind; + + /// A reference to the entire token contents; this is always a pointer into + /// a memory buffer owned by the source manager. + StringRef spelling; +}; + +/// This class implements a simple lexer for operation assembly format strings. +class FormatLexer { +public: + FormatLexer(llvm::SourceMgr &mgr, Operator &op); + + /// Lex the next token and return it. + Token lexToken(); + + /// Emit an error to the lexer with the given location and message. + Token emitError(llvm::SMLoc loc, const Twine &msg); + Token emitError(const char *loc, const Twine &msg); + + Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine ¬e); + +private: + Token formToken(Token::Kind kind, const char *tokStart) { + return Token(kind, StringRef(tokStart, curPtr - tokStart)); + } + + /// Return the next character in the stream. + int getNextChar(); + + /// Lex an identifier, literal, or variable. + Token lexIdentifier(const char *tokStart); + Token lexLiteral(const char *tokStart); + Token lexVariable(const char *tokStart); + + llvm::SourceMgr &srcMgr; + Operator &op; + StringRef curBuffer; + const char *curPtr; +}; +} // end anonymous namespace + +FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op) + : srcMgr(mgr), op(op) { + curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer(); + curPtr = curBuffer.begin(); +} + +Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) { + srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); + llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note, + "in custom assembly format for this operation"); + return formToken(Token::error, loc.getPointer()); +} +Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, + const Twine ¬e) { + srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); + llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note, + "in custom assembly format for this operation"); + srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note); + return formToken(Token::error, loc.getPointer()); +} +Token FormatLexer::emitError(const char *loc, const Twine &msg) { + return emitError(llvm::SMLoc::getFromPointer(loc), msg); +} + +int FormatLexer::getNextChar() { + char curChar = *curPtr++; + switch (curChar) { + default: + return (unsigned char)curChar; + case 0: { + // A nul character in the stream is either the end of the current buffer or + // a random nul in the file. Disambiguate that here. + if (curPtr - 1 != curBuffer.end()) + return 0; + + // Otherwise, return end of file. + --curPtr; + return EOF; + } + case '\n': + case '\r': + // Handle the newline character by ignoring it and incrementing the line + // count. However, be careful about 'dos style' files with \n\r in them. + // Only treat a \n\r or \r\n as a single line. + if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) + ++curPtr; + return '\n'; + } +} + +Token FormatLexer::lexToken() { + const char *tokStart = curPtr; + + // This always consumes at least one character. + int curChar = getNextChar(); + switch (curChar) { + default: + // Handle identifiers: [a-zA-Z_] + if (isalpha(curChar) || curChar == '_') + return lexIdentifier(tokStart); + + // Unknown character, emit an error. + return emitError(tokStart, "unexpected character"); + case EOF: + // Return EOF denoting the end of lexing. + return formToken(Token::eof, tokStart); + + // Lex punctuation. + case '^': + return formToken(Token::caret, tokStart); + case ':': + return formToken(Token::colon, tokStart); + case ',': + return formToken(Token::comma, tokStart); + case '=': + return formToken(Token::equal, tokStart); + case '<': + return formToken(Token::less, tokStart); + case '>': + return formToken(Token::greater, tokStart); + case '?': + return formToken(Token::question, tokStart); + case '(': + return formToken(Token::l_paren, tokStart); + case ')': + return formToken(Token::r_paren, tokStart); + + // Ignore whitespace characters. + case 0: + case ' ': + case '\t': + case '\n': + return lexToken(); + + case '`': + return lexLiteral(tokStart); + case '$': + return lexVariable(tokStart); + } +} + +Token FormatLexer::lexLiteral(const char *tokStart) { + assert(curPtr[-1] == '`'); + + // Lex a literal surrounded by ``. + while (const char curChar = *curPtr++) { + if (curChar == '`') + return formToken(Token::literal, tokStart); + } + return emitError(curPtr - 1, "unexpected end of file in literal"); +} + +Token FormatLexer::lexVariable(const char *tokStart) { + if (!isalpha(curPtr[0]) && curPtr[0] != '_') + return emitError(curPtr - 1, "expected variable name"); + + // Otherwise, consume the rest of the characters. + while (isalnum(*curPtr) || *curPtr == '_') + ++curPtr; + return formToken(Token::variable, tokStart); +} + +Token FormatLexer::lexIdentifier(const char *tokStart) { + // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* + while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') + ++curPtr; + + // Check to see if this identifier is a keyword. + StringRef str(tokStart, curPtr - tokStart); + Token::Kind kind = + StringSwitch(str) + .Case("attr-dict", Token::kw_attr_dict) + .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword) + .Case("custom", Token::kw_custom) + .Case("functional-type", Token::kw_functional_type) + .Case("operands", Token::kw_operands) + .Case("ref", Token::kw_ref) + .Case("regions", Token::kw_regions) + .Case("results", Token::kw_results) + .Case("successors", Token::kw_successors) + .Case("type", Token::kw_type) + .Default(Token::identifier); + return Token(kind, str); +} + +//===----------------------------------------------------------------------===// +// FormatParser +//===----------------------------------------------------------------------===// + +/// Function to find an element within the given range that has the same name as +/// 'name'. +template +static auto findArg(RangeT &&range, StringRef name) { + auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); + return it != range.end() ? &*it : nullptr; +} + +namespace { +/// This class implements a parser for an instance of an operation assembly +/// format. +class FormatParser { +public: + FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) + : lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op), + seenOperandTypes(op.getNumOperands()), + seenResultTypes(op.getNumResults()) {} + + /// Parse the operation assembly format. + LogicalResult parse(); + +private: + /// The current context of the parser when parsing an element. + enum ParserContext { + /// The element is being parsed in a "top-level" context, i.e. at the top of + /// the format or in an optional group. + TopLevelContext, + /// The element is being parsed as a custom directive child. + CustomDirectiveContext, + /// The element is being parsed as a type directive child. + TypeDirectiveContext, + /// The element is being parsed as a reference directive child. + RefDirectiveContext + }; + + /// This struct represents a type resolution instance. It includes a specific + /// type as well as an optional transformer to apply to that type in order to + /// properly resolve the type of a variable. + struct TypeResolutionInstance { + ConstArgument resolver; + Optional transformer; + }; + + /// An iterator over the elements of a format group. + using ElementsIterT = llvm::pointee_iterator< + std::vector>::const_iterator>; + + /// Verify the state of operation attributes within the format. + LogicalResult verifyAttributes(llvm::SMLoc loc); + /// Verify the attribute elements at the back of the given stack of iterators. + LogicalResult verifyAttributes( + llvm::SMLoc loc, + SmallVectorImpl> &iteratorStack); + + /// Verify the state of operation operands within the format. + LogicalResult + verifyOperands(llvm::SMLoc loc, + llvm::StringMap &variableTyResolver); + + /// Verify the state of operation regions within the format. + LogicalResult verifyRegions(llvm::SMLoc loc); + + /// Verify the state of operation results within the format. + LogicalResult + verifyResults(llvm::SMLoc loc, + llvm::StringMap &variableTyResolver); + + /// Verify the state of operation successors within the format. + LogicalResult verifySuccessors(llvm::SMLoc loc); + + /// Given the values of an `AllTypesMatch` trait, check for inferable type + /// resolution. + void handleAllTypesMatchConstraint( + ArrayRef values, + llvm::StringMap &variableTyResolver); + /// Check for inferable type resolution given all operands, and or results, + /// have the same type. If 'includeResults' is true, the results also have the + /// same type as all of the operands. + void handleSameTypesConstraint( + llvm::StringMap &variableTyResolver, + bool includeResults); + /// Check for inferable type resolution based on another operand, result, or + /// attribute. + void handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, + llvm::Record def); + + /// Returns an argument or attribute with the given name that has been seen + /// within the format. + ConstArgument findSeenArg(StringRef name); + + /// Parse a specific element. + LogicalResult parseElement(std::unique_ptr &element, + ParserContext context); + LogicalResult parseVariable(std::unique_ptr &element, + ParserContext context); + LogicalResult parseDirective(std::unique_ptr &element, + ParserContext context); + LogicalResult parseLiteral(std::unique_ptr &element, + ParserContext context); + LogicalResult parseOptional(std::unique_ptr &element, + ParserContext context); + LogicalResult parseOptionalChildElement( + std::vector> &childElements, + Optional &anchorIdx); + LogicalResult verifyOptionalChildElement(Element *element, + llvm::SMLoc childLoc, bool isAnchor); + + /// Parse the various different directives. + LogicalResult parseAttrDictDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context, + bool withKeyword); + LogicalResult parseCustomDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context); + LogicalResult parseCustomDirectiveParameter( + std::vector> ¶meters); + LogicalResult parseFunctionalTypeDirective(std::unique_ptr &element, + Token tok, ParserContext context); + LogicalResult parseOperandsDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context); + LogicalResult parseReferenceDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context); + LogicalResult parseRegionsDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context); + LogicalResult parseResultsDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context); + LogicalResult parseSuccessorsDirective(std::unique_ptr &element, + llvm::SMLoc loc, + ParserContext context); + LogicalResult parseTypeDirective(std::unique_ptr &element, Token tok, + ParserContext context); + LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, + bool isRefChild = false); + + //===--------------------------------------------------------------------===// + // Lexer Utilities + //===--------------------------------------------------------------------===// + + /// Advance the current lexer onto the next token. + void consumeToken() { + assert(curToken.getKind() != Token::eof && + curToken.getKind() != Token::error && + "shouldn't advance past EOF or errors"); + curToken = lexer.lexToken(); + } + LogicalResult parseToken(Token::Kind kind, const Twine &msg) { + if (curToken.getKind() != kind) + return emitError(curToken.getLoc(), msg); + consumeToken(); + return ::mlir::success(); + } + LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) { + lexer.emitError(loc, msg); + return ::mlir::failure(); + } + LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, + const Twine ¬e) { + lexer.emitErrorAndNote(loc, msg, note); + return ::mlir::failure(); + } + + //===--------------------------------------------------------------------===// + // Fields + //===--------------------------------------------------------------------===// + + FormatLexer lexer; + Token curToken; + OperationFormat &fmt; + Operator &op; + + // The following are various bits of format state used for verification + // during parsing. + bool hasAttrDict = false; + bool hasAllRegions = false, hasAllSuccessors = false; + llvm::SmallBitVector seenOperandTypes, seenResultTypes; + llvm::SmallSetVector seenAttrs; + llvm::DenseSet seenOperands; + llvm::DenseSet seenRegions; + llvm::DenseSet seenSuccessors; +}; +} // end anonymous namespace + +LogicalResult FormatParser::parse() { + llvm::SMLoc loc = curToken.getLoc(); + + // Parse each of the format elements into the main format. + while (curToken.getKind() != Token::eof) { + std::unique_ptr element; + if (failed(parseElement(element, TopLevelContext))) + return ::mlir::failure(); + fmt.elements.push_back(std::move(element)); + } + + // Check that the attribute dictionary is in the format. + if (!hasAttrDict) + return emitError(loc, "'attr-dict' directive not found in " + "custom assembly format"); + + // Check for any type traits that we can use for inferring types. + llvm::StringMap variableTyResolver; + for (const Trait &trait : op.getTraits()) { + const llvm::Record &def = trait.getDef(); + if (def.isSubClassOf("AllTypesMatch")) { + handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"), + variableTyResolver); + } else if (def.getName() == "SameTypeOperands") { + handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); + } else if (def.getName() == "SameOperandsAndResultType") { + handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); + } else if (def.isSubClassOf("TypesMatchWith")) { + handleTypesMatchConstraint(variableTyResolver, def); + } + } + + // Verify the state of the various operation components. + if (failed(verifyAttributes(loc)) || + failed(verifyResults(loc, variableTyResolver)) || + failed(verifyOperands(loc, variableTyResolver)) || + failed(verifyRegions(loc)) || failed(verifySuccessors(loc))) + return ::mlir::failure(); + + // Collect the set of used attributes in the format. + fmt.usedAttributes = seenAttrs.takeVector(); + return ::mlir::success(); +} + +LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) { + // Check that there are no `:` literals after an attribute without a constant + // type. The attribute grammar contains an optional trailing colon type, which + // can lead to unexpected and generally unintended behavior. Given that, it is + // better to just error out here instead. + using ElementsIterT = llvm::pointee_iterator< + std::vector>::const_iterator>; + SmallVector, 1> iteratorStack; + iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end()); + while (!iteratorStack.empty()) + if (failed(verifyAttributes(loc, iteratorStack))) + return ::mlir::failure(); + return ::mlir::success(); +} +/// Verify the attribute elements at the back of the given stack of iterators. +LogicalResult FormatParser::verifyAttributes( + llvm::SMLoc loc, + SmallVectorImpl> &iteratorStack) { + auto &stackIt = iteratorStack.back(); + ElementsIterT &it = stackIt.first, e = stackIt.second; + while (it != e) { + Element *element = &*(it++); + + // Traverse into optional groups. + if (auto *optional = dyn_cast(element)) { + auto thenElements = optional->getThenElements(); + iteratorStack.emplace_back(thenElements.begin(), thenElements.end()); + + auto elseElements = optional->getElseElements(); + iteratorStack.emplace_back(elseElements.begin(), elseElements.end()); + return ::mlir::success(); + } + + // We are checking for an attribute element followed by a `:`, so there is + // no need to check the end. + if (it == e && iteratorStack.size() == 1) + break; + + // Check for an attribute with a constant type builder, followed by a `:`. + auto *prevAttr = dyn_cast(element); + if (!prevAttr || prevAttr->getTypeBuilder()) + continue; + + // Check the next iterator within the stack for literal elements. + for (auto &nextItPair : iteratorStack) { + ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second; + for (; nextIt != nextE; ++nextIt) { + // Skip any trailing whitespace, attribute dictionaries, or optional + // groups. + if (isa(*nextIt) || + isa(*nextIt) || isa(*nextIt)) + continue; + + // We are only interested in `:` literals. + auto *literal = dyn_cast(&*nextIt); + if (!literal || literal->getLiteral() != ":") + break; + + // TODO: Use the location of the literal element itself. + return emitError( + loc, llvm::formatv("format ambiguity caused by `:` literal found " + "after attribute `{0}` which does not have " + "a buildable type", + prevAttr->getVar()->name)); + } + } + } + iteratorStack.pop_back(); + return ::mlir::success(); +} + +LogicalResult FormatParser::verifyOperands( + llvm::SMLoc loc, + llvm::StringMap &variableTyResolver) { + // Check that all of the operands are within the format, and their types can + // be inferred. + auto &buildableTypes = fmt.buildableTypes; + for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { + NamedTypeConstraint &operand = op.getOperand(i); + + // Check that the operand itself is in the format. + if (!fmt.allOperands && !seenOperands.count(&operand)) { + return emitErrorAndNote(loc, + "operand #" + Twine(i) + ", named '" + + operand.name + "', not found", + "suggest adding a '$" + operand.name + + "' directive to the custom assembly format"); + } + + // Check that the operand type is in the format, or that it can be inferred. + if (fmt.allOperandTypes || seenOperandTypes.test(i)) + continue; + + // Check to see if we can infer this type from another variable. + auto varResolverIt = variableTyResolver.find(op.getOperand(i).name); + if (varResolverIt != variableTyResolver.end()) { + TypeResolutionInstance &resolver = varResolverIt->second; + fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer); + continue; + } + + // Similarly to results, allow a custom builder for resolving the type if + // we aren't using the 'operands' directive. + Optional builder = operand.constraint.getBuilderCall(); + if (!builder || (fmt.allOperands && operand.isVariableLength())) { + return emitErrorAndNote( + loc, + "type of operand #" + Twine(i) + ", named '" + operand.name + + "', is not buildable and a buildable type cannot be inferred", + "suggest adding a type constraint to the operation or adding a " + "'type($" + + operand.name + ")' directive to the " + "custom assembly format"); + } + auto it = buildableTypes.insert({*builder, buildableTypes.size()}); + fmt.operandTypes[i].setBuilderIdx(it.first->second); + } + return ::mlir::success(); +} + +LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) { + // Check that all of the regions are within the format. + if (hasAllRegions) + return ::mlir::success(); + + for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) { + const NamedRegion ®ion = op.getRegion(i); + if (!seenRegions.count(®ion)) { + return emitErrorAndNote(loc, + "region #" + Twine(i) + ", named '" + + region.name + "', not found", + "suggest adding a '$" + region.name + + "' directive to the custom assembly format"); + } + } + return ::mlir::success(); +} + +LogicalResult FormatParser::verifyResults( + llvm::SMLoc loc, + llvm::StringMap &variableTyResolver) { + // If we format all of the types together, there is nothing to check. + if (fmt.allResultTypes) + return ::mlir::success(); + + // Check that all of the result types can be inferred. + auto &buildableTypes = fmt.buildableTypes; + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { + if (seenResultTypes.test(i)) + continue; + + // Check to see if we can infer this type from another variable. + auto varResolverIt = variableTyResolver.find(op.getResultName(i)); + if (varResolverIt != variableTyResolver.end()) { + TypeResolutionInstance resolver = varResolverIt->second; + fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer); + continue; + } + + // If the result is not variable length, allow for the case where the type + // has a builder that we can use. + NamedTypeConstraint &result = op.getResult(i); + Optional builder = result.constraint.getBuilderCall(); + if (!builder || result.isVariableLength()) { + return emitErrorAndNote( + loc, + "type of result #" + Twine(i) + ", named '" + result.name + + "', is not buildable and a buildable type cannot be inferred", + "suggest adding a type constraint to the operation or adding a " + "'type($" + + result.name + ")' directive to the " + "custom assembly format"); + } + // Note in the format that this result uses the custom builder. + auto it = buildableTypes.insert({*builder, buildableTypes.size()}); + fmt.resultTypes[i].setBuilderIdx(it.first->second); + } + return ::mlir::success(); +} + +LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) { + // Check that all of the successors are within the format. + if (hasAllSuccessors) + return ::mlir::success(); + + for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { + const NamedSuccessor &successor = op.getSuccessor(i); + if (!seenSuccessors.count(&successor)) { + return emitErrorAndNote(loc, + "successor #" + Twine(i) + ", named '" + + successor.name + "', not found", + "suggest adding a '$" + successor.name + + "' directive to the custom assembly format"); + } + } + return ::mlir::success(); +} + +void FormatParser::handleAllTypesMatchConstraint( + ArrayRef values, + llvm::StringMap &variableTyResolver) { + for (unsigned i = 0, e = values.size(); i != e; ++i) { + // Check to see if this value matches a resolved operand or result type. + ConstArgument arg = findSeenArg(values[i]); + if (!arg) + continue; + + // Mark this value as the type resolver for the other variables. + for (unsigned j = 0; j != i; ++j) + variableTyResolver[values[j]] = {arg, llvm::None}; + for (unsigned j = i + 1; j != e; ++j) + variableTyResolver[values[j]] = {arg, llvm::None}; + } +} + +void FormatParser::handleSameTypesConstraint( + llvm::StringMap &variableTyResolver, + bool includeResults) { + const NamedTypeConstraint *resolver = nullptr; + int resolvedIt = -1; + + // Check to see if there is an operand or result to use for the resolution. + if ((resolvedIt = seenOperandTypes.find_first()) != -1) + resolver = &op.getOperand(resolvedIt); + else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1) + resolver = &op.getResult(resolvedIt); + else + return; + + // Set the resolvers for each operand and result. + for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) + if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty()) + variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None}; + if (includeResults) { + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) + if (!seenResultTypes.test(i) && !op.getResultName(i).empty()) + variableTyResolver[op.getResultName(i)] = {resolver, llvm::None}; + } +} + +void FormatParser::handleTypesMatchConstraint( + llvm::StringMap &variableTyResolver, + llvm::Record def) { + StringRef lhsName = def.getValueAsString("lhs"); + StringRef rhsName = def.getValueAsString("rhs"); + StringRef transformer = def.getValueAsString("transformer"); + if (ConstArgument arg = findSeenArg(lhsName)) + variableTyResolver[rhsName] = {arg, transformer}; +} + +ConstArgument FormatParser::findSeenArg(StringRef name) { + if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name)) + return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; + if (const NamedTypeConstraint *arg = findArg(op.getResults(), name)) + return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr; + if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) + return seenAttrs.count(attr) ? attr : nullptr; + return nullptr; +} + +LogicalResult FormatParser::parseElement(std::unique_ptr &element, + ParserContext context) { + // Directives. + if (curToken.isKeyword()) + return parseDirective(element, context); + // Literals. + if (curToken.getKind() == Token::literal) + return parseLiteral(element, context); + // Optionals. + if (curToken.getKind() == Token::l_paren) + return parseOptional(element, context); + // Variables. + if (curToken.getKind() == Token::variable) + return parseVariable(element, context); + return emitError(curToken.getLoc(), + "expected directive, literal, variable, or optional group"); +} + +LogicalResult FormatParser::parseVariable(std::unique_ptr &element, + ParserContext context) { + Token varTok = curToken; + consumeToken(); + + StringRef name = varTok.getSpelling().drop_front(); + llvm::SMLoc loc = varTok.getLoc(); + + // Check that the parsed argument is something actually registered on the + // op. + /// Attributes + if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) { + if (context == TypeDirectiveContext) + return emitError( + loc, "attributes cannot be used as children to a `type` directive"); + if (context == RefDirectiveContext) { + if (!seenAttrs.count(attr)) + return emitError(loc, "attribute '" + name + + "' must be bound before it is referenced"); + } else if (!seenAttrs.insert(attr)) { + return emitError(loc, "attribute '" + name + "' is already bound"); + } + + element = std::make_unique(attr); + return ::mlir::success(); + } + /// Operands + if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) { + if (context == TopLevelContext || context == CustomDirectiveContext) { + if (fmt.allOperands || !seenOperands.insert(operand).second) + return emitError(loc, "operand '" + name + "' is already bound"); + } else if (context == RefDirectiveContext && !seenOperands.count(operand)) { + return emitError(loc, "operand '" + name + + "' must be bound before it is referenced"); + } + element = std::make_unique(operand); + return ::mlir::success(); + } + /// Regions + if (const NamedRegion *region = findArg(op.getRegions(), name)) { + if (context == TopLevelContext || context == CustomDirectiveContext) { + if (hasAllRegions || !seenRegions.insert(region).second) + return emitError(loc, "region '" + name + "' is already bound"); + } else if (context == RefDirectiveContext && !seenRegions.count(region)) { + return emitError(loc, "region '" + name + + "' must be bound before it is referenced"); + } else { + return emitError(loc, "regions can only be used at the top level"); + } + element = std::make_unique(region); + return ::mlir::success(); + } + /// Results. + if (const auto *result = findArg(op.getResults(), name)) { + if (context != TypeDirectiveContext) + return emitError(loc, "result variables can can only be used as a child " + "to a 'type' directive"); + element = std::make_unique(result); + return ::mlir::success(); + } + /// Successors. + if (const auto *successor = findArg(op.getSuccessors(), name)) { + if (context == TopLevelContext || context == CustomDirectiveContext) { + if (hasAllSuccessors || !seenSuccessors.insert(successor).second) + return emitError(loc, "successor '" + name + "' is already bound"); + } else if (context == RefDirectiveContext && + !seenSuccessors.count(successor)) { + return emitError(loc, "successor '" + name + + "' must be bound before it is referenced"); + } else { + return emitError(loc, "successors can only be used at the top level"); + } + + element = std::make_unique(successor); + return ::mlir::success(); + } + return emitError(loc, "expected variable to refer to an argument, region, " + "result, or successor"); +} + +LogicalResult FormatParser::parseDirective(std::unique_ptr &element, + ParserContext context) { + Token dirTok = curToken; + consumeToken(); + + switch (dirTok.getKind()) { + case Token::kw_attr_dict: + return parseAttrDictDirective(element, dirTok.getLoc(), context, + /*withKeyword=*/false); + case Token::kw_attr_dict_w_keyword: + return parseAttrDictDirective(element, dirTok.getLoc(), context, + /*withKeyword=*/true); + case Token::kw_custom: + return parseCustomDirective(element, dirTok.getLoc(), context); + case Token::kw_functional_type: + return parseFunctionalTypeDirective(element, dirTok, context); + case Token::kw_operands: + return parseOperandsDirective(element, dirTok.getLoc(), context); + case Token::kw_regions: + return parseRegionsDirective(element, dirTok.getLoc(), context); + case Token::kw_results: + return parseResultsDirective(element, dirTok.getLoc(), context); + case Token::kw_successors: + return parseSuccessorsDirective(element, dirTok.getLoc(), context); + case Token::kw_ref: + return parseReferenceDirective(element, dirTok.getLoc(), context); + case Token::kw_type: + return parseTypeDirective(element, dirTok, context); + + default: + llvm_unreachable("unknown directive token"); + } +} + +LogicalResult FormatParser::parseLiteral(std::unique_ptr &element, + ParserContext context) { + Token literalTok = curToken; + if (context != TopLevelContext) { + return emitError( + literalTok.getLoc(), + "literals may only be used in a top-level section of the format"); + } + consumeToken(); + + StringRef value = literalTok.getSpelling().drop_front().drop_back(); + + // The parsed literal is a space element (`` or ` `). + if (value.empty() || (value.size() == 1 && value.front() == ' ')) { + element = std::make_unique(!value.empty()); + return ::mlir::success(); + } + // The parsed literal is a newline element. + if (value == "\\n") { + element = std::make_unique(); + return ::mlir::success(); + } + + // Check that the parsed literal is valid. + if (!LiteralElement::isValidLiteral(value)) + return emitError(literalTok.getLoc(), "expected valid literal"); + + element = std::make_unique(value); + return ::mlir::success(); +} + +LogicalResult FormatParser::parseOptional(std::unique_ptr &element, + ParserContext context) { + llvm::SMLoc curLoc = curToken.getLoc(); + if (context != TopLevelContext) + return emitError(curLoc, "optional groups can only be used as top-level " + "elements"); + consumeToken(); + + // Parse the child elements for this optional group. + std::vector> thenElements, elseElements; + Optional anchorIdx; + do { + if (failed(parseOptionalChildElement(thenElements, anchorIdx))) + return ::mlir::failure(); + } while (curToken.getKind() != Token::r_paren); + consumeToken(); + + // Parse the `else` elements of this optional group. + if (curToken.getKind() == Token::colon) { + consumeToken(); + if (failed(parseToken(Token::l_paren, "expected '(' to start else branch " + "of optional group"))) + return failure(); + do { + llvm::SMLoc childLoc = curToken.getLoc(); + elseElements.push_back({}); + if (failed(parseElement(elseElements.back(), TopLevelContext)) || + failed(verifyOptionalChildElement(elseElements.back().get(), childLoc, + /*isAnchor=*/false))) + return failure(); + } while (curToken.getKind() != Token::r_paren); + consumeToken(); + } + + if (failed(parseToken(Token::question, "expected '?' after optional group"))) + return ::mlir::failure(); + + // The optional group is required to have an anchor. + if (!anchorIdx) + return emitError(curLoc, "optional group specified no anchor element"); + + // The first parsable element of the group must be able to be parsed in an + // optional fashion. + auto parseBegin = llvm::find_if_not(thenElements, [](auto &element) { + return isa(element.get()); + }); + Element *firstElement = parseBegin->get(); + if (!isa(firstElement) && + !isa(firstElement) && + !isa(firstElement) && !isa(firstElement)) + return emitError(curLoc, + "first parsable element of an operand group must be " + "an attribute, literal, operand, or region"); + + auto parseStart = parseBegin - thenElements.begin(); + element = std::make_unique( + std::move(thenElements), std::move(elseElements), *anchorIdx, parseStart); + return ::mlir::success(); +} + +LogicalResult FormatParser::parseOptionalChildElement( + std::vector> &childElements, + Optional &anchorIdx) { + llvm::SMLoc childLoc = curToken.getLoc(); + childElements.push_back({}); + if (failed(parseElement(childElements.back(), TopLevelContext))) + return ::mlir::failure(); + + // Check to see if this element is the anchor of the optional group. + bool isAnchor = curToken.getKind() == Token::caret; + if (isAnchor) { + if (anchorIdx) + return emitError(childLoc, "only one element can be marked as the anchor " + "of an optional group"); + anchorIdx = childElements.size() - 1; + consumeToken(); + } + + return verifyOptionalChildElement(childElements.back().get(), childLoc, + isAnchor); +} + +LogicalResult FormatParser::verifyOptionalChildElement(Element *element, + llvm::SMLoc childLoc, + bool isAnchor) { + return TypeSwitch(element) + // All attributes can be within the optional group, but only optional + // attributes can be the anchor. + .Case([&](AttributeVariable *attrEle) { + if (isAnchor && !attrEle->getVar()->attr.isOptional()) + return emitError(childLoc, "only optional attributes can be used to " + "anchor an optional group"); + return ::mlir::success(); + }) + // Only optional-like(i.e. variadic) operands can be within an optional + // group. + .Case([&](OperandVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(childLoc, "only variable length operands can be " + "used within an optional group"); + return ::mlir::success(); + }) + // Only optional-like(i.e. variadic) results can be within an optional + // group. + .Case([&](ResultVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(childLoc, "only variable length results can be " + "used within an optional group"); + return ::mlir::success(); + }) + .Case([&](RegionVariable *) { + // TODO: When ODS has proper support for marking "optional" regions, add + // a check here. + return ::mlir::success(); + }) + .Case([&](TypeDirective *ele) { + return verifyOptionalChildElement(ele->getOperand(), childLoc, + /*isAnchor=*/false); + }) + .Case([&](FunctionalTypeDirective *ele) { + if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc, + /*isAnchor=*/false))) + return failure(); + return verifyOptionalChildElement(ele->getResults(), childLoc, + /*isAnchor=*/false); + }) + // Literals, whitespace, and custom directives may be used, but they can't + // anchor the group. + .Case([&](Element *) { + if (isAnchor) + return emitError(childLoc, "only variables and types can be used " + "to anchor an optional group"); + return ::mlir::success(); + }) + .Default([&](Element *) { + return emitError(childLoc, "only literals, types, and variables can be " + "used within an optional group"); + }); +} + +LogicalResult +FormatParser::parseAttrDictDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context, + bool withKeyword) { + if (context == TypeDirectiveContext) + return emitError(loc, "'attr-dict' directive can only be used as a " + "top-level directive"); + + if (context == RefDirectiveContext) { + if (!hasAttrDict) + return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior " + "'attr-dict' directive"); + + // Otherwise, this is a top-level context. + } else { + if (hasAttrDict) + return emitError(loc, "'attr-dict' directive has already been seen"); + hasAttrDict = true; + } + + element = std::make_unique(withKeyword); + return ::mlir::success(); +} + +LogicalResult +FormatParser::parseCustomDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context) { + llvm::SMLoc curLoc = curToken.getLoc(); + if (context != TopLevelContext) + return emitError(loc, "'custom' is only valid as a top-level directive"); + + // Parse the custom directive name. + if (failed( + parseToken(Token::less, "expected '<' before custom directive name"))) + return ::mlir::failure(); + + Token nameTok = curToken; + if (failed(parseToken(Token::identifier, + "expected custom directive name identifier")) || + failed(parseToken(Token::greater, + "expected '>' after custom directive name")) || + failed(parseToken(Token::l_paren, + "expected '(' before custom directive parameters"))) + return ::mlir::failure(); + + // Parse the child elements for this optional group.= + std::vector> elements; + do { + if (failed(parseCustomDirectiveParameter(elements))) + return ::mlir::failure(); + if (curToken.getKind() != Token::comma) + break; + consumeToken(); + } while (true); + + if (failed(parseToken(Token::r_paren, + "expected ')' after custom directive parameters"))) + return ::mlir::failure(); + + // After parsing all of the elements, ensure that all type directives refer + // only to variables. + for (auto &ele : elements) { + if (auto *typeEle = dyn_cast(ele.get())) { + if (!isa(typeEle->getOperand())) { + return emitError(curLoc, "type directives within a custom directive " + "may only refer to variables"); + } + } + } + + element = std::make_unique(nameTok.getSpelling(), + std::move(elements)); + return ::mlir::success(); +} + +LogicalResult FormatParser::parseCustomDirectiveParameter( + std::vector> ¶meters) { + llvm::SMLoc childLoc = curToken.getLoc(); + parameters.push_back({}); + if (failed(parseElement(parameters.back(), CustomDirectiveContext))) + return ::mlir::failure(); + + // Verify that the element can be placed within a custom directive. + if (!isa( + parameters.back().get())) { + return emitError(childLoc, "only variables and types may be used as " + "parameters to a custom directive"); + } + return ::mlir::success(); +} + +LogicalResult +FormatParser::parseFunctionalTypeDirective(std::unique_ptr &element, + Token tok, ParserContext context) { + llvm::SMLoc loc = tok.getLoc(); + if (context != TopLevelContext) + return emitError( + loc, "'functional-type' is only valid as a top-level directive"); + + // Parse the main operand. + std::unique_ptr inputs, results; + if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + failed(parseTypeDirectiveOperand(inputs)) || + failed(parseToken(Token::comma, "expected ',' after inputs argument")) || + failed(parseTypeDirectiveOperand(results)) || + failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + return ::mlir::failure(); + element = std::make_unique(std::move(inputs), + std::move(results)); + return ::mlir::success(); +} + +LogicalResult +FormatParser::parseOperandsDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context) { + if (context == RefDirectiveContext) { + if (!fmt.allOperands) + return emitError(loc, "'ref' of 'operands' is not bound by a prior " + "'operands' directive"); + + } else if (context == TopLevelContext || context == CustomDirectiveContext) { + if (fmt.allOperands || !seenOperands.empty()) + return emitError(loc, "'operands' directive creates overlap in format"); + fmt.allOperands = true; + } + element = std::make_unique(); + return ::mlir::success(); +} + +LogicalResult +FormatParser::parseReferenceDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context) { + if (context != CustomDirectiveContext) + return emitError(loc, "'ref' is only valid within a `custom` directive"); + + std::unique_ptr operand; + if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + failed(parseElement(operand, RefDirectiveContext)) || + failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + return ::mlir::failure(); + + element = std::make_unique(std::move(operand)); + return ::mlir::success(); +} + +LogicalResult +FormatParser::parseRegionsDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context) { + if (context == TypeDirectiveContext) + return emitError(loc, "'regions' is only valid as a top-level directive"); + if (context == RefDirectiveContext) { + if (!hasAllRegions) + return emitError(loc, "'ref' of 'regions' is not bound by a prior " + "'regions' directive"); + + // Otherwise, this is a TopLevel directive. + } else { + if (hasAllRegions || !seenRegions.empty()) + return emitError(loc, "'regions' directive creates overlap in format"); + hasAllRegions = true; + } + element = std::make_unique(); + return ::mlir::success(); +} + +LogicalResult +FormatParser::parseResultsDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context) { + if (context != TypeDirectiveContext) + return emitError(loc, "'results' directive can can only be used as a child " + "to a 'type' directive"); + element = std::make_unique(); + return ::mlir::success(); +} + +LogicalResult +FormatParser::parseSuccessorsDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context) { + if (context == TypeDirectiveContext) + return emitError(loc, + "'successors' is only valid as a top-level directive"); + if (context == RefDirectiveContext) { + if (!hasAllSuccessors) + return emitError(loc, "'ref' of 'successors' is not bound by a prior " + "'successors' directive"); + + // Otherwise, this is a TopLevel directive. + } else { + if (hasAllSuccessors || !seenSuccessors.empty()) + return emitError(loc, "'successors' directive creates overlap in format"); + hasAllSuccessors = true; + } + element = std::make_unique(); + return ::mlir::success(); +} + +LogicalResult +FormatParser::parseTypeDirective(std::unique_ptr &element, Token tok, + ParserContext context) { + llvm::SMLoc loc = tok.getLoc(); + if (context == TypeDirectiveContext) + return emitError(loc, "'type' cannot be used as a child of another `type`"); + + bool isRefChild = context == RefDirectiveContext; + std::unique_ptr operand; + if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + failed(parseTypeDirectiveOperand(operand, isRefChild)) || + failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + return ::mlir::failure(); + + element = std::make_unique(std::move(operand)); + return ::mlir::success(); +} + +LogicalResult +FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element, + bool isRefChild) { + llvm::SMLoc loc = curToken.getLoc(); + if (failed(parseElement(element, TypeDirectiveContext))) + return ::mlir::failure(); + if (isa(element.get())) + return emitError( + loc, "'type' directive operand expects variable or directive operand"); + + if (auto *var = dyn_cast(element.get())) { + unsigned opIdx = var->getVar() - op.operand_begin(); + if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx))) + return emitError(loc, "'type' of '" + var->getVar()->name + + "' is already bound"); + if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx))) + return emitError(loc, "'ref' of 'type($" + var->getVar()->name + + ")' is not bound by a prior 'type' directive"); + seenOperandTypes.set(opIdx); + } else if (auto *var = dyn_cast(element.get())) { + unsigned resIdx = var->getVar() - op.result_begin(); + if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx))) + return emitError(loc, "'type' of '" + var->getVar()->name + + "' is already bound"); + if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx))) + return emitError(loc, "'ref' of 'type($" + var->getVar()->name + + ")' is not bound by a prior 'type' directive"); + seenResultTypes.set(resIdx); + } else if (isa(&*element)) { + if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any())) + return emitError(loc, "'operands' 'type' is already bound"); + if (isRefChild && !fmt.allOperandTypes) + return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior " + "'type' directive"); + fmt.allOperandTypes = true; + } else if (isa(&*element)) { + if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any())) + return emitError(loc, "'results' 'type' is already bound"); + if (isRefChild && !fmt.allResultTypes) + return emitError(loc, "'ref' of 'type(results)' is not bound by a prior " + "'type' directive"); + fmt.allResultTypes = true; + } else { + return emitError(loc, "invalid argument to 'type' directive"); + } + return ::mlir::success(); +} + +//===----------------------------------------------------------------------===// +// Interface +//===----------------------------------------------------------------------===// + +void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) { + // TODO: Operator doesn't expose all necessary functionality via + // the const interface. + Operator &op = const_cast(constOp); + if (!op.hasAssemblyFormat()) + return; + + // Parse the format description. + llvm::SourceMgr mgr; + mgr.AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), llvm::SMLoc()); + OperationFormat format(op); + if (failed(FormatParser(mgr, format, op).parse())) { + // Exit the process if format errors are treated as fatal. + if (formatErrorIsFatal) { + // Invoke the interrupt handlers to run the file cleanup handlers. + llvm::sys::RunInterruptHandlers(); + std::exit(1); + } + return; + } + + // Generate the printer and parser based on the parsed format. + format.genParser(op, opClass); + format.genPrinter(op, opClass); +} diff --git a/tools/mlir-tblgen-builder/OpFormatGen.h b/tools/mlir-tblgen-builder/OpFormatGen.h new file mode 100644 index 0000000..ebafec0 --- /dev/null +++ b/tools/mlir-tblgen-builder/OpFormatGen.h @@ -0,0 +1,28 @@ +//===- OpFormatGen.h - MLIR operation format generator ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the interface for generating parsers and printers from the +// declarative format. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_ +#define MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_ + +namespace mlir { +namespace tblgen { +class OpClass; +class Operator; + +// Generate the assembly format for the given operator. +void generateOpFormat(const Operator &constOp, OpClass &opClass); + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_ diff --git a/tools/mlir-tblgen-builder/OpGenHelpers.cpp b/tools/mlir-tblgen-builder/OpGenHelpers.cpp new file mode 100644 index 0000000..b08f3fb --- /dev/null +++ b/tools/mlir-tblgen-builder/OpGenHelpers.cpp @@ -0,0 +1,65 @@ +//===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines helpers used in the op generators. +// +//===----------------------------------------------------------------------===// + +#include "OpGenHelpers.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Regex.h" +#include "llvm/TableGen/Error.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::tblgen; + +cl::OptionCategory opDefGenCat("Options for op definition generators"); + +static cl::opt opIncFilter( + "op-include-regex", + cl::desc("Regex of name of op's to include (no filter if empty)"), + cl::cat(opDefGenCat)); +static cl::opt opExcFilter( + "op-exclude-regex", + cl::desc("Regex of name of op's to exclude (no filter if empty)"), + cl::cat(opDefGenCat)); + +static std::string getOperationName(const Record &def) { + auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name"); + auto opName = def.getValueAsString("opName"); + if (prefix.empty()) + return std::string(opName); + return std::string(llvm::formatv("{0}.{1}", prefix, opName)); +} + +std::vector +mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) { + Record *classDef = recordKeeper.getClass("Op"); + if (!classDef) + PrintFatalError("ERROR: Couldn't find the 'Op' class!\n"); + + llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); + std::vector defs; + for (const auto &def : recordKeeper.getDefs()) { + if (!def.second->isSubClassOf(classDef)) + continue; + // Include if no include filter or include filter matches. + if (!opIncFilter.empty() && + !includeRegex.match(getOperationName(*def.second))) + continue; + // Unless there is an exclude filter and it matches. + if (!opExcFilter.empty() && + excludeRegex.match(getOperationName(*def.second))) + continue; + defs.push_back(def.second.get()); + } + + return defs; +} diff --git a/tools/mlir-tblgen-builder/OpGenHelpers.h b/tools/mlir-tblgen-builder/OpGenHelpers.h new file mode 100644 index 0000000..d4b35d2 --- /dev/null +++ b/tools/mlir-tblgen-builder/OpGenHelpers.h @@ -0,0 +1,30 @@ +//===- OpGenHelpers.h - MLIR operation generator helpers --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines helpers used in the op generators. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ +#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ + +#include "llvm/TableGen/Record.h" +#include + +namespace mlir { +namespace tblgen { + +/// Returns all the op definitions filtered by the user. The filtering is via +/// command-line option "op-include-regex" and "op-exclude-regex". +std::vector +getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper); + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Argument.cpp b/tools/mlir-tblgen-builder/TableGen/Argument.cpp new file mode 100644 index 0000000..ab3e8d3 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Argument.cpp @@ -0,0 +1,21 @@ +//===- Argument.cpp - Argument definitions --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Argument.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +bool NamedTypeConstraint::hasPredicate() const { + return !constraint.getPredicate().isNull(); +} + +bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); } + +bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); } diff --git a/tools/mlir-tblgen-builder/TableGen/Argument.h b/tools/mlir-tblgen-builder/TableGen/Argument.h new file mode 100644 index 0000000..8a05bbc --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Argument.h @@ -0,0 +1,65 @@ +//===- Argument.h - Argument definitions ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file contains definitions for TableGen operation's arguments. +// Operation arguments fall into two categories: +// +// 1. Operands: SSA values operated on by the operation +// 2. Attributes: compile-time known properties that have influence over +// the operation's behavior +// +// These two categories are modelled with the unified argument concept in +// TableGen because we need similar pattern matching mechanisms for them. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_ARGUMENT_H_ +#define MLIR_TABLEGEN_ARGUMENT_H_ + +#include "Attribute.h" +#include "Type.h" +#include "llvm/ADT/PointerUnion.h" +#include + +namespace llvm { +class StringRef; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// A struct wrapping an op attribute and its name together +struct NamedAttribute { + llvm::StringRef name; + Attribute attr; +}; + +// A struct wrapping an op operand/result's constraint and its name together +struct NamedTypeConstraint { + // Returns true if this operand/result has constraint to be satisfied. + bool hasPredicate() const; + // Returns true if this is an optional type constraint. This is a special case + // of variadic for 0 or 1 type. + bool isOptional() const; + // Returns true if this operand/result is variadic. + bool isVariadic() const; + // Returns true if this is a variable length type constraint. This is either + // variadic or optional. + bool isVariableLength() const { return isOptional() || isVariadic(); } + + llvm::StringRef name; + TypeConstraint constraint; +}; + +// Operation argument: either attribute or operand +using Argument = llvm::PointerUnion; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_ARGUMENT_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/AttrOrTypeDef.cpp b/tools/mlir-tblgen-builder/TableGen/AttrOrTypeDef.cpp new file mode 100644 index 0000000..08dea54 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/AttrOrTypeDef.cpp @@ -0,0 +1,250 @@ +//===- AttrOrTypeDef.cpp - AttrOrTypeDef wrapper classes ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "AttrOrTypeDef.h" +#include "Dialect.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// AttrOrTypeBuilder +//===----------------------------------------------------------------------===// + +/// Returns true if this builder is able to infer the MLIRContext parameter. +bool AttrOrTypeBuilder::hasInferredContextParameter() const { + return def->getValueAsBit("hasInferredContextParam"); +} + +//===----------------------------------------------------------------------===// +// AttrOrTypeDef +//===----------------------------------------------------------------------===// + +AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) { + // Populate the builders. + auto *builderList = + dyn_cast_or_null(def->getValueInit("builders")); + if (builderList && !builderList->empty()) { + for (llvm::Init *init : builderList->getValues()) { + AttrOrTypeBuilder builder(cast(init)->getDef(), + def->getLoc()); + + // Ensure that all parameters have names. + for (const AttrOrTypeBuilder::Parameter ¶m : + builder.getParameters()) { + if (!param.getName()) + PrintFatalError(def->getLoc(), "builder parameters must have a name"); + } + builders.emplace_back(builder); + } + } + + // Populate the traits. + if (auto *traitList = def->getValueAsListInit("traits")) { + SmallPtrSet traitSet; + traits.reserve(traitSet.size()); + for (auto *traitInit : *traitList) + if (traitSet.insert(traitInit).second) + traits.push_back(Trait::create(traitInit)); + } +} + +Dialect AttrOrTypeDef::getDialect() const { + auto *dialect = dyn_cast(def->getValue("dialect")->getValue()); + return Dialect(dialect ? dialect->getDef() : nullptr); +} + +StringRef AttrOrTypeDef::getName() const { return def->getName(); } + +StringRef AttrOrTypeDef::getCppClassName() const { + return def->getValueAsString("cppClassName"); +} + +StringRef AttrOrTypeDef::getCppBaseClassName() const { + return def->getValueAsString("cppBaseClassName"); +} + +bool AttrOrTypeDef::hasDescription() const { + const llvm::RecordVal *desc = def->getValue("description"); + return desc && isa(desc->getValue()); +} + +StringRef AttrOrTypeDef::getDescription() const { + return def->getValueAsString("description"); +} + +bool AttrOrTypeDef::hasSummary() const { + const llvm::RecordVal *summary = def->getValue("summary"); + return summary && isa(summary->getValue()); +} + +StringRef AttrOrTypeDef::getSummary() const { + return def->getValueAsString("summary"); +} + +StringRef AttrOrTypeDef::getStorageClassName() const { + return def->getValueAsString("storageClass"); +} + +StringRef AttrOrTypeDef::getStorageNamespace() const { + return def->getValueAsString("storageNamespace"); +} + +bool AttrOrTypeDef::genStorageClass() const { + return def->getValueAsBit("genStorageClass"); +} + +bool AttrOrTypeDef::hasStorageCustomConstructor() const { + return def->getValueAsBit("hasStorageCustomConstructor"); +} + +void AttrOrTypeDef::getParameters( + SmallVectorImpl ¶meters) const { + if (auto *parametersDag = def->getValueAsDag("parameters")) { + for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i) + parameters.push_back(AttrOrTypeParameter(parametersDag, i)); + } +} + +unsigned AttrOrTypeDef::getNumParameters() const { + auto *parametersDag = def->getValueAsDag("parameters"); + return parametersDag ? parametersDag->getNumArgs() : 0; +} + +Optional AttrOrTypeDef::getMnemonic() const { + return def->getValueAsOptionalString("mnemonic"); +} + +Optional AttrOrTypeDef::getPrinterCode() const { + return def->getValueAsOptionalString("printer"); +} + +Optional AttrOrTypeDef::getParserCode() const { + return def->getValueAsOptionalString("parser"); +} + +bool AttrOrTypeDef::genAccessors() const { + return def->getValueAsBit("genAccessors"); +} + +bool AttrOrTypeDef::genVerifyDecl() const { + return def->getValueAsBit("genVerifyDecl"); +} + +Optional AttrOrTypeDef::getExtraDecls() const { + auto value = def->getValueAsString("extraClassDeclaration"); + return value.empty() ? Optional() : value; +} + +ArrayRef AttrOrTypeDef::getLoc() const { return def->getLoc(); } + +bool AttrOrTypeDef::skipDefaultBuilders() const { + return def->getValueAsBit("skipDefaultBuilders"); +} + +bool AttrOrTypeDef::operator==(const AttrOrTypeDef &other) const { + return def == other.def; +} + +bool AttrOrTypeDef::operator<(const AttrOrTypeDef &other) const { + return getName() < other.getName(); +} + +//===----------------------------------------------------------------------===// +// AttrDef +//===----------------------------------------------------------------------===// + +Optional AttrDef::getTypeBuilder() const { + return def->getValueAsOptionalString("typeBuilder"); +} + +bool AttrDef::classof(const AttrOrTypeDef *def) { + return def->getDef()->isSubClassOf("AttrDef"); +} + +//===----------------------------------------------------------------------===// +// AttrOrTypeParameter +//===----------------------------------------------------------------------===// + +StringRef AttrOrTypeParameter::getName() const { + return def->getArgName(index)->getValue(); +} + +Optional AttrOrTypeParameter::getAllocator() const { + llvm::Init *parameterType = def->getArg(index); + if (isa(parameterType)) + return Optional(); + if (auto *param = dyn_cast(parameterType)) + return param->getDef()->getValueAsOptionalString("allocator"); + llvm::PrintFatalError("Parameters DAG arguments must be either strings or " + "defs which inherit from AttrOrTypeParameter\n"); +} + +Optional AttrOrTypeParameter::getComparator() const { + llvm::Init *parameterType = def->getArg(index); + if (isa(parameterType)) + return Optional(); + if (auto *param = dyn_cast(parameterType)) + return param->getDef()->getValueAsOptionalString("comparator"); + llvm::PrintFatalError("Parameters DAG arguments must be either strings or " + "defs which inherit from AttrOrTypeParameter\n"); +} + +StringRef AttrOrTypeParameter::getCppType() const { + auto *parameterType = def->getArg(index); + if (auto *stringType = dyn_cast(parameterType)) + return stringType->getValue(); + if (auto *param = dyn_cast(parameterType)) + return param->getDef()->getValueAsString("cppType"); + llvm::PrintFatalError( + "Parameters DAG arguments must be either strings or defs " + "which inherit from AttrOrTypeParameter\n"); +} + +Optional AttrOrTypeParameter::getSummary() const { + auto *parameterType = def->getArg(index); + if (auto *param = dyn_cast(parameterType)) { + const auto *desc = param->getDef()->getValue("summary"); + if (llvm::StringInit *ci = dyn_cast(desc->getValue())) + return ci->getValue(); + } + return Optional(); +} + +StringRef AttrOrTypeParameter::getSyntax() const { + auto *parameterType = def->getArg(index); + if (auto *stringType = dyn_cast(parameterType)) + return stringType->getValue(); + if (auto *param = dyn_cast(parameterType)) { + const auto *syntax = param->getDef()->getValue("syntax"); + if (syntax && isa(syntax->getValue())) + return cast(syntax->getValue())->getValue(); + return getCppType(); + } + llvm::PrintFatalError("Parameters DAG arguments must be either strings or " + "defs which inherit from AttrOrTypeParameter"); +} + +const llvm::Init *AttrOrTypeParameter::getDef() const { + return def->getArg(index); +} + +//===----------------------------------------------------------------------===// +// AttributeSelfTypeParameter +//===----------------------------------------------------------------------===// + +bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) { + const llvm::Init *paramDef = param->getDef(); + if (auto *paramDefInit = dyn_cast(paramDef)) + return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter"); + return false; +} diff --git a/tools/mlir-tblgen-builder/TableGen/AttrOrTypeDef.h b/tools/mlir-tblgen-builder/TableGen/AttrOrTypeDef.h new file mode 100644 index 0000000..8c279f4 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/AttrOrTypeDef.h @@ -0,0 +1,229 @@ +//===-- AttrOrTypeDef.h - Wrapper for attr and type definitions -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// AttrOrTypeDef, AttrDef, and TypeDef wrappers to simplify using TableGen +// Record defining a MLIR attributes and types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_ATTRORTYPEDEF_H +#define MLIR_TABLEGEN_ATTRORTYPEDEF_H + +#include "mlir/Support/LLVM.h" +#include "Builder.h" +#include "Trait.h" + +namespace llvm { +class DagInit; +class Record; +class SMLoc; +} // namespace llvm + +namespace mlir { +namespace tblgen { +class Dialect; +class AttrOrTypeParameter; + +//===----------------------------------------------------------------------===// +// AttrOrTypeBuilder +//===----------------------------------------------------------------------===// + +/// Wrapper class that represents a Tablegen AttrOrTypeBuilder. +class AttrOrTypeBuilder : public Builder { +public: + using Builder::Builder; + + /// Returns true if this builder is able to infer the MLIRContext parameter. + bool hasInferredContextParameter() const; +}; + +//===----------------------------------------------------------------------===// +// AttrOrTypeDef +//===----------------------------------------------------------------------===// + +/// Wrapper class that contains a TableGen AttrOrTypeDef's record and provides +/// helper methods for accessing them. +class AttrOrTypeDef { +public: + explicit AttrOrTypeDef(const llvm::Record *def); + + // Get the dialect for which this def belongs. + Dialect getDialect() const; + + // Returns the name of this AttrOrTypeDef record. + StringRef getName() const; + + // Query functions for the documentation of the def. + bool hasDescription() const; + StringRef getDescription() const; + bool hasSummary() const; + StringRef getSummary() const; + + // Returns the name of the C++ class to generate. + StringRef getCppClassName() const; + + // Returns the name of the C++ base class to use when generating this def. + StringRef getCppBaseClassName() const; + + // Returns the name of the storage class for this def. + StringRef getStorageClassName() const; + + // Returns the C++ namespace for this def's storage class. + StringRef getStorageNamespace() const; + + // Returns true if we should generate the storage class. + bool genStorageClass() const; + + // Indicates whether or not to generate the storage class constructor. + bool hasStorageCustomConstructor() const; + + // Fill a list with this def's parameters. See AttrOrTypeDef in OpBase.td for + // documentation of parameter usage. + void getParameters(SmallVectorImpl &) const; + + // Return the number of parameters + unsigned getNumParameters() const; + + // Return the keyword/mnemonic to use in the printer/parser methods if we are + // supposed to auto-generate them. + Optional getMnemonic() const; + + // Returns the code to use as the types printer method. If not specified, + // return a non-value. Otherwise, return the contents of that code block. + Optional getPrinterCode() const; + + // Returns the code to use as the parser method. If not specified, returns + // None. Otherwise, returns the contents of that code block. + Optional getParserCode() const; + + // Returns true if the accessors based on the parameters should be generated. + bool genAccessors() const; + + // Return true if we need to generate the verify declaration and getChecked + // method. + bool genVerifyDecl() const; + + // Returns the def's extra class declaration code. + Optional getExtraDecls() const; + + // Get the code location (for error printing). + ArrayRef getLoc() const; + + // Returns true if the default get/getChecked methods should be skipped during + // generation. + bool skipDefaultBuilders() const; + + // Returns the builders of this def. + ArrayRef getBuilders() const { return builders; } + + // Returns the traits of this def. + ArrayRef getTraits() const { return traits; } + + // Returns whether two AttrOrTypeDefs are equal by checking the equality of + // the underlying record. + bool operator==(const AttrOrTypeDef &other) const; + + // Compares two AttrOrTypeDefs by comparing the names of the dialects. + bool operator<(const AttrOrTypeDef &other) const; + + // Returns whether the AttrOrTypeDef is defined. + operator bool() const { return def != nullptr; } + + // Return the underlying def. + const llvm::Record *getDef() const { return def; } + +protected: + const llvm::Record *def; + + // The builders of this definition. + SmallVector builders; + + // The traits of this definition. + SmallVector traits; +}; + +//===----------------------------------------------------------------------===// +// AttrDef +//===----------------------------------------------------------------------===// + +/// This class represents a wrapper around a tablegen AttrDef record. +class AttrDef : public AttrOrTypeDef { +public: + using AttrOrTypeDef::AttrOrTypeDef; + + // Returns the attributes value type builder code block, or None if it doesn't + // have one. + Optional getTypeBuilder() const; + + static bool classof(const AttrOrTypeDef *def); +}; + +//===----------------------------------------------------------------------===// +// TypeDef +//===----------------------------------------------------------------------===// + +/// This class represents a wrapper around a tablegen TypeDef record. +class TypeDef : public AttrOrTypeDef { +public: + using AttrOrTypeDef::AttrOrTypeDef; +}; + +//===----------------------------------------------------------------------===// +// AttrOrTypeParameter +//===----------------------------------------------------------------------===// + +// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to +// AttrOrTypeDefs to parameterize them. +class AttrOrTypeParameter { +public: + explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index) + : def(def), index(index) {} + + // Get the parameter name. + StringRef getName() const; + + // If specified, get the custom allocator code for this parameter. + Optional getAllocator() const; + + // If specified, get the custom comparator code for this parameter. + Optional getComparator() const; + + // Get the C++ type of this parameter. + StringRef getCppType() const; + + // Get a description of this parameter for documentation purposes. + Optional getSummary() const; + + // Get the assembly syntax documentation. + StringRef getSyntax() const; + + // Return the underlying def of this parameter. + const llvm::Init *getDef() const; + +private: + /// The underlying tablegen parameter list this parameter is a part of. + const llvm::DagInit *def; + /// The index of the parameter within the parameter list (`def`). + unsigned index; +}; + +//===----------------------------------------------------------------------===// +// AttributeSelfTypeParameter +//===----------------------------------------------------------------------===// + +// A wrapper class for the AttributeSelfTypeParameter tblgen class. This +// represents a parameter of mlir::Type that is the value type of an AttrDef. +class AttributeSelfTypeParameter : public AttrOrTypeParameter { +public: + static bool classof(const AttrOrTypeParameter *param); +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_ATTRORTYPEDEF_H diff --git a/tools/mlir-tblgen-builder/TableGen/Attribute.cpp b/tools/mlir-tblgen-builder/TableGen/Attribute.cpp new file mode 100644 index 0000000..26a5cba --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Attribute.cpp @@ -0,0 +1,296 @@ +//===- Attribute.cpp - Attribute wrapper class ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Attribute wrapper to simplify using TableGen Record defining a MLIR +// Attribute. +// +//===----------------------------------------------------------------------===// + +#include "Format.h" +#include "Operator.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +using llvm::DefInit; +using llvm::Init; +using llvm::Record; +using llvm::StringInit; + +// Returns the initializer's value as string if the given TableGen initializer +// is a code or string initializer. Returns the empty StringRef otherwise. +static StringRef getValueAsString(const Init *init) { + if (const auto *str = dyn_cast(init)) + return str->getValue().trim(); + return {}; +} + +AttrConstraint::AttrConstraint(const Record *record) + : Constraint(Constraint::CK_Attr, record) { + assert(isSubClassOf("AttrConstraint") && + "must be subclass of TableGen 'AttrConstraint' class"); +} + +bool AttrConstraint::isSubClassOf(StringRef className) const { + return def->isSubClassOf(className); +} + +Attribute::Attribute(const Record *record) : AttrConstraint(record) { + assert(record->isSubClassOf("Attr") && + "must be subclass of TableGen 'Attr' class"); +} + +Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} + +bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); } + +bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); } + +bool Attribute::isSymbolRefAttr() const { + StringRef defName = def->getName(); + if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr") + return true; + return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr"); +} + +bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); } + +StringRef Attribute::getStorageType() const { + const auto *init = def->getValueInit("storageType"); + auto type = getValueAsString(init); + if (type.empty()) + return "Attribute"; + return type; +} + +StringRef Attribute::getReturnType() const { + const auto *init = def->getValueInit("returnType"); + return getValueAsString(init); +} + +// Return the type constraint corresponding to the type of this attribute, or +// None if this is not a TypedAttr. +llvm::Optional Attribute::getValueType() const { + if (auto *defInit = dyn_cast(def->getValueInit("valueType"))) + return Type(defInit->getDef()); + return llvm::None; +} + +StringRef Attribute::getConvertFromStorageCall() const { + const auto *init = def->getValueInit("convertFromStorage"); + return getValueAsString(init); +} + +bool Attribute::isConstBuildable() const { + const auto *init = def->getValueInit("constBuilderCall"); + return !getValueAsString(init).empty(); +} + +StringRef Attribute::getConstBuilderTemplate() const { + const auto *init = def->getValueInit("constBuilderCall"); + return getValueAsString(init); +} + +Attribute Attribute::getBaseAttr() const { + if (const auto *defInit = + llvm::dyn_cast(def->getValueInit("baseAttr"))) { + return Attribute(defInit).getBaseAttr(); + } + return *this; +} + +bool Attribute::hasDefaultValue() const { + const auto *init = def->getValueInit("defaultValue"); + return !getValueAsString(init).empty(); +} + +StringRef Attribute::getDefaultValue() const { + const auto *init = def->getValueInit("defaultValue"); + return getValueAsString(init); +} + +bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); } + +StringRef Attribute::getAttrDefName() const { + if (def->isAnonymous()) { + return getBaseAttr().def->getName(); + } + return def->getName(); +} + +StringRef Attribute::getDerivedCodeBody() const { + assert(isDerivedAttr() && "only derived attribute has 'body' field"); + return def->getValueAsString("body"); +} + +Dialect Attribute::getDialect() const { + const llvm::RecordVal *record = def->getValue("dialect"); + if (record && record->getValue()) { + if (DefInit *init = dyn_cast(record->getValue())) + return Dialect(init->getDef()); + } + return Dialect(nullptr); +} + +ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { + assert(def->isSubClassOf("ConstantAttr") && + "must be subclass of TableGen 'ConstantAttr' class"); +} + +Attribute ConstantAttr::getAttribute() const { + return Attribute(def->getValueAsDef("attr")); +} + +StringRef ConstantAttr::getConstantValue() const { + return def->getValueAsString("value"); +} + +EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) { + assert(isSubClassOf("EnumAttrCaseInfo") && + "must be subclass of TableGen 'EnumAttrInfo' class"); +} + +EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) + : EnumAttrCase(init->getDef()) {} + +bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); } + +StringRef EnumAttrCase::getSymbol() const { + return def->getValueAsString("symbol"); +} + +StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); } + +int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } + +const llvm::Record &EnumAttrCase::getDef() const { return *def; } + +EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { + assert(isSubClassOf("EnumAttrInfo") && + "must be subclass of TableGen 'EnumAttr' class"); +} + +EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} + +EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} + +bool EnumAttr::classof(const Attribute *attr) { + return attr->isSubClassOf("EnumAttrInfo"); +} + +bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } + +StringRef EnumAttr::getEnumClassName() const { + return def->getValueAsString("className"); +} + +StringRef EnumAttr::getCppNamespace() const { + return def->getValueAsString("cppNamespace"); +} + +StringRef EnumAttr::getUnderlyingType() const { + return def->getValueAsString("underlyingType"); +} + +StringRef EnumAttr::getUnderlyingToSymbolFnName() const { + return def->getValueAsString("underlyingToSymbolFnName"); +} + +StringRef EnumAttr::getStringToSymbolFnName() const { + return def->getValueAsString("stringToSymbolFnName"); +} + +StringRef EnumAttr::getSymbolToStringFnName() const { + return def->getValueAsString("symbolToStringFnName"); +} + +StringRef EnumAttr::getSymbolToStringFnRetType() const { + return def->getValueAsString("symbolToStringFnRetType"); +} + +StringRef EnumAttr::getMaxEnumValFnName() const { + return def->getValueAsString("maxEnumValFnName"); +} + +std::vector EnumAttr::getAllCases() const { + const auto *inits = def->getValueAsListInit("enumerants"); + + std::vector cases; + cases.reserve(inits->size()); + + for (const llvm::Init *init : *inits) { + cases.push_back(EnumAttrCase(cast(init))); + } + + return cases; +} + +bool EnumAttr::genSpecializedAttr() const { + return def->getValueAsBit("genSpecializedAttr"); +} + +llvm::Record *EnumAttr::getBaseAttrClass() const { + return def->getValueAsDef("baseAttrClass"); +} + +StringRef EnumAttr::getSpecializedAttrClassName() const { + return def->getValueAsString("specializedAttrClassName"); +} + +StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) { + assert(def->isSubClassOf("StructFieldAttr") && + "must be subclass of TableGen 'StructFieldAttr' class"); +} + +StructFieldAttr::StructFieldAttr(const llvm::Record &record) + : StructFieldAttr(&record) {} + +StructFieldAttr::StructFieldAttr(const llvm::DefInit *init) + : StructFieldAttr(init->getDef()) {} + +StringRef StructFieldAttr::getName() const { + return def->getValueAsString("name"); +} + +Attribute StructFieldAttr::getType() const { + auto init = def->getValueInit("type"); + return Attribute(cast(init)); +} + +StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) { + assert(isSubClassOf("StructAttr") && + "must be subclass of TableGen 'StructAttr' class"); +} + +StructAttr::StructAttr(const llvm::DefInit *init) + : StructAttr(init->getDef()) {} + +StringRef StructAttr::getStructClassName() const { + return def->getValueAsString("className"); +} + +StringRef StructAttr::getCppNamespace() const { + Dialect dialect(def->getValueAsDef("dialect")); + return dialect.getCppNamespace(); +} + +std::vector StructAttr::getAllFields() const { + std::vector attributes; + + const auto *inits = def->getValueAsListInit("fields"); + attributes.reserve(inits->size()); + + for (const llvm::Init *init : *inits) { + attributes.emplace_back(cast(init)); + } + + return attributes; +} + +const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface"; diff --git a/tools/mlir-tblgen-builder/TableGen/Attribute.h b/tools/mlir-tblgen-builder/TableGen/Attribute.h new file mode 100644 index 0000000..7401e42 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Attribute.h @@ -0,0 +1,247 @@ +//===- Attribute.h - Attribute wrapper class --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Attribute wrapper to simplify using TableGen Record defining a MLIR +// Attribute. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_ATTRIBUTE_H_ +#define MLIR_TABLEGEN_ATTRIBUTE_H_ + +#include "mlir/Support/LLVM.h" +#include "Constraint.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class DefInit; +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { +class Dialect; +class Type; + +// Wrapper class with helper methods for accessing attribute constraints defined +// in TableGen. +class AttrConstraint : public Constraint { +public: + explicit AttrConstraint(const llvm::Record *record); + + static bool classof(const Constraint *c) { return c->getKind() == CK_Attr; } + + // Returns true if this constraint is a subclass of the given `className` + // class defined in TableGen. + bool isSubClassOf(StringRef className) const; +}; + +// Wrapper class providing helper methods for accessing MLIR Attribute defined +// in TableGen. This class should closely reflect what is defined as class +// `Attr` in TableGen. +class Attribute : public AttrConstraint { +public: + explicit Attribute(const llvm::Record *record); + explicit Attribute(const llvm::DefInit *init); + + // Returns the storage type if set. Returns the default storage type + // ("Attribute") otherwise. + StringRef getStorageType() const; + + // Returns the return type for this attribute. + StringRef getReturnType() const; + + // Return the type constraint corresponding to the type of this attribute, or + // None if this is not a TypedAttr. + llvm::Optional getValueType() const; + + // Returns the template getter method call which reads this attribute's + // storage and returns the value as of the desired return type. + // The call will contain a `{0}` which will be expanded to this attribute. + StringRef getConvertFromStorageCall() const; + + // Returns true if this attribute can be built from a constant value. + bool isConstBuildable() const; + + // Returns the template that can be used to produce an instance of the + // attribute. + // Syntax: `$builder` should be replaced with a builder, `$0` should be + // replaced with the constant value. + StringRef getConstBuilderTemplate() const; + + // Returns the base-level attribute that this attribute constraint is + // built upon. + Attribute getBaseAttr() const; + + // Returns whether this attribute has a default value. + bool hasDefaultValue() const; + // Returns the default value for this attribute. + StringRef getDefaultValue() const; + + // Returns whether this attribute is optional. + bool isOptional() const; + + // Returns true if this attribute is a derived attribute (i.e., a subclass + // of `DerivedAttr`). + bool isDerivedAttr() const; + + // Returns true if this attribute is a type attribute (i.e., a subclass + // of `TypeAttrBase`). + bool isTypeAttr() const; + + // Returns true if this attribute is a symbol reference attribute (i.e., a + // subclass of `SymbolRefAttr` or `FlatSymbolRefAttr`). + bool isSymbolRefAttr() const; + + // Returns true if this attribute is an enum attribute (i.e., a subclass of + // `EnumAttrInfo`) + bool isEnumAttr() const; + + // Returns this attribute's TableGen def name. If this is an `OptionalAttr` + // or `DefaultValuedAttr` without explicit name, returns the base attribute's + // name. + StringRef getAttrDefName() const; + + // Returns the code body for derived attribute. Aborts if this is not a + // derived attribute. + StringRef getDerivedCodeBody() const; + + // Returns the dialect for the attribute if defined. + Dialect getDialect() const; +}; + +// Wrapper class providing helper methods for accessing MLIR constant attribute +// defined in TableGen. This class should closely reflect what is defined as +// class `ConstantAttr` in TableGen. +class ConstantAttr { +public: + explicit ConstantAttr(const llvm::DefInit *init); + + // Returns the attribute kind. + Attribute getAttribute() const; + + // Returns the constant value. + StringRef getConstantValue() const; + +private: + // The TableGen definition of this constant attribute. + const llvm::Record *def; +}; + +// Wrapper class providing helper methods for accessing enum attribute cases +// defined in TableGen. This is used for enum attribute case backed by both +// StringAttr and IntegerAttr. +class EnumAttrCase : public Attribute { +public: + explicit EnumAttrCase(const llvm::Record *record); + explicit EnumAttrCase(const llvm::DefInit *init); + + // Returns true if this EnumAttrCase is backed by a StringAttr. + bool isStrCase() const; + + // Returns the symbol of this enum attribute case. + StringRef getSymbol() const; + + // Returns the textual representation of this enum attribute case. + StringRef getStr() const; + + // Returns the value of this enum attribute case. + int64_t getValue() const; + + // Returns the TableGen definition this EnumAttrCase was constructed from. + const llvm::Record &getDef() const; +}; + +// Wrapper class providing helper methods for accessing enum attributes defined +// in TableGen.This is used for enum attribute case backed by both StringAttr +// and IntegerAttr. +class EnumAttr : public Attribute { +public: + explicit EnumAttr(const llvm::Record *record); + explicit EnumAttr(const llvm::Record &record); + explicit EnumAttr(const llvm::DefInit *init); + + static bool classof(const Attribute *attr); + + // Returns true if this is a bit enum attribute. + bool isBitEnum() const; + + // Returns the enum class name. + StringRef getEnumClassName() const; + + // Returns the C++ namespaces this enum class should be placed in. + StringRef getCppNamespace() const; + + // Returns the underlying type. + StringRef getUnderlyingType() const; + + // Returns the name of the utility function that converts a value of the + // underlying type to the corresponding symbol. + StringRef getUnderlyingToSymbolFnName() const; + + // Returns the name of the utility function that converts a string to the + // corresponding symbol. + StringRef getStringToSymbolFnName() const; + + // Returns the name of the utility function that converts a symbol to the + // corresponding string. + StringRef getSymbolToStringFnName() const; + + // Returns the return type of the utility function that converts a symbol to + // the corresponding string. + StringRef getSymbolToStringFnRetType() const; + + // Returns the name of the utilit function that returns the max enum value + // used within the enum class. + StringRef getMaxEnumValFnName() const; + + // Returns all allowed cases for this enum attribute. + std::vector getAllCases() const; + + bool genSpecializedAttr() const; + llvm::Record *getBaseAttrClass() const; + StringRef getSpecializedAttrClassName() const; +}; + +class StructFieldAttr { +public: + explicit StructFieldAttr(const llvm::Record *record); + explicit StructFieldAttr(const llvm::Record &record); + explicit StructFieldAttr(const llvm::DefInit *init); + + StringRef getName() const; + Attribute getType() const; + +private: + const llvm::Record *def; +}; + +// Wrapper class providing helper methods for accessing struct attributes +// defined in TableGen. +class StructAttr : public Attribute { +public: + explicit StructAttr(const llvm::Record *record); + explicit StructAttr(const llvm::Record &record) : StructAttr(&record){}; + explicit StructAttr(const llvm::DefInit *init); + + // Returns the struct class name. + StringRef getStructClassName() const; + + // Returns the C++ namespaces this struct class should be placed in. + StringRef getCppNamespace() const; + + std::vector getAllFields() const; +}; + +// Name of infer type op interface. +extern const char *inferTypeOpInterface; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_ATTRIBUTE_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Builder.cpp b/tools/mlir-tblgen-builder/TableGen/Builder.cpp new file mode 100644 index 0000000..a9d9034 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Builder.cpp @@ -0,0 +1,74 @@ +//===- Builder.cpp - Builder definitions ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Builder.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// Builder::Parameter +//===----------------------------------------------------------------------===// + +/// Return a string containing the C++ type of this parameter. +StringRef Builder::Parameter::getCppType() const { + if (const auto *stringInit = dyn_cast(def)) + return stringInit->getValue(); + const llvm::Record *record = cast(def)->getDef(); + return record->getValueAsString("type"); +} + +/// Return an optional string containing the default value to use for this +/// parameter. +Optional Builder::Parameter::getDefaultValue() const { + if (isa(def)) + return llvm::None; + const llvm::Record *record = cast(def)->getDef(); + Optional value = record->getValueAsOptionalString("defaultValue"); + return value && !value->empty() ? value : llvm::None; +} + +//===----------------------------------------------------------------------===// +// Builder +//===----------------------------------------------------------------------===// + +Builder::Builder(const llvm::Record *record, ArrayRef loc) + : def(record) { + // Initialize the parameters of the builder. + const llvm::DagInit *dag = def->getValueAsDag("dagParams"); + auto *defInit = dyn_cast(dag->getOperator()); + if (!defInit || !defInit->getDef()->getName().equals("ins")) + PrintFatalError(def->getLoc(), "expected 'ins' in builders"); + + bool seenDefaultValue = false; + for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) { + const llvm::StringInit *paramName = dag->getArgName(i); + const llvm::Init *paramValue = dag->getArg(i); + Parameter param(paramName ? paramName->getValue() : Optional(), + paramValue); + + // Similarly to C++, once an argument with a default value is detected, the + // following arguments must have default values as well. + if (param.getDefaultValue()) { + seenDefaultValue = true; + } else if (seenDefaultValue) { + PrintFatalError(loc, + "expected an argument with default value after other " + "arguments with default values"); + } + parameters.emplace_back(param); + } +} + +/// Return an optional string containing the body of the builder. +Optional Builder::getBody() const { + Optional body = def->getValueAsOptionalString("body"); + return body && !body->empty() ? body : llvm::None; +} diff --git a/tools/mlir-tblgen-builder/TableGen/Builder.h b/tools/mlir-tblgen-builder/TableGen/Builder.h new file mode 100644 index 0000000..b901c84 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Builder.h @@ -0,0 +1,85 @@ +//===- Builder.h - Builder classes ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Builder wrapper to simplify using TableGen Record for building +// operations/types/etc. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_BUILDER_H_ +#define MLIR_TABLEGEN_BUILDER_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class Init; +class Record; +class SMLoc; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +/// Wrapper class with helper methods for accessing Builders defined in +/// TableGen. +class Builder { +public: + /// This class represents a single parameter to a builder method. + class Parameter { + public: + /// Return a string containing the C++ type of this parameter. + StringRef getCppType() const; + + /// Return an optional string containing the name of this parameter. If + /// None, no name was specified for this parameter by the user. + Optional getName() const { return name; } + + /// Return an optional string containing the default value to use for this + /// parameter. + Optional getDefaultValue() const; + + private: + Parameter(Optional name, const llvm::Init *def) + : name(name), def(def) {} + + /// The optional name of the parameter. + Optional name; + + /// The tablegen definition of the parameter. This is either a StringInit, + /// or a CArg DefInit. + const llvm::Init *def; + + // Allow access to the constructor. + friend Builder; + }; + + /// Construct a builder from the given Record instance. + Builder(const llvm::Record *record, ArrayRef loc); + + /// Return a list of parameters used in this build method. + ArrayRef getParameters() const { return parameters; } + + /// Return an optional string containing the body of the builder. + Optional getBody() const; + +protected: + /// The TableGen definition of this builder. + const llvm::Record *def; + +private: + /// A collection of parameters to the builder. + SmallVector parameters; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_BUILDER_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/CodeGenHelpers.h b/tools/mlir-tblgen-builder/TableGen/CodeGenHelpers.h new file mode 100644 index 0000000..3da4758 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/CodeGenHelpers.h @@ -0,0 +1,68 @@ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines common utilities for generating C++ from tablegen +// structures. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_CODEGENHELPERS_H +#define MLIR_TABLEGEN_CODEGENHELPERS_H + +#include "Dialect.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace tblgen { + +// Simple RAII helper for defining ifdef-undef-endif scopes. +class IfDefScope { +public: + IfDefScope(llvm::StringRef name, llvm::raw_ostream &os) + : name(name.str()), os(os) { + os << "#ifdef " << name << "\n" + << "#undef " << name << "\n\n"; + } + ~IfDefScope() { os << "\n#endif // " << name << "\n\n"; } + +private: + std::string name; + llvm::raw_ostream &os; +}; + +// A helper RAII class to emit nested namespaces for this op. +class NamespaceEmitter { +public: + NamespaceEmitter(raw_ostream &os, const Dialect &dialect) : os(os) { + if (!dialect) + return; + emitNamespaceStarts(os, dialect.getCppNamespace()); + } + NamespaceEmitter(raw_ostream &os, StringRef cppNamespace) : os(os) { + emitNamespaceStarts(os, cppNamespace); + } + + ~NamespaceEmitter() { + for (StringRef ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; + } + +private: + void emitNamespaceStarts(raw_ostream &os, StringRef cppNamespace) { + llvm::SplitString(cppNamespace, namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + } + raw_ostream &os; + SmallVector namespaces; +}; + +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TABLEGEN_CODEGENHELPERS_H diff --git a/tools/mlir-tblgen-builder/TableGen/Constraint.cpp b/tools/mlir-tblgen-builder/TableGen/Constraint.cpp new file mode 100644 index 0000000..c99846a --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Constraint.cpp @@ -0,0 +1,70 @@ +//===- Constraint.cpp - Constraint class ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Constraint wrapper to simplify using TableGen Record for constraints. +// +//===----------------------------------------------------------------------===// + +#include "Constraint.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +Constraint::Constraint(const llvm::Record *record) + : def(record), kind(CK_Uncategorized) { + // Look through OpVariable's to their constraint. + if (def->isSubClassOf("OpVariable")) + def = def->getValueAsDef("constraint"); + if (def->isSubClassOf("TypeConstraint")) { + kind = CK_Type; + } else if (def->isSubClassOf("AttrConstraint")) { + kind = CK_Attr; + } else if (def->isSubClassOf("RegionConstraint")) { + kind = CK_Region; + } else if (def->isSubClassOf("SuccessorConstraint")) { + kind = CK_Successor; + } else { + assert(def->isSubClassOf("Constraint")); + } +} + +Constraint::Constraint(Kind kind, const llvm::Record *record) + : def(record), kind(kind) { + // Look through OpVariable's to their constraint. + if (def->isSubClassOf("OpVariable")) + def = def->getValueAsDef("constraint"); +} + +Pred Constraint::getPredicate() const { + auto *val = def->getValue("predicate"); + + // If no predicate is specified, then return the null predicate (which + // corresponds to true). + if (!val) + return Pred(); + + const auto *pred = dyn_cast(val->getValue()); + return Pred(pred); +} + +std::string Constraint::getConditionTemplate() const { + return getPredicate().getCondition(); +} + +StringRef Constraint::getSummary() const { + if (Optional summary = def->getValueAsOptionalString("summary")) + return *summary; + return def->getName(); +} + +AppliedConstraint::AppliedConstraint(Constraint &&constraint, + llvm::StringRef self, + std::vector &&entities) + : constraint(constraint), self(std::string(self)), + entities(std::move(entities)) {} diff --git a/tools/mlir-tblgen-builder/TableGen/Constraint.h b/tools/mlir-tblgen-builder/TableGen/Constraint.h new file mode 100644 index 0000000..1870ae9 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Constraint.h @@ -0,0 +1,88 @@ +//===- Constraint.h - Constraint class --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Constraint wrapper to simplify using TableGen Record for constraints. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_CONSTRAINT_H_ +#define MLIR_TABLEGEN_CONSTRAINT_H_ + +#include "mlir/Support/LLVM.h" +#include "Predicate.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// Wrapper class with helper methods for accessing Constraint defined in +// TableGen. +class Constraint { +public: + Constraint(const llvm::Record *record); + + bool operator==(const Constraint &that) { return def == that.def; } + bool operator!=(const Constraint &that) { return def != that.def; } + + // Returns the predicate for this constraint. + Pred getPredicate() const; + + // Returns the condition template that can be used to check if a type or + // attribute satisfies this constraint. The template may contain "{0}" that + // must be substituted with an expression returning an mlir::Type or + // mlir::Attribute. + std::string getConditionTemplate() const; + + // Returns the user-readable description of this constraint. If the + // description is not provided, returns the TableGen def name. + StringRef getSummary() const; + + // Constraint kind + enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized }; + + Kind getKind() const { return kind; } + + /// Get an opaque pointer to the constraint. + const void *getAsOpaquePointer() const { return def; } + /// Construct a constraint from the opaque pointer representation. + static Constraint getFromOpaquePointer(const void *ptr) { + return Constraint(reinterpret_cast(ptr)); + } + +protected: + Constraint(Kind kind, const llvm::Record *record); + + // The TableGen definition of this constraint. + const llvm::Record *def; + +private: + // What kind of constraint this is. + Kind kind; +}; + +// An constraint and the concrete entities to place the constraint on. +struct AppliedConstraint { + AppliedConstraint(Constraint &&constraint, StringRef self, + std::vector &&entities); + + Constraint constraint; + // The symbol to replace `$_self` special placeholder in the constraint. + std::string self; + // The symbols to replace `$N` positional placeholders in the constraint. + std::vector entities; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_CONSTRAINT_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Dialect.cpp b/tools/mlir-tblgen-builder/TableGen/Dialect.cpp new file mode 100644 index 0000000..25d6b30 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Dialect.cpp @@ -0,0 +1,94 @@ +//===- Dialect.cpp - Dialect wrapper class --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect. +// +//===----------------------------------------------------------------------===// + +#include "Dialect.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; +Dialect::Dialect(const llvm::Record *def) : def(def) { + if (def == nullptr) + return; + for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects")) + dependentDialects.push_back(dialect); +} + +StringRef Dialect::getName() const { return def->getValueAsString("name"); } + +StringRef Dialect::getCppNamespace() const { + return def->getValueAsString("cppNamespace"); +} + +std::string Dialect::getCppClassName() const { + // Simply use the name and remove any '_' tokens. + std::string cppName = def->getName().str(); + llvm::erase_if(cppName, [](char c) { return c == '_'; }); + return cppName; +} + +static StringRef getAsStringOrEmpty(const llvm::Record &record, + StringRef fieldName) { + if (auto valueInit = record.getValueInit(fieldName)) { + if (llvm::isa(valueInit)) + return record.getValueAsString(fieldName); + } + return ""; +} + +StringRef Dialect::getSummary() const { + return getAsStringOrEmpty(*def, "summary"); +} + +StringRef Dialect::getDescription() const { + return getAsStringOrEmpty(*def, "description"); +} + +ArrayRef Dialect::getDependentDialects() const { + return dependentDialects; +} + +llvm::Optional Dialect::getExtraClassDeclaration() const { + auto value = def->getValueAsString("extraClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} + +bool Dialect::hasCanonicalizer() const { + return def->getValueAsBit("hasCanonicalizer"); +} + +bool Dialect::hasConstantMaterializer() const { + return def->getValueAsBit("hasConstantMaterializer"); +} + +bool Dialect::hasOperationAttrVerify() const { + return def->getValueAsBit("hasOperationAttrVerify"); +} + +bool Dialect::hasRegionArgAttrVerify() const { + return def->getValueAsBit("hasRegionArgAttrVerify"); +} + +bool Dialect::hasRegionResultAttrVerify() const { + return def->getValueAsBit("hasRegionResultAttrVerify"); +} + +bool Dialect::hasOperationInterfaceFallback() const { + return def->getValueAsBit("hasOperationInterfaceFallback"); +} + +bool Dialect::operator==(const Dialect &other) const { + return def == other.def; +} + +bool Dialect::operator<(const Dialect &other) const { + return getName() < other.getName(); +} diff --git a/tools/mlir-tblgen-builder/TableGen/Dialect.h b/tools/mlir-tblgen-builder/TableGen/Dialect.h new file mode 100644 index 0000000..609bf4e --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Dialect.h @@ -0,0 +1,91 @@ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_DIALECT_H_ +#define MLIR_TABLEGEN_DIALECT_H_ + +#include "mlir/Support/LLVM.h" +#include +#include + +namespace llvm { +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { +// Wrapper class that contains a MLIR dialect's information defined in TableGen +// and provides helper methods for accessing them. +class Dialect { +public: + explicit Dialect(const llvm::Record *def); + + // Returns the name of this dialect. + StringRef getName() const; + + // Returns the C++ namespaces that ops of this dialect should be placed into. + StringRef getCppNamespace() const; + + // Returns this dialect's C++ class name. + std::string getCppClassName() const; + + // Returns the summary description of the dialect. Returns empty string if + // none. + StringRef getSummary() const; + + // Returns the description of the dialect. Returns empty string if none. + StringRef getDescription() const; + + // Returns the list of dialect (class names) that this dialect depends on. + // These are dialects that will be loaded on construction of this dialect. + ArrayRef getDependentDialects() const; + + // Returns the dialects extra class declaration code. + llvm::Optional getExtraClassDeclaration() const; + + /// Returns true if this dialect has a canonicalizer. + bool hasCanonicalizer() const; + + // Returns true if this dialect has a constant materializer. + bool hasConstantMaterializer() const; + + /// Returns true if this dialect has an operation attribute verifier. + bool hasOperationAttrVerify() const; + + /// Returns true if this dialect has a region argument attribute verifier. + bool hasRegionArgAttrVerify() const; + + /// Returns true if this dialect has a region result attribute verifier. + bool hasRegionResultAttrVerify() const; + + /// Returns true if this dialect has fallback interfaces for its operations. + bool hasOperationInterfaceFallback() const; + + // Returns whether two dialects are equal by checking the equality of the + // underlying record. + bool operator==(const Dialect &other) const; + + bool operator!=(const Dialect &other) const { return !(*this == other); } + + // Compares two dialects by comparing the names of the dialects. + bool operator<(const Dialect &other) const; + + // Returns whether the dialect is defined. + explicit operator bool() const { return def != nullptr; } + +private: + const llvm::Record *def; + std::vector dependentDialects; +}; +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_DIALECT_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Format.cpp b/tools/mlir-tblgen-builder/TableGen/Format.cpp new file mode 100644 index 0000000..938f5f6 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Format.cpp @@ -0,0 +1,194 @@ +//===- Format.cpp - Utilities for String Format ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines utilities for formatting strings. They are specially +// tailored to the needs of TableGen'ing op definitions and rewrite rules, +// so they are not expected to be used as widely applicable utilities. +// +//===----------------------------------------------------------------------===// + +#include "Format.h" +#include + +using namespace mlir; +using namespace mlir::tblgen; + +// Marker to indicate an error happened when replacing a placeholder. +const char *const kMarkerForNoSubst = ""; + +FmtContext &FmtContext::addSubst(StringRef placeholder, Twine subst) { + customSubstMap[placeholder] = subst.str(); + return *this; +} + +FmtContext &FmtContext::withBuilder(Twine subst) { + builtinSubstMap[PHKind::Builder] = subst.str(); + return *this; +} + +FmtContext &FmtContext::withOp(Twine subst) { + builtinSubstMap[PHKind::Op] = subst.str(); + return *this; +} + +FmtContext &FmtContext::withSelf(Twine subst) { + builtinSubstMap[PHKind::Self] = subst.str(); + return *this; +} + +Optional +FmtContext::getSubstFor(FmtContext::PHKind placeholder) const { + if (placeholder == FmtContext::PHKind::None || + placeholder == FmtContext::PHKind::Custom) + return {}; + auto it = builtinSubstMap.find(placeholder); + if (it == builtinSubstMap.end()) + return {}; + return StringRef(it->second); +} + +Optional FmtContext::getSubstFor(StringRef placeholder) const { + auto it = customSubstMap.find(placeholder); + if (it == customSubstMap.end()) + return {}; + return StringRef(it->second); +} + +FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) { + return StringSwitch(str) + .Case("_builder", FmtContext::PHKind::Builder) + .Case("_op", FmtContext::PHKind::Op) + .Case("_self", FmtContext::PHKind::Self) + .Case("", FmtContext::PHKind::None) + .Default(FmtContext::PHKind::Custom); +} + +std::pair +FmtObjectBase::splitFmtSegment(StringRef fmt) { + size_t begin = fmt.find_first_of('$'); + if (begin == StringRef::npos) { + // No placeholders: the whole format string should be returned as a + // literal string. + return {FmtReplacement{fmt}, StringRef()}; + } + if (begin != 0) { + // The first placeholder is not at the beginning: we can split the format + // string into a literal string and the rest. + return {FmtReplacement{fmt.substr(0, begin)}, fmt.substr(begin)}; + } + + // The first placeholder is at the beginning + + if (fmt.size() == 1) { + // The whole format string just contains '$': treat as literal. + return {FmtReplacement{fmt}, StringRef()}; + } + + // Allow escaping dollar with '$$' + if (fmt[1] == '$') { + return {FmtReplacement{fmt.substr(0, 1)}, fmt.substr(2)}; + } + + // First try to see if it's a positional placeholder, and then handle special + // placeholders. + + size_t end = fmt.find_if_not([](char c) { return std::isdigit(c); }, 1); + if (end != 1) { + // We have a positional placeholder. Parse the index. + size_t index = 0; + if (fmt.substr(1, end - 1).consumeInteger(0, index)) { + llvm_unreachable("invalid replacement sequence index"); + } + + if (end == StringRef::npos) { + // All the remaining characters are part of the positional placeholder. + return {FmtReplacement{fmt, index}, StringRef()}; + } + return {FmtReplacement{fmt.substr(0, end), index}, fmt.substr(end)}; + } + + end = fmt.find_if_not([](char c) { return std::isalnum(c) || c == '_'; }, 1); + auto placeholder = FmtContext::getPlaceHolderKind(fmt.substr(1, end - 1)); + if (end == StringRef::npos) { + // All the remaining characters are part of the special placeholder. + return {FmtReplacement{fmt, placeholder}, StringRef()}; + } + return {FmtReplacement{fmt.substr(0, end), placeholder}, fmt.substr(end)}; +} + +std::vector FmtObjectBase::parseFormatString(StringRef fmt) { + std::vector replacements; + FmtReplacement repl; + while (!fmt.empty()) { + std::tie(repl, fmt) = splitFmtSegment(fmt); + if (repl.type != FmtReplacement::Type::Empty) + replacements.push_back(repl); + } + return replacements; +} + +void FmtObjectBase::format(raw_ostream &s) const { + for (auto &repl : replacements) { + if (repl.type == FmtReplacement::Type::Empty) + continue; + + if (repl.type == FmtReplacement::Type::Literal) { + s << repl.spec; + continue; + } + + if (repl.type == FmtReplacement::Type::SpecialPH) { + if (repl.placeholder == FmtContext::PHKind::None) { + s << repl.spec; + } else if (!context) { + // We need the context to replace special placeholders. + s << repl.spec << kMarkerForNoSubst; + } else { + Optional subst; + if (repl.placeholder == FmtContext::PHKind::Custom) { + // Skip the leading '$' sign for the custom placeholder + subst = context->getSubstFor(repl.spec.substr(1)); + } else { + subst = context->getSubstFor(repl.placeholder); + } + if (subst) + s << *subst; + else + s << repl.spec << kMarkerForNoSubst; + } + continue; + } + + assert(repl.type == FmtReplacement::Type::PositionalPH); + + if (repl.index >= adapters.size()) { + s << repl.spec << kMarkerForNoSubst; + continue; + } + adapters[repl.index]->format(s, /*Options=*/""); + } +} + +FmtStrVecObject::FmtStrVecObject(StringRef fmt, const FmtContext *ctx, + ArrayRef params) + : FmtObjectBase(fmt, ctx, params.size()) { + parameters.reserve(params.size()); + for (std::string p : params) + parameters.push_back(llvm::detail::build_format_adapter(std::move(p))); + + adapters.reserve(parameters.size()); + for (auto &p : parameters) + adapters.push_back(&p); +} + +FmtStrVecObject::FmtStrVecObject(FmtStrVecObject &&that) + : FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) { + adapters.reserve(parameters.size()); + for (auto &p : parameters) + adapters.push_back(&p); +} diff --git a/tools/mlir-tblgen-builder/TableGen/Format.h b/tools/mlir-tblgen-builder/TableGen/Format.h new file mode 100644 index 0000000..441e05c --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Format.h @@ -0,0 +1,259 @@ +//===- Format.h - Utilities for String Format -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares utilities for formatting strings. They are specially +// tailored to the needs of TableGen'ing op definitions and rewrite rules, +// so they are not expected to be used as widely applicable utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_FORMAT_H_ +#define MLIR_TABLEGEN_FORMAT_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/FormatVariadic.h" + +namespace mlir { +namespace tblgen { + +/// Format context containing substitutions for special placeholders. +/// +/// This context divides special placeholders into two categories: builtin ones +/// and custom ones. +/// +/// Builtin placeholders are baked into `FmtContext` and each one of them has a +/// dedicated setter. They can be used in all dialects. Their names follow the +/// convention of `$_`. The rationale of the leading underscore is to +/// avoid confusion and name collision: op arguments/attributes/results are +/// named as $, and we can potentially support referencing those entities +/// directly in the format template in the future. +// +/// Custom ones are registered by dialect-specific TableGen backends and use the +/// same unified setter. +class FmtContext { +public: + // Placeholder kinds + enum class PHKind : char { + None, + Custom, // For custom placeholders + Builder, // For the $_builder placeholder + Op, // For the $_op placeholder + Self, // For the $_self placeholder + }; + + FmtContext() = default; + + // Setter for custom placeholders + FmtContext &addSubst(StringRef placeholder, Twine subst); + + // Setters for builtin placeholders + FmtContext &withBuilder(Twine subst); + FmtContext &withOp(Twine subst); + FmtContext &withSelf(Twine subst); + + Optional getSubstFor(PHKind placeholder) const; + Optional getSubstFor(StringRef placeholder) const; + + static PHKind getPlaceHolderKind(StringRef str); + +private: + struct PHKindInfo : DenseMapInfo { + using CharInfo = DenseMapInfo; + + static inline PHKind getEmptyKey() { + return static_cast(CharInfo::getEmptyKey()); + } + static inline PHKind getTombstoneKey() { + return static_cast(CharInfo::getTombstoneKey()); + } + static unsigned getHashValue(const PHKind &val) { + return CharInfo::getHashValue(static_cast(val)); + } + + static bool isEqual(const PHKind &lhs, const PHKind &rhs) { + return lhs == rhs; + } + }; + + llvm::SmallDenseMap builtinSubstMap; + llvm::StringMap customSubstMap; +}; + +/// Struct representing a replacement segment for the formatted string. It can +/// be a segment of the formatting template (for `Literal`) or a replacement +/// parameter (for `PositionalPH` and `SpecialPH`). +struct FmtReplacement { + enum class Type { Empty, Literal, PositionalPH, SpecialPH }; + + FmtReplacement() = default; + explicit FmtReplacement(StringRef literal) + : type(Type::Literal), spec(literal) {} + FmtReplacement(StringRef spec, size_t index) + : type(Type::PositionalPH), spec(spec), index(index) {} + FmtReplacement(StringRef spec, FmtContext::PHKind placeholder) + : type(Type::SpecialPH), spec(spec), placeholder(placeholder) {} + + Type type = Type::Empty; + StringRef spec; + size_t index = 0; + FmtContext::PHKind placeholder = FmtContext::PHKind::None; +}; + +class FmtObjectBase { +private: + static std::pair splitFmtSegment(StringRef fmt); + static std::vector parseFormatString(StringRef fmt); + +protected: + // The parameters are stored in a std::tuple, which does not provide runtime + // indexing capabilities. In order to enable runtime indexing, we use this + // structure to put the parameters into a std::vector. Since the parameters + // are not all the same type, we use some type-erasure by wrapping the + // parameters in a template class that derives from a non-template superclass. + // Essentially, we are converting a std::tuple> to a + // std::vector. + struct CreateAdapters { + template + std::vector operator()(Ts &... items) { + return std::vector{&items...}; + } + }; + + StringRef fmt; + const FmtContext *context; + std::vector adapters; + std::vector replacements; + +public: + FmtObjectBase(StringRef fmt, const FmtContext *ctx, size_t numParams) + : fmt(fmt), context(ctx), replacements(parseFormatString(fmt)) {} + + FmtObjectBase(const FmtObjectBase &that) = delete; + + FmtObjectBase(FmtObjectBase &&that) + : fmt(std::move(that.fmt)), context(that.context), + adapters(), // adapters are initialized by FmtObject + replacements(std::move(that.replacements)) {} + + void format(llvm::raw_ostream &s) const; + + std::string str() const { + std::string result; + llvm::raw_string_ostream s(result); + format(s); + return s.str(); + } + + template SmallString sstr() const { + SmallString result; + llvm::raw_svector_ostream s(result); + format(s); + return result; + } + + template operator SmallString() const { return sstr(); } + + operator std::string() const { return str(); } +}; + +template class FmtObject : public FmtObjectBase { + // Storage for the parameter adapters. Since the base class erases the type + // of the parameters, we have to own the storage for the parameters here, and + // have the base class store type-erased pointers into this tuple. + Tuple parameters; + +public: + FmtObject(StringRef fmt, const FmtContext *ctx, Tuple &¶ms) + : FmtObjectBase(fmt, ctx, std::tuple_size::value), + parameters(std::move(params)) { + adapters.reserve(std::tuple_size::value); + adapters = llvm::apply_tuple(CreateAdapters(), parameters); + } + + FmtObject(FmtObject const &that) = delete; + + FmtObject(FmtObject &&that) + : FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) { + adapters.reserve(that.adapters.size()); + adapters = llvm::apply_tuple(CreateAdapters(), parameters); + } +}; + +class FmtStrVecObject : public FmtObjectBase { +public: + using StrFormatAdapter = + decltype(llvm::detail::build_format_adapter(std::declval())); + + FmtStrVecObject(StringRef fmt, const FmtContext *ctx, + ArrayRef params); + FmtStrVecObject(FmtStrVecObject const &that) = delete; + FmtStrVecObject(FmtStrVecObject &&that); + +private: + SmallVector parameters; +}; + +/// Formats text by substituting placeholders in format string with replacement +/// parameters. +/// +/// There are two categories of placeholders accepted, both led by a '$' sign: +/// +/// 1. Positional placeholder: $[0-9]+ +/// 2. Special placeholder: $[a-zA-Z_][a-zA-Z0-9_]* +/// +/// Replacement parameters for positional placeholders are supplied as the +/// `vals` parameter pack with 1:1 mapping. That is, $0 will be replaced by the +/// first parameter in `vals`, $1 by the second one, and so on. Note that you +/// can use the positional placeholders in any order and repeat any times, for +/// example, "$2 $1 $1 $0" is accepted. +/// +/// Replacement parameters for special placeholders are supplied using the `ctx` +/// format context. +/// +/// The `fmt` is recorded as a `StringRef` inside the returned `FmtObject`. +/// The caller needs to make sure the underlying data is available when the +/// `FmtObject` is used. +/// +/// `ctx` accepts a nullptr if there is no special placeholder is used. +/// +/// If no substitution is provided for a placeholder or any error happens during +/// format string parsing or replacement, the placeholder will be outputted +/// as-is with an additional marker '', to aid debugging. +/// +/// To print a '$' literally, escape it with '$$'. +/// +/// This utility function is inspired by LLVM formatv(), with modifications +/// specially tailored for TableGen C++ generation usage: +/// +/// 1. This utility use '$' instead of '{' and '}' for denoting the placeholder +/// because '{' and '}' are frequently used in C++ code. +/// 2. This utility does not support format layout because it is rarely needed +/// in C++ code generation. +template +inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals) + -> FmtObject(vals))...))> { + using ParamTuple = decltype(std::make_tuple( + llvm::detail::build_format_adapter(std::forward(vals))...)); + return FmtObject( + fmt, ctx, + std::make_tuple( + llvm::detail::build_format_adapter(std::forward(vals))...)); +} + +inline FmtStrVecObject tgfmt(StringRef fmt, const FmtContext *ctx, + ArrayRef params) { + return FmtStrVecObject(fmt, ctx, params); +} + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_FORMAT_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/GenInfo.h b/tools/mlir-tblgen-builder/TableGen/GenInfo.h new file mode 100644 index 0000000..16ed559 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/GenInfo.h @@ -0,0 +1,72 @@ +//===- GenInfo.h - Generator info -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_GENINFO_H_ +#define MLIR_TABLEGEN_GENINFO_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace llvm { +class RecordKeeper; +} // end namespace llvm + +namespace mlir { + +/// Generator function to invoke. +using GenFunction = std::function; + +/// Structure to group information about a generator (argument to invoke via +/// mlir-tblgen, description, and generator function). +class GenInfo { +public: + /// GenInfo constructor should not be invoked directly, instead use + /// GenRegistration or registerGen. + GenInfo(StringRef arg, StringRef description, GenFunction generator) + : arg(arg), description(description), generator(generator) {} + + /// Invokes the generator and returns whether the generator failed. + bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const { + assert(generator && "Cannot call generator with null generator"); + return generator(recordKeeper, os); + } + + /// Returns the command line option that may be passed to 'mlir-tblgen' to + /// invoke this generator. + StringRef getGenArgument() const { return arg; } + + /// Returns a description for the generator. + StringRef getGenDescription() const { return description; } + +private: + // The argument with which to invoke the generator via mlir-tblgen. + StringRef arg; + + // Description of the generator. + StringRef description; + + // Generator function. + GenFunction generator; +}; + +/// GenRegistration provides a global initializer that registers a generator +/// function. +/// +/// Usage: +/// +/// // At namespace scope. +/// static GenRegistration Print("print", "Print records", [](...){...}); +struct GenRegistration { + GenRegistration(StringRef arg, StringRef description, GenFunction function); +}; + +} // end namespace mlir + +#endif // MLIR_TABLEGEN_GENINFO_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/GenNameParser.h b/tools/mlir-tblgen-builder/TableGen/GenNameParser.h new file mode 100644 index 0000000..b029951 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/GenNameParser.h @@ -0,0 +1,31 @@ +//===- GenNameParser.h - Command line parser for generators -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The GenNameParser class adds all passes linked in to the system that are +// creatable to the tool. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_GENNAMEPARSER_H_ +#define MLIR_TABLEGEN_GENNAMEPARSER_H_ + +#include "llvm/Support/CommandLine.h" + +namespace mlir { +class GenInfo; + +/// Adds command line option for each registered generator. +struct GenNameParser : public llvm::cl::parser { + GenNameParser(llvm::cl::Option &opt); + + void printOptionInfo(const llvm::cl::Option &O, + size_t GlobalWidth) const override; +}; +} // end namespace mlir + +#endif // MLIR_TABLEGEN_GENNAMEPARSER_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Interfaces.cpp b/tools/mlir-tblgen-builder/TableGen/Interfaces.cpp new file mode 100644 index 0000000..d94f902 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Interfaces.cpp @@ -0,0 +1,144 @@ +//===- Interfaces.cpp - Interface classes ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Interfaces.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// InterfaceMethod +//===----------------------------------------------------------------------===// + +InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) { + llvm::DagInit *args = def->getValueAsDag("arguments"); + for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) { + arguments.push_back( + {llvm::cast(args->getArg(i))->getValue(), + args->getArgNameStr(i)}); + } +} + +StringRef InterfaceMethod::getReturnType() const { + return def->getValueAsString("returnType"); +} + +// Return the name of this method. +StringRef InterfaceMethod::getName() const { + return def->getValueAsString("name"); +} + +// Return if this method is static. +bool InterfaceMethod::isStatic() const { + return def->isSubClassOf("StaticInterfaceMethod"); +} + +// Return the body for this method if it has one. +llvm::Optional InterfaceMethod::getBody() const { + auto value = def->getValueAsString("body"); + return value.empty() ? llvm::Optional() : value; +} + +// Return the default implementation for this method if it has one. +llvm::Optional InterfaceMethod::getDefaultImplementation() const { + auto value = def->getValueAsString("defaultBody"); + return value.empty() ? llvm::Optional() : value; +} + +// Return the description of this method if it has one. +llvm::Optional InterfaceMethod::getDescription() const { + auto value = def->getValueAsString("description"); + return value.empty() ? llvm::Optional() : value; +} + +ArrayRef InterfaceMethod::getArguments() const { + return arguments; +} + +bool InterfaceMethod::arg_empty() const { return arguments.empty(); } + +//===----------------------------------------------------------------------===// +// Interface +//===----------------------------------------------------------------------===// + +Interface::Interface(const llvm::Record *def) : def(def) { + assert(def->isSubClassOf("Interface") && + "must be subclass of TableGen 'Interface' class"); + + auto *listInit = dyn_cast(def->getValueInit("methods")); + for (llvm::Init *init : listInit->getValues()) + methods.emplace_back(cast(init)->getDef()); +} + +// Return the name of this interface. +StringRef Interface::getName() const { + return def->getValueAsString("cppClassName"); +} + +// Return the C++ namespace of this interface. +StringRef Interface::getCppNamespace() const { + return def->getValueAsString("cppNamespace"); +} + +// Return the methods of this interface. +ArrayRef Interface::getMethods() const { return methods; } + +// Return the description of this method if it has one. +llvm::Optional Interface::getDescription() const { + auto value = def->getValueAsString("description"); + return value.empty() ? llvm::Optional() : value; +} + +// Return the interfaces extra class declaration code. +llvm::Optional Interface::getExtraClassDeclaration() const { + auto value = def->getValueAsString("extraClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} + +// Return the traits extra class declaration code. +llvm::Optional Interface::getExtraTraitClassDeclaration() const { + auto value = def->getValueAsString("extraTraitClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} + +// Return the body for this method if it has one. +llvm::Optional Interface::getVerify() const { + // Only OpInterface supports the verify method. + if (!isa(this)) + return llvm::None; + auto value = def->getValueAsString("verify"); + return value.empty() ? llvm::Optional() : value; +} + +//===----------------------------------------------------------------------===// +// AttrInterface +//===----------------------------------------------------------------------===// + +bool AttrInterface::classof(const Interface *interface) { + return interface->getDef().isSubClassOf("AttrInterface"); +} + +//===----------------------------------------------------------------------===// +// OpInterface +//===----------------------------------------------------------------------===// + +bool OpInterface::classof(const Interface *interface) { + return interface->getDef().isSubClassOf("OpInterface"); +} + +//===----------------------------------------------------------------------===// +// TypeInterface +//===----------------------------------------------------------------------===// + +bool TypeInterface::classof(const Interface *interface) { + return interface->getDef().isSubClassOf("TypeInterface"); +} diff --git a/tools/mlir-tblgen-builder/TableGen/Interfaces.h b/tools/mlir-tblgen-builder/TableGen/Interfaces.h new file mode 100644 index 0000000..a346209 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Interfaces.h @@ -0,0 +1,129 @@ +//===- Interfaces.h - Interface wrapper classes -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_INTERFACES_H_ +#define MLIR_TABLEGEN_INTERFACES_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class Init; +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// Wrapper class with helper methods for accessing InterfaceMethod defined +// in TableGen. +class InterfaceMethod { +public: + // This struct represents a single method argument. + struct Argument { + StringRef type; + StringRef name; + }; + + explicit InterfaceMethod(const llvm::Record *def); + + // Return the return type of this method. + StringRef getReturnType() const; + + // Return the name of this method. + StringRef getName() const; + + // Return if this method is static. + bool isStatic() const; + + // Return the body for this method if it has one. + llvm::Optional getBody() const; + + // Return the default implementation for this method if it has one. + llvm::Optional getDefaultImplementation() const; + + // Return the description of this method if it has one. + llvm::Optional getDescription() const; + + // Arguments. + ArrayRef getArguments() const; + bool arg_empty() const; + +private: + // The TableGen definition of this method. + const llvm::Record *def; + + // The arguments of this method. + SmallVector arguments; +}; + +//===----------------------------------------------------------------------===// +// Interface +//===----------------------------------------------------------------------===// + +// Wrapper class with helper methods for accessing Interfaces defined in +// TableGen. +class Interface { +public: + explicit Interface(const llvm::Record *def); + + // Return the name of this interface. + StringRef getName() const; + + // Return the C++ namespace of this interface. + StringRef getCppNamespace() const; + + // Return the methods of this interface. + ArrayRef getMethods() const; + + // Return the description of this method if it has one. + llvm::Optional getDescription() const; + + // Return the interfaces extra class declaration code. + llvm::Optional getExtraClassDeclaration() const; + + // Return the traits extra class declaration code. + llvm::Optional getExtraTraitClassDeclaration() const; + + // Return the verify method body if it has one. + llvm::Optional getVerify() const; + + // Returns the Tablegen definition this interface was constructed from. + const llvm::Record &getDef() const { return *def; } + +private: + // The TableGen definition of this interface. + const llvm::Record *def; + + // The methods of this interface. + SmallVector methods; +}; + +// An interface that is registered to an Attribute. +struct AttrInterface : public Interface { + using Interface::Interface; + + static bool classof(const Interface *interface); +}; +// An interface that is registered to an Operation. +struct OpInterface : public Interface { + using Interface::Interface; + + static bool classof(const Interface *interface); +}; +// An interface that is registered to a Type. +struct TypeInterface : public Interface { + using Interface::Interface; + + static bool classof(const Interface *interface); +}; +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_INTERFACES_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/OpClass.cpp b/tools/mlir-tblgen-builder/TableGen/OpClass.cpp new file mode 100644 index 0000000..2fffe5a --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/OpClass.cpp @@ -0,0 +1,347 @@ +//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OpClass.h" + +#include "Format.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "mlir-tblgen-opclass" + +using namespace mlir; +using namespace mlir::tblgen; + +namespace { + +// Returns space to be emitted after the given C++ `type`. return "" if the +// ends with '&' or '*', or is empty, else returns " ". +StringRef getSpaceAfterType(StringRef type) { + return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " "; +} + +} // namespace + +//===----------------------------------------------------------------------===// +// OpMethodParameter definitions +//===----------------------------------------------------------------------===// + +void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const { + if (properties & PP_Optional) + os << "/*optional*/"; + os << type << getSpaceAfterType(type) << name; + if (emitDefault && !defaultValue.empty()) + os << " = " << defaultValue; +} + +//===----------------------------------------------------------------------===// +// OpMethodParameters definitions +//===----------------------------------------------------------------------===// + +// Factory methods to construct the correct type of `OpMethodParameters` +// object based on the arguments. +std::unique_ptr OpMethodParameters::create() { + return std::make_unique(); +} + +std::unique_ptr +OpMethodParameters::create(StringRef params) { + return std::make_unique(params); +} + +std::unique_ptr +OpMethodParameters::create(llvm::SmallVectorImpl &¶ms) { + return std::make_unique(std::move(params)); +} + +std::unique_ptr +OpMethodParameters::create(StringRef type, StringRef name, + StringRef defaultValue) { + return std::make_unique(type, name, defaultValue); +} + +//===----------------------------------------------------------------------===// +// OpMethodUnresolvedParameters definitions +//===----------------------------------------------------------------------===// +void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const { + os << parameters; +} + +void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const { + // We need to remove the default values for parameters in method definition. + // TODO: We are using '=' and ',' as delimiters for parameter + // initializers. This is incorrect for initializer list with more than one + // element. Change to a more robust approach. + llvm::SmallVector tokens; + StringRef params = parameters; + while (!params.empty()) { + std::pair parts = params.split("="); + tokens.push_back(parts.first); + params = parts.second.split(',').second; + } + llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; }); +} + +//===----------------------------------------------------------------------===// +// OpMethodResolvedParameters definitions +//===----------------------------------------------------------------------===// + +// Returns true if a method with these parameters makes a method with parameters +// `other` redundant. This should return true only if all possible calls to the +// other method can be replaced by calls to this method. +bool OpMethodResolvedParameters::makesRedundant( + const OpMethodResolvedParameters &other) const { + const size_t otherNumParams = other.getNumParameters(); + const size_t thisNumParams = getNumParameters(); + + // All calls to the other method can be replaced this method only if this + // method has the same or more arguments number of arguments as the other, and + // the common arguments have the same type. + if (thisNumParams < otherNumParams) + return false; + for (int idx : llvm::seq(0, otherNumParams)) + if (parameters[idx].getType() != other.parameters[idx].getType()) + return false; + + // If all the common arguments have the same type, we can elide the other + // method if this method has the same number of arguments as other or the + // first argument after the common ones has a default value (and by C++ + // requirement, all the later ones will also have a default value). + return thisNumParams == otherNumParams || + parameters[otherNumParams].hasDefaultValue(); +} + +void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const { + llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) { + param.writeDeclTo(os); + }); +} + +void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const { + llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) { + param.writeDefTo(os); + }); +} + +//===----------------------------------------------------------------------===// +// OpMethodSignature definitions +//===----------------------------------------------------------------------===// + +// Returns if a method with this signature makes a method with `other` signature +// redundant. Only supports resolved parameters. +bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const { + if (methodName != other.methodName) + return false; + auto *resolvedThis = dyn_cast(parameters.get()); + auto *resolvedOther = + dyn_cast(other.parameters.get()); + if (resolvedThis && resolvedOther) + return resolvedThis->makesRedundant(*resolvedOther); + return false; +} + +void OpMethodSignature::writeDeclTo(raw_ostream &os) const { + os << returnType << getSpaceAfterType(returnType) << methodName << "("; + parameters->writeDeclTo(os); + os << ")"; +} + +void OpMethodSignature::writeDefTo(raw_ostream &os, + StringRef namePrefix) const { + os << returnType << getSpaceAfterType(returnType) << namePrefix + << (namePrefix.empty() ? "" : "::") << methodName << "("; + parameters->writeDefTo(os); + os << ")"; +} + +//===----------------------------------------------------------------------===// +// OpMethodBody definitions +//===----------------------------------------------------------------------===// + +OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {} + +OpMethodBody &OpMethodBody::operator<<(Twine content) { + if (isEffective) + body.append(content.str()); + return *this; +} + +OpMethodBody &OpMethodBody::operator<<(int content) { + if (isEffective) + body.append(std::to_string(content)); + return *this; +} + +OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) { + if (isEffective) + body.append(content.str()); + return *this; +} + +void OpMethodBody::writeTo(raw_ostream &os) const { + auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); + os << bodyRef; + if (bodyRef.empty() || bodyRef.back() != '\n') + os << "\n"; +} + +//===----------------------------------------------------------------------===// +// OpMethod definitions +//===----------------------------------------------------------------------===// + +void OpMethod::writeDeclTo(raw_ostream &os) const { + os.indent(2); + if (isStatic()) + os << "static "; + if (properties & MP_Constexpr) + os << "constexpr "; + methodSignature.writeDeclTo(os); + if (!isInline()) + os << ";"; + else { + os << " {\n"; + methodBody.writeTo(os); + os << "}"; + } +} + +void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { + // Do not write definition if the method is decl only. + if (properties & MP_Declaration) + return; + // Do not generate separate definition for inline method + if (isInline()) + return; + methodSignature.writeDefTo(os, namePrefix); + os << " {\n"; + methodBody.writeTo(os); + os << "}"; +} + +//===----------------------------------------------------------------------===// +// OpConstructor definitions +//===----------------------------------------------------------------------===// + +void OpConstructor::addMemberInitializer(StringRef name, StringRef value) { + memberInitializers.append(std::string(llvm::formatv( + "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value))); +} + +void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const { + // Do not write definition if the method is decl only. + if (properties & MP_Declaration) + return; + + methodSignature.writeDefTo(os, namePrefix); + os << " " << memberInitializers << " {\n"; + methodBody.writeTo(os); + os << "}"; +} + +//===----------------------------------------------------------------------===// +// Class definitions +//===----------------------------------------------------------------------===// + +Class::Class(StringRef name) : className(name) {} + +void Class::newField(StringRef type, StringRef name, StringRef defaultValue) { + std::string varName = formatv("{0} {1}", type, name).str(); + std::string field = defaultValue.empty() + ? varName + : formatv("{0} = {1}", varName, defaultValue).str(); + fields.push_back(std::move(field)); +} +void Class::writeDeclTo(raw_ostream &os) const { + bool hasPrivateMethod = false; + os << "class " << className << " {\n"; + os << "public:\n"; + + forAllMethods([&](const OpMethod &method) { + if (!method.isPrivate()) { + method.writeDeclTo(os); + os << '\n'; + } else { + hasPrivateMethod = true; + } + }); + + os << '\n'; + os << "private:\n"; + if (hasPrivateMethod) { + forAllMethods([&](const OpMethod &method) { + if (method.isPrivate()) { + method.writeDeclTo(os); + os << '\n'; + } + }); + os << '\n'; + } + + for (const auto &field : fields) + os.indent(2) << field << ";\n"; + os << "};\n"; +} + +void Class::writeDefTo(raw_ostream &os) const { + forAllMethods([&](const OpMethod &method) { + method.writeDefTo(os, className); + os << "\n\n"; + }); +} + +//===----------------------------------------------------------------------===// +// OpClass definitions +//===----------------------------------------------------------------------===// + +OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) + : Class(name), extraClassDeclaration(extraClassDeclaration) {} + +void OpClass::addTrait(Twine trait) { + auto traitStr = trait.str(); + if (traitsSet.insert(traitStr).second) + traitsVec.push_back(std::move(traitStr)); +} + +void OpClass::writeDeclTo(raw_ostream &os) const { + os << "class " << className << " : public ::mlir::Op<" << className; + for (const auto &trait : traitsVec) + os << ", " << trait; + os << "> {\npublic:\n"; + // << " using Op::Op;\n" + // << " using Op::print;\n" + // << " using Adaptor = " << className << "Adaptor;\n"; + + bool hasPrivateMethod = false; + forAllMethods([&](const OpMethod &method) { + if (!method.isPrivate()) { + method.writeDeclTo(os); + os << "\n"; + } else { + hasPrivateMethod = true; + } + }); + + // TODO: Add line control markers to make errors easier to debug. + if (!extraClassDeclaration.empty()) + os << extraClassDeclaration << "\n"; + + if (hasPrivateMethod) { + os << "\nprivate:\n"; + forAllMethods([&](const OpMethod &method) { + if (method.isPrivate()) { + method.writeDeclTo(os); + os << "\n"; + } + }); + } + + os << "};\n"; +} diff --git a/tools/mlir-tblgen-builder/TableGen/OpClass.h b/tools/mlir-tblgen-builder/TableGen/OpClass.h new file mode 100644 index 0000000..243e7fa --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/OpClass.h @@ -0,0 +1,442 @@ +//===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines several classes for Op C++ code emission. They are only +// expected to be used by MLIR TableGen backends. +// +// We emit the op declaration and definition into separate files: *Ops.h.inc +// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and +// the latter for dialect *Ops.cpp. This way provides a cleaner interface. +// +// In order to do this split, we need to track method signature and +// implementation logic separately. Signature information is used for both +// declaration and definition, while implementation logic is only for +// definition. So we have the following classes for C++ code emission. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_OPCLASS_H_ +#define MLIR_TABLEGEN_OPCLASS_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +namespace mlir { +namespace tblgen { +class FmtObjectBase; + +// Class for holding a single parameter of an op's method for C++ code emission. +class OpMethodParameter { +public: + // Properties (qualifiers) for the parameter. + enum Property { + PP_None = 0x0, + PP_Optional = 0x1, + }; + + OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "", + Property properties = PP_None) + : type(type), name(name), defaultValue(defaultValue), + properties(properties) {} + + OpMethodParameter(StringRef type, StringRef name, Property property) + : OpMethodParameter(type, name, "", property) {} + + // Writes the parameter as a part of a method declaration to `os`. + void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); } + + // Writes the parameter as a part of a method definition to `os` + void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); } + + const std::string &getType() const { return type; } + bool hasDefaultValue() const { return !defaultValue.empty(); } + +private: + void writeTo(raw_ostream &os, bool emitDefault) const; + + std::string type; + std::string name; + std::string defaultValue; + Property properties; +}; + +// Base class for holding parameters of an op's method for C++ code emission. +class OpMethodParameters { +public: + // Discriminator for LLVM-style RTTI. + enum ParamsKind { + // Separate type and name for each parameter is not known. + PK_Unresolved, + // Each parameter is resolved to a type and name. + PK_Resolved, + }; + + OpMethodParameters(ParamsKind kind) : kind(kind) {} + virtual ~OpMethodParameters() {} + + // LLVM-style RTTI support. + ParamsKind getKind() const { return kind; } + + // Writes the parameters as a part of a method declaration to `os`. + virtual void writeDeclTo(raw_ostream &os) const = 0; + + // Writes the parameters as a part of a method definition to `os` + virtual void writeDefTo(raw_ostream &os) const = 0; + + // Factory methods to create the correct type of `OpMethodParameters` + // object based on the arguments. + static std::unique_ptr create(); + + static std::unique_ptr create(StringRef params); + + static std::unique_ptr + create(llvm::SmallVectorImpl &¶ms); + + static std::unique_ptr + create(StringRef type, StringRef name, StringRef defaultValue = ""); + +private: + const ParamsKind kind; +}; + +// Class for holding unresolved parameters. +class OpMethodUnresolvedParameters : public OpMethodParameters { +public: + OpMethodUnresolvedParameters(StringRef params) + : OpMethodParameters(PK_Unresolved), parameters(params) {} + + // write the parameters as a part of a method declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const override; + + // write the parameters as a part of a method definition to the given `os` + void writeDefTo(raw_ostream &os) const override; + + // LLVM-style RTTI support. + static bool classof(const OpMethodParameters *params) { + return params->getKind() == PK_Unresolved; + } + +private: + std::string parameters; +}; + +// Class for holding resolved parameters. +class OpMethodResolvedParameters : public OpMethodParameters { +public: + OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {} + + OpMethodResolvedParameters(llvm::SmallVectorImpl &¶ms) + : OpMethodParameters(PK_Resolved) { + for (OpMethodParameter ¶m : params) + parameters.emplace_back(std::move(param)); + } + + OpMethodResolvedParameters(StringRef type, StringRef name, + StringRef defaultValue) + : OpMethodParameters(PK_Resolved) { + parameters.emplace_back(type, name, defaultValue); + } + + // Returns the number of parameters. + size_t getNumParameters() const { return parameters.size(); } + + // Returns if this method makes the `other` method redundant. Note that this + // is more than just finding conflicting methods. This method determines if + // the 2 set of parameters are conflicting and if so, returns true if this + // method has a more general set of parameters that can replace all possible + // calls to the `other` method. + bool makesRedundant(const OpMethodResolvedParameters &other) const; + + // write the parameters as a part of a method declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const override; + + // write the parameters as a part of a method definition to the given `os` + void writeDefTo(raw_ostream &os) const override; + + // LLVM-style RTTI support. + static bool classof(const OpMethodParameters *params) { + return params->getKind() == PK_Resolved; + } + +private: + llvm::SmallVector parameters; +}; + +// Class for holding the signature of an op's method for C++ code emission +class OpMethodSignature { +public: + template + OpMethodSignature(StringRef retType, StringRef name, Args &&...args) + : returnType(retType), methodName(name), + parameters(OpMethodParameters::create(std::forward(args)...)) {} + OpMethodSignature(OpMethodSignature &&) = default; + + // Returns if a method with this signature makes a method with `other` + // signature redundant. Only supports resolved parameters. + bool makesRedundant(const OpMethodSignature &other) const; + + // Returns the number of parameters (for resolved parameters). + size_t getNumParameters() const { + return cast(parameters.get()) + ->getNumParameters(); + } + + // Returns the name of the method. + StringRef getName() const { return methodName; } + + // Writes the signature as a method declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const; + + // Writes the signature as the start of a method definition to the given `os`. + // `namePrefix` is the prefix to be prepended to the method name (typically + // namespaces for qualifying the method definition). + void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + +private: + std::string returnType; + std::string methodName; + std::unique_ptr parameters; +}; + +// Class for holding the body of an op's method for C++ code emission +class OpMethodBody { +public: + explicit OpMethodBody(bool declOnly); + + OpMethodBody &operator<<(Twine content); + OpMethodBody &operator<<(int content); + OpMethodBody &operator<<(const FmtObjectBase &content); + + void writeTo(raw_ostream &os) const; + +private: + // Whether this class should record method body. + bool isEffective; + std::string body; +}; + +// Class for holding an op's method for C++ code emission +class OpMethod { +public: + // Properties (qualifiers) of class methods. Bitfield is used here to help + // querying properties. + enum Property { + MP_None = 0x0, + MP_Static = 0x1, + MP_Constructor = 0x2, + MP_Private = 0x4, + MP_Declaration = 0x8, + MP_Inline = 0x10, + MP_Constexpr = 0x20 | MP_Inline, + MP_StaticDeclaration = MP_Static | MP_Declaration, + }; + + template + OpMethod(StringRef retType, StringRef name, Property property, unsigned id, + Args &&...args) + : properties(property), + methodSignature(retType, name, std::forward(args)...), + methodBody(properties & MP_Declaration), id(id) {} + + OpMethod(OpMethod &&) = default; + + virtual ~OpMethod() = default; + + OpMethodBody &body() { return methodBody; } + + // Returns true if this is a static method. + bool isStatic() const { return properties & MP_Static; } + + // Returns true if this is a private method. + bool isPrivate() const { return properties & MP_Private; } + + // Returns true if this is an inline method. + bool isInline() const { return properties & MP_Inline; } + + // Returns the name of this method. + StringRef getName() const { return methodSignature.getName(); } + + // Returns the ID for this method + unsigned getID() const { return id; } + + // Returns if this method makes the `other` method redundant. + bool makesRedundant(const OpMethod &other) const { + return methodSignature.makesRedundant(other.methodSignature); + } + + // Writes the method as a declaration to the given `os`. + virtual void writeDeclTo(raw_ostream &os) const; + + // Writes the method as a definition to the given `os`. `namePrefix` is the + // prefix to be prepended to the method name (typically namespaces for + // qualifying the method definition). + virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + +protected: + Property properties; + OpMethodSignature methodSignature; + OpMethodBody methodBody; + const unsigned id; +}; + +// Class for holding an op's constructor method for C++ code emission. +class OpConstructor : public OpMethod { +public: + template + OpConstructor(StringRef className, Property property, unsigned id, + Args &&...args) + : OpMethod("", className, property, id, std::forward(args)...) {} + + // Add member initializer to constructor initializing `name` with `value`. + void addMemberInitializer(StringRef name, StringRef value); + + // Writes the method as a definition to the given `os`. `namePrefix` is the + // prefix to be prepended to the method name (typically namespaces for + // qualifying the method definition). + void writeDefTo(raw_ostream &os, StringRef namePrefix) const override; + +private: + // Member initializers. + std::string memberInitializers; +}; + +// A class used to emit C++ classes from Tablegen. Contains a list of public +// methods and a list of private fields to be emitted. +class Class { +public: + explicit Class(StringRef name); + + // Adds a new method to this class and prune redundant methods. Returns null + // if the method was not added (because an existing method would make it + // redundant), else returns a pointer to the added method. Note that this call + // may also delete existing methods that are made redundant by a method to the + // class. + template + OpMethod *addMethodAndPrune(StringRef retType, StringRef name, + OpMethod::Property properties, Args &&...args) { + auto newMethod = std::make_unique( + retType, name, properties, nextMethodID++, std::forward(args)...); + return addMethodAndPrune(methods, std::move(newMethod)); + } + + template + OpMethod *addMethodAndPrune(StringRef retType, StringRef name, + Args &&...args) { + return addMethodAndPrune(retType, name, OpMethod::MP_None, + std::forward(args)...); + } + + template + OpConstructor *addConstructorAndPrune(Args &&...args) { + auto newConstructor = std::make_unique( + getClassName(), OpMethod::MP_Constructor, nextMethodID++, + std::forward(args)...); + return addMethodAndPrune(constructors, std::move(newConstructor)); + } + + // Creates a new field in this class. + void newField(StringRef type, StringRef name, StringRef defaultValue = ""); + + // Writes this op's class as a declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const; + // Writes the method definitions in this op's class to the given `os`. + void writeDefTo(raw_ostream &os) const; + + // Returns the C++ class name of the op. + StringRef getClassName() const { return className; } + +protected: + // Get a list of all the methods to emit, filtering out hidden ones. + void forAllMethods(llvm::function_ref func) const { + using ConsRef = const std::unique_ptr &; + using MethodRef = const std::unique_ptr &; + llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); }); + llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); }); + } + + // For deterministic code generation, keep methods sorted in the order in + // which they were generated. + template + struct MethodCompare { + bool operator()(const std::unique_ptr &x, + const std::unique_ptr &y) const { + return x->getID() < y->getID(); + } + }; + + template + using MethodSet = + std::set, MethodCompare>; + + template + MethodTy *addMethodAndPrune(MethodSet &set, + std::unique_ptr &&newMethod) { + // Check if the new method will be made redundant by existing methods. + for (auto &method : set) + if (method->makesRedundant(*newMethod)) + return nullptr; + + // We can add this a method to the set. Prune any existing methods that will + // be made redundant by adding this new method. Note that the redundant + // check between two methods is more than a conflict check. makesRedundant() + // below will check if the new method conflicts with an existing method and + // if so, returns true if the new method makes the existing method redundant + // because all calls to the existing method can be subsumed by the new + // method. So makesRedundant() does a combined job of finding conflicts and + // deciding which of the 2 conflicting methods survive. + // + // Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing + // it manually here. + for (auto it = set.begin(), end = set.end(); it != end;) { + if (newMethod->makesRedundant(*(it->get()))) + it = set.erase(it); + else + ++it; + } + + MethodTy *ret = newMethod.get(); + set.insert(std::move(newMethod)); + return ret; + } + + std::string className; + MethodSet constructors; + MethodSet methods; + unsigned nextMethodID = 0; + SmallVector fields; +}; + +// Class for holding an op for C++ code emission +class OpClass : public Class { +public: + explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); + + // Adds an op trait. + void addTrait(Twine trait); + + // Writes this op's class as a declaration to the given `os`. Redefines + // Class::writeDeclTo to also emit traits and extra class declarations. + void writeDeclTo(raw_ostream &os) const; + +private: + StringRef extraClassDeclaration; + SmallVector traitsVec; + StringSet<> traitsSet; +}; + +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TABLEGEN_OPCLASS_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Operator.cpp b/tools/mlir-tblgen-builder/TableGen/Operator.cpp new file mode 100644 index 0000000..3331b9f --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Operator.cpp @@ -0,0 +1,592 @@ +//===- Operator.cpp - Operator class --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Operator wrapper to simplify using TableGen Record defining a MLIR Op. +// +//===----------------------------------------------------------------------===// + +#include "Operator.h" +#include "Predicate.h" +#include "Trait.h" +#include "Type.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +#define DEBUG_TYPE "mlir-tblgen-operator" + +using namespace mlir; +using namespace mlir::tblgen; + +using llvm::DagInit; +using llvm::DefInit; +using llvm::Record; + +Operator::Operator(const llvm::Record &def) + : dialect(def.getValueAsDef("opDialect")), def(def) { + // The first `_` in the op's TableGen def name is treated as separating the + // dialect prefix and the op class name. The dialect prefix will be ignored if + // not empty. Otherwise, if def name starts with a `_`, the `_` is considered + // as part of the class name. + StringRef prefix; + std::tie(prefix, cppClassName) = def.getName().split('_'); + if (prefix.empty()) { + // Class name with a leading underscore and without dialect prefix + cppClassName = def.getName(); + } else if (cppClassName.empty()) { + // Class name without dialect prefix + cppClassName = prefix; + } + + cppNamespace = def.getValueAsString("cppNamespace"); + + populateOpStructure(); +} + +std::string Operator::getOperationName() const { + auto prefix = dialect.getName(); + auto opName = def.getValueAsString("opName"); + if (prefix.empty()) + return std::string(opName); + return std::string(llvm::formatv("{0}.{1}", prefix, opName)); +} + +std::string Operator::getAdaptorName() const { + return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); +} + +StringRef Operator::getDialectName() const { return dialect.getName(); } + +StringRef Operator::getCppClassName() const { return cppClassName; } + +std::string Operator::getQualCppClassName() const { + if (cppNamespace.empty()) + return std::string(cppClassName); + return std::string(llvm::formatv("{0}::{1}", cppNamespace, cppClassName)); +} + +StringRef Operator::getCppNamespace() const { return cppNamespace; } + +int Operator::getNumResults() const { + DagInit *results = def.getValueAsDag("results"); + return results->getNumArgs(); +} + +StringRef Operator::getExtraClassDeclaration() const { + constexpr auto attr = "extraClassDeclaration"; + if (def.isValueUnset(attr)) + return {}; + return def.getValueAsString(attr); +} + +const llvm::Record &Operator::getDef() const { return def; } + +bool Operator::skipDefaultBuilders() const { + return def.getValueAsBit("skipDefaultBuilders"); +} + +auto Operator::result_begin() -> value_iterator { return results.begin(); } + +auto Operator::result_end() -> value_iterator { return results.end(); } + +auto Operator::getResults() -> value_range { + return {result_begin(), result_end()}; +} + +TypeConstraint Operator::getResultTypeConstraint(int index) const { + DagInit *results = def.getValueAsDag("results"); + return TypeConstraint(cast(results->getArg(index))); +} + +StringRef Operator::getResultName(int index) const { + DagInit *results = def.getValueAsDag("results"); + return results->getArgNameStr(index); +} + +auto Operator::getResultDecorators(int index) const -> var_decorator_range { + Record *result = + cast(def.getValueAsDag("results")->getArg(index))->getDef(); + if (!result->isSubClassOf("OpVariable")) + return var_decorator_range(nullptr, nullptr); + return *result->getValueAsListInit("decorators"); +} + +unsigned Operator::getNumVariableLengthResults() const { + return llvm::count_if(results, [](const NamedTypeConstraint &c) { + return c.constraint.isVariableLength(); + }); +} + +unsigned Operator::getNumVariableLengthOperands() const { + return llvm::count_if(operands, [](const NamedTypeConstraint &c) { + return c.constraint.isVariableLength(); + }); +} + +bool Operator::hasSingleVariadicArg() const { + return getNumArgs() == 1 && getArg(0).is() && + getOperand(0).isVariadic(); +} + +Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); } + +Operator::arg_iterator Operator::arg_end() const { return arguments.end(); } + +Operator::arg_range Operator::getArgs() const { + return {arg_begin(), arg_end()}; +} + +StringRef Operator::getArgName(int index) const { + DagInit *argumentValues = def.getValueAsDag("arguments"); + return argumentValues->getArgNameStr(index); +} + +auto Operator::getArgDecorators(int index) const -> var_decorator_range { + Record *arg = + cast(def.getValueAsDag("arguments")->getArg(index))->getDef(); + if (!arg->isSubClassOf("OpVariable")) + return var_decorator_range(nullptr, nullptr); + return *arg->getValueAsListInit("decorators"); +} + +const Trait *Operator::getTrait(StringRef trait) const { + for (const auto &t : traits) { + if (const auto *traitDef = dyn_cast(&t)) { + if (traitDef->getFullyQualifiedTraitName() == trait) + return traitDef; + } else if (const auto *traitDef = dyn_cast(&t)) { + if (traitDef->getFullyQualifiedTraitName() == trait) + return traitDef; + } else if (const auto *traitDef = dyn_cast(&t)) { + if (traitDef->getFullyQualifiedTraitName() == trait) + return traitDef; + } + } + return nullptr; +} + +auto Operator::region_begin() const -> const_region_iterator { + return regions.begin(); +} +auto Operator::region_end() const -> const_region_iterator { + return regions.end(); +} +auto Operator::getRegions() const + -> llvm::iterator_range { + return {region_begin(), region_end()}; +} + +unsigned Operator::getNumRegions() const { return regions.size(); } + +const NamedRegion &Operator::getRegion(unsigned index) const { + return regions[index]; +} + +unsigned Operator::getNumVariadicRegions() const { + return llvm::count_if(regions, + [](const NamedRegion &c) { return c.isVariadic(); }); +} + +auto Operator::successor_begin() const -> const_successor_iterator { + return successors.begin(); +} +auto Operator::successor_end() const -> const_successor_iterator { + return successors.end(); +} +auto Operator::getSuccessors() const + -> llvm::iterator_range { + return {successor_begin(), successor_end()}; +} + +unsigned Operator::getNumSuccessors() const { return successors.size(); } + +const NamedSuccessor &Operator::getSuccessor(unsigned index) const { + return successors[index]; +} + +unsigned Operator::getNumVariadicSuccessors() const { + return llvm::count_if(successors, + [](const NamedSuccessor &c) { return c.isVariadic(); }); +} + +auto Operator::trait_begin() const -> const_trait_iterator { + return traits.begin(); +} +auto Operator::trait_end() const -> const_trait_iterator { + return traits.end(); +} +auto Operator::getTraits() const -> llvm::iterator_range { + return {trait_begin(), trait_end()}; +} + +auto Operator::attribute_begin() const -> attribute_iterator { + return attributes.begin(); +} +auto Operator::attribute_end() const -> attribute_iterator { + return attributes.end(); +} +auto Operator::getAttributes() const + -> llvm::iterator_range { + return {attribute_begin(), attribute_end()}; +} + +auto Operator::operand_begin() -> value_iterator { return operands.begin(); } +auto Operator::operand_end() -> value_iterator { return operands.end(); } +auto Operator::getOperands() -> value_range { + return {operand_begin(), operand_end()}; +} + +auto Operator::getArg(int index) const -> Argument { return arguments[index]; } + +// Mapping from result index to combined argument and result index. Arguments +// are indexed to match getArg index, while the result indexes are mapped to +// avoid overlap. +static int resultIndex(int i) { return -1 - i; } + +bool Operator::isVariadic() const { + return any_of(llvm::concat(operands, results), + [](const NamedTypeConstraint &op) { return op.isVariadic(); }); +} + +void Operator::populateTypeInferenceInfo( + const llvm::StringMap &argumentsAndResultsIndex) { + // If the type inference op interface is not registered, then do not attempt + // to determine if the result types an be inferred. + auto &recordKeeper = def.getRecords(); + auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface); + allResultsHaveKnownTypes = false; + if (!inferTrait) + return; + + // If there are no results, the skip this else the build method generated + // overlaps with another autogenerated builder. + if (getNumResults() == 0) + return; + + // Skip for ops with variadic operands/results. + // TODO: This can be relaxed. + if (isVariadic()) + return; + + // Skip cases currently being custom generated. + // TODO: Remove special cases. + if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) + return; + + // We create equivalence classes of argument/result types where arguments + // and results are mapped into the same index space and indices corresponding + // to the same type are in the same equivalence class. + llvm::EquivalenceClasses ecs; + resultTypeMapping.resize(getNumResults()); + // Captures the argument whose type matches a given result type. Preference + // towards capturing operands first before attributes. + auto captureMapping = [&](int i) { + bool found = false; + ecs.insert(resultIndex(i)); + auto mi = ecs.findLeader(resultIndex(i)); + for (auto me = ecs.member_end(); mi != me; ++mi) { + if (*mi < 0) { + auto tc = getResultTypeConstraint(i); + if (tc.getBuilderCall().hasValue()) { + resultTypeMapping[i].emplace_back(tc); + found = true; + } + continue; + } + + if (getArg(*mi).is()) { + // TODO: Handle attributes. + continue; + } else { + resultTypeMapping[i].emplace_back(*mi); + found = true; + } + } + return found; + }; + + for (const Trait &trait : traits) { + const llvm::Record &def = trait.getDef(); + // If the infer type op interface was manually added, then treat it as + // intention that the op needs special handling. + // TODO: Reconsider whether to always generate, this is more conservative + // and keeps existing behavior so starting that way for now. + if (def.isSubClassOf( + llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) + return; + if (const auto *traitDef = dyn_cast(&trait)) + if (&traitDef->getDef() == inferTrait) + return; + + if (!def.isSubClassOf("AllTypesMatch")) + continue; + + auto values = def.getValueAsListOfStrings("values"); + auto root = argumentsAndResultsIndex.lookup(values.front()); + for (StringRef str : values) + ecs.unionSets(argumentsAndResultsIndex.lookup(str), root); + } + + // Verifies that all output types have a corresponding known input type + // and chooses matching operand or attribute (in that order) that + // matches it. + allResultsHaveKnownTypes = + all_of(llvm::seq(0, getNumResults()), captureMapping); + + // If the types could be computed, then add type inference trait. + if (allResultsHaveKnownTypes) + traits.push_back(Trait::create(inferTrait->getDefInit())); +} + +void Operator::populateOpStructure() { + auto &recordKeeper = def.getRecords(); + auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint"); + auto *attrClass = recordKeeper.getClass("Attr"); + auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr"); + auto *opVarClass = recordKeeper.getClass("OpVariable"); + numNativeAttributes = 0; + + DagInit *argumentValues = def.getValueAsDag("arguments"); + unsigned numArgs = argumentValues->getNumArgs(); + + // Mapping from name of to argument or result index. Arguments are indexed + // to match getArg index, while the results are negatively indexed. + llvm::StringMap argumentsAndResultsIndex; + + // Handle operands and native attributes. + for (unsigned i = 0; i != numArgs; ++i) { + auto *arg = argumentValues->getArg(i); + auto givenName = argumentValues->getArgNameStr(i); + auto *argDefInit = dyn_cast(arg); + if (!argDefInit) + PrintFatalError(def.getLoc(), + Twine("undefined type for argument #") + Twine(i)); + Record *argDef = argDefInit->getDef(); + if (argDef->isSubClassOf(opVarClass)) + argDef = argDef->getValueAsDef("constraint"); + + if (argDef->isSubClassOf(typeConstraintClass)) { + operands.push_back( + NamedTypeConstraint{givenName, TypeConstraint(argDef)}); + } else if (argDef->isSubClassOf(attrClass)) { + if (givenName.empty()) + PrintFatalError(argDef->getLoc(), "attributes must be named"); + if (argDef->isSubClassOf(derivedAttrClass)) + PrintFatalError(argDef->getLoc(), + "derived attributes not allowed in argument list"); + attributes.push_back({givenName, Attribute(argDef)}); + ++numNativeAttributes; + } else { + PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving " + "from TypeConstraint or Attr are allowed"); + } + if (!givenName.empty()) + argumentsAndResultsIndex[givenName] = i; + } + + // Handle derived attributes. + for (const auto &val : def.getValues()) { + if (auto *record = dyn_cast(val.getType())) { + if (!record->isSubClassOf(attrClass)) + continue; + if (!record->isSubClassOf(derivedAttrClass)) + PrintFatalError(def.getLoc(), + "unexpected Attr where only DerivedAttr is allowed"); + + if (record->getClasses().size() != 1) { + PrintFatalError( + def.getLoc(), + "unsupported attribute modelling, only single class expected"); + } + attributes.push_back( + {cast(val.getNameInit())->getValue(), + Attribute(cast(val.getValue()))}); + } + } + + // Populate `arguments`. This must happen after we've finalized `operands` and + // `attributes` because we will put their elements' pointers in `arguments`. + // SmallVector may perform re-allocation under the hood when adding new + // elements. + int operandIndex = 0, attrIndex = 0; + for (unsigned i = 0; i != numArgs; ++i) { + Record *argDef = dyn_cast(argumentValues->getArg(i))->getDef(); + if (argDef->isSubClassOf(opVarClass)) + argDef = argDef->getValueAsDef("constraint"); + + if (argDef->isSubClassOf(typeConstraintClass)) { + attrOrOperandMapping.push_back( + {OperandOrAttribute::Kind::Operand, operandIndex}); + arguments.emplace_back(&operands[operandIndex++]); + } else { + assert(argDef->isSubClassOf(attrClass)); + attrOrOperandMapping.push_back( + {OperandOrAttribute::Kind::Attribute, attrIndex}); + arguments.emplace_back(&attributes[attrIndex++]); + } + } + + auto *resultsDag = def.getValueAsDag("results"); + auto *outsOp = dyn_cast(resultsDag->getOperator()); + if (!outsOp || outsOp->getDef()->getName() != "outs") { + PrintFatalError(def.getLoc(), "'results' must have 'outs' directive"); + } + + // Handle results. + for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { + auto name = resultsDag->getArgNameStr(i); + auto *resultInit = dyn_cast(resultsDag->getArg(i)); + if (!resultInit) { + PrintFatalError(def.getLoc(), + Twine("undefined type for result #") + Twine(i)); + } + auto *resultDef = resultInit->getDef(); + if (resultDef->isSubClassOf(opVarClass)) + resultDef = resultDef->getValueAsDef("constraint"); + results.push_back({name, TypeConstraint(resultDef)}); + if (!name.empty()) + argumentsAndResultsIndex[name] = resultIndex(i); + } + + // Handle successors + auto *successorsDag = def.getValueAsDag("successors"); + auto *successorsOp = dyn_cast(successorsDag->getOperator()); + if (!successorsOp || successorsOp->getDef()->getName() != "successor") { + PrintFatalError(def.getLoc(), + "'successors' must have 'successor' directive"); + } + + for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) { + auto name = successorsDag->getArgNameStr(i); + auto *successorInit = dyn_cast(successorsDag->getArg(i)); + if (!successorInit) { + PrintFatalError(def.getLoc(), + Twine("undefined kind for successor #") + Twine(i)); + } + Successor successor(successorInit->getDef()); + + // Only support variadic successors if it is the last one for now. + if (i != e - 1 && successor.isVariadic()) + PrintFatalError(def.getLoc(), "only the last successor can be variadic"); + successors.push_back({name, successor}); + } + + // Create list of traits, skipping over duplicates: appending to lists in + // tablegen is easy, making them unique less so, so dedupe here. + if (auto *traitList = def.getValueAsListInit("traits")) { + // This is uniquing based on pointers of the trait. + SmallPtrSet traitSet; + traits.reserve(traitSet.size()); + for (auto *traitInit : *traitList) { + // Keep traits in the same order while skipping over duplicates. + if (traitSet.insert(traitInit).second) + traits.push_back(Trait::create(traitInit)); + } + } + + populateTypeInferenceInfo(argumentsAndResultsIndex); + + // Handle regions + auto *regionsDag = def.getValueAsDag("regions"); + auto *regionsOp = dyn_cast(regionsDag->getOperator()); + if (!regionsOp || regionsOp->getDef()->getName() != "region") { + PrintFatalError(def.getLoc(), "'regions' must have 'region' directive"); + } + + for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { + auto name = regionsDag->getArgNameStr(i); + auto *regionInit = dyn_cast(regionsDag->getArg(i)); + if (!regionInit) { + PrintFatalError(def.getLoc(), + Twine("undefined kind for region #") + Twine(i)); + } + Region region(regionInit->getDef()); + if (region.isVariadic()) { + // Only support variadic regions if it is the last one for now. + if (i != e - 1) + PrintFatalError(def.getLoc(), "only the last region can be variadic"); + if (name.empty()) + PrintFatalError(def.getLoc(), "variadic regions must be named"); + } + + regions.push_back({name, region}); + } + + // Populate the builders. + auto *builderList = + dyn_cast_or_null(def.getValueInit("builders")); + if (builderList && !builderList->empty()) { + for (llvm::Init *init : builderList->getValues()) + builders.emplace_back(cast(init)->getDef(), def.getLoc()); + } else if (skipDefaultBuilders()) { + PrintFatalError( + def.getLoc(), + "default builders are skipped and no custom builders provided"); + } + + LLVM_DEBUG(print(llvm::dbgs())); +} + +auto Operator::getSameTypeAsResult(int index) const -> ArrayRef { + assert(allResultTypesKnown()); + return resultTypeMapping[index]; +} + +ArrayRef Operator::getLoc() const { return def.getLoc(); } + +bool Operator::hasDescription() const { + return def.getValue("description") != nullptr; +} + +StringRef Operator::getDescription() const { + return def.getValueAsString("description"); +} + +bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; } + +StringRef Operator::getSummary() const { + return def.getValueAsString("summary"); +} + +bool Operator::hasAssemblyFormat() const { + auto *valueInit = def.getValueInit("assemblyFormat"); + return isa(valueInit); +} + +StringRef Operator::getAssemblyFormat() const { + return TypeSwitch(def.getValueInit("assemblyFormat")) + .Case( + [&](auto *init) { return init->getValue(); }); +} + +void Operator::print(llvm::raw_ostream &os) const { + os << "op '" << getOperationName() << "'\n"; + for (Argument arg : arguments) { + if (auto *attr = arg.dyn_cast()) + os << "[attribute] " << attr->name << '\n'; + else + os << "[operand] " << arg.get()->name << '\n'; + } +} + +auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) + -> VariableDecorator { + return VariableDecorator(cast(init)->getDef()); +} + +auto Operator::getArgToOperandOrAttribute(int index) const + -> OperandOrAttribute { + return attrOrOperandMapping[index]; +} diff --git a/tools/mlir-tblgen-builder/TableGen/Operator.h b/tools/mlir-tblgen-builder/TableGen/Operator.h new file mode 100644 index 0000000..de6c892 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Operator.h @@ -0,0 +1,360 @@ +//===- Operator.h - Operator class ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Operator wrapper to simplify using TableGen Record defining a MLIR Op. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_OPERATOR_H_ +#define MLIR_TABLEGEN_OPERATOR_H_ + +#include "mlir/Support/LLVM.h" +#include "Argument.h" +#include "Attribute.h" +#include "Builder.h" +#include "Dialect.h" +#include "Region.h" +#include "Successor.h" +#include "Trait.h" +#include "Type.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SMLoc.h" + +namespace llvm { +class DefInit; +class Record; +class StringInit; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// Wrapper class that contains a MLIR op's information (e.g., operands, +// attributes) defined in TableGen and provides helper methods for +// accessing them. +class Operator { +public: + explicit Operator(const llvm::Record &def); + explicit Operator(const llvm::Record *def) : Operator(*def) {} + + // Returns this op's dialect name. + StringRef getDialectName() const; + + // Returns the operation name. The name will follow the "." + // format if its dialect name is not empty. + std::string getOperationName() const; + + // Returns this op's C++ class name. + StringRef getCppClassName() const; + + // Returns this op's C++ class name prefixed with namespaces. + std::string getQualCppClassName() const; + + // Returns this op's C++ namespace. + StringRef getCppNamespace() const; + + // Returns the name of op's adaptor C++ class. + std::string getAdaptorName() const; + + /// A class used to represent the decorators of an operator variable, i.e. + /// argument or result. + struct VariableDecorator { + public: + explicit VariableDecorator(const llvm::Record *def) : def(def) {} + const llvm::Record &getDef() const { return *def; } + + protected: + // The TableGen definition of this decorator. + const llvm::Record *def; + }; + + // A utility iterator over a list of variable decorators. + struct VariableDecoratorIterator + : public llvm::mapped_iterator { + using reference = VariableDecorator; + + /// Initializes the iterator to the specified iterator. + VariableDecoratorIterator(llvm::Init *const *it) + : llvm::mapped_iterator(it, + &unwrap) {} + static VariableDecorator unwrap(llvm::Init *init); + }; + using var_decorator_iterator = VariableDecoratorIterator; + using var_decorator_range = llvm::iterator_range; + + using value_iterator = NamedTypeConstraint *; + using value_range = llvm::iterator_range; + + // Returns true if this op has variable length operands or results. + bool isVariadic() const; + + // Returns true if default builders should not be generated. + bool skipDefaultBuilders() const; + + // Op result iterators. + value_iterator result_begin(); + value_iterator result_end(); + value_range getResults(); + + // Returns the number of results this op produces. + int getNumResults() const; + + // Returns the op result at the given `index`. + NamedTypeConstraint &getResult(int index) { return results[index]; } + const NamedTypeConstraint &getResult(int index) const { + return results[index]; + } + + // Returns the `index`-th result's type constraint. + TypeConstraint getResultTypeConstraint(int index) const; + // Returns the `index`-th result's name. + StringRef getResultName(int index) const; + // Returns the `index`-th result's decorators. + var_decorator_range getResultDecorators(int index) const; + + // Returns the number of variable length results in this operation. + unsigned getNumVariableLengthResults() const; + + // Op attribute iterators. + using attribute_iterator = const NamedAttribute *; + attribute_iterator attribute_begin() const; + attribute_iterator attribute_end() const; + llvm::iterator_range getAttributes() const; + + int getNumAttributes() const { return attributes.size(); } + int getNumNativeAttributes() const { return numNativeAttributes; } + + // Op attribute accessors. + NamedAttribute &getAttribute(int index) { return attributes[index]; } + + // Op operand iterators. + value_iterator operand_begin(); + value_iterator operand_end(); + value_range getOperands(); + + int getNumOperands() const { return operands.size(); } + NamedTypeConstraint &getOperand(int index) { return operands[index]; } + const NamedTypeConstraint &getOperand(int index) const { + return operands[index]; + } + + // Returns the number of variadic operands in this operation. + unsigned getNumVariableLengthOperands() const; + + // Returns the total number of arguments. + int getNumArgs() const { return arguments.size(); } + + // Returns true of the operation has a single variadic arg. + bool hasSingleVariadicArg() const; + + // Returns true if the operation has a single variadic result. + bool hasSingleVariadicResult() const { + return getNumResults() == 1 && getResult(0).isVariadic(); + } + + // Returns true of the operation has no variadic regions. + bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; } + + using arg_iterator = const Argument *; + using arg_range = llvm::iterator_range; + + // Op argument (attribute or operand) iterators. + arg_iterator arg_begin() const; + arg_iterator arg_end() const; + arg_range getArgs() const; + + // Op argument (attribute or operand) accessors. + Argument getArg(int index) const; + StringRef getArgName(int index) const; + var_decorator_range getArgDecorators(int index) const; + + // Returns the trait wrapper for the given MLIR C++ `trait`. + const Trait *getTrait(llvm::StringRef trait) const; + + // Regions. + using const_region_iterator = const NamedRegion *; + const_region_iterator region_begin() const; + const_region_iterator region_end() const; + llvm::iterator_range getRegions() const; + + // Returns the number of regions. + unsigned getNumRegions() const; + // Returns the `index`-th region. + const NamedRegion &getRegion(unsigned index) const; + + // Returns the number of variadic regions in this operation. + unsigned getNumVariadicRegions() const; + + // Successors. + using const_successor_iterator = const NamedSuccessor *; + const_successor_iterator successor_begin() const; + const_successor_iterator successor_end() const; + llvm::iterator_range getSuccessors() const; + + // Returns the number of successors. + unsigned getNumSuccessors() const; + // Returns the `index`-th successor. + const NamedSuccessor &getSuccessor(unsigned index) const; + + // Returns the number of variadic successors in this operation. + unsigned getNumVariadicSuccessors() const; + + // Trait. + using const_trait_iterator = const Trait *; + const_trait_iterator trait_begin() const; + const_trait_iterator trait_end() const; + llvm::iterator_range getTraits() const; + + ArrayRef getLoc() const; + + // Query functions for the documentation of the operator. + bool hasDescription() const; + StringRef getDescription() const; + bool hasSummary() const; + StringRef getSummary() const; + + // Query functions for the assembly format of the operator. + bool hasAssemblyFormat() const; + StringRef getAssemblyFormat() const; + + // Returns this op's extra class declaration code. + StringRef getExtraClassDeclaration() const; + + // Returns the Tablegen definition this operator was constructed from. + // TODO: do not expose the TableGen record, this is a temporary solution to + // OpEmitter requiring a Record because Operator does not provide enough + // methods. + const llvm::Record &getDef() const; + + // Returns the dialect of the op. + const Dialect &getDialect() const { return dialect; } + + // Prints the contents in this operator to the given `os`. This is used for + // debugging purposes. + void print(llvm::raw_ostream &os) const; + + // Return whether all the result types are known. + bool allResultTypesKnown() const { return allResultsHaveKnownTypes; }; + + // Pair representing either a index to an argument or a type constraint. Only + // one of these entries should have the non-default value. + struct ArgOrType { + explicit ArgOrType(int index) : index(index), constraint(None) {} + explicit ArgOrType(TypeConstraint constraint) + : index(None), constraint(constraint) {} + bool isArg() const { + assert(constraint.hasValue() ^ index.hasValue()); + return index.hasValue(); + } + bool isType() const { + assert(constraint.hasValue() ^ index.hasValue()); + return constraint.hasValue(); + } + + int getArg() const { return *index; } + TypeConstraint getType() const { return *constraint; } + + private: + Optional index; + Optional constraint; + }; + + // Return all arguments or type constraints with same type as result[index]. + // Requires: all result types are known. + ArrayRef getSameTypeAsResult(int index) const; + + // Pair consisting kind of argument and index into operands or attributes. + struct OperandOrAttribute { + enum class Kind { Operand, Attribute }; + OperandOrAttribute(Kind kind, int index) { + packed = (index << 1) & (kind == Kind::Attribute); + } + int operandOrAttributeIndex() const { return (packed >> 1); } + Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; } + + private: + int packed; + }; + + // Returns the OperandOrAttribute corresponding to the index. + OperandOrAttribute getArgToOperandOrAttribute(int index) const; + + // Returns the builders of this operation. + ArrayRef getBuilders() const { return builders; } + +private: + // Populates the vectors containing operands, attributes, results and traits. + void populateOpStructure(); + + // Populates type inference info (mostly equality) with input a mapping from + // names to indices for arguments and results. + void populateTypeInferenceInfo( + const llvm::StringMap &argumentsAndResultsIndex); + + // The dialect of this op. + Dialect dialect; + + // The unqualified C++ class name of the op. + StringRef cppClassName; + + // The C++ namespace for this op. + StringRef cppNamespace; + + // The operands of the op. + SmallVector operands; + + // The attributes of the op. Contains native attributes (corresponding to the + // actual stored attributed of the operation) followed by derived attributes + // (corresponding to dynamic properties of the operation that are computed + // upon request). + SmallVector attributes; + + // The arguments of the op (operands and native attributes). + SmallVector arguments; + + // The results of the op. + SmallVector results; + + // The successors of this op. + SmallVector successors; + + // The traits of the op. + SmallVector traits; + + // The regions of this op. + SmallVector regions; + + // The argument with the same type as the result. + SmallVector, 4> resultTypeMapping; + + // Map from argument to attribute or operand number. + SmallVector attrOrOperandMapping; + + // The builders of this operator. + SmallVector builders; + + // The number of native attributes stored in the leading positions of + // `attributes`. + int numNativeAttributes; + + // The TableGen definition of this op. + const llvm::Record &def; + + // Whether the type of all results are known. + bool allResultsHaveKnownTypes; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_OPERATOR_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Pass.cpp b/tools/mlir-tblgen-builder/TableGen/Pass.cpp new file mode 100644 index 0000000..0db3fd8 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Pass.cpp @@ -0,0 +1,99 @@ +//===- Pass.cpp - Pass related classes ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Pass.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// PassOption +//===----------------------------------------------------------------------===// + +StringRef PassOption::getCppVariableName() const { + return def->getValueAsString("cppName"); +} + +StringRef PassOption::getArgument() const { + return def->getValueAsString("argument"); +} + +StringRef PassOption::getType() const { return def->getValueAsString("type"); } + +Optional PassOption::getDefaultValue() const { + StringRef defaultVal = def->getValueAsString("defaultValue"); + return defaultVal.empty() ? Optional() : defaultVal; +} + +StringRef PassOption::getDescription() const { + return def->getValueAsString("description"); +} + +Optional PassOption::getAdditionalFlags() const { + StringRef additionalFlags = def->getValueAsString("additionalOptFlags"); + return additionalFlags.empty() ? Optional() : additionalFlags; +} + +bool PassOption::isListOption() const { + return def->isSubClassOf("ListOption"); +} + +//===----------------------------------------------------------------------===// +// PassStatistic +//===----------------------------------------------------------------------===// + +StringRef PassStatistic::getCppVariableName() const { + return def->getValueAsString("cppName"); +} + +StringRef PassStatistic::getName() const { + return def->getValueAsString("name"); +} + +StringRef PassStatistic::getDescription() const { + return def->getValueAsString("description"); +} + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +Pass::Pass(const llvm::Record *def) : def(def) { + for (auto *init : def->getValueAsListOfDefs("options")) + options.push_back(PassOption(init)); + for (auto *init : def->getValueAsListOfDefs("statistics")) + statistics.push_back(PassStatistic(init)); + for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects")) + dependentDialects.push_back(dialect); +} + +StringRef Pass::getArgument() const { + return def->getValueAsString("argument"); +} + +StringRef Pass::getBaseClass() const { + return def->getValueAsString("baseClass"); +} + +StringRef Pass::getSummary() const { return def->getValueAsString("summary"); } + +StringRef Pass::getDescription() const { + return def->getValueAsString("description"); +} + +StringRef Pass::getConstructor() const { + return def->getValueAsString("constructor"); +} +ArrayRef Pass::getDependentDialects() const { + return dependentDialects; +} + +ArrayRef Pass::getOptions() const { return options; } + +ArrayRef Pass::getStatistics() const { return statistics; } diff --git a/tools/mlir-tblgen-builder/TableGen/Pass.h b/tools/mlir-tblgen-builder/TableGen/Pass.h new file mode 100644 index 0000000..968c854 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Pass.h @@ -0,0 +1,118 @@ +//===- Pass.h - TableGen pass definitions -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_PASS_H_ +#define MLIR_TABLEGEN_PASS_H_ + +#include "mlir/Support/LLVM.h" +#include + +namespace llvm { +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { +//===----------------------------------------------------------------------===// +// PassOption +//===----------------------------------------------------------------------===// +class PassOption { +public: + explicit PassOption(const llvm::Record *def) : def(def) {} + + /// Return the name for the C++ option variable. + StringRef getCppVariableName() const; + + /// Return the command line argument to use for this option. + StringRef getArgument() const; + + /// Return the C++ type of the option. + StringRef getType() const; + + /// Return the default value of the option. + Optional getDefaultValue() const; + + /// Return the description for this option. + StringRef getDescription() const; + + /// Return the additional flags passed to the option constructor. + Optional getAdditionalFlags() const; + + /// Flag indicating if this is a list option. + bool isListOption() const; + +private: + const llvm::Record *def; +}; + +//===----------------------------------------------------------------------===// +// PassStatistic +//===----------------------------------------------------------------------===// +class PassStatistic { +public: + explicit PassStatistic(const llvm::Record *def) : def(def) {} + + /// Return the name for the C++ statistic variable. + StringRef getCppVariableName() const; + + /// Return the name of the statistic. + StringRef getName() const; + + /// Return the description for this statistic. + StringRef getDescription() const; + +private: + const llvm::Record *def; +}; + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +/// Wrapper class providing helper methods for Passes defined in TableGen. +class Pass { +public: + explicit Pass(const llvm::Record *def); + + /// Return the command line argument of the pass. + StringRef getArgument() const; + + /// Return the name for the C++ base class. + StringRef getBaseClass() const; + + /// Return the short 1-line summary of the pass. + StringRef getSummary() const; + + /// Return the description of the pass. + StringRef getDescription() const; + + /// Return the C++ constructor call to create an instance of this pass. + StringRef getConstructor() const; + + /// Return the dialects this pass needs to be registered. + ArrayRef getDependentDialects() const; + + /// Return the options provided by this pass. + ArrayRef getOptions() const; + + /// Return the statistics provided by this pass. + ArrayRef getStatistics() const; + + const llvm::Record *getDef() const { return def; } + +private: + const llvm::Record *def; + std::vector dependentDialects; + std::vector options; + std::vector statistics; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_PASS_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Pattern.cpp b/tools/mlir-tblgen-builder/TableGen/Pattern.cpp new file mode 100644 index 0000000..f2a727d --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Pattern.cpp @@ -0,0 +1,739 @@ +//===- Pattern.cpp - Pattern wrapper class --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pattern wrapper class to simplify using TableGen Record defining a MLIR +// Pattern. +// +//===----------------------------------------------------------------------===// + +#include "Pattern.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +#define DEBUG_TYPE "mlir-tblgen-pattern" + +using namespace mlir; +using namespace tblgen; + +using llvm::formatv; + +//===----------------------------------------------------------------------===// +// DagLeaf +//===----------------------------------------------------------------------===// + +bool DagLeaf::isUnspecified() const { + return dyn_cast_or_null(def); +} + +bool DagLeaf::isOperandMatcher() const { + // Operand matchers specify a type constraint. + return isSubClassOf("TypeConstraint"); +} + +bool DagLeaf::isAttrMatcher() const { + // Attribute matchers specify an attribute constraint. + return isSubClassOf("AttrConstraint"); +} + +bool DagLeaf::isNativeCodeCall() const { + return isSubClassOf("NativeCodeCall"); +} + +bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } + +bool DagLeaf::isEnumAttrCase() const { + return isSubClassOf("EnumAttrCaseInfo"); +} + +bool DagLeaf::isStringAttr() const { + return isa(def); +} + +Constraint DagLeaf::getAsConstraint() const { + assert((isOperandMatcher() || isAttrMatcher()) && + "the DAG leaf must be operand or attribute"); + return Constraint(cast(def)->getDef()); +} + +ConstantAttr DagLeaf::getAsConstantAttr() const { + assert(isConstantAttr() && "the DAG leaf must be constant attribute"); + return ConstantAttr(cast(def)); +} + +EnumAttrCase DagLeaf::getAsEnumAttrCase() const { + assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); + return EnumAttrCase(cast(def)); +} + +std::string DagLeaf::getConditionTemplate() const { + return getAsConstraint().getConditionTemplate(); +} + +llvm::StringRef DagLeaf::getNativeCodeTemplate() const { + assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); + return cast(def)->getDef()->getValueAsString("expression"); +} + +std::string DagLeaf::getStringAttr() const { + assert(isStringAttr() && "the DAG leaf must be string attribute"); + return def->getAsUnquotedString(); +} +bool DagLeaf::isSubClassOf(StringRef superclass) const { + if (auto *defInit = dyn_cast_or_null(def)) + return defInit->getDef()->isSubClassOf(superclass); + return false; +} + +void DagLeaf::print(raw_ostream &os) const { + if (def) + def->print(os); +} + +//===----------------------------------------------------------------------===// +// DagNode +//===----------------------------------------------------------------------===// + +bool DagNode::isNativeCodeCall() const { + if (auto *defInit = dyn_cast_or_null(node->getOperator())) + return defInit->getDef()->isSubClassOf("NativeCodeCall"); + return false; +} + +bool DagNode::isOperation() const { + return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); +} + +llvm::StringRef DagNode::getNativeCodeTemplate() const { + assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); + return cast(node->getOperator()) + ->getDef() + ->getValueAsString("expression"); +} + +llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } + +Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { + llvm::Record *opDef = cast(node->getOperator())->getDef(); + auto it = mapper->find(opDef); + if (it != mapper->end()) + return *it->second; + return *mapper->try_emplace(opDef, std::make_unique(opDef)) + .first->second; +} + +int DagNode::getNumOps() const { + int count = isReplaceWithValue() ? 0 : 1; + for (int i = 0, e = getNumArgs(); i != e; ++i) { + if (auto child = getArgAsNestedDag(i)) + count += child.getNumOps(); + } + return count; +} + +int DagNode::getNumArgs() const { return node->getNumArgs(); } + +bool DagNode::isNestedDagArg(unsigned index) const { + return isa(node->getArg(index)); +} + +DagNode DagNode::getArgAsNestedDag(unsigned index) const { + return DagNode(dyn_cast_or_null(node->getArg(index))); +} + +DagLeaf DagNode::getArgAsLeaf(unsigned index) const { + assert(!isNestedDagArg(index)); + return DagLeaf(node->getArg(index)); +} + +StringRef DagNode::getArgName(unsigned index) const { + return node->getArgNameStr(index); +} + +bool DagNode::isReplaceWithValue() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "replaceWithValue"; +} + +bool DagNode::isLocationDirective() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "location"; +} + +void DagNode::print(raw_ostream &os) const { + if (node) + node->print(os); +} + +//===----------------------------------------------------------------------===// +// SymbolInfoMap +//===----------------------------------------------------------------------===// + +StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { + StringRef name, indexStr; + int idx = -1; + std::tie(name, indexStr) = symbol.rsplit("__"); + + if (indexStr.consumeInteger(10, idx)) { + // The second part is not an index; we return the whole symbol as-is. + return symbol; + } + if (index) { + *index = idx; + } + return name; +} + +SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, + Optional index) + : op(op), kind(kind), argIndex(index) {} + +int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { + switch (kind) { + case Kind::Attr: + case Kind::Operand: + case Kind::Value: + return 1; + case Kind::Result: + return op->getNumResults(); + } + llvm_unreachable("unknown kind"); +} + +std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { + return alternativeName.hasValue() ? alternativeName.getValue() : name.str(); +} + +std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { + LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); + switch (kind) { + case Kind::Attr: { + if (op) { + auto type = + op->getArg(*argIndex).get()->attr.getStorageType(); + return std::string(formatv("{0} {1};\n", type, name)); + } + // TODO(suderman): Use a more exact type when available. + return std::string(formatv("Attribute {0};\n", name)); + } + case Kind::Operand: { + // Use operand range for captured operands (to support potential variadic + // operands). + return std::string( + formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n", + getVarName(name))); + } + case Kind::Value: { + return std::string(formatv("::mlir::Value {0};\n", name)); + } + case Kind::Result: { + // Use the op itself for captured results. + return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); + } + } + llvm_unreachable("unknown kind"); +} + +std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( + StringRef name, int index, const char *fmt, const char *separator) const { + LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); + switch (kind) { + case Kind::Attr: { + assert(index < 0); + auto repl = formatv(fmt, name); + LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); + return std::string(repl); + } + case Kind::Operand: { + assert(index < 0); + auto *operand = op->getArg(*argIndex).get(); + // If this operand is variadic, then return a range. Otherwise, return the + // value itself. + if (operand->isVariableLength()) { + auto repl = formatv(fmt, name); + LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); + return std::string(repl); + } + auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); + LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); + return std::string(repl); + } + case Kind::Result: { + // If `index` is greater than zero, then we are referencing a specific + // result of a multi-result op. The result can still be variadic. + if (index >= 0) { + std::string v = + std::string(formatv("{0}.getODSResults({1})", name, index)); + if (!op->getResult(index).isVariadic()) + v = std::string(formatv("(*{0}.begin())", v)); + auto repl = formatv(fmt, v); + LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); + return std::string(repl); + } + + // If this op has no result at all but still we bind a symbol to it, it + // means we want to capture the op itself. + if (op->getNumResults() == 0) { + LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); + return std::string(name); + } + + // We are referencing all results of the multi-result op. A specific result + // can either be a value or a range. Then join them with `separator`. + SmallVector values; + values.reserve(op->getNumResults()); + + for (int i = 0, e = op->getNumResults(); i < e; ++i) { + std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); + if (!op->getResult(i).isVariadic()) { + v = std::string(formatv("(*{0}.begin())", v)); + } + values.push_back(std::string(formatv(fmt, v))); + } + auto repl = llvm::join(values, separator); + LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); + return repl; + } + case Kind::Value: { + assert(index < 0); + assert(op == nullptr); + auto repl = formatv(fmt, name); + LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); + return std::string(repl); + } + } + llvm_unreachable("unknown kind"); +} + +std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( + StringRef name, int index, const char *fmt, const char *separator) const { + LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); + switch (kind) { + case Kind::Attr: + case Kind::Operand: { + assert(index < 0 && "only allowed for symbol bound to result"); + auto repl = formatv(fmt, name); + LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); + return std::string(repl); + } + case Kind::Result: { + if (index >= 0) { + auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); + LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); + return std::string(repl); + } + + // We are referencing all results of the multi-result op. Each result should + // have a value range, and then join them with `separator`. + SmallVector values; + values.reserve(op->getNumResults()); + + for (int i = 0, e = op->getNumResults(); i < e; ++i) { + values.push_back(std::string( + formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); + } + auto repl = llvm::join(values, separator); + LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); + return repl; + } + case Kind::Value: { + assert(index < 0 && "only allowed for symbol bound to result"); + assert(op == nullptr); + auto repl = formatv(fmt, formatv("{{{0}}", name)); + LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); + return std::string(repl); + } + } + llvm_unreachable("unknown kind"); +} + +bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, + int argIndex) { + StringRef name = getValuePackName(symbol); + if (name != symbol) { + auto error = formatv( + "symbol '{0}' with trailing index cannot bind to op argument", symbol); + PrintFatalError(loc, error); + } + + auto symInfo = op.getArg(argIndex).is() + ? SymbolInfo::getAttr(&op, argIndex) + : SymbolInfo::getOperand(&op, argIndex); + + std::string key = symbol.str(); + if (symbolInfoMap.count(key)) { + // Only non unique name for the operand is supported. + if (symInfo.kind != SymbolInfo::Kind::Operand) { + return false; + } + + // Cannot add new operand if there is already non operand with the same + // name. + if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { + return false; + } + } + + symbolInfoMap.emplace(key, symInfo); + return true; +} + +bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { + std::string name = getValuePackName(symbol).str(); + auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); + + return symbolInfoMap.count(inserted->first) == 1; +} + +bool SymbolInfoMap::bindValue(StringRef symbol) { + auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); + return symbolInfoMap.count(inserted->first) == 1; +} + +bool SymbolInfoMap::bindAttr(StringRef symbol) { + auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr()); + return symbolInfoMap.count(inserted->first) == 1; +} + +bool SymbolInfoMap::contains(StringRef symbol) const { + return find(symbol) != symbolInfoMap.end(); +} + +SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { + std::string name = getValuePackName(key).str(); + + return symbolInfoMap.find(name); +} + +SymbolInfoMap::const_iterator +SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op, + int argIndex) const { + std::string name = getValuePackName(key).str(); + auto range = symbolInfoMap.equal_range(name); + + for (auto it = range.first; it != range.second; ++it) { + if (it->second.op == &op && it->second.argIndex == argIndex) { + return it; + } + } + + return symbolInfoMap.end(); +} + +std::pair +SymbolInfoMap::getRangeOfEqualElements(StringRef key) { + std::string name = getValuePackName(key).str(); + + return symbolInfoMap.equal_range(name); +} + +int SymbolInfoMap::count(StringRef key) const { + std::string name = getValuePackName(key).str(); + return symbolInfoMap.count(name); +} + +int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { + StringRef name = getValuePackName(symbol); + if (name != symbol) { + // If there is a trailing index inside symbol, it references just one + // static value. + return 1; + } + // Otherwise, find how many it represents by querying the symbol's info. + return find(name)->second.getStaticValueCount(); +} + +std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, + const char *fmt, + const char *separator) const { + int index = -1; + StringRef name = getValuePackName(symbol, &index); + + auto it = symbolInfoMap.find(name.str()); + if (it == symbolInfoMap.end()) { + auto error = formatv("referencing unbound symbol '{0}'", symbol); + PrintFatalError(loc, error); + } + + return it->second.getValueAndRangeUse(name, index, fmt, separator); +} + +std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, + const char *separator) const { + int index = -1; + StringRef name = getValuePackName(symbol, &index); + + auto it = symbolInfoMap.find(name.str()); + if (it == symbolInfoMap.end()) { + auto error = formatv("referencing unbound symbol '{0}'", symbol); + PrintFatalError(loc, error); + } + + return it->second.getAllRangeUse(name, index, fmt, separator); +} + +void SymbolInfoMap::assignUniqueAlternativeNames() { + llvm::StringSet<> usedNames; + + for (auto symbolInfoIt = symbolInfoMap.begin(); + symbolInfoIt != symbolInfoMap.end();) { + auto range = symbolInfoMap.equal_range(symbolInfoIt->first); + auto startRange = range.first; + auto endRange = range.second; + + auto operandName = symbolInfoIt->first; + int startSearchIndex = 0; + for (++startRange; startRange != endRange; ++startRange) { + // Current operand name is not unique, find a unique one + // and set the alternative name. + for (int i = startSearchIndex;; ++i) { + std::string alternativeName = operandName + std::to_string(i); + if (!usedNames.contains(alternativeName) && + symbolInfoMap.count(alternativeName) == 0) { + usedNames.insert(alternativeName); + startRange->second.alternativeName = alternativeName; + startSearchIndex = i + 1; + + break; + } + } + } + + symbolInfoIt = endRange; + } +} + +//===----------------------------------------------------------------------===// +// Pattern +//==----------------------------------------------------------------------===// + +Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) + : def(*def), recordOpMap(mapper) {} + +DagNode Pattern::getSourcePattern() const { + return DagNode(def.getValueAsDag("sourcePattern")); +} + +int Pattern::getNumResultPatterns() const { + auto *results = def.getValueAsListInit("resultPatterns"); + return results->size(); +} + +DagNode Pattern::getResultPattern(unsigned index) const { + auto *results = def.getValueAsListInit("resultPatterns"); + return DagNode(cast(results->getElement(index))); +} + +void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { + LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); + collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); + LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); + + LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); + infoMap.assignUniqueAlternativeNames(); + LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); +} + +void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { + LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); + for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { + auto pattern = getResultPattern(i); + collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); + } + LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); +} + +const Operator &Pattern::getSourceRootOp() { + return getSourcePattern().getDialectOp(recordOpMap); +} + +Operator &Pattern::getDialectOp(DagNode node) { + return node.getDialectOp(recordOpMap); +} + +std::vector Pattern::getConstraints() const { + auto *listInit = def.getValueAsListInit("constraints"); + std::vector ret; + ret.reserve(listInit->size()); + + for (auto it : *listInit) { + auto *dagInit = dyn_cast(it); + if (!dagInit) + PrintFatalError(&def, "all elements in Pattern multi-entity " + "constraints should be DAG nodes"); + + std::vector entities; + entities.reserve(dagInit->arg_size()); + for (auto *argName : dagInit->getArgNames()) { + if (!argName) { + PrintFatalError( + &def, + "operands to additional constraints can only be symbol references"); + } + entities.push_back(std::string(argName->getValue())); + } + + ret.emplace_back(cast(dagInit->getOperator())->getDef(), + dagInit->getNameStr(), std::move(entities)); + } + return ret; +} + +int Pattern::getBenefit() const { + // The initial benefit value is a heuristic with number of ops in the source + // pattern. + int initBenefit = getSourcePattern().getNumOps(); + llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); + if (delta->getNumArgs() != 1 || !isa(delta->getArg(0))) { + PrintFatalError(&def, + "The 'addBenefit' takes and only takes one integer value"); + } + return initBenefit + dyn_cast(delta->getArg(0))->getValue(); +} + +std::vector Pattern::getLocation() const { + std::vector> result; + result.reserve(def.getLoc().size()); + for (auto loc : def.getLoc()) { + unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); + assert(buf && "invalid source location"); + result.emplace_back( + llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), + llvm::SrcMgr.getLineAndColumn(loc, buf).first); + } + return result; +} + +void Pattern::verifyBind(bool result, StringRef symbolName) { + if (!result) { + auto err = formatv("symbol '{0}' bound more than once", symbolName); + PrintFatalError(&def, err); + } +} + +void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, + bool isSrcPattern) { + auto treeName = tree.getSymbol(); + auto numTreeArgs = tree.getNumArgs(); + + if (tree.isNativeCodeCall()) { + if (!treeName.empty()) { + if (!isSrcPattern) { + LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " + << treeName << '\n'); + verifyBind(infoMap.bindValue(treeName), treeName); + } else { + PrintFatalError(&def, + formatv("binding symbol '{0}' to NativecodeCall in " + "MatchPattern is not supported", + treeName)); + } + } + + for (int i = 0; i != numTreeArgs; ++i) { + if (auto treeArg = tree.getArgAsNestedDag(i)) { + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); + continue; + } + + if (!isSrcPattern) + continue; + + // We can only bind symbols to arguments in source pattern. Those + // symbols are referenced in result patterns. + auto treeArgName = tree.getArgName(i); + + // `$_` is a special symbol meaning ignore the current argument. + if (!treeArgName.empty() && treeArgName != "_") { + DagLeaf leaf = tree.getArgAsLeaf(i); + + // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), + if (leaf.isUnspecified()) { + // This is case of $c, a Value without any constraints. + verifyBind(infoMap.bindValue(treeArgName), treeArgName); + } else { + auto constraint = leaf.getAsConstraint(); + bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || + leaf.isConstantAttr() || + constraint.getKind() == Constraint::Kind::CK_Attr; + + if (isAttr) { + // This is case of $a, a binding to a certain attribute. + verifyBind(infoMap.bindAttr(treeArgName), treeArgName); + continue; + } + + // This is case of $b, a binding to a certain type. + verifyBind(infoMap.bindValue(treeArgName), treeArgName); + } + } + } + + return; + } + + if (tree.isOperation()) { + auto &op = getDialectOp(tree); + auto numOpArgs = op.getNumArgs(); + + // The pattern might have the last argument specifying the location. + bool hasLocDirective = false; + if (numTreeArgs != 0) { + if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) + hasLocDirective = lastArg.isLocationDirective(); + } + + if (numOpArgs != numTreeArgs - hasLocDirective) { + auto err = formatv("op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + op.getOperationName(), numTreeArgs, numOpArgs); + PrintFatalError(&def, err); + } + + // The name attached to the DAG node's operator is for representing the + // results generated from this op. It should be remembered as bound results. + if (!treeName.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "found symbol bound to op result: " << treeName << '\n'); + verifyBind(infoMap.bindOpResult(treeName, op), treeName); + } + + for (int i = 0; i != numTreeArgs; ++i) { + if (auto treeArg = tree.getArgAsNestedDag(i)) { + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); + continue; + } + + if (isSrcPattern) { + // We can only bind symbols to op arguments in source pattern. Those + // symbols are referenced in result patterns. + auto treeArgName = tree.getArgName(i); + // `$_` is a special symbol meaning ignore the current argument. + if (!treeArgName.empty() && treeArgName != "_") { + LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " + << treeArgName << '\n'); + verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName); + } + } + } + return; + } + + if (!treeName.empty()) { + PrintFatalError( + &def, formatv("binding symbol '{0}' to non-operation/native code call " + "unsupported right now", + treeName)); + } + return; +} diff --git a/tools/mlir-tblgen-builder/TableGen/Pattern.h b/tools/mlir-tblgen-builder/TableGen/Pattern.h new file mode 100644 index 0000000..32d15cd --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Pattern.h @@ -0,0 +1,451 @@ +//===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pattern wrapper class to simplify using TableGen Record defining a MLIR +// Pattern. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_PATTERN_H_ +#define MLIR_TABLEGEN_PATTERN_H_ + +#include "mlir/Support/LLVM.h" +#include "Argument.h" +#include "Operator.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" + +#include + +namespace llvm { +class DagInit; +class Init; +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// Mapping from TableGen Record to Operator wrapper object. +// +// We allocate each wrapper object in heap to make sure the pointer to it is +// valid throughout the lifetime of this map. This is important because this map +// is shared among multiple patterns to avoid creating the wrapper object for +// the same op again and again. But this map will continuously grow. +using RecordOperatorMap = + DenseMap>; + +class Pattern; + +// Wrapper class providing helper methods for accessing TableGen DAG leaves +// used inside Patterns. This class is lightweight and designed to be used like +// values. +// +// A TableGen DAG construct is of the syntax +// `(operator, arg0, arg1, ...)`. +// +// This class provides getters to retrieve `arg*` as tblgen:: wrapper objects +// for handy helper methods. It only works on `arg*`s that are not nested DAG +// constructs. +class DagLeaf { +public: + explicit DagLeaf(const llvm::Init *def) : def(def) {} + + // Returns true if this DAG leaf is not specified in the pattern. That is, it + // places no further constraints/transforms and just carries over the original + // value. + bool isUnspecified() const; + + // Returns true if this DAG leaf is matching an operand. That is, it specifies + // a type constraint. + bool isOperandMatcher() const; + + // Returns true if this DAG leaf is matching an attribute. That is, it + // specifies an attribute constraint. + bool isAttrMatcher() const; + + // Returns true if this DAG leaf is wrapping native code call. + bool isNativeCodeCall() const; + + // Returns true if this DAG leaf is specifying a constant attribute. + bool isConstantAttr() const; + + // Returns true if this DAG leaf is specifying an enum attribute case. + bool isEnumAttrCase() const; + + // Returns true if this DAG leaf is specifying a string attribute. + bool isStringAttr() const; + + // Returns this DAG leaf as a constraint. Asserts if fails. + Constraint getAsConstraint() const; + + // Returns this DAG leaf as an constant attribute. Asserts if fails. + ConstantAttr getAsConstantAttr() const; + + // Returns this DAG leaf as an enum attribute case. + // Precondition: isEnumAttrCase() + EnumAttrCase getAsEnumAttrCase() const; + + // Returns the matching condition template inside this DAG leaf. Assumes the + // leaf is an operand/attribute matcher and asserts otherwise. + std::string getConditionTemplate() const; + + // Returns the native code call template inside this DAG leaf. + // Precondition: isNativeCodeCall() + StringRef getNativeCodeTemplate() const; + + // Returns the string associated with the leaf. + // Precondition: isStringAttr() + std::string getStringAttr() const; + + void print(raw_ostream &os) const; + +private: + // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and + // also a subclass of the given `superclass`. + bool isSubClassOf(StringRef superclass) const; + + const llvm::Init *def; +}; + +// Wrapper class providing helper methods for accessing TableGen DAG constructs +// used inside Patterns. This class is lightweight and designed to be used like +// values. +// +// A TableGen DAG construct is of the syntax +// `(operator, arg0, arg1, ...)`. +// +// When used inside Patterns, `operator` corresponds to some dialect op, or +// a known list of verbs that defines special transformation actions. This +// `arg*` can be a nested DAG construct. This class provides getters to +// retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper +// methods. +// +// A null DagNode contains a nullptr and converts to false implicitly. +class DagNode { +public: + explicit DagNode(const llvm::DagInit *node) : node(node) {} + + // Implicit bool converter that returns true if this DagNode is not a null + // DagNode. + operator bool() const { return node != nullptr; } + + // Returns the symbol bound to this DAG node. + StringRef getSymbol() const; + + // Returns the operator wrapper object corresponding to the dialect op matched + // by this DAG. The operator wrapper will be queried from the given `mapper` + // and created in it if not existing. + Operator &getDialectOp(RecordOperatorMap *mapper) const; + + // Returns the number of operations recursively involved in the DAG tree + // rooted from this node. + int getNumOps() const; + + // Returns the number of immediate arguments to this DAG node. + int getNumArgs() const; + + // Returns true if the `index`-th argument is a nested DAG construct. + bool isNestedDagArg(unsigned index) const; + + // Gets the `index`-th argument as a nested DAG construct if possible. Returns + // null DagNode otherwise. + DagNode getArgAsNestedDag(unsigned index) const; + + // Gets the `index`-th argument as a DAG leaf. + DagLeaf getArgAsLeaf(unsigned index) const; + + // Returns the specified name of the `index`-th argument. + StringRef getArgName(unsigned index) const; + + // Returns true if this DAG construct means to replace with an existing SSA + // value. + bool isReplaceWithValue() const; + + // Returns whether this DAG represents the location of an op creation. + bool isLocationDirective() const; + + // Returns true if this DAG node is wrapping native code call. + bool isNativeCodeCall() const; + + // Returns true if this DAG node is an operation. + bool isOperation() const; + + // Returns the native code call template inside this DAG node. + // Precondition: isNativeCodeCall() + StringRef getNativeCodeTemplate() const; + + void print(raw_ostream &os) const; + +private: + const llvm::DagInit *node; // nullptr means null DagNode +}; + +// A class for maintaining information for symbols bound in patterns and +// provides methods for resolving them according to specific use cases. +// +// Symbols can be bound to +// +// * Op arguments and op results in the source pattern and +// * Op results in result patterns. +// +// Symbols can be referenced in result patterns and additional constraints to +// the pattern. +// +// For example, in +// +// ``` +// def : Pattern< +// (SrcOp:$results1 $arg0, %arg1), +// [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>; +// ``` +// +// `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to +// `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build +// `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`. +// +// If a symbol binds to a multi-result op and it does not have the `__N` +// suffix, the symbol is expanded to represent all results generated by the +// multi-result op. If the symbol has a `__N` suffix, then it will expand to +// only the N-th *static* result as declared in ODS, and that can still +// corresponds to multiple *dynamic* values if the N-th *static* result is +// variadic. +// +// This class keeps track of such symbols and resolves them into their bound +// values in a suitable way. +class SymbolInfoMap { +public: + explicit SymbolInfoMap(ArrayRef loc) : loc(loc) {} + + // Class for information regarding a symbol. + class SymbolInfo { + public: + // Returns a string for defining a variable named as `name` to store the + // value bound by this symbol. + std::string getVarDecl(StringRef name) const; + + // Returns a variable name for the symbol named as `name`. + std::string getVarName(StringRef name) const; + + private: + // Allow SymbolInfoMap to access private methods. + friend class SymbolInfoMap; + + // What kind of entity this symbol represents: + // * Attr: op attribute + // * Operand: op operand + // * Result: op result + // * Value: a value not attached to an op (e.g., from NativeCodeCall) + enum class Kind : uint8_t { Attr, Operand, Result, Value }; + + // Creates a SymbolInfo instance. `index` is only used for `Attr` and + // `Operand` so should be negative for `Result` and `Value` kind. + SymbolInfo(const Operator *op, Kind kind, Optional index); + + // Static methods for creating SymbolInfo. + static SymbolInfo getAttr(const Operator *op, int index) { + return SymbolInfo(op, Kind::Attr, index); + } + static SymbolInfo getAttr() { + return SymbolInfo(nullptr, Kind::Attr, llvm::None); + } + static SymbolInfo getOperand(const Operator *op, int index) { + return SymbolInfo(op, Kind::Operand, index); + } + static SymbolInfo getResult(const Operator *op) { + return SymbolInfo(op, Kind::Result, llvm::None); + } + static SymbolInfo getValue() { + return SymbolInfo(nullptr, Kind::Value, llvm::None); + } + + // Returns the number of static values this symbol corresponds to. + // A static value is an operand/result declared in ODS. Normally a symbol + // only represents one static value, but symbols bound to op results can + // represent more than one if the op is a multi-result op. + int getStaticValueCount() const; + + // Returns a string containing the C++ expression for referencing this + // symbol as a value (if this symbol represents one static value) or a value + // range (if this symbol represents multiple static values). `name` is the + // name of the C++ variable that this symbol bounds to. `index` should only + // be used for indexing results. `fmt` is used to format each value. + // `separator` is used to separate values if this is a value range. + std::string getValueAndRangeUse(StringRef name, int index, const char *fmt, + const char *separator) const; + + // Returns a string containing the C++ expression for referencing this + // symbol as a value range regardless of how many static values this symbol + // represents. `name` is the name of the C++ variable that this symbol + // bounds to. `index` should only be used for indexing results. `fmt` is + // used to format each value. `separator` is used to separate values in the + // range. + std::string getAllRangeUse(StringRef name, int index, const char *fmt, + const char *separator) const; + + const Operator *op; // The op where the bound entity belongs + Kind kind; // The kind of the bound entity + // The argument index (for `Attr` and `Operand` only) + Optional argIndex; + // Alternative name for the symbol. It is used in case the name + // is not unique. Applicable for `Operand` only. + Optional alternativeName; + }; + + using BaseT = std::unordered_multimap; + + // Iterators for accessing all symbols. + using iterator = BaseT::iterator; + iterator begin() { return symbolInfoMap.begin(); } + iterator end() { return symbolInfoMap.end(); } + + // Const iterators for accessing all symbols. + using const_iterator = BaseT::const_iterator; + const_iterator begin() const { return symbolInfoMap.begin(); } + const_iterator end() const { return symbolInfoMap.end(); } + + // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. + // Returns false if `symbol` is already bound and symbols are not operands. + bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex); + + // Binds the given `symbol` to the results the given `op`. Returns false if + // `symbol` is already bound. + bool bindOpResult(StringRef symbol, const Operator &op); + + // Registers the given `symbol` as bound to a value. Returns false if `symbol` + // is already bound. + bool bindValue(StringRef symbol); + + // Registers the given `symbol` as bound to an attr. Returns false if `symbol` + // is already bound. + bool bindAttr(StringRef symbol); + + // Returns true if the given `symbol` is bound. + bool contains(StringRef symbol) const; + + // Returns an iterator to the information of the given symbol named as `key`. + const_iterator find(StringRef key) const; + + // Returns an iterator to the information of the given symbol named as `key`, + // with index `argIndex` for operator `op`. + const_iterator findBoundSymbol(StringRef key, const Operator &op, + int argIndex) const; + + // Returns the bounds of a range that includes all the elements which + // bind to the `key`. + std::pair getRangeOfEqualElements(StringRef key); + + // Returns number of times symbol named as `key` was used. + int count(StringRef key) const; + + // Returns the number of static values of the given `symbol` corresponds to. + // A static value is an operand/result declared in ODS. Normally a symbol only + // represents one static value, but symbols bound to op results can represent + // more than one if the op is a multi-result op. + int getStaticValueCount(StringRef symbol) const; + + // Returns a string containing the C++ expression for referencing this + // symbol as a value (if this symbol represents one static value) or a value + // range (if this symbol represents multiple static values). `fmt` is used to + // format each value. `separator` is used to separate values if `symbol` + // represents a value range. + std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}", + const char *separator = ", ") const; + + // Returns a string containing the C++ expression for referencing this + // symbol as a value range regardless of how many static values this symbol + // represents. `fmt` is used to format each value. `separator` is used to + // separate values in the range. + std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}", + const char *separator = ", ") const; + + // Assign alternative unique names to Operands that have equal names. + void assignUniqueAlternativeNames(); + + // Splits the given `symbol` into a value pack name and an index. Returns the + // value pack name and writes the index to `index` on success. Returns + // `symbol` itself if it does not contain an index. + // + // We can use `name__N` to access the `N`-th value in the value pack bound to + // `name`. `name` is typically the results of an multi-result op. + static StringRef getValuePackName(StringRef symbol, int *index = nullptr); + +private: + BaseT symbolInfoMap; + + // Pattern instantiation location. This is intended to be used as parameter + // to PrintFatalError() to report errors. + ArrayRef loc; +}; + +// Wrapper class providing helper methods for accessing MLIR Pattern defined +// in TableGen. This class should closely reflect what is defined as class +// `Pattern` in TableGen. This class contains maps so it is not intended to be +// used as values. +class Pattern { +public: + explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper); + + // Returns the source pattern to match. + DagNode getSourcePattern() const; + + // Returns the number of result patterns generated by applying this rewrite + // rule. + int getNumResultPatterns() const; + + // Returns the DAG tree root node of the `index`-th result pattern. + DagNode getResultPattern(unsigned index) const; + + // Collects all symbols bound in the source pattern into `infoMap`. + void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap); + + // Collects all symbols bound in result patterns into `infoMap`. + void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap); + + // Returns the op that the root node of the source pattern matches. + const Operator &getSourceRootOp(); + + // Returns the operator wrapper object corresponding to the given `node`'s DAG + // operator. + Operator &getDialectOp(DagNode node); + + // Returns the constraints. + std::vector getConstraints() const; + + // Returns the benefit score of the pattern. + int getBenefit() const; + + using IdentifierLine = std::pair; + + // Returns the file location of the pattern (buffer identifier + line number + // pair). + std::vector getLocation() const; + +private: + // Helper function to verify variabld binding. + void verifyBind(bool result, StringRef symbolName); + + // Recursively collects all bound symbols inside the DAG tree rooted + // at `tree` and updates the given `infoMap`. + void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, + bool isSrcPattern); + + // The TableGen definition of this pattern. + const llvm::Record &def; + + // All operators. + // TODO: we need a proper context manager, like MLIRContext, for managing the + // lifetime of shared entities. + RecordOperatorMap *recordOpMap; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_PATTERN_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Predicate.cpp b/tools/mlir-tblgen-builder/TableGen/Predicate.cpp new file mode 100644 index 0000000..27f8e8e --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Predicate.cpp @@ -0,0 +1,376 @@ +//===- Predicate.cpp - Predicate class ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Wrapper around predicates defined in TableGen. +// +//===----------------------------------------------------------------------===// + +#include "Predicate.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace tblgen; + +// Construct a Predicate from a record. +Pred::Pred(const llvm::Record *record) : def(record) { + assert(def->isSubClassOf("Pred") && + "must be a subclass of TableGen 'Pred' class"); +} + +// Construct a Predicate from an initializer. +Pred::Pred(const llvm::Init *init) : def(nullptr) { + if (const auto *defInit = dyn_cast_or_null(init)) + def = defInit->getDef(); +} + +std::string Pred::getCondition() const { + // Static dispatch to subclasses. + if (def->isSubClassOf("CombinedPred")) + return static_cast(this)->getConditionImpl(); + if (def->isSubClassOf("CPred")) + return static_cast(this)->getConditionImpl(); + llvm_unreachable("Pred::getCondition must be overridden in subclasses"); +} + +bool Pred::isCombined() const { + return def && def->isSubClassOf("CombinedPred"); +} + +ArrayRef Pred::getLoc() const { return def->getLoc(); } + +CPred::CPred(const llvm::Record *record) : Pred(record) { + assert(def->isSubClassOf("CPred") && + "must be a subclass of Tablegen 'CPred' class"); +} + +CPred::CPred(const llvm::Init *init) : Pred(init) { + assert((!def || def->isSubClassOf("CPred")) && + "must be a subclass of Tablegen 'CPred' class"); +} + +// Get condition of the C Predicate. +std::string CPred::getConditionImpl() const { + assert(!isNull() && "null predicate does not have a condition"); + return std::string(def->getValueAsString("predExpr")); +} + +CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { + assert(def->isSubClassOf("CombinedPred") && + "must be a subclass of Tablegen 'CombinedPred' class"); +} + +CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { + assert((!def || def->isSubClassOf("CombinedPred")) && + "must be a subclass of Tablegen 'CombinedPred' class"); +} + +const llvm::Record *CombinedPred::getCombinerDef() const { + assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); + return def->getValueAsDef("kind"); +} + +const std::vector CombinedPred::getChildren() const { + assert(def->getValue("children") && + "CombinedPred must have a value 'children'"); + return def->getValueAsListOfDefs("children"); +} + +namespace { +// Kinds of nodes in a logical predicate tree. +enum class PredCombinerKind { + Leaf, + And, + Or, + Not, + SubstLeaves, + Concat, + // Special kinds that are used in simplification. + False, + True +}; + +// A node in a logical predicate tree. +struct PredNode { + PredCombinerKind kind; + const Pred *predicate; + SmallVector children; + std::string expr; + + // Prefix and suffix are used by ConcatPred. + std::string prefix; + std::string suffix; +}; +} // end anonymous namespace + +// Get a predicate tree node kind based on the kind used in the predicate +// TableGen record. +static PredCombinerKind getPredCombinerKind(const Pred &pred) { + if (!pred.isCombined()) + return PredCombinerKind::Leaf; + + const auto &combinedPred = static_cast(pred); + return StringSwitch( + combinedPred.getCombinerDef()->getName()) + .Case("PredCombinerAnd", PredCombinerKind::And) + .Case("PredCombinerOr", PredCombinerKind::Or) + .Case("PredCombinerNot", PredCombinerKind::Not) + .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves) + .Case("PredCombinerConcat", PredCombinerKind::Concat); +} + +namespace { +// Substitution. +using Subst = std::pair; +} // end anonymous namespace + +/// Perform the given substitutions on 'str' in-place. +static void performSubstitutions(std::string &str, + ArrayRef substitutions) { + // Apply all parent substitutions from innermost to outermost. + for (const auto &subst : llvm::reverse(substitutions)) { + auto pos = str.find(std::string(subst.first)); + while (pos != std::string::npos) { + str.replace(pos, subst.first.size(), std::string(subst.second)); + // Skip the newly inserted substring, which itself may consider the + // pattern to match. + pos += subst.second.size(); + // Find the next possible match position. + pos = str.find(std::string(subst.first), pos); + } + } +} + +// Build the predicate tree starting from the top-level predicate, which may +// have children, and perform leaf substitutions inplace. Note that after +// substitution, nodes are still pointing to the original TableGen record. +// All nodes are created within "allocator". +static PredNode * +buildPredicateTree(const Pred &root, + llvm::SpecificBumpPtrAllocator &allocator, + ArrayRef substitutions) { + auto *rootNode = allocator.Allocate(); + new (rootNode) PredNode; + rootNode->kind = getPredCombinerKind(root); + rootNode->predicate = &root; + if (!root.isCombined()) { + rootNode->expr = root.getCondition(); + performSubstitutions(rootNode->expr, substitutions); + return rootNode; + } + + // If the current combined predicate is a leaf substitution, append it to the + // list before continuing. + auto allSubstitutions = llvm::to_vector<4>(substitutions); + if (rootNode->kind == PredCombinerKind::SubstLeaves) { + const auto &substPred = static_cast(root); + allSubstitutions.push_back( + {substPred.getPattern(), substPred.getReplacement()}); + + // If the current predicate is a ConcatPred, record the prefix and suffix. + } else if (rootNode->kind == PredCombinerKind::Concat) { + const auto &concatPred = static_cast(root); + rootNode->prefix = std::string(concatPred.getPrefix()); + performSubstitutions(rootNode->prefix, substitutions); + rootNode->suffix = std::string(concatPred.getSuffix()); + performSubstitutions(rootNode->suffix, substitutions); + } + + // Build child subtrees. + auto combined = static_cast(root); + for (const auto *record : combined.getChildren()) { + auto childTree = + buildPredicateTree(Pred(record), allocator, allSubstitutions); + rootNode->children.push_back(childTree); + } + return rootNode; +} + +// Simplify a predicate tree rooted at "node" using the predicates that are +// known to be true(false). For AND(OR) combined predicates, if any of the +// children is known to be false(true), the result is also false(true). +// Furthermore, for AND(OR) combined predicates, children that are known to be +// true(false) don't have to be checked dynamically. +static PredNode * +propagateGroundTruth(PredNode *node, + const llvm::SmallPtrSetImpl &knownTruePreds, + const llvm::SmallPtrSetImpl &knownFalsePreds) { + // If the current predicate is known to be true or false, change the kind of + // the node and return immediately. + if (knownTruePreds.count(node->predicate) != 0) { + node->kind = PredCombinerKind::True; + node->children.clear(); + return node; + } + if (knownFalsePreds.count(node->predicate) != 0) { + node->kind = PredCombinerKind::False; + node->children.clear(); + return node; + } + + // If the current node is a substitution, stop recursion now. + // The expressions in the leaves below this node were rewritten, but the nodes + // still point to the original predicate records. While the original + // predicate may be known to be true or false, it is not necessarily the case + // after rewriting. + // TODO: we can support ground truth for rewritten + // predicates by either (a) having our own unique'ing of the predicates + // instead of relying on TableGen record pointers or (b) taking ground truth + // values optionally prefixed with a list of substitutions to apply, e.g. + // "predX is true by itself as well as predSubY leaf substitution had been + // applied to it". + if (node->kind == PredCombinerKind::SubstLeaves) { + return node; + } + + // Otherwise, look at child nodes. + + // Move child nodes into some local variable so that they can be optimized + // separately and re-added if necessary. + llvm::SmallVector children; + std::swap(node->children, children); + + for (auto &child : children) { + // First, simplify the child. This maintains the predicate as it was. + auto simplifiedChild = + propagateGroundTruth(child, knownTruePreds, knownFalsePreds); + + // Just add the child if we don't know how to simplify the current node. + if (node->kind != PredCombinerKind::And && + node->kind != PredCombinerKind::Or) { + node->children.push_back(simplifiedChild); + continue; + } + + // Second, based on the type define which known values of child predicates + // immediately collapse this predicate to a known value, and which others + // may be safely ignored. + // OR(..., True, ...) = True + // OR(..., False, ...) = OR(..., ...) + // AND(..., False, ...) = False + // AND(..., True, ...) = AND(..., ...) + auto collapseKind = node->kind == PredCombinerKind::And + ? PredCombinerKind::False + : PredCombinerKind::True; + auto eraseKind = node->kind == PredCombinerKind::And + ? PredCombinerKind::True + : PredCombinerKind::False; + const auto &collapseList = + node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds; + const auto &eraseList = + node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds; + if (simplifiedChild->kind == collapseKind || + collapseList.count(simplifiedChild->predicate) != 0) { + node->kind = collapseKind; + node->children.clear(); + return node; + } else if (simplifiedChild->kind == eraseKind || + eraseList.count(simplifiedChild->predicate) != 0) { + continue; + } + node->children.push_back(simplifiedChild); + } + return node; +} + +// Combine a list of predicate expressions using a binary combiner. If a list +// is empty, return "init". +static std::string combineBinary(ArrayRef children, + std::string combiner, std::string init) { + if (children.empty()) + return init; + + auto size = children.size(); + if (size == 1) + return children.front(); + + std::string str; + llvm::raw_string_ostream os(str); + os << '(' << children.front() << ')'; + for (unsigned i = 1; i < size; ++i) { + os << ' ' << combiner << " (" << children[i] << ')'; + } + return os.str(); +} + +// Prepend negation to the only condition in the predicate expression list. +static std::string combineNot(ArrayRef children) { + assert(children.size() == 1 && "expected exactly one child predicate of Neg"); + return (Twine("!(") + children.front() + Twine(')')).str(); +} + +// Recursively traverse the predicate tree in depth-first post-order and build +// the final expression. +static std::string getCombinedCondition(const PredNode &root) { + // Immediately return for non-combiner predicates that don't have children. + if (root.kind == PredCombinerKind::Leaf) + return root.expr; + if (root.kind == PredCombinerKind::True) + return "true"; + if (root.kind == PredCombinerKind::False) + return "false"; + + // Recurse into children. + llvm::SmallVector childExpressions; + childExpressions.reserve(root.children.size()); + for (const auto &child : root.children) + childExpressions.push_back(getCombinedCondition(*child)); + + // Combine the expressions based on the predicate node kind. + if (root.kind == PredCombinerKind::And) + return combineBinary(childExpressions, "&&", "true"); + if (root.kind == PredCombinerKind::Or) + return combineBinary(childExpressions, "||", "false"); + if (root.kind == PredCombinerKind::Not) + return combineNot(childExpressions); + if (root.kind == PredCombinerKind::Concat) { + assert(childExpressions.size() == 1 && + "ConcatPred should only have one child"); + return root.prefix + childExpressions.front() + root.suffix; + } + + // Substitutions were applied before so just ignore them. + if (root.kind == PredCombinerKind::SubstLeaves) { + assert(childExpressions.size() == 1 && + "substitution predicate must have one child"); + return childExpressions[0]; + } + + llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind"); +} + +std::string CombinedPred::getConditionImpl() const { + llvm::SpecificBumpPtrAllocator allocator; + auto predicateTree = buildPredicateTree(*this, allocator, {}); + predicateTree = + propagateGroundTruth(predicateTree, + /*knownTruePreds=*/llvm::SmallPtrSet(), + /*knownFalsePreds=*/llvm::SmallPtrSet()); + + return getCombinedCondition(*predicateTree); +} + +StringRef SubstLeavesPred::getPattern() const { + return def->getValueAsString("pattern"); +} + +StringRef SubstLeavesPred::getReplacement() const { + return def->getValueAsString("replacement"); +} + +StringRef ConcatPred::getPrefix() const { + return def->getValueAsString("prefix"); +} + +StringRef ConcatPred::getSuffix() const { + return def->getValueAsString("suffix"); +} diff --git a/tools/mlir-tblgen-builder/TableGen/Predicate.h b/tools/mlir-tblgen-builder/TableGen/Predicate.h new file mode 100644 index 0000000..7caea7c --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Predicate.h @@ -0,0 +1,119 @@ +//===- Predicate.h - Predicate class ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Wrapper around predicates defined in TableGen. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_PREDICATE_H_ +#define MLIR_TABLEGEN_PREDICATE_H_ + +#include "mlir/Support/LLVM.h" + +#include +#include + +namespace llvm { +class Init; +class ListInit; +class Record; +class SMLoc; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// A logical predicate. This class must closely follow the definition of +// TableGen class 'Pred'. +class Pred { +public: + // Constructs the null Predicate (e.g., always true). + explicit Pred() : def(nullptr) {} + // Construct a Predicate from a record. + explicit Pred(const llvm::Record *record); + // Construct a Predicate from an initializer. + explicit Pred(const llvm::Init *init); + + // Check if the predicate is defined. Callers may use this to interpret the + // missing predicate as either true (e.g. in filters) or false (e.g. in + // precondition verification). + bool isNull() const { return def == nullptr; } + + // Get the predicate condition. This may dispatch to getConditionImpl() of + // the underlying predicate type. + std::string getCondition() const; + + // Whether the predicate is a combination of other predicates, i.e. an + // record of type CombinedPred. + bool isCombined() const; + + // Records are pointer-comparable. + bool operator==(const Pred &other) const { return def == other.def; } + + // Get the location of the predicate. + ArrayRef getLoc() const; + +protected: + // The TableGen definition of this predicate. + const llvm::Record *def; +}; + +// A logical predicate wrapping a C expression. This class must closely follow +// the definition of TableGen class 'CPred'. +class CPred : public Pred { +public: + // Construct a CPred from a record. + explicit CPred(const llvm::Record *record); + // Construct a CPred an initializer. + explicit CPred(const llvm::Init *init); + + // Get the predicate condition. + std::string getConditionImpl() const; +}; + +// A logical predicate that is a combination of other predicates. This class +// must closely follow the definition of TableGen class 'CombinedPred'. +class CombinedPred : public Pred { +public: + // Construct a CombinedPred from a record. + explicit CombinedPred(const llvm::Record *record); + // Construct a CombinedPred from an initializer. + explicit CombinedPred(const llvm::Init *init); + + // Get the predicate condition. + std::string getConditionImpl() const; + + // Get the definition of the combiner used in this predicate. + const llvm::Record *getCombinerDef() const; + + // Get the predicates that are combined by this predicate. + const std::vector getChildren() const; +}; + +// A combined predicate that requires all child predicates of 'CPred' type to +// have their expression rewritten with a simple string substitution rule. +class SubstLeavesPred : public CombinedPred { +public: + // Get the replacement pattern. + StringRef getPattern() const; + // Get the string used to replace the pattern. + StringRef getReplacement() const; +}; + +// A combined predicate that prepends a prefix and appends a suffix to the +// predicate string composed from a child predicate. +class ConcatPred : public CombinedPred { +public: + StringRef getPrefix() const; + StringRef getSuffix() const; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_PREDICATE_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Region.cpp b/tools/mlir-tblgen-builder/TableGen/Region.cpp new file mode 100644 index 0000000..cbe7204 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Region.cpp @@ -0,0 +1,20 @@ +//===- Region.cpp - Region class ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Region wrapper to simplify using TableGen Record defining a MLIR Region. +// +//===----------------------------------------------------------------------===// + +#include "Region.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +// Returns true if this region is variadic. +bool Region::isVariadic() const { return def->isSubClassOf("VariadicRegion"); } diff --git a/tools/mlir-tblgen-builder/TableGen/Region.h b/tools/mlir-tblgen-builder/TableGen/Region.h new file mode 100644 index 0000000..523a72c --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Region.h @@ -0,0 +1,42 @@ +//===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_REGION_H_ +#define MLIR_TABLEGEN_REGION_H_ + +#include "mlir/Support/LLVM.h" +#include "Constraint.h" + +namespace mlir { +namespace tblgen { + +// Wrapper class providing helper methods for accessing Region defined in +// TableGen. +class Region : public Constraint { +public: + using Constraint::Constraint; + + static bool classof(const Constraint *c) { return c->getKind() == CK_Region; } + + // Returns true if this region is variadic. + bool isVariadic() const; +}; + +// A struct bundling a region's constraint and its name. +struct NamedRegion { + // Returns true if this region is variadic. + bool isVariadic() const { return constraint.isVariadic(); } + + StringRef name; + Region constraint; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_REGION_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/SideEffects.cpp b/tools/mlir-tblgen-builder/TableGen/SideEffects.cpp new file mode 100644 index 0000000..929e6e5 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/SideEffects.cpp @@ -0,0 +1,58 @@ +//===- SideEffects.cpp - SideEffect classes -------------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "SideEffects.h" +#include "llvm/ADT/Twine.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// SideEffect +//===----------------------------------------------------------------------===// + +StringRef SideEffect::getName() const { + return def->getValueAsString("effect"); +} + +StringRef SideEffect::getBaseEffectName() const { + return def->getValueAsString("baseEffectName"); +} + +std::string SideEffect::getInterfaceTrait() const { + StringRef trait = def->getValueAsString("interfaceTrait"); + StringRef cppNamespace = def->getValueAsString("cppNamespace"); + return cppNamespace.empty() ? trait.str() + : (cppNamespace + "::" + trait).str(); +} + +StringRef SideEffect::getResource() const { + return def->getValueAsString("resource"); +} + +bool SideEffect::classof(const Operator::VariableDecorator *var) { + return var->getDef().isSubClassOf("SideEffect"); +} + +//===----------------------------------------------------------------------===// +// SideEffectsTrait +//===----------------------------------------------------------------------===// + +Operator::var_decorator_range SideEffectTrait::getEffects() const { + auto *listInit = dyn_cast(def->getValueInit("effects")); + return {listInit->begin(), listInit->end()}; +} + +StringRef SideEffectTrait::getBaseEffectName() const { + return def->getValueAsString("baseEffectName"); +} + +bool SideEffectTrait::classof(const Trait *t) { + return t->getDef().isSubClassOf("SideEffectsTraitBase"); +} diff --git a/tools/mlir-tblgen-builder/TableGen/SideEffects.h b/tools/mlir-tblgen-builder/TableGen/SideEffects.h new file mode 100644 index 0000000..7449646 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/SideEffects.h @@ -0,0 +1,58 @@ +//===- SideEffects.h - Side Effects classes ---------------------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Wrapper around side effect related classes defined in TableGen. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_SIDEEFFECTS_H_ +#define MLIR_TABLEGEN_SIDEEFFECTS_H_ + +#include "mlir/Support/LLVM.h" +#include "Operator.h" + +namespace mlir { +namespace tblgen { + +// This class represents a specific instance of an effect that is being +// exhibited. +class SideEffect : public Operator::VariableDecorator { +public: + // Return the name of the C++ effect. + StringRef getName() const; + + // Return the name of the base C++ effect. + StringRef getBaseEffectName() const; + + // Return the name of the Interface that the effect belongs to. + std::string getInterfaceTrait() const; + + // Return the name of the resource class. + StringRef getResource() const; + + static bool classof(const Operator::VariableDecorator *var); +}; + +// This class represents an instance of a side effect interface applied to an +// operation. This is a wrapper around an OpInterfaceTrait that also includes +// the effects that are applied. +class SideEffectTrait : public InterfaceTrait { +public: + // Return the effects that are attached to the side effect interface. + Operator::var_decorator_range getEffects() const; + + // Return the name of the base C++ effect. + StringRef getBaseEffectName() const; + + static bool classof(const Trait *t); +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_SIDEEFFECTS_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Successor.cpp b/tools/mlir-tblgen-builder/TableGen/Successor.cpp new file mode 100644 index 0000000..1292c52 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Successor.cpp @@ -0,0 +1,24 @@ +//===- Successor.cpp - Successor class ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Successor wrapper to simplify using TableGen Record defining a MLIR +// Successor. +// +//===----------------------------------------------------------------------===// + +#include "Successor.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +// Returns true if this successor is variadic. +bool Successor::isVariadic() const { + return def->isSubClassOf("VariadicSuccessor"); +} diff --git a/tools/mlir-tblgen-builder/TableGen/Successor.h b/tools/mlir-tblgen-builder/TableGen/Successor.h new file mode 100644 index 0000000..3821fe1 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Successor.h @@ -0,0 +1,44 @@ +//===- Successor.h - TableGen successor definitions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_SUCCESSOR_H_ +#define MLIR_TABLEGEN_SUCCESSOR_H_ + +#include "mlir/Support/LLVM.h" +#include "Constraint.h" + +namespace mlir { +namespace tblgen { + +// Wrapper class providing helper methods for accessing Successor defined in +// TableGen. +class Successor : public Constraint { +public: + using Constraint::Constraint; + + static bool classof(const Constraint *c) { + return c->getKind() == CK_Successor; + } + + // Returns true if this successor is variadic. + bool isVariadic() const; +}; + +// A struct bundling a successor's constraint and its name. +struct NamedSuccessor { + // Returns true if this successor is variadic. + bool isVariadic() const { return constraint.isVariadic(); } + + StringRef name; + Successor constraint; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_SUCCESSOR_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Trait.cpp b/tools/mlir-tblgen-builder/TableGen/Trait.cpp new file mode 100644 index 0000000..baf700b --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Trait.cpp @@ -0,0 +1,93 @@ +//===- Trait.cpp ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Trait wrapper to simplify using TableGen Record defining a MLIR Trait. +// +//===----------------------------------------------------------------------===// + +#include "Trait.h" +#include "Interfaces.h" +#include "Predicate.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// Trait +//===----------------------------------------------------------------------===// + +Trait Trait::create(const llvm::Init *init) { + auto def = cast(init)->getDef(); + if (def->isSubClassOf("PredTrait")) + return Trait(Kind::Pred, def); + if (def->isSubClassOf("GenInternalTrait")) + return Trait(Kind::Internal, def); + if (def->isSubClassOf("InterfaceTrait")) + return Trait(Kind::Interface, def); + assert(def->isSubClassOf("NativeTrait")); + return Trait(Kind::Native, def); +} + +Trait::Trait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {} + +//===----------------------------------------------------------------------===// +// NativeTrait +//===----------------------------------------------------------------------===// + +std::string NativeTrait::getFullyQualifiedTraitName() const { + llvm::StringRef trait = def->getValueAsString("trait"); + llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace"); + return cppNamespace.empty() ? trait.str() + : (cppNamespace + "::" + trait).str(); +} + +//===----------------------------------------------------------------------===// +// InternalTrait +//===----------------------------------------------------------------------===// + +llvm::StringRef InternalTrait::getFullyQualifiedTraitName() const { + return def->getValueAsString("trait"); +} + +//===----------------------------------------------------------------------===// +// PredTrait +//===----------------------------------------------------------------------===// + +std::string PredTrait::getPredTemplate() const { + auto pred = Pred(def->getValueInit("predicate")); + return pred.getCondition(); +} + +llvm::StringRef PredTrait::getSummary() const { + return def->getValueAsString("summary"); +} + +//===----------------------------------------------------------------------===// +// InterfaceTrait +//===----------------------------------------------------------------------===// + +Interface InterfaceTrait::getInterface() const { return Interface(def); } + +std::string InterfaceTrait::getFullyQualifiedTraitName() const { + llvm::StringRef trait = def->getValueAsString("trait"); + llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace"); + return cppNamespace.empty() ? trait.str() + : (cppNamespace + "::" + trait).str(); +} + +bool InterfaceTrait::shouldDeclareMethods() const { + return def->isSubClassOf("DeclareInterfaceMethods"); +} + +std::vector InterfaceTrait::getAlwaysDeclaredMethods() const { + return def->getValueAsListOfStrings("alwaysOverriddenMethods"); +} diff --git a/tools/mlir-tblgen-builder/TableGen/Trait.h b/tools/mlir-tblgen-builder/TableGen/Trait.h new file mode 100644 index 0000000..52d056d --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Trait.h @@ -0,0 +1,116 @@ +//===- Trait.h - Trait wrapper class ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Trait wrapper to simplify using TableGen Record defining an MLIR Trait. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_TRAIT_H_ +#define MLIR_TABLEGEN_TRAIT_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace llvm { +class Init; +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +class Interface; + +// Wrapper class with helper methods for accessing Trait constraints defined in +// TableGen. +class Trait { +public: + // Discriminator for kinds of traits. + enum class Kind { + // Trait corresponding to C++ class. + Native, + // Trait corresponding to a predicate. + Pred, + // Trait controlling definition generator internals. + Internal, + // Trait corresponding to an Interface. + Interface + }; + + explicit Trait(Kind kind, const llvm::Record *def); + + // Returns an Trait corresponding to the init provided. + static Trait create(const llvm::Init *init); + + Kind getKind() const { return kind; } + + // Returns the Tablegen definition this operator was constructed from. + const llvm::Record &getDef() const { return *def; } + +protected: + // The TableGen definition of this trait. + const llvm::Record *def; + Kind kind; +}; + +// Trait corresponding to a native C++ Trait. +class NativeTrait : public Trait { +public: + // Returns the trait corresponding to a C++ trait class. + std::string getFullyQualifiedTraitName() const; + + static bool classof(const Trait *t) { return t->getKind() == Kind::Native; } +}; + +// Trait corresponding to a predicate on the operation. +class PredTrait : public Trait { +public: + // Returns the template for constructing the predicate. + std::string getPredTemplate() const; + + // Returns the description of what the predicate is verifying. + StringRef getSummary() const; + + static bool classof(const Trait *t) { return t->getKind() == Kind::Pred; } +}; + +// Trait controlling op definition generator internals. +class InternalTrait : public Trait { +public: + // Returns the trait controlling op definition generator internals. + StringRef getFullyQualifiedTraitName() const; + + static bool classof(const Trait *t) { return t->getKind() == Kind::Internal; } +}; + +// Trait corresponding to an OpInterface on the operation. +class InterfaceTrait : public Trait { +public: + // Returns interface corresponding to the trait. + Interface getInterface() const; + + // Returns the trait corresponding to a C++ trait class. + std::string getFullyQualifiedTraitName() const; + + static bool classof(const Trait *t) { + return t->getKind() == Kind::Interface; + } + + // Whether the declaration of methods for this trait should be emitted. + bool shouldDeclareMethods() const; + + // Returns the methods that should always be declared if this interface is + // emitting declarations. + std::vector getAlwaysDeclaredMethods() const; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_TRAIT_H_ diff --git a/tools/mlir-tblgen-builder/TableGen/Type.cpp b/tools/mlir-tblgen-builder/TableGen/Type.cpp new file mode 100644 index 0000000..9b2b584 --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Type.cpp @@ -0,0 +1,82 @@ +//===- Type.cpp - Type class ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Type wrapper to simplify using TableGen Record defining a MLIR Type. +// +//===----------------------------------------------------------------------===// + +#include "Type.h" +#include "Dialect.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +TypeConstraint::TypeConstraint(const llvm::Record *record) + : Constraint(Constraint::CK_Type, record) { + assert(def->isSubClassOf("TypeConstraint") && + "must be subclass of TableGen 'TypeConstraint' class"); +} + +TypeConstraint::TypeConstraint(const llvm::DefInit *init) + : TypeConstraint(init->getDef()) {} + +bool TypeConstraint::isOptional() const { + return def->isSubClassOf("Optional"); +} + +bool TypeConstraint::isVariadic() const { + return def->isSubClassOf("Variadic"); +} + +// Returns the builder call for this constraint if this is a buildable type, +// returns None otherwise. +Optional TypeConstraint::getBuilderCall() const { + const llvm::Record *baseType = def; + if (isVariableLength()) + baseType = baseType->getValueAsDef("baseType"); + + // Check to see if this type constraint has a builder call. + const llvm::RecordVal *builderCall = baseType->getValue("builderCall"); + if (!builderCall || !builderCall->getValue()) + return llvm::None; + return TypeSwitch>(builderCall->getValue()) + .Case([&](auto *init) { + StringRef value = init->getValue(); + return value.empty() ? Optional() : value; + }) + .Default([](auto *) { return llvm::None; }); +} + +// Return the C++ class name for this type (which may just be ::mlir::Type). +std::string TypeConstraint::getCPPClassName() const { + StringRef className = def->getValueAsString("cppClassName"); + + // If the class name is already namespace resolved, use it. + if (className.contains("::")) + return className.str(); + + // Otherwise, check to see if there is a namespace from a dialect to prepend. + if (const llvm::RecordVal *value = def->getValue("dialect")) { + Dialect dialect(cast(value->getValue())->getDef()); + return (dialect.getCppNamespace() + "::" + className).str(); + } + return className.str(); +} + +Type::Type(const llvm::Record *record) : TypeConstraint(record) {} + +StringRef Type::getDescription() const { + return def->getValueAsString("description"); +} + +Dialect Type::getDialect() const { + return Dialect(def->getValueAsDef("dialect")); +} diff --git a/tools/mlir-tblgen-builder/TableGen/Type.h b/tools/mlir-tblgen-builder/TableGen/Type.h new file mode 100644 index 0000000..9dce16b --- /dev/null +++ b/tools/mlir-tblgen-builder/TableGen/Type.h @@ -0,0 +1,70 @@ +//===- Type.h - Type class --------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Type wrapper to simplify using TableGen Record defining a MLIR Type. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_TYPE_H_ +#define MLIR_TABLEGEN_TYPE_H_ + +#include "mlir/Support/LLVM.h" +#include "Constraint.h" +#include "Dialect.h" + +namespace llvm { +class DefInit; +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// Wrapper class with helper methods for accessing Type constraints defined in +// TableGen. +class TypeConstraint : public Constraint { +public: + explicit TypeConstraint(const llvm::Record *record); + explicit TypeConstraint(const llvm::DefInit *init); + + static bool classof(const Constraint *c) { return c->getKind() == CK_Type; } + + // Returns true if this is an optional type constraint. + bool isOptional() const; + + // Returns true if this is a variadic type constraint. + bool isVariadic() const; + + // Returns true if this is a variable length type constraint. This is either + // variadic or optional. + bool isVariableLength() const { return isOptional() || isVariadic(); } + + // Returns the builder call for this constraint if this is a buildable type, + // returns None otherwise. + Optional getBuilderCall() const; + + // Return the C++ class name for this type (which may just be ::mlir::Type). + std::string getCPPClassName() const; +}; + +// Wrapper class with helper methods for accessing Types defined in TableGen. +class Type : public TypeConstraint { +public: + explicit Type(const llvm::Record *record); + + // Returns the description of the type. + StringRef getDescription() const; + + // Returns the dialect for the type if defined. + Dialect getDialect() const; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_TYPE_H_ diff --git a/tools/mlir-tblgen-builder/mlir-tblgen-builder.cpp b/tools/mlir-tblgen-builder/mlir-tblgen-builder.cpp new file mode 100644 index 0000000..0699552 --- /dev/null +++ b/tools/mlir-tblgen-builder/mlir-tblgen-builder.cpp @@ -0,0 +1,83 @@ +//===- mlir-tblgen.cpp - Top-Level TableGen implementation for MLIR -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the main function for MLIR's TableGen. +// +//===----------------------------------------------------------------------===// + +#include "TableGen/GenInfo.h" +#include "TableGen/GenNameParser.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace llvm; +using namespace mlir; + +static llvm::ManagedStatic> generatorRegistry; + +mlir::GenRegistration::GenRegistration(StringRef arg, StringRef description, + GenFunction function) { + generatorRegistry->emplace_back(arg, description, function); +} + +GenNameParser::GenNameParser(llvm::cl::Option &opt) + : llvm::cl::parser(opt) { + for (const auto &kv : *generatorRegistry) { + addLiteralOption(kv.getGenArgument(), &kv, kv.getGenDescription()); + } +} + +void GenNameParser::printOptionInfo(const llvm::cl::Option &O, + size_t GlobalWidth) const { + GenNameParser *TP = const_cast(this); + llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(), + [](const GenNameParser::OptionInfo *VT1, + const GenNameParser::OptionInfo *VT2) { + return VT1->Name.compare(VT2->Name); + }); + using llvm::cl::parser; + parser::printOptionInfo(O, GlobalWidth); +} + +// Generator that prints records. +GenRegistration printRecords("print-records", "Print all records to stdout", + [](const RecordKeeper &records, raw_ostream &os) { + os << records; + return false; + }); + +// Generator to invoke. +const mlir::GenInfo *generator; + +// TableGenMain requires a function pointer so this function is passed in which +// simply wraps the call to the generator. +static bool MlirTableGenMain(raw_ostream &os, RecordKeeper &records) { + if (!generator) { + os << records; + return false; + } + return generator->invoke(records, os); +} + +int main(int argc, char **argv) { + llvm::InitLLVM y(argc, argv); + llvm::cl::opt generator( + "", llvm::cl::desc("Generator to run")); + cl::ParseCommandLineOptions(argc, argv); + ::generator = generator.getValue(); + + return TableGenMain(argv[0], &MlirTableGenMain); +}