add mlir tblgen builder

This commit is contained in:
colin.liang 2021-07-23 11:38:34 +08:00
parent b0dd7a7518
commit 898eb732de
47 changed files with 12797 additions and 0 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@ build
llvm-project
llvm-build
bazel-*
bazel-bin
.vscode

41
BUILD
View File

@ -122,6 +122,46 @@ gentbl_cc_library(
deps = [":hlo_ops_td_files"],
)
cc_binary(
name = "mlir-tblgen-builder",
srcs = glob([
"tools/mlir-tblgen-builder/*.h",
"tools/mlir-tblgen-builder/*.cpp",
"tools/mlir-tblgen-builder/TableGen/*.h",
"tools/mlir-tblgen-builder/TableGen/*.cpp",
]),
deps = [
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
# "@llvm-project//mlir:TableGen",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//llvm:config",
],
)
gentbl_cc_library(
name = "hlo_ops_builder_gen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-builder-decls"],
"include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc",
),
(
["-gen-builder-defs"],
"include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.cc.inc",
),
],
tblgen = ":mlir-tblgen-builder",
td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td",
td_includes = [
"external/mlir-hlo/include",
"include",
],
deps = [":hlo_ops_td_files"],
)
gentbl_cc_library(
name = "hlo_ops_base_inc_gen",
strip_include_prefix = "include",
@ -519,6 +559,7 @@ cc_library(
":hlo_ops_base_structs",
":hlo_ops_common",
":hlo_ops_inc_gen",
":hlo_ops_builder_gen",
":hlo_ops_pattern_gen",
":infer_fusibility_op_interface",
"@llvm-project//llvm:Support",

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,28 @@
//===- OpFormatGen.h - MLIR operation format generator ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the interface for generating parsers and printers from the
// declarative format.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_
#define MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_
namespace mlir {
namespace tblgen {
class OpClass;
class Operator;
// Generate the assembly format for the given operator.
void generateOpFormat(const Operator &constOp, OpClass &opClass);
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_

View File

@ -0,0 +1,65 @@
//===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines helpers used in the op generators.
//
//===----------------------------------------------------------------------===//
#include "OpGenHelpers.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Regex.h"
#include "llvm/TableGen/Error.h"
using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
cl::OptionCategory opDefGenCat("Options for op definition generators");
static cl::opt<std::string> opIncFilter(
"op-include-regex",
cl::desc("Regex of name of op's to include (no filter if empty)"),
cl::cat(opDefGenCat));
static cl::opt<std::string> opExcFilter(
"op-exclude-regex",
cl::desc("Regex of name of op's to exclude (no filter if empty)"),
cl::cat(opDefGenCat));
static std::string getOperationName(const Record &def) {
auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
auto opName = def.getValueAsString("opName");
if (prefix.empty())
return std::string(opName);
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
}
std::vector<Record *>
mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
Record *classDef = recordKeeper.getClass("Op");
if (!classDef)
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
std::vector<Record *> defs;
for (const auto &def : recordKeeper.getDefs()) {
if (!def.second->isSubClassOf(classDef))
continue;
// Include if no include filter or include filter matches.
if (!opIncFilter.empty() &&
!includeRegex.match(getOperationName(*def.second)))
continue;
// Unless there is an exclude filter and it matches.
if (!opExcFilter.empty() &&
excludeRegex.match(getOperationName(*def.second)))
continue;
defs.push_back(def.second.get());
}
return defs;
}

View File

@ -0,0 +1,30 @@
//===- OpGenHelpers.h - MLIR operation generator helpers --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines helpers used in the op generators.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
#include "llvm/TableGen/Record.h"
#include <vector>
namespace mlir {
namespace tblgen {
/// Returns all the op definitions filtered by the user. The filtering is via
/// command-line option "op-include-regex" and "op-exclude-regex".
std::vector<llvm::Record *>
getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_

View File

@ -0,0 +1,21 @@
//===- Argument.cpp - Argument definitions --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Argument.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
bool NamedTypeConstraint::hasPredicate() const {
return !constraint.getPredicate().isNull();
}
bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); }
bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); }

View File

@ -0,0 +1,65 @@
//===- Argument.h - Argument definitions ------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file contains definitions for TableGen operation's arguments.
// Operation arguments fall into two categories:
//
// 1. Operands: SSA values operated on by the operation
// 2. Attributes: compile-time known properties that have influence over
// the operation's behavior
//
// These two categories are modelled with the unified argument concept in
// TableGen because we need similar pattern matching mechanisms for them.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_ARGUMENT_H_
#define MLIR_TABLEGEN_ARGUMENT_H_
#include "Attribute.h"
#include "Type.h"
#include "llvm/ADT/PointerUnion.h"
#include <string>
namespace llvm {
class StringRef;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// A struct wrapping an op attribute and its name together
struct NamedAttribute {
llvm::StringRef name;
Attribute attr;
};
// A struct wrapping an op operand/result's constraint and its name together
struct NamedTypeConstraint {
// Returns true if this operand/result has constraint to be satisfied.
bool hasPredicate() const;
// Returns true if this is an optional type constraint. This is a special case
// of variadic for 0 or 1 type.
bool isOptional() const;
// Returns true if this operand/result is variadic.
bool isVariadic() const;
// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }
llvm::StringRef name;
TypeConstraint constraint;
};
// Operation argument: either attribute or operand
using Argument = llvm::PointerUnion<NamedAttribute *, NamedTypeConstraint *>;
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_ARGUMENT_H_

View File

@ -0,0 +1,250 @@
//===- AttrOrTypeDef.cpp - AttrOrTypeDef wrapper classes ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "AttrOrTypeDef.h"
#include "Dialect.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// AttrOrTypeBuilder
//===----------------------------------------------------------------------===//
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool AttrOrTypeBuilder::hasInferredContextParameter() const {
return def->getValueAsBit("hasInferredContextParam");
}
//===----------------------------------------------------------------------===//
// AttrOrTypeDef
//===----------------------------------------------------------------------===//
AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
def->getLoc());
// Ensure that all parameters have names.
for (const AttrOrTypeBuilder::Parameter &param :
builder.getParameters()) {
if (!param.getName())
PrintFatalError(def->getLoc(), "builder parameters must have a name");
}
builders.emplace_back(builder);
}
}
// Populate the traits.
if (auto *traitList = def->getValueAsListInit("traits")) {
SmallPtrSet<const llvm::Init *, 32> traitSet;
traits.reserve(traitSet.size());
for (auto *traitInit : *traitList)
if (traitSet.insert(traitInit).second)
traits.push_back(Trait::create(traitInit));
}
}
Dialect AttrOrTypeDef::getDialect() const {
auto *dialect = dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
return Dialect(dialect ? dialect->getDef() : nullptr);
}
StringRef AttrOrTypeDef::getName() const { return def->getName(); }
StringRef AttrOrTypeDef::getCppClassName() const {
return def->getValueAsString("cppClassName");
}
StringRef AttrOrTypeDef::getCppBaseClassName() const {
return def->getValueAsString("cppBaseClassName");
}
bool AttrOrTypeDef::hasDescription() const {
const llvm::RecordVal *desc = def->getValue("description");
return desc && isa<llvm::StringInit>(desc->getValue());
}
StringRef AttrOrTypeDef::getDescription() const {
return def->getValueAsString("description");
}
bool AttrOrTypeDef::hasSummary() const {
const llvm::RecordVal *summary = def->getValue("summary");
return summary && isa<llvm::StringInit>(summary->getValue());
}
StringRef AttrOrTypeDef::getSummary() const {
return def->getValueAsString("summary");
}
StringRef AttrOrTypeDef::getStorageClassName() const {
return def->getValueAsString("storageClass");
}
StringRef AttrOrTypeDef::getStorageNamespace() const {
return def->getValueAsString("storageNamespace");
}
bool AttrOrTypeDef::genStorageClass() const {
return def->getValueAsBit("genStorageClass");
}
bool AttrOrTypeDef::hasStorageCustomConstructor() const {
return def->getValueAsBit("hasStorageCustomConstructor");
}
void AttrOrTypeDef::getParameters(
SmallVectorImpl<AttrOrTypeParameter> &parameters) const {
if (auto *parametersDag = def->getValueAsDag("parameters")) {
for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i)
parameters.push_back(AttrOrTypeParameter(parametersDag, i));
}
}
unsigned AttrOrTypeDef::getNumParameters() const {
auto *parametersDag = def->getValueAsDag("parameters");
return parametersDag ? parametersDag->getNumArgs() : 0;
}
Optional<StringRef> AttrOrTypeDef::getMnemonic() const {
return def->getValueAsOptionalString("mnemonic");
}
Optional<StringRef> AttrOrTypeDef::getPrinterCode() const {
return def->getValueAsOptionalString("printer");
}
Optional<StringRef> AttrOrTypeDef::getParserCode() const {
return def->getValueAsOptionalString("parser");
}
bool AttrOrTypeDef::genAccessors() const {
return def->getValueAsBit("genAccessors");
}
bool AttrOrTypeDef::genVerifyDecl() const {
return def->getValueAsBit("genVerifyDecl");
}
Optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? Optional<StringRef>() : value;
}
ArrayRef<llvm::SMLoc> AttrOrTypeDef::getLoc() const { return def->getLoc(); }
bool AttrOrTypeDef::skipDefaultBuilders() const {
return def->getValueAsBit("skipDefaultBuilders");
}
bool AttrOrTypeDef::operator==(const AttrOrTypeDef &other) const {
return def == other.def;
}
bool AttrOrTypeDef::operator<(const AttrOrTypeDef &other) const {
return getName() < other.getName();
}
//===----------------------------------------------------------------------===//
// AttrDef
//===----------------------------------------------------------------------===//
Optional<StringRef> AttrDef::getTypeBuilder() const {
return def->getValueAsOptionalString("typeBuilder");
}
bool AttrDef::classof(const AttrOrTypeDef *def) {
return def->getDef()->isSubClassOf("AttrDef");
}
//===----------------------------------------------------------------------===//
// AttrOrTypeParameter
//===----------------------------------------------------------------------===//
StringRef AttrOrTypeParameter::getName() const {
return def->getArgName(index)->getValue();
}
Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
llvm::Init *parameterType = def->getArg(index);
if (isa<llvm::StringInit>(parameterType))
return Optional<StringRef>();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
return param->getDef()->getValueAsOptionalString("allocator");
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from AttrOrTypeParameter\n");
}
Optional<StringRef> AttrOrTypeParameter::getComparator() const {
llvm::Init *parameterType = def->getArg(index);
if (isa<llvm::StringInit>(parameterType))
return Optional<StringRef>();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
return param->getDef()->getValueAsOptionalString("comparator");
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from AttrOrTypeParameter\n");
}
StringRef AttrOrTypeParameter::getCppType() const {
auto *parameterType = def->getArg(index);
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return stringType->getValue();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
return param->getDef()->getValueAsString("cppType");
llvm::PrintFatalError(
"Parameters DAG arguments must be either strings or defs "
"which inherit from AttrOrTypeParameter\n");
}
Optional<StringRef> AttrOrTypeParameter::getSummary() const {
auto *parameterType = def->getArg(index);
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
const auto *desc = param->getDef()->getValue("summary");
if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
return ci->getValue();
}
return Optional<StringRef>();
}
StringRef AttrOrTypeParameter::getSyntax() const {
auto *parameterType = def->getArg(index);
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return stringType->getValue();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
const auto *syntax = param->getDef()->getValue("syntax");
if (syntax && isa<llvm::StringInit>(syntax->getValue()))
return cast<llvm::StringInit>(syntax->getValue())->getValue();
return getCppType();
}
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from AttrOrTypeParameter");
}
const llvm::Init *AttrOrTypeParameter::getDef() const {
return def->getArg(index);
}
//===----------------------------------------------------------------------===//
// AttributeSelfTypeParameter
//===----------------------------------------------------------------------===//
bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
const llvm::Init *paramDef = param->getDef();
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
return false;
}

View File

@ -0,0 +1,229 @@
//===-- AttrOrTypeDef.h - Wrapper for attr and type definitions -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// AttrOrTypeDef, AttrDef, and TypeDef wrappers to simplify using TableGen
// Record defining a MLIR attributes and types.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_ATTRORTYPEDEF_H
#define MLIR_TABLEGEN_ATTRORTYPEDEF_H
#include "mlir/Support/LLVM.h"
#include "Builder.h"
#include "Trait.h"
namespace llvm {
class DagInit;
class Record;
class SMLoc;
} // namespace llvm
namespace mlir {
namespace tblgen {
class Dialect;
class AttrOrTypeParameter;
//===----------------------------------------------------------------------===//
// AttrOrTypeBuilder
//===----------------------------------------------------------------------===//
/// Wrapper class that represents a Tablegen AttrOrTypeBuilder.
class AttrOrTypeBuilder : public Builder {
public:
using Builder::Builder;
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool hasInferredContextParameter() const;
};
//===----------------------------------------------------------------------===//
// AttrOrTypeDef
//===----------------------------------------------------------------------===//
/// Wrapper class that contains a TableGen AttrOrTypeDef's record and provides
/// helper methods for accessing them.
class AttrOrTypeDef {
public:
explicit AttrOrTypeDef(const llvm::Record *def);
// Get the dialect for which this def belongs.
Dialect getDialect() const;
// Returns the name of this AttrOrTypeDef record.
StringRef getName() const;
// Query functions for the documentation of the def.
bool hasDescription() const;
StringRef getDescription() const;
bool hasSummary() const;
StringRef getSummary() const;
// Returns the name of the C++ class to generate.
StringRef getCppClassName() const;
// Returns the name of the C++ base class to use when generating this def.
StringRef getCppBaseClassName() const;
// Returns the name of the storage class for this def.
StringRef getStorageClassName() const;
// Returns the C++ namespace for this def's storage class.
StringRef getStorageNamespace() const;
// Returns true if we should generate the storage class.
bool genStorageClass() const;
// Indicates whether or not to generate the storage class constructor.
bool hasStorageCustomConstructor() const;
// Fill a list with this def's parameters. See AttrOrTypeDef in OpBase.td for
// documentation of parameter usage.
void getParameters(SmallVectorImpl<AttrOrTypeParameter> &) const;
// Return the number of parameters
unsigned getNumParameters() const;
// Return the keyword/mnemonic to use in the printer/parser methods if we are
// supposed to auto-generate them.
Optional<StringRef> getMnemonic() const;
// Returns the code to use as the types printer method. If not specified,
// return a non-value. Otherwise, return the contents of that code block.
Optional<StringRef> getPrinterCode() const;
// Returns the code to use as the parser method. If not specified, returns
// None. Otherwise, returns the contents of that code block.
Optional<StringRef> getParserCode() const;
// Returns true if the accessors based on the parameters should be generated.
bool genAccessors() const;
// Return true if we need to generate the verify declaration and getChecked
// method.
bool genVerifyDecl() const;
// Returns the def's extra class declaration code.
Optional<StringRef> getExtraDecls() const;
// Get the code location (for error printing).
ArrayRef<llvm::SMLoc> getLoc() const;
// Returns true if the default get/getChecked methods should be skipped during
// generation.
bool skipDefaultBuilders() const;
// Returns the builders of this def.
ArrayRef<AttrOrTypeBuilder> getBuilders() const { return builders; }
// Returns the traits of this def.
ArrayRef<Trait> getTraits() const { return traits; }
// Returns whether two AttrOrTypeDefs are equal by checking the equality of
// the underlying record.
bool operator==(const AttrOrTypeDef &other) const;
// Compares two AttrOrTypeDefs by comparing the names of the dialects.
bool operator<(const AttrOrTypeDef &other) const;
// Returns whether the AttrOrTypeDef is defined.
operator bool() const { return def != nullptr; }
// Return the underlying def.
const llvm::Record *getDef() const { return def; }
protected:
const llvm::Record *def;
// The builders of this definition.
SmallVector<AttrOrTypeBuilder> builders;
// The traits of this definition.
SmallVector<Trait> traits;
};
//===----------------------------------------------------------------------===//
// AttrDef
//===----------------------------------------------------------------------===//
/// This class represents a wrapper around a tablegen AttrDef record.
class AttrDef : public AttrOrTypeDef {
public:
using AttrOrTypeDef::AttrOrTypeDef;
// Returns the attributes value type builder code block, or None if it doesn't
// have one.
Optional<StringRef> getTypeBuilder() const;
static bool classof(const AttrOrTypeDef *def);
};
//===----------------------------------------------------------------------===//
// TypeDef
//===----------------------------------------------------------------------===//
/// This class represents a wrapper around a tablegen TypeDef record.
class TypeDef : public AttrOrTypeDef {
public:
using AttrOrTypeDef::AttrOrTypeDef;
};
//===----------------------------------------------------------------------===//
// AttrOrTypeParameter
//===----------------------------------------------------------------------===//
// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
// AttrOrTypeDefs to parameterize them.
class AttrOrTypeParameter {
public:
explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
: def(def), index(index) {}
// Get the parameter name.
StringRef getName() const;
// If specified, get the custom allocator code for this parameter.
Optional<StringRef> getAllocator() const;
// If specified, get the custom comparator code for this parameter.
Optional<StringRef> getComparator() const;
// Get the C++ type of this parameter.
StringRef getCppType() const;
// Get a description of this parameter for documentation purposes.
Optional<StringRef> getSummary() const;
// Get the assembly syntax documentation.
StringRef getSyntax() const;
// Return the underlying def of this parameter.
const llvm::Init *getDef() const;
private:
/// The underlying tablegen parameter list this parameter is a part of.
const llvm::DagInit *def;
/// The index of the parameter within the parameter list (`def`).
unsigned index;
};
//===----------------------------------------------------------------------===//
// AttributeSelfTypeParameter
//===----------------------------------------------------------------------===//
// A wrapper class for the AttributeSelfTypeParameter tblgen class. This
// represents a parameter of mlir::Type that is the value type of an AttrDef.
class AttributeSelfTypeParameter : public AttrOrTypeParameter {
public:
static bool classof(const AttrOrTypeParameter *param);
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_ATTRORTYPEDEF_H

View File

@ -0,0 +1,296 @@
//===- Attribute.cpp - Attribute wrapper class ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Attribute wrapper to simplify using TableGen Record defining a MLIR
// Attribute.
//
//===----------------------------------------------------------------------===//
#include "Format.h"
#include "Operator.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
using llvm::DefInit;
using llvm::Init;
using llvm::Record;
using llvm::StringInit;
// Returns the initializer's value as string if the given TableGen initializer
// is a code or string initializer. Returns the empty StringRef otherwise.
static StringRef getValueAsString(const Init *init) {
if (const auto *str = dyn_cast<StringInit>(init))
return str->getValue().trim();
return {};
}
AttrConstraint::AttrConstraint(const Record *record)
: Constraint(Constraint::CK_Attr, record) {
assert(isSubClassOf("AttrConstraint") &&
"must be subclass of TableGen 'AttrConstraint' class");
}
bool AttrConstraint::isSubClassOf(StringRef className) const {
return def->isSubClassOf(className);
}
Attribute::Attribute(const Record *record) : AttrConstraint(record) {
assert(record->isSubClassOf("Attr") &&
"must be subclass of TableGen 'Attr' class");
}
Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
bool Attribute::isSymbolRefAttr() const {
StringRef defName = def->getName();
if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
return true;
return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
}
bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
StringRef Attribute::getStorageType() const {
const auto *init = def->getValueInit("storageType");
auto type = getValueAsString(init);
if (type.empty())
return "Attribute";
return type;
}
StringRef Attribute::getReturnType() const {
const auto *init = def->getValueInit("returnType");
return getValueAsString(init);
}
// Return the type constraint corresponding to the type of this attribute, or
// None if this is not a TypedAttr.
llvm::Optional<Type> Attribute::getValueType() const {
if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
return Type(defInit->getDef());
return llvm::None;
}
StringRef Attribute::getConvertFromStorageCall() const {
const auto *init = def->getValueInit("convertFromStorage");
return getValueAsString(init);
}
bool Attribute::isConstBuildable() const {
const auto *init = def->getValueInit("constBuilderCall");
return !getValueAsString(init).empty();
}
StringRef Attribute::getConstBuilderTemplate() const {
const auto *init = def->getValueInit("constBuilderCall");
return getValueAsString(init);
}
Attribute Attribute::getBaseAttr() const {
if (const auto *defInit =
llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
return Attribute(defInit).getBaseAttr();
}
return *this;
}
bool Attribute::hasDefaultValue() const {
const auto *init = def->getValueInit("defaultValue");
return !getValueAsString(init).empty();
}
StringRef Attribute::getDefaultValue() const {
const auto *init = def->getValueInit("defaultValue");
return getValueAsString(init);
}
bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); }
StringRef Attribute::getAttrDefName() const {
if (def->isAnonymous()) {
return getBaseAttr().def->getName();
}
return def->getName();
}
StringRef Attribute::getDerivedCodeBody() const {
assert(isDerivedAttr() && "only derived attribute has 'body' field");
return def->getValueAsString("body");
}
Dialect Attribute::getDialect() const {
const llvm::RecordVal *record = def->getValue("dialect");
if (record && record->getValue()) {
if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
return Dialect(init->getDef());
}
return Dialect(nullptr);
}
ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
"must be subclass of TableGen 'ConstantAttr' class");
}
Attribute ConstantAttr::getAttribute() const {
return Attribute(def->getValueAsDef("attr"));
}
StringRef ConstantAttr::getConstantValue() const {
return def->getValueAsString("value");
}
EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrCaseInfo") &&
"must be subclass of TableGen 'EnumAttrInfo' class");
}
EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
: EnumAttrCase(init->getDef()) {}
bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
StringRef EnumAttrCase::getSymbol() const {
return def->getValueAsString("symbol");
}
StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
const llvm::Record &EnumAttrCase::getDef() const { return *def; }
EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrInfo") &&
"must be subclass of TableGen 'EnumAttr' class");
}
EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
bool EnumAttr::classof(const Attribute *attr) {
return attr->isSubClassOf("EnumAttrInfo");
}
bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
StringRef EnumAttr::getEnumClassName() const {
return def->getValueAsString("className");
}
StringRef EnumAttr::getCppNamespace() const {
return def->getValueAsString("cppNamespace");
}
StringRef EnumAttr::getUnderlyingType() const {
return def->getValueAsString("underlyingType");
}
StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
return def->getValueAsString("underlyingToSymbolFnName");
}
StringRef EnumAttr::getStringToSymbolFnName() const {
return def->getValueAsString("stringToSymbolFnName");
}
StringRef EnumAttr::getSymbolToStringFnName() const {
return def->getValueAsString("symbolToStringFnName");
}
StringRef EnumAttr::getSymbolToStringFnRetType() const {
return def->getValueAsString("symbolToStringFnRetType");
}
StringRef EnumAttr::getMaxEnumValFnName() const {
return def->getValueAsString("maxEnumValFnName");
}
std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
const auto *inits = def->getValueAsListInit("enumerants");
std::vector<EnumAttrCase> cases;
cases.reserve(inits->size());
for (const llvm::Init *init : *inits) {
cases.push_back(EnumAttrCase(cast<llvm::DefInit>(init)));
}
return cases;
}
bool EnumAttr::genSpecializedAttr() const {
return def->getValueAsBit("genSpecializedAttr");
}
llvm::Record *EnumAttr::getBaseAttrClass() const {
return def->getValueAsDef("baseAttrClass");
}
StringRef EnumAttr::getSpecializedAttrClassName() const {
return def->getValueAsString("specializedAttrClassName");
}
StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
assert(def->isSubClassOf("StructFieldAttr") &&
"must be subclass of TableGen 'StructFieldAttr' class");
}
StructFieldAttr::StructFieldAttr(const llvm::Record &record)
: StructFieldAttr(&record) {}
StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
: StructFieldAttr(init->getDef()) {}
StringRef StructFieldAttr::getName() const {
return def->getValueAsString("name");
}
Attribute StructFieldAttr::getType() const {
auto init = def->getValueInit("type");
return Attribute(cast<llvm::DefInit>(init));
}
StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
assert(isSubClassOf("StructAttr") &&
"must be subclass of TableGen 'StructAttr' class");
}
StructAttr::StructAttr(const llvm::DefInit *init)
: StructAttr(init->getDef()) {}
StringRef StructAttr::getStructClassName() const {
return def->getValueAsString("className");
}
StringRef StructAttr::getCppNamespace() const {
Dialect dialect(def->getValueAsDef("dialect"));
return dialect.getCppNamespace();
}
std::vector<StructFieldAttr> StructAttr::getAllFields() const {
std::vector<StructFieldAttr> attributes;
const auto *inits = def->getValueAsListInit("fields");
attributes.reserve(inits->size());
for (const llvm::Init *init : *inits) {
attributes.emplace_back(cast<llvm::DefInit>(init));
}
return attributes;
}
const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";

View File

@ -0,0 +1,247 @@
//===- Attribute.h - Attribute wrapper class --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Attribute wrapper to simplify using TableGen Record defining a MLIR
// Attribute.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_ATTRIBUTE_H_
#define MLIR_TABLEGEN_ATTRIBUTE_H_
#include "mlir/Support/LLVM.h"
#include "Constraint.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
class DefInit;
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
class Dialect;
class Type;
// Wrapper class with helper methods for accessing attribute constraints defined
// in TableGen.
class AttrConstraint : public Constraint {
public:
explicit AttrConstraint(const llvm::Record *record);
static bool classof(const Constraint *c) { return c->getKind() == CK_Attr; }
// Returns true if this constraint is a subclass of the given `className`
// class defined in TableGen.
bool isSubClassOf(StringRef className) const;
};
// Wrapper class providing helper methods for accessing MLIR Attribute defined
// in TableGen. This class should closely reflect what is defined as class
// `Attr` in TableGen.
class Attribute : public AttrConstraint {
public:
explicit Attribute(const llvm::Record *record);
explicit Attribute(const llvm::DefInit *init);
// Returns the storage type if set. Returns the default storage type
// ("Attribute") otherwise.
StringRef getStorageType() const;
// Returns the return type for this attribute.
StringRef getReturnType() const;
// Return the type constraint corresponding to the type of this attribute, or
// None if this is not a TypedAttr.
llvm::Optional<Type> getValueType() const;
// Returns the template getter method call which reads this attribute's
// storage and returns the value as of the desired return type.
// The call will contain a `{0}` which will be expanded to this attribute.
StringRef getConvertFromStorageCall() const;
// Returns true if this attribute can be built from a constant value.
bool isConstBuildable() const;
// Returns the template that can be used to produce an instance of the
// attribute.
// Syntax: `$builder` should be replaced with a builder, `$0` should be
// replaced with the constant value.
StringRef getConstBuilderTemplate() const;
// Returns the base-level attribute that this attribute constraint is
// built upon.
Attribute getBaseAttr() const;
// Returns whether this attribute has a default value.
bool hasDefaultValue() const;
// Returns the default value for this attribute.
StringRef getDefaultValue() const;
// Returns whether this attribute is optional.
bool isOptional() const;
// Returns true if this attribute is a derived attribute (i.e., a subclass
// of `DerivedAttr`).
bool isDerivedAttr() const;
// Returns true if this attribute is a type attribute (i.e., a subclass
// of `TypeAttrBase`).
bool isTypeAttr() const;
// Returns true if this attribute is a symbol reference attribute (i.e., a
// subclass of `SymbolRefAttr` or `FlatSymbolRefAttr`).
bool isSymbolRefAttr() const;
// Returns true if this attribute is an enum attribute (i.e., a subclass of
// `EnumAttrInfo`)
bool isEnumAttr() const;
// Returns this attribute's TableGen def name. If this is an `OptionalAttr`
// or `DefaultValuedAttr` without explicit name, returns the base attribute's
// name.
StringRef getAttrDefName() const;
// Returns the code body for derived attribute. Aborts if this is not a
// derived attribute.
StringRef getDerivedCodeBody() const;
// Returns the dialect for the attribute if defined.
Dialect getDialect() const;
};
// Wrapper class providing helper methods for accessing MLIR constant attribute
// defined in TableGen. This class should closely reflect what is defined as
// class `ConstantAttr` in TableGen.
class ConstantAttr {
public:
explicit ConstantAttr(const llvm::DefInit *init);
// Returns the attribute kind.
Attribute getAttribute() const;
// Returns the constant value.
StringRef getConstantValue() const;
private:
// The TableGen definition of this constant attribute.
const llvm::Record *def;
};
// Wrapper class providing helper methods for accessing enum attribute cases
// defined in TableGen. This is used for enum attribute case backed by both
// StringAttr and IntegerAttr.
class EnumAttrCase : public Attribute {
public:
explicit EnumAttrCase(const llvm::Record *record);
explicit EnumAttrCase(const llvm::DefInit *init);
// Returns true if this EnumAttrCase is backed by a StringAttr.
bool isStrCase() const;
// Returns the symbol of this enum attribute case.
StringRef getSymbol() const;
// Returns the textual representation of this enum attribute case.
StringRef getStr() const;
// Returns the value of this enum attribute case.
int64_t getValue() const;
// Returns the TableGen definition this EnumAttrCase was constructed from.
const llvm::Record &getDef() const;
};
// Wrapper class providing helper methods for accessing enum attributes defined
// in TableGen.This is used for enum attribute case backed by both StringAttr
// and IntegerAttr.
class EnumAttr : public Attribute {
public:
explicit EnumAttr(const llvm::Record *record);
explicit EnumAttr(const llvm::Record &record);
explicit EnumAttr(const llvm::DefInit *init);
static bool classof(const Attribute *attr);
// Returns true if this is a bit enum attribute.
bool isBitEnum() const;
// Returns the enum class name.
StringRef getEnumClassName() const;
// Returns the C++ namespaces this enum class should be placed in.
StringRef getCppNamespace() const;
// Returns the underlying type.
StringRef getUnderlyingType() const;
// Returns the name of the utility function that converts a value of the
// underlying type to the corresponding symbol.
StringRef getUnderlyingToSymbolFnName() const;
// Returns the name of the utility function that converts a string to the
// corresponding symbol.
StringRef getStringToSymbolFnName() const;
// Returns the name of the utility function that converts a symbol to the
// corresponding string.
StringRef getSymbolToStringFnName() const;
// Returns the return type of the utility function that converts a symbol to
// the corresponding string.
StringRef getSymbolToStringFnRetType() const;
// Returns the name of the utilit function that returns the max enum value
// used within the enum class.
StringRef getMaxEnumValFnName() const;
// Returns all allowed cases for this enum attribute.
std::vector<EnumAttrCase> getAllCases() const;
bool genSpecializedAttr() const;
llvm::Record *getBaseAttrClass() const;
StringRef getSpecializedAttrClassName() const;
};
class StructFieldAttr {
public:
explicit StructFieldAttr(const llvm::Record *record);
explicit StructFieldAttr(const llvm::Record &record);
explicit StructFieldAttr(const llvm::DefInit *init);
StringRef getName() const;
Attribute getType() const;
private:
const llvm::Record *def;
};
// Wrapper class providing helper methods for accessing struct attributes
// defined in TableGen.
class StructAttr : public Attribute {
public:
explicit StructAttr(const llvm::Record *record);
explicit StructAttr(const llvm::Record &record) : StructAttr(&record){};
explicit StructAttr(const llvm::DefInit *init);
// Returns the struct class name.
StringRef getStructClassName() const;
// Returns the C++ namespaces this struct class should be placed in.
StringRef getCppNamespace() const;
std::vector<StructFieldAttr> getAllFields() const;
};
// Name of infer type op interface.
extern const char *inferTypeOpInterface;
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_ATTRIBUTE_H_

View File

@ -0,0 +1,74 @@
//===- Builder.cpp - Builder definitions ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Builder.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// Builder::Parameter
//===----------------------------------------------------------------------===//
/// Return a string containing the C++ type of this parameter.
StringRef Builder::Parameter::getCppType() const {
if (const auto *stringInit = dyn_cast<llvm::StringInit>(def))
return stringInit->getValue();
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
return record->getValueAsString("type");
}
/// Return an optional string containing the default value to use for this
/// parameter.
Optional<StringRef> Builder::Parameter::getDefaultValue() const {
if (isa<llvm::StringInit>(def))
return llvm::None;
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
Optional<StringRef> value = record->getValueAsOptionalString("defaultValue");
return value && !value->empty() ? value : llvm::None;
}
//===----------------------------------------------------------------------===//
// Builder
//===----------------------------------------------------------------------===//
Builder::Builder(const llvm::Record *record, ArrayRef<llvm::SMLoc> loc)
: def(record) {
// Initialize the parameters of the builder.
const llvm::DagInit *dag = def->getValueAsDag("dagParams");
auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
if (!defInit || !defInit->getDef()->getName().equals("ins"))
PrintFatalError(def->getLoc(), "expected 'ins' in builders");
bool seenDefaultValue = false;
for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) {
const llvm::StringInit *paramName = dag->getArgName(i);
const llvm::Init *paramValue = dag->getArg(i);
Parameter param(paramName ? paramName->getValue() : Optional<StringRef>(),
paramValue);
// Similarly to C++, once an argument with a default value is detected, the
// following arguments must have default values as well.
if (param.getDefaultValue()) {
seenDefaultValue = true;
} else if (seenDefaultValue) {
PrintFatalError(loc,
"expected an argument with default value after other "
"arguments with default values");
}
parameters.emplace_back(param);
}
}
/// Return an optional string containing the body of the builder.
Optional<StringRef> Builder::getBody() const {
Optional<StringRef> body = def->getValueAsOptionalString("body");
return body && !body->empty() ? body : llvm::None;
}

View File

@ -0,0 +1,85 @@
//===- Builder.h - Builder classes ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Builder wrapper to simplify using TableGen Record for building
// operations/types/etc.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_BUILDER_H_
#define MLIR_TABLEGEN_BUILDER_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
class Init;
class Record;
class SMLoc;
} // end namespace llvm
namespace mlir {
namespace tblgen {
/// Wrapper class with helper methods for accessing Builders defined in
/// TableGen.
class Builder {
public:
/// This class represents a single parameter to a builder method.
class Parameter {
public:
/// Return a string containing the C++ type of this parameter.
StringRef getCppType() const;
/// Return an optional string containing the name of this parameter. If
/// None, no name was specified for this parameter by the user.
Optional<StringRef> getName() const { return name; }
/// Return an optional string containing the default value to use for this
/// parameter.
Optional<StringRef> getDefaultValue() const;
private:
Parameter(Optional<StringRef> name, const llvm::Init *def)
: name(name), def(def) {}
/// The optional name of the parameter.
Optional<StringRef> name;
/// The tablegen definition of the parameter. This is either a StringInit,
/// or a CArg DefInit.
const llvm::Init *def;
// Allow access to the constructor.
friend Builder;
};
/// Construct a builder from the given Record instance.
Builder(const llvm::Record *record, ArrayRef<llvm::SMLoc> loc);
/// Return a list of parameters used in this build method.
ArrayRef<Parameter> getParameters() const { return parameters; }
/// Return an optional string containing the body of the builder.
Optional<StringRef> getBody() const;
protected:
/// The TableGen definition of this builder.
const llvm::Record *def;
private:
/// A collection of parameters to the builder.
SmallVector<Parameter> parameters;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_BUILDER_H_

View File

@ -0,0 +1,68 @@
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines common utilities for generating C++ from tablegen
// structures.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_CODEGENHELPERS_H
#define MLIR_TABLEGEN_CODEGENHELPERS_H
#include "Dialect.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
namespace tblgen {
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
public:
IfDefScope(llvm::StringRef name, llvm::raw_ostream &os)
: name(name.str()), os(os) {
os << "#ifdef " << name << "\n"
<< "#undef " << name << "\n\n";
}
~IfDefScope() { os << "\n#endif // " << name << "\n\n"; }
private:
std::string name;
llvm::raw_ostream &os;
};
// A helper RAII class to emit nested namespaces for this op.
class NamespaceEmitter {
public:
NamespaceEmitter(raw_ostream &os, const Dialect &dialect) : os(os) {
if (!dialect)
return;
emitNamespaceStarts(os, dialect.getCppNamespace());
}
NamespaceEmitter(raw_ostream &os, StringRef cppNamespace) : os(os) {
emitNamespaceStarts(os, cppNamespace);
}
~NamespaceEmitter() {
for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
}
private:
void emitNamespaceStarts(raw_ostream &os, StringRef cppNamespace) {
llvm::SplitString(cppNamespace, namespaces, "::");
for (StringRef ns : namespaces)
os << "namespace " << ns << " {\n";
}
raw_ostream &os;
SmallVector<StringRef, 2> namespaces;
};
} // namespace tblgen
} // namespace mlir
#endif // MLIR_TABLEGEN_CODEGENHELPERS_H

View File

@ -0,0 +1,70 @@
//===- Constraint.cpp - Constraint class ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Constraint wrapper to simplify using TableGen Record for constraints.
//
//===----------------------------------------------------------------------===//
#include "Constraint.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
Constraint::Constraint(const llvm::Record *record)
: def(record), kind(CK_Uncategorized) {
// Look through OpVariable's to their constraint.
if (def->isSubClassOf("OpVariable"))
def = def->getValueAsDef("constraint");
if (def->isSubClassOf("TypeConstraint")) {
kind = CK_Type;
} else if (def->isSubClassOf("AttrConstraint")) {
kind = CK_Attr;
} else if (def->isSubClassOf("RegionConstraint")) {
kind = CK_Region;
} else if (def->isSubClassOf("SuccessorConstraint")) {
kind = CK_Successor;
} else {
assert(def->isSubClassOf("Constraint"));
}
}
Constraint::Constraint(Kind kind, const llvm::Record *record)
: def(record), kind(kind) {
// Look through OpVariable's to their constraint.
if (def->isSubClassOf("OpVariable"))
def = def->getValueAsDef("constraint");
}
Pred Constraint::getPredicate() const {
auto *val = def->getValue("predicate");
// If no predicate is specified, then return the null predicate (which
// corresponds to true).
if (!val)
return Pred();
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return Pred(pred);
}
std::string Constraint::getConditionTemplate() const {
return getPredicate().getCondition();
}
StringRef Constraint::getSummary() const {
if (Optional<StringRef> summary = def->getValueAsOptionalString("summary"))
return *summary;
return def->getName();
}
AppliedConstraint::AppliedConstraint(Constraint &&constraint,
llvm::StringRef self,
std::vector<std::string> &&entities)
: constraint(constraint), self(std::string(self)),
entities(std::move(entities)) {}

View File

@ -0,0 +1,88 @@
//===- Constraint.h - Constraint class --------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Constraint wrapper to simplify using TableGen Record for constraints.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_CONSTRAINT_H_
#define MLIR_TABLEGEN_CONSTRAINT_H_
#include "mlir/Support/LLVM.h"
#include "Predicate.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Wrapper class with helper methods for accessing Constraint defined in
// TableGen.
class Constraint {
public:
Constraint(const llvm::Record *record);
bool operator==(const Constraint &that) { return def == that.def; }
bool operator!=(const Constraint &that) { return def != that.def; }
// Returns the predicate for this constraint.
Pred getPredicate() const;
// Returns the condition template that can be used to check if a type or
// attribute satisfies this constraint. The template may contain "{0}" that
// must be substituted with an expression returning an mlir::Type or
// mlir::Attribute.
std::string getConditionTemplate() const;
// Returns the user-readable description of this constraint. If the
// description is not provided, returns the TableGen def name.
StringRef getSummary() const;
// Constraint kind
enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized };
Kind getKind() const { return kind; }
/// Get an opaque pointer to the constraint.
const void *getAsOpaquePointer() const { return def; }
/// Construct a constraint from the opaque pointer representation.
static Constraint getFromOpaquePointer(const void *ptr) {
return Constraint(reinterpret_cast<const llvm::Record *>(ptr));
}
protected:
Constraint(Kind kind, const llvm::Record *record);
// The TableGen definition of this constraint.
const llvm::Record *def;
private:
// What kind of constraint this is.
Kind kind;
};
// An constraint and the concrete entities to place the constraint on.
struct AppliedConstraint {
AppliedConstraint(Constraint &&constraint, StringRef self,
std::vector<std::string> &&entities);
Constraint constraint;
// The symbol to replace `$_self` special placeholder in the constraint.
std::string self;
// The symbols to replace `$N` positional placeholders in the constraint.
std::vector<std::string> entities;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_CONSTRAINT_H_

View File

@ -0,0 +1,94 @@
//===- Dialect.cpp - Dialect wrapper class --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect.
//
//===----------------------------------------------------------------------===//
#include "Dialect.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
Dialect::Dialect(const llvm::Record *def) : def(def) {
if (def == nullptr)
return;
for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
dependentDialects.push_back(dialect);
}
StringRef Dialect::getName() const { return def->getValueAsString("name"); }
StringRef Dialect::getCppNamespace() const {
return def->getValueAsString("cppNamespace");
}
std::string Dialect::getCppClassName() const {
// Simply use the name and remove any '_' tokens.
std::string cppName = def->getName().str();
llvm::erase_if(cppName, [](char c) { return c == '_'; });
return cppName;
}
static StringRef getAsStringOrEmpty(const llvm::Record &record,
StringRef fieldName) {
if (auto valueInit = record.getValueInit(fieldName)) {
if (llvm::isa<llvm::StringInit>(valueInit))
return record.getValueAsString(fieldName);
}
return "";
}
StringRef Dialect::getSummary() const {
return getAsStringOrEmpty(*def, "summary");
}
StringRef Dialect::getDescription() const {
return getAsStringOrEmpty(*def, "description");
}
ArrayRef<StringRef> Dialect::getDependentDialects() const {
return dependentDialects;
}
llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
bool Dialect::hasCanonicalizer() const {
return def->getValueAsBit("hasCanonicalizer");
}
bool Dialect::hasConstantMaterializer() const {
return def->getValueAsBit("hasConstantMaterializer");
}
bool Dialect::hasOperationAttrVerify() const {
return def->getValueAsBit("hasOperationAttrVerify");
}
bool Dialect::hasRegionArgAttrVerify() const {
return def->getValueAsBit("hasRegionArgAttrVerify");
}
bool Dialect::hasRegionResultAttrVerify() const {
return def->getValueAsBit("hasRegionResultAttrVerify");
}
bool Dialect::hasOperationInterfaceFallback() const {
return def->getValueAsBit("hasOperationInterfaceFallback");
}
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
bool Dialect::operator<(const Dialect &other) const {
return getName() < other.getName();
}

View File

@ -0,0 +1,91 @@
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_DIALECT_H_
#define MLIR_TABLEGEN_DIALECT_H_
#include "mlir/Support/LLVM.h"
#include <string>
#include <vector>
namespace llvm {
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Wrapper class that contains a MLIR dialect's information defined in TableGen
// and provides helper methods for accessing them.
class Dialect {
public:
explicit Dialect(const llvm::Record *def);
// Returns the name of this dialect.
StringRef getName() const;
// Returns the C++ namespaces that ops of this dialect should be placed into.
StringRef getCppNamespace() const;
// Returns this dialect's C++ class name.
std::string getCppClassName() const;
// Returns the summary description of the dialect. Returns empty string if
// none.
StringRef getSummary() const;
// Returns the description of the dialect. Returns empty string if none.
StringRef getDescription() const;
// Returns the list of dialect (class names) that this dialect depends on.
// These are dialects that will be loaded on construction of this dialect.
ArrayRef<StringRef> getDependentDialects() const;
// Returns the dialects extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;
/// Returns true if this dialect has a canonicalizer.
bool hasCanonicalizer() const;
// Returns true if this dialect has a constant materializer.
bool hasConstantMaterializer() const;
/// Returns true if this dialect has an operation attribute verifier.
bool hasOperationAttrVerify() const;
/// Returns true if this dialect has a region argument attribute verifier.
bool hasRegionArgAttrVerify() const;
/// Returns true if this dialect has a region result attribute verifier.
bool hasRegionResultAttrVerify() const;
/// Returns true if this dialect has fallback interfaces for its operations.
bool hasOperationInterfaceFallback() const;
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;
bool operator!=(const Dialect &other) const { return !(*this == other); }
// Compares two dialects by comparing the names of the dialects.
bool operator<(const Dialect &other) const;
// Returns whether the dialect is defined.
explicit operator bool() const { return def != nullptr; }
private:
const llvm::Record *def;
std::vector<StringRef> dependentDialects;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_DIALECT_H_

View File

@ -0,0 +1,194 @@
//===- Format.cpp - Utilities for String Format ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines utilities for formatting strings. They are specially
// tailored to the needs of TableGen'ing op definitions and rewrite rules,
// so they are not expected to be used as widely applicable utilities.
//
//===----------------------------------------------------------------------===//
#include "Format.h"
#include <cctype>
using namespace mlir;
using namespace mlir::tblgen;
// Marker to indicate an error happened when replacing a placeholder.
const char *const kMarkerForNoSubst = "<no-subst-found>";
FmtContext &FmtContext::addSubst(StringRef placeholder, Twine subst) {
customSubstMap[placeholder] = subst.str();
return *this;
}
FmtContext &FmtContext::withBuilder(Twine subst) {
builtinSubstMap[PHKind::Builder] = subst.str();
return *this;
}
FmtContext &FmtContext::withOp(Twine subst) {
builtinSubstMap[PHKind::Op] = subst.str();
return *this;
}
FmtContext &FmtContext::withSelf(Twine subst) {
builtinSubstMap[PHKind::Self] = subst.str();
return *this;
}
Optional<StringRef>
FmtContext::getSubstFor(FmtContext::PHKind placeholder) const {
if (placeholder == FmtContext::PHKind::None ||
placeholder == FmtContext::PHKind::Custom)
return {};
auto it = builtinSubstMap.find(placeholder);
if (it == builtinSubstMap.end())
return {};
return StringRef(it->second);
}
Optional<StringRef> FmtContext::getSubstFor(StringRef placeholder) const {
auto it = customSubstMap.find(placeholder);
if (it == customSubstMap.end())
return {};
return StringRef(it->second);
}
FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) {
return StringSwitch<FmtContext::PHKind>(str)
.Case("_builder", FmtContext::PHKind::Builder)
.Case("_op", FmtContext::PHKind::Op)
.Case("_self", FmtContext::PHKind::Self)
.Case("", FmtContext::PHKind::None)
.Default(FmtContext::PHKind::Custom);
}
std::pair<FmtReplacement, StringRef>
FmtObjectBase::splitFmtSegment(StringRef fmt) {
size_t begin = fmt.find_first_of('$');
if (begin == StringRef::npos) {
// No placeholders: the whole format string should be returned as a
// literal string.
return {FmtReplacement{fmt}, StringRef()};
}
if (begin != 0) {
// The first placeholder is not at the beginning: we can split the format
// string into a literal string and the rest.
return {FmtReplacement{fmt.substr(0, begin)}, fmt.substr(begin)};
}
// The first placeholder is at the beginning
if (fmt.size() == 1) {
// The whole format string just contains '$': treat as literal.
return {FmtReplacement{fmt}, StringRef()};
}
// Allow escaping dollar with '$$'
if (fmt[1] == '$') {
return {FmtReplacement{fmt.substr(0, 1)}, fmt.substr(2)};
}
// First try to see if it's a positional placeholder, and then handle special
// placeholders.
size_t end = fmt.find_if_not([](char c) { return std::isdigit(c); }, 1);
if (end != 1) {
// We have a positional placeholder. Parse the index.
size_t index = 0;
if (fmt.substr(1, end - 1).consumeInteger(0, index)) {
llvm_unreachable("invalid replacement sequence index");
}
if (end == StringRef::npos) {
// All the remaining characters are part of the positional placeholder.
return {FmtReplacement{fmt, index}, StringRef()};
}
return {FmtReplacement{fmt.substr(0, end), index}, fmt.substr(end)};
}
end = fmt.find_if_not([](char c) { return std::isalnum(c) || c == '_'; }, 1);
auto placeholder = FmtContext::getPlaceHolderKind(fmt.substr(1, end - 1));
if (end == StringRef::npos) {
// All the remaining characters are part of the special placeholder.
return {FmtReplacement{fmt, placeholder}, StringRef()};
}
return {FmtReplacement{fmt.substr(0, end), placeholder}, fmt.substr(end)};
}
std::vector<FmtReplacement> FmtObjectBase::parseFormatString(StringRef fmt) {
std::vector<FmtReplacement> replacements;
FmtReplacement repl;
while (!fmt.empty()) {
std::tie(repl, fmt) = splitFmtSegment(fmt);
if (repl.type != FmtReplacement::Type::Empty)
replacements.push_back(repl);
}
return replacements;
}
void FmtObjectBase::format(raw_ostream &s) const {
for (auto &repl : replacements) {
if (repl.type == FmtReplacement::Type::Empty)
continue;
if (repl.type == FmtReplacement::Type::Literal) {
s << repl.spec;
continue;
}
if (repl.type == FmtReplacement::Type::SpecialPH) {
if (repl.placeholder == FmtContext::PHKind::None) {
s << repl.spec;
} else if (!context) {
// We need the context to replace special placeholders.
s << repl.spec << kMarkerForNoSubst;
} else {
Optional<StringRef> subst;
if (repl.placeholder == FmtContext::PHKind::Custom) {
// Skip the leading '$' sign for the custom placeholder
subst = context->getSubstFor(repl.spec.substr(1));
} else {
subst = context->getSubstFor(repl.placeholder);
}
if (subst)
s << *subst;
else
s << repl.spec << kMarkerForNoSubst;
}
continue;
}
assert(repl.type == FmtReplacement::Type::PositionalPH);
if (repl.index >= adapters.size()) {
s << repl.spec << kMarkerForNoSubst;
continue;
}
adapters[repl.index]->format(s, /*Options=*/"");
}
}
FmtStrVecObject::FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
ArrayRef<std::string> params)
: FmtObjectBase(fmt, ctx, params.size()) {
parameters.reserve(params.size());
for (std::string p : params)
parameters.push_back(llvm::detail::build_format_adapter(std::move(p)));
adapters.reserve(parameters.size());
for (auto &p : parameters)
adapters.push_back(&p);
}
FmtStrVecObject::FmtStrVecObject(FmtStrVecObject &&that)
: FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) {
adapters.reserve(parameters.size());
for (auto &p : parameters)
adapters.push_back(&p);
}

View File

@ -0,0 +1,259 @@
//===- Format.h - Utilities for String Format -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares utilities for formatting strings. They are specially
// tailored to the needs of TableGen'ing op definitions and rewrite rules,
// so they are not expected to be used as widely applicable utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_FORMAT_H_
#define MLIR_TABLEGEN_FORMAT_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tblgen {
/// Format context containing substitutions for special placeholders.
///
/// This context divides special placeholders into two categories: builtin ones
/// and custom ones.
///
/// Builtin placeholders are baked into `FmtContext` and each one of them has a
/// dedicated setter. They can be used in all dialects. Their names follow the
/// convention of `$_<name>`. The rationale of the leading underscore is to
/// avoid confusion and name collision: op arguments/attributes/results are
/// named as $<name>, and we can potentially support referencing those entities
/// directly in the format template in the future.
//
/// Custom ones are registered by dialect-specific TableGen backends and use the
/// same unified setter.
class FmtContext {
public:
// Placeholder kinds
enum class PHKind : char {
None,
Custom, // For custom placeholders
Builder, // For the $_builder placeholder
Op, // For the $_op placeholder
Self, // For the $_self placeholder
};
FmtContext() = default;
// Setter for custom placeholders
FmtContext &addSubst(StringRef placeholder, Twine subst);
// Setters for builtin placeholders
FmtContext &withBuilder(Twine subst);
FmtContext &withOp(Twine subst);
FmtContext &withSelf(Twine subst);
Optional<StringRef> getSubstFor(PHKind placeholder) const;
Optional<StringRef> getSubstFor(StringRef placeholder) const;
static PHKind getPlaceHolderKind(StringRef str);
private:
struct PHKindInfo : DenseMapInfo<PHKind> {
using CharInfo = DenseMapInfo<char>;
static inline PHKind getEmptyKey() {
return static_cast<PHKind>(CharInfo::getEmptyKey());
}
static inline PHKind getTombstoneKey() {
return static_cast<PHKind>(CharInfo::getTombstoneKey());
}
static unsigned getHashValue(const PHKind &val) {
return CharInfo::getHashValue(static_cast<char>(val));
}
static bool isEqual(const PHKind &lhs, const PHKind &rhs) {
return lhs == rhs;
}
};
llvm::SmallDenseMap<PHKind, std::string, 4, PHKindInfo> builtinSubstMap;
llvm::StringMap<std::string> customSubstMap;
};
/// Struct representing a replacement segment for the formatted string. It can
/// be a segment of the formatting template (for `Literal`) or a replacement
/// parameter (for `PositionalPH` and `SpecialPH`).
struct FmtReplacement {
enum class Type { Empty, Literal, PositionalPH, SpecialPH };
FmtReplacement() = default;
explicit FmtReplacement(StringRef literal)
: type(Type::Literal), spec(literal) {}
FmtReplacement(StringRef spec, size_t index)
: type(Type::PositionalPH), spec(spec), index(index) {}
FmtReplacement(StringRef spec, FmtContext::PHKind placeholder)
: type(Type::SpecialPH), spec(spec), placeholder(placeholder) {}
Type type = Type::Empty;
StringRef spec;
size_t index = 0;
FmtContext::PHKind placeholder = FmtContext::PHKind::None;
};
class FmtObjectBase {
private:
static std::pair<FmtReplacement, StringRef> splitFmtSegment(StringRef fmt);
static std::vector<FmtReplacement> parseFormatString(StringRef fmt);
protected:
// The parameters are stored in a std::tuple, which does not provide runtime
// indexing capabilities. In order to enable runtime indexing, we use this
// structure to put the parameters into a std::vector. Since the parameters
// are not all the same type, we use some type-erasure by wrapping the
// parameters in a template class that derives from a non-template superclass.
// Essentially, we are converting a std::tuple<Derived<Ts...>> to a
// std::vector<Base*>.
struct CreateAdapters {
template <typename... Ts>
std::vector<llvm::detail::format_adapter *> operator()(Ts &... items) {
return std::vector<llvm::detail::format_adapter *>{&items...};
}
};
StringRef fmt;
const FmtContext *context;
std::vector<llvm::detail::format_adapter *> adapters;
std::vector<FmtReplacement> replacements;
public:
FmtObjectBase(StringRef fmt, const FmtContext *ctx, size_t numParams)
: fmt(fmt), context(ctx), replacements(parseFormatString(fmt)) {}
FmtObjectBase(const FmtObjectBase &that) = delete;
FmtObjectBase(FmtObjectBase &&that)
: fmt(std::move(that.fmt)), context(that.context),
adapters(), // adapters are initialized by FmtObject
replacements(std::move(that.replacements)) {}
void format(llvm::raw_ostream &s) const;
std::string str() const {
std::string result;
llvm::raw_string_ostream s(result);
format(s);
return s.str();
}
template <unsigned N> SmallString<N> sstr() const {
SmallString<N> result;
llvm::raw_svector_ostream s(result);
format(s);
return result;
}
template <unsigned N> operator SmallString<N>() const { return sstr<N>(); }
operator std::string() const { return str(); }
};
template <typename Tuple> class FmtObject : public FmtObjectBase {
// Storage for the parameter adapters. Since the base class erases the type
// of the parameters, we have to own the storage for the parameters here, and
// have the base class store type-erased pointers into this tuple.
Tuple parameters;
public:
FmtObject(StringRef fmt, const FmtContext *ctx, Tuple &&params)
: FmtObjectBase(fmt, ctx, std::tuple_size<Tuple>::value),
parameters(std::move(params)) {
adapters.reserve(std::tuple_size<Tuple>::value);
adapters = llvm::apply_tuple(CreateAdapters(), parameters);
}
FmtObject(FmtObject const &that) = delete;
FmtObject(FmtObject &&that)
: FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) {
adapters.reserve(that.adapters.size());
adapters = llvm::apply_tuple(CreateAdapters(), parameters);
}
};
class FmtStrVecObject : public FmtObjectBase {
public:
using StrFormatAdapter =
decltype(llvm::detail::build_format_adapter(std::declval<std::string>()));
FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
ArrayRef<std::string> params);
FmtStrVecObject(FmtStrVecObject const &that) = delete;
FmtStrVecObject(FmtStrVecObject &&that);
private:
SmallVector<StrFormatAdapter, 16> parameters;
};
/// Formats text by substituting placeholders in format string with replacement
/// parameters.
///
/// There are two categories of placeholders accepted, both led by a '$' sign:
///
/// 1. Positional placeholder: $[0-9]+
/// 2. Special placeholder: $[a-zA-Z_][a-zA-Z0-9_]*
///
/// Replacement parameters for positional placeholders are supplied as the
/// `vals` parameter pack with 1:1 mapping. That is, $0 will be replaced by the
/// first parameter in `vals`, $1 by the second one, and so on. Note that you
/// can use the positional placeholders in any order and repeat any times, for
/// example, "$2 $1 $1 $0" is accepted.
///
/// Replacement parameters for special placeholders are supplied using the `ctx`
/// format context.
///
/// The `fmt` is recorded as a `StringRef` inside the returned `FmtObject`.
/// The caller needs to make sure the underlying data is available when the
/// `FmtObject` is used.
///
/// `ctx` accepts a nullptr if there is no special placeholder is used.
///
/// If no substitution is provided for a placeholder or any error happens during
/// format string parsing or replacement, the placeholder will be outputted
/// as-is with an additional marker '<no-subst-found>', to aid debugging.
///
/// To print a '$' literally, escape it with '$$'.
///
/// This utility function is inspired by LLVM formatv(), with modifications
/// specially tailored for TableGen C++ generation usage:
///
/// 1. This utility use '$' instead of '{' and '}' for denoting the placeholder
/// because '{' and '}' are frequently used in C++ code.
/// 2. This utility does not support format layout because it is rarely needed
/// in C++ code generation.
template <typename... Ts>
inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
-> FmtObject<decltype(std::make_tuple(
llvm::detail::build_format_adapter(std::forward<Ts>(vals))...))> {
using ParamTuple = decltype(std::make_tuple(
llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
return FmtObject<ParamTuple>(
fmt, ctx,
std::make_tuple(
llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
}
inline FmtStrVecObject tgfmt(StringRef fmt, const FmtContext *ctx,
ArrayRef<std::string> params) {
return FmtStrVecObject(fmt, ctx, params);
}
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_FORMAT_H_

View File

@ -0,0 +1,72 @@
//===- GenInfo.h - Generator info -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_GENINFO_H_
#define MLIR_TABLEGEN_GENINFO_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include <functional>
namespace llvm {
class RecordKeeper;
} // end namespace llvm
namespace mlir {
/// Generator function to invoke.
using GenFunction = std::function<bool(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os)>;
/// Structure to group information about a generator (argument to invoke via
/// mlir-tblgen, description, and generator function).
class GenInfo {
public:
/// GenInfo constructor should not be invoked directly, instead use
/// GenRegistration or registerGen.
GenInfo(StringRef arg, StringRef description, GenFunction generator)
: arg(arg), description(description), generator(generator) {}
/// Invokes the generator and returns whether the generator failed.
bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
assert(generator && "Cannot call generator with null generator");
return generator(recordKeeper, os);
}
/// Returns the command line option that may be passed to 'mlir-tblgen' to
/// invoke this generator.
StringRef getGenArgument() const { return arg; }
/// Returns a description for the generator.
StringRef getGenDescription() const { return description; }
private:
// The argument with which to invoke the generator via mlir-tblgen.
StringRef arg;
// Description of the generator.
StringRef description;
// Generator function.
GenFunction generator;
};
/// GenRegistration provides a global initializer that registers a generator
/// function.
///
/// Usage:
///
/// // At namespace scope.
/// static GenRegistration Print("print", "Print records", [](...){...});
struct GenRegistration {
GenRegistration(StringRef arg, StringRef description, GenFunction function);
};
} // end namespace mlir
#endif // MLIR_TABLEGEN_GENINFO_H_

View File

@ -0,0 +1,31 @@
//===- GenNameParser.h - Command line parser for generators -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The GenNameParser class adds all passes linked in to the system that are
// creatable to the tool.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_GENNAMEPARSER_H_
#define MLIR_TABLEGEN_GENNAMEPARSER_H_
#include "llvm/Support/CommandLine.h"
namespace mlir {
class GenInfo;
/// Adds command line option for each registered generator.
struct GenNameParser : public llvm::cl::parser<const GenInfo *> {
GenNameParser(llvm::cl::Option &opt);
void printOptionInfo(const llvm::cl::Option &O,
size_t GlobalWidth) const override;
};
} // end namespace mlir
#endif // MLIR_TABLEGEN_GENNAMEPARSER_H_

View File

@ -0,0 +1,144 @@
//===- Interfaces.cpp - Interface classes ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Interfaces.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// InterfaceMethod
//===----------------------------------------------------------------------===//
InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
llvm::DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back(
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
args->getArgNameStr(i)});
}
}
StringRef InterfaceMethod::getReturnType() const {
return def->getValueAsString("returnType");
}
// Return the name of this method.
StringRef InterfaceMethod::getName() const {
return def->getValueAsString("name");
}
// Return if this method is static.
bool InterfaceMethod::isStatic() const {
return def->isSubClassOf("StaticInterfaceMethod");
}
// Return the body for this method if it has one.
llvm::Optional<StringRef> InterfaceMethod::getBody() const {
auto value = def->getValueAsString("body");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the default implementation for this method if it has one.
llvm::Optional<StringRef> InterfaceMethod::getDefaultImplementation() const {
auto value = def->getValueAsString("defaultBody");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the description of this method if it has one.
llvm::Optional<StringRef> InterfaceMethod::getDescription() const {
auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
ArrayRef<InterfaceMethod::Argument> InterfaceMethod::getArguments() const {
return arguments;
}
bool InterfaceMethod::arg_empty() const { return arguments.empty(); }
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
Interface::Interface(const llvm::Record *def) : def(def) {
assert(def->isSubClassOf("Interface") &&
"must be subclass of TableGen 'Interface' class");
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
}
// Return the name of this interface.
StringRef Interface::getName() const {
return def->getValueAsString("cppClassName");
}
// Return the C++ namespace of this interface.
StringRef Interface::getCppNamespace() const {
return def->getValueAsString("cppNamespace");
}
// Return the methods of this interface.
ArrayRef<InterfaceMethod> Interface::getMethods() const { return methods; }
// Return the description of this method if it has one.
llvm::Optional<StringRef> Interface::getDescription() const {
auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the interfaces extra class declaration code.
llvm::Optional<StringRef> Interface::getExtraClassDeclaration() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the traits extra class declaration code.
llvm::Optional<StringRef> Interface::getExtraTraitClassDeclaration() const {
auto value = def->getValueAsString("extraTraitClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Return the body for this method if it has one.
llvm::Optional<StringRef> Interface::getVerify() const {
// Only OpInterface supports the verify method.
if (!isa<OpInterface>(this))
return llvm::None;
auto value = def->getValueAsString("verify");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
//===----------------------------------------------------------------------===//
// AttrInterface
//===----------------------------------------------------------------------===//
bool AttrInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("AttrInterface");
}
//===----------------------------------------------------------------------===//
// OpInterface
//===----------------------------------------------------------------------===//
bool OpInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("OpInterface");
}
//===----------------------------------------------------------------------===//
// TypeInterface
//===----------------------------------------------------------------------===//
bool TypeInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("TypeInterface");
}

View File

@ -0,0 +1,129 @@
//===- Interfaces.h - Interface wrapper classes -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_INTERFACES_H_
#define MLIR_TABLEGEN_INTERFACES_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
class Init;
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Wrapper class with helper methods for accessing InterfaceMethod defined
// in TableGen.
class InterfaceMethod {
public:
// This struct represents a single method argument.
struct Argument {
StringRef type;
StringRef name;
};
explicit InterfaceMethod(const llvm::Record *def);
// Return the return type of this method.
StringRef getReturnType() const;
// Return the name of this method.
StringRef getName() const;
// Return if this method is static.
bool isStatic() const;
// Return the body for this method if it has one.
llvm::Optional<StringRef> getBody() const;
// Return the default implementation for this method if it has one.
llvm::Optional<StringRef> getDefaultImplementation() const;
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const;
// Arguments.
ArrayRef<Argument> getArguments() const;
bool arg_empty() const;
private:
// The TableGen definition of this method.
const llvm::Record *def;
// The arguments of this method.
SmallVector<Argument, 2> arguments;
};
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
// Wrapper class with helper methods for accessing Interfaces defined in
// TableGen.
class Interface {
public:
explicit Interface(const llvm::Record *def);
// Return the name of this interface.
StringRef getName() const;
// Return the C++ namespace of this interface.
StringRef getCppNamespace() const;
// Return the methods of this interface.
ArrayRef<InterfaceMethod> getMethods() const;
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const;
// Return the interfaces extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;
// Return the traits extra class declaration code.
llvm::Optional<StringRef> getExtraTraitClassDeclaration() const;
// Return the verify method body if it has one.
llvm::Optional<StringRef> getVerify() const;
// Returns the Tablegen definition this interface was constructed from.
const llvm::Record &getDef() const { return *def; }
private:
// The TableGen definition of this interface.
const llvm::Record *def;
// The methods of this interface.
SmallVector<InterfaceMethod, 8> methods;
};
// An interface that is registered to an Attribute.
struct AttrInterface : public Interface {
using Interface::Interface;
static bool classof(const Interface *interface);
};
// An interface that is registered to an Operation.
struct OpInterface : public Interface {
using Interface::Interface;
static bool classof(const Interface *interface);
};
// An interface that is registered to a Type.
struct TypeInterface : public Interface {
using Interface::Interface;
static bool classof(const Interface *interface);
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_INTERFACES_H_

View File

@ -0,0 +1,347 @@
//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "OpClass.h"
#include "Format.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <unordered_set>
#define DEBUG_TYPE "mlir-tblgen-opclass"
using namespace mlir;
using namespace mlir::tblgen;
namespace {
// Returns space to be emitted after the given C++ `type`. return "" if the
// ends with '&' or '*', or is empty, else returns " ".
StringRef getSpaceAfterType(StringRef type) {
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
}
} // namespace
//===----------------------------------------------------------------------===//
// OpMethodParameter definitions
//===----------------------------------------------------------------------===//
void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
if (properties & PP_Optional)
os << "/*optional*/";
os << type << getSpaceAfterType(type) << name;
if (emitDefault && !defaultValue.empty())
os << " = " << defaultValue;
}
//===----------------------------------------------------------------------===//
// OpMethodParameters definitions
//===----------------------------------------------------------------------===//
// Factory methods to construct the correct type of `OpMethodParameters`
// object based on the arguments.
std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
return std::make_unique<OpMethodResolvedParameters>();
}
std::unique_ptr<OpMethodParameters>
OpMethodParameters::create(StringRef params) {
return std::make_unique<OpMethodUnresolvedParameters>(params);
}
std::unique_ptr<OpMethodParameters>
OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &&params) {
return std::make_unique<OpMethodResolvedParameters>(std::move(params));
}
std::unique_ptr<OpMethodParameters>
OpMethodParameters::create(StringRef type, StringRef name,
StringRef defaultValue) {
return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
}
//===----------------------------------------------------------------------===//
// OpMethodUnresolvedParameters definitions
//===----------------------------------------------------------------------===//
void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
os << parameters;
}
void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
// We need to remove the default values for parameters in method definition.
// TODO: We are using '=' and ',' as delimiters for parameter
// initializers. This is incorrect for initializer list with more than one
// element. Change to a more robust approach.
llvm::SmallVector<StringRef, 4> tokens;
StringRef params = parameters;
while (!params.empty()) {
std::pair<StringRef, StringRef> parts = params.split("=");
tokens.push_back(parts.first);
params = parts.second.split(',').second;
}
llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
}
//===----------------------------------------------------------------------===//
// OpMethodResolvedParameters definitions
//===----------------------------------------------------------------------===//
// Returns true if a method with these parameters makes a method with parameters
// `other` redundant. This should return true only if all possible calls to the
// other method can be replaced by calls to this method.
bool OpMethodResolvedParameters::makesRedundant(
const OpMethodResolvedParameters &other) const {
const size_t otherNumParams = other.getNumParameters();
const size_t thisNumParams = getNumParameters();
// All calls to the other method can be replaced this method only if this
// method has the same or more arguments number of arguments as the other, and
// the common arguments have the same type.
if (thisNumParams < otherNumParams)
return false;
for (int idx : llvm::seq<int>(0, otherNumParams))
if (parameters[idx].getType() != other.parameters[idx].getType())
return false;
// If all the common arguments have the same type, we can elide the other
// method if this method has the same number of arguments as other or the
// first argument after the common ones has a default value (and by C++
// requirement, all the later ones will also have a default value).
return thisNumParams == otherNumParams ||
parameters[otherNumParams].hasDefaultValue();
}
void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
param.writeDeclTo(os);
});
}
void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
param.writeDefTo(os);
});
}
//===----------------------------------------------------------------------===//
// OpMethodSignature definitions
//===----------------------------------------------------------------------===//
// Returns if a method with this signature makes a method with `other` signature
// redundant. Only supports resolved parameters.
bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
if (methodName != other.methodName)
return false;
auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
auto *resolvedOther =
dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
if (resolvedThis && resolvedOther)
return resolvedThis->makesRedundant(*resolvedOther);
return false;
}
void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
parameters->writeDeclTo(os);
os << ")";
}
void OpMethodSignature::writeDefTo(raw_ostream &os,
StringRef namePrefix) const {
os << returnType << getSpaceAfterType(returnType) << namePrefix
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
parameters->writeDefTo(os);
os << ")";
}
//===----------------------------------------------------------------------===//
// OpMethodBody definitions
//===----------------------------------------------------------------------===//
OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
OpMethodBody &OpMethodBody::operator<<(Twine content) {
if (isEffective)
body.append(content.str());
return *this;
}
OpMethodBody &OpMethodBody::operator<<(int content) {
if (isEffective)
body.append(std::to_string(content));
return *this;
}
OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
if (isEffective)
body.append(content.str());
return *this;
}
void OpMethodBody::writeTo(raw_ostream &os) const {
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
os << bodyRef;
if (bodyRef.empty() || bodyRef.back() != '\n')
os << "\n";
}
//===----------------------------------------------------------------------===//
// OpMethod definitions
//===----------------------------------------------------------------------===//
void OpMethod::writeDeclTo(raw_ostream &os) const {
os.indent(2);
if (isStatic())
os << "static ";
if (properties & MP_Constexpr)
os << "constexpr ";
methodSignature.writeDeclTo(os);
if (!isInline())
os << ";";
else {
os << " {\n";
methodBody.writeTo(os);
os << "}";
}
}
void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// Do not write definition if the method is decl only.
if (properties & MP_Declaration)
return;
// Do not generate separate definition for inline method
if (isInline())
return;
methodSignature.writeDefTo(os, namePrefix);
os << " {\n";
methodBody.writeTo(os);
os << "}";
}
//===----------------------------------------------------------------------===//
// OpConstructor definitions
//===----------------------------------------------------------------------===//
void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
memberInitializers.append(std::string(llvm::formatv(
"{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
}
void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// Do not write definition if the method is decl only.
if (properties & MP_Declaration)
return;
methodSignature.writeDefTo(os, namePrefix);
os << " " << memberInitializers << " {\n";
methodBody.writeTo(os);
os << "}";
}
//===----------------------------------------------------------------------===//
// Class definitions
//===----------------------------------------------------------------------===//
Class::Class(StringRef name) : className(name) {}
void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
std::string varName = formatv("{0} {1}", type, name).str();
std::string field = defaultValue.empty()
? varName
: formatv("{0} = {1}", varName, defaultValue).str();
fields.push_back(std::move(field));
}
void Class::writeDeclTo(raw_ostream &os) const {
bool hasPrivateMethod = false;
os << "class " << className << " {\n";
os << "public:\n";
forAllMethods([&](const OpMethod &method) {
if (!method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
} else {
hasPrivateMethod = true;
}
});
os << '\n';
os << "private:\n";
if (hasPrivateMethod) {
forAllMethods([&](const OpMethod &method) {
if (method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
}
});
os << '\n';
}
for (const auto &field : fields)
os.indent(2) << field << ";\n";
os << "};\n";
}
void Class::writeDefTo(raw_ostream &os) const {
forAllMethods([&](const OpMethod &method) {
method.writeDefTo(os, className);
os << "\n\n";
});
}
//===----------------------------------------------------------------------===//
// OpClass definitions
//===----------------------------------------------------------------------===//
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
: Class(name), extraClassDeclaration(extraClassDeclaration) {}
void OpClass::addTrait(Twine trait) {
auto traitStr = trait.str();
if (traitsSet.insert(traitStr).second)
traitsVec.push_back(std::move(traitStr));
}
void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public ::mlir::Op<" << className;
for (const auto &trait : traitsVec)
os << ", " << trait;
os << "> {\npublic:\n";
// << " using Op::Op;\n"
// << " using Op::print;\n"
// << " using Adaptor = " << className << "Adaptor;\n";
bool hasPrivateMethod = false;
forAllMethods([&](const OpMethod &method) {
if (!method.isPrivate()) {
method.writeDeclTo(os);
os << "\n";
} else {
hasPrivateMethod = true;
}
});
// TODO: Add line control markers to make errors easier to debug.
if (!extraClassDeclaration.empty())
os << extraClassDeclaration << "\n";
if (hasPrivateMethod) {
os << "\nprivate:\n";
forAllMethods([&](const OpMethod &method) {
if (method.isPrivate()) {
method.writeDeclTo(os);
os << "\n";
}
});
}
os << "};\n";
}

View File

@ -0,0 +1,442 @@
//===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines several classes for Op C++ code emission. They are only
// expected to be used by MLIR TableGen backends.
//
// We emit the op declaration and definition into separate files: *Ops.h.inc
// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
//
// In order to do this split, we need to track method signature and
// implementation logic separately. Signature information is used for both
// declaration and definition, while implementation logic is only for
// definition. So we have the following classes for C++ code emission.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_OPCLASS_H_
#define MLIR_TABLEGEN_OPCLASS_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/raw_ostream.h"
#include <set>
#include <string>
namespace mlir {
namespace tblgen {
class FmtObjectBase;
// Class for holding a single parameter of an op's method for C++ code emission.
class OpMethodParameter {
public:
// Properties (qualifiers) for the parameter.
enum Property {
PP_None = 0x0,
PP_Optional = 0x1,
};
OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "",
Property properties = PP_None)
: type(type), name(name), defaultValue(defaultValue),
properties(properties) {}
OpMethodParameter(StringRef type, StringRef name, Property property)
: OpMethodParameter(type, name, "", property) {}
// Writes the parameter as a part of a method declaration to `os`.
void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
// Writes the parameter as a part of a method definition to `os`
void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
const std::string &getType() const { return type; }
bool hasDefaultValue() const { return !defaultValue.empty(); }
private:
void writeTo(raw_ostream &os, bool emitDefault) const;
std::string type;
std::string name;
std::string defaultValue;
Property properties;
};
// Base class for holding parameters of an op's method for C++ code emission.
class OpMethodParameters {
public:
// Discriminator for LLVM-style RTTI.
enum ParamsKind {
// Separate type and name for each parameter is not known.
PK_Unresolved,
// Each parameter is resolved to a type and name.
PK_Resolved,
};
OpMethodParameters(ParamsKind kind) : kind(kind) {}
virtual ~OpMethodParameters() {}
// LLVM-style RTTI support.
ParamsKind getKind() const { return kind; }
// Writes the parameters as a part of a method declaration to `os`.
virtual void writeDeclTo(raw_ostream &os) const = 0;
// Writes the parameters as a part of a method definition to `os`
virtual void writeDefTo(raw_ostream &os) const = 0;
// Factory methods to create the correct type of `OpMethodParameters`
// object based on the arguments.
static std::unique_ptr<OpMethodParameters> create();
static std::unique_ptr<OpMethodParameters> create(StringRef params);
static std::unique_ptr<OpMethodParameters>
create(llvm::SmallVectorImpl<OpMethodParameter> &&params);
static std::unique_ptr<OpMethodParameters>
create(StringRef type, StringRef name, StringRef defaultValue = "");
private:
const ParamsKind kind;
};
// Class for holding unresolved parameters.
class OpMethodUnresolvedParameters : public OpMethodParameters {
public:
OpMethodUnresolvedParameters(StringRef params)
: OpMethodParameters(PK_Unresolved), parameters(params) {}
// write the parameters as a part of a method declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const override;
// write the parameters as a part of a method definition to the given `os`
void writeDefTo(raw_ostream &os) const override;
// LLVM-style RTTI support.
static bool classof(const OpMethodParameters *params) {
return params->getKind() == PK_Unresolved;
}
private:
std::string parameters;
};
// Class for holding resolved parameters.
class OpMethodResolvedParameters : public OpMethodParameters {
public:
OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {}
OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> &&params)
: OpMethodParameters(PK_Resolved) {
for (OpMethodParameter &param : params)
parameters.emplace_back(std::move(param));
}
OpMethodResolvedParameters(StringRef type, StringRef name,
StringRef defaultValue)
: OpMethodParameters(PK_Resolved) {
parameters.emplace_back(type, name, defaultValue);
}
// Returns the number of parameters.
size_t getNumParameters() const { return parameters.size(); }
// Returns if this method makes the `other` method redundant. Note that this
// is more than just finding conflicting methods. This method determines if
// the 2 set of parameters are conflicting and if so, returns true if this
// method has a more general set of parameters that can replace all possible
// calls to the `other` method.
bool makesRedundant(const OpMethodResolvedParameters &other) const;
// write the parameters as a part of a method declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const override;
// write the parameters as a part of a method definition to the given `os`
void writeDefTo(raw_ostream &os) const override;
// LLVM-style RTTI support.
static bool classof(const OpMethodParameters *params) {
return params->getKind() == PK_Resolved;
}
private:
llvm::SmallVector<OpMethodParameter, 4> parameters;
};
// Class for holding the signature of an op's method for C++ code emission
class OpMethodSignature {
public:
template <typename... Args>
OpMethodSignature(StringRef retType, StringRef name, Args &&...args)
: returnType(retType), methodName(name),
parameters(OpMethodParameters::create(std::forward<Args>(args)...)) {}
OpMethodSignature(OpMethodSignature &&) = default;
// Returns if a method with this signature makes a method with `other`
// signature redundant. Only supports resolved parameters.
bool makesRedundant(const OpMethodSignature &other) const;
// Returns the number of parameters (for resolved parameters).
size_t getNumParameters() const {
return cast<OpMethodResolvedParameters>(parameters.get())
->getNumParameters();
}
// Returns the name of the method.
StringRef getName() const { return methodName; }
// Writes the signature as a method declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
// Writes the signature as the start of a method definition to the given `os`.
// `namePrefix` is the prefix to be prepended to the method name (typically
// namespaces for qualifying the method definition).
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
private:
std::string returnType;
std::string methodName;
std::unique_ptr<OpMethodParameters> parameters;
};
// Class for holding the body of an op's method for C++ code emission
class OpMethodBody {
public:
explicit OpMethodBody(bool declOnly);
OpMethodBody &operator<<(Twine content);
OpMethodBody &operator<<(int content);
OpMethodBody &operator<<(const FmtObjectBase &content);
void writeTo(raw_ostream &os) const;
private:
// Whether this class should record method body.
bool isEffective;
std::string body;
};
// Class for holding an op's method for C++ code emission
class OpMethod {
public:
// Properties (qualifiers) of class methods. Bitfield is used here to help
// querying properties.
enum Property {
MP_None = 0x0,
MP_Static = 0x1,
MP_Constructor = 0x2,
MP_Private = 0x4,
MP_Declaration = 0x8,
MP_Inline = 0x10,
MP_Constexpr = 0x20 | MP_Inline,
MP_StaticDeclaration = MP_Static | MP_Declaration,
};
template <typename... Args>
OpMethod(StringRef retType, StringRef name, Property property, unsigned id,
Args &&...args)
: properties(property),
methodSignature(retType, name, std::forward<Args>(args)...),
methodBody(properties & MP_Declaration), id(id) {}
OpMethod(OpMethod &&) = default;
virtual ~OpMethod() = default;
OpMethodBody &body() { return methodBody; }
// Returns true if this is a static method.
bool isStatic() const { return properties & MP_Static; }
// Returns true if this is a private method.
bool isPrivate() const { return properties & MP_Private; }
// Returns true if this is an inline method.
bool isInline() const { return properties & MP_Inline; }
// Returns the name of this method.
StringRef getName() const { return methodSignature.getName(); }
// Returns the ID for this method
unsigned getID() const { return id; }
// Returns if this method makes the `other` method redundant.
bool makesRedundant(const OpMethod &other) const {
return methodSignature.makesRedundant(other.methodSignature);
}
// Writes the method as a declaration to the given `os`.
virtual void writeDeclTo(raw_ostream &os) const;
// Writes the method as a definition to the given `os`. `namePrefix` is the
// prefix to be prepended to the method name (typically namespaces for
// qualifying the method definition).
virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
protected:
Property properties;
OpMethodSignature methodSignature;
OpMethodBody methodBody;
const unsigned id;
};
// Class for holding an op's constructor method for C++ code emission.
class OpConstructor : public OpMethod {
public:
template <typename... Args>
OpConstructor(StringRef className, Property property, unsigned id,
Args &&...args)
: OpMethod("", className, property, id, std::forward<Args>(args)...) {}
// Add member initializer to constructor initializing `name` with `value`.
void addMemberInitializer(StringRef name, StringRef value);
// Writes the method as a definition to the given `os`. `namePrefix` is the
// prefix to be prepended to the method name (typically namespaces for
// qualifying the method definition).
void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
private:
// Member initializers.
std::string memberInitializers;
};
// A class used to emit C++ classes from Tablegen. Contains a list of public
// methods and a list of private fields to be emitted.
class Class {
public:
explicit Class(StringRef name);
// Adds a new method to this class and prune redundant methods. Returns null
// if the method was not added (because an existing method would make it
// redundant), else returns a pointer to the added method. Note that this call
// may also delete existing methods that are made redundant by a method to the
// class.
template <typename... Args>
OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
OpMethod::Property properties, Args &&...args) {
auto newMethod = std::make_unique<OpMethod>(
retType, name, properties, nextMethodID++, std::forward<Args>(args)...);
return addMethodAndPrune(methods, std::move(newMethod));
}
template <typename... Args>
OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
Args &&...args) {
return addMethodAndPrune(retType, name, OpMethod::MP_None,
std::forward<Args>(args)...);
}
template <typename... Args>
OpConstructor *addConstructorAndPrune(Args &&...args) {
auto newConstructor = std::make_unique<OpConstructor>(
getClassName(), OpMethod::MP_Constructor, nextMethodID++,
std::forward<Args>(args)...);
return addMethodAndPrune(constructors, std::move(newConstructor));
}
// Creates a new field in this class.
void newField(StringRef type, StringRef name, StringRef defaultValue = "");
// Writes this op's class as a declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
// Writes the method definitions in this op's class to the given `os`.
void writeDefTo(raw_ostream &os) const;
// Returns the C++ class name of the op.
StringRef getClassName() const { return className; }
protected:
// Get a list of all the methods to emit, filtering out hidden ones.
void forAllMethods(llvm::function_ref<void(const OpMethod &)> func) const {
using ConsRef = const std::unique_ptr<OpConstructor> &;
using MethodRef = const std::unique_ptr<OpMethod> &;
llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); });
llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); });
}
// For deterministic code generation, keep methods sorted in the order in
// which they were generated.
template <typename MethodTy>
struct MethodCompare {
bool operator()(const std::unique_ptr<MethodTy> &x,
const std::unique_ptr<MethodTy> &y) const {
return x->getID() < y->getID();
}
};
template <typename MethodTy>
using MethodSet =
std::set<std::unique_ptr<MethodTy>, MethodCompare<MethodTy>>;
template <typename MethodTy>
MethodTy *addMethodAndPrune(MethodSet<MethodTy> &set,
std::unique_ptr<MethodTy> &&newMethod) {
// Check if the new method will be made redundant by existing methods.
for (auto &method : set)
if (method->makesRedundant(*newMethod))
return nullptr;
// We can add this a method to the set. Prune any existing methods that will
// be made redundant by adding this new method. Note that the redundant
// check between two methods is more than a conflict check. makesRedundant()
// below will check if the new method conflicts with an existing method and
// if so, returns true if the new method makes the existing method redundant
// because all calls to the existing method can be subsumed by the new
// method. So makesRedundant() does a combined job of finding conflicts and
// deciding which of the 2 conflicting methods survive.
//
// Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing
// it manually here.
for (auto it = set.begin(), end = set.end(); it != end;) {
if (newMethod->makesRedundant(*(it->get())))
it = set.erase(it);
else
++it;
}
MethodTy *ret = newMethod.get();
set.insert(std::move(newMethod));
return ret;
}
std::string className;
MethodSet<OpConstructor> constructors;
MethodSet<OpMethod> methods;
unsigned nextMethodID = 0;
SmallVector<std::string, 4> fields;
};
// Class for holding an op for C++ code emission
class OpClass : public Class {
public:
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
// Adds an op trait.
void addTrait(Twine trait);
// Writes this op's class as a declaration to the given `os`. Redefines
// Class::writeDeclTo to also emit traits and extra class declarations.
void writeDeclTo(raw_ostream &os) const;
private:
StringRef extraClassDeclaration;
SmallVector<std::string, 4> traitsVec;
StringSet<> traitsSet;
};
} // namespace tblgen
} // namespace mlir
#endif // MLIR_TABLEGEN_OPCLASS_H_

View File

@ -0,0 +1,592 @@
//===- Operator.cpp - Operator class --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
//
//===----------------------------------------------------------------------===//
#include "Operator.h"
#include "Predicate.h"
#include "Trait.h"
#include "Type.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#define DEBUG_TYPE "mlir-tblgen-operator"
using namespace mlir;
using namespace mlir::tblgen;
using llvm::DagInit;
using llvm::DefInit;
using llvm::Record;
Operator::Operator(const llvm::Record &def)
: dialect(def.getValueAsDef("opDialect")), def(def) {
// The first `_` in the op's TableGen def name is treated as separating the
// dialect prefix and the op class name. The dialect prefix will be ignored if
// not empty. Otherwise, if def name starts with a `_`, the `_` is considered
// as part of the class name.
StringRef prefix;
std::tie(prefix, cppClassName) = def.getName().split('_');
if (prefix.empty()) {
// Class name with a leading underscore and without dialect prefix
cppClassName = def.getName();
} else if (cppClassName.empty()) {
// Class name without dialect prefix
cppClassName = prefix;
}
cppNamespace = def.getValueAsString("cppNamespace");
populateOpStructure();
}
std::string Operator::getOperationName() const {
auto prefix = dialect.getName();
auto opName = def.getValueAsString("opName");
if (prefix.empty())
return std::string(opName);
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
}
std::string Operator::getAdaptorName() const {
return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
}
StringRef Operator::getDialectName() const { return dialect.getName(); }
StringRef Operator::getCppClassName() const { return cppClassName; }
std::string Operator::getQualCppClassName() const {
if (cppNamespace.empty())
return std::string(cppClassName);
return std::string(llvm::formatv("{0}::{1}", cppNamespace, cppClassName));
}
StringRef Operator::getCppNamespace() const { return cppNamespace; }
int Operator::getNumResults() const {
DagInit *results = def.getValueAsDag("results");
return results->getNumArgs();
}
StringRef Operator::getExtraClassDeclaration() const {
constexpr auto attr = "extraClassDeclaration";
if (def.isValueUnset(attr))
return {};
return def.getValueAsString(attr);
}
const llvm::Record &Operator::getDef() const { return def; }
bool Operator::skipDefaultBuilders() const {
return def.getValueAsBit("skipDefaultBuilders");
}
auto Operator::result_begin() -> value_iterator { return results.begin(); }
auto Operator::result_end() -> value_iterator { return results.end(); }
auto Operator::getResults() -> value_range {
return {result_begin(), result_end()};
}
TypeConstraint Operator::getResultTypeConstraint(int index) const {
DagInit *results = def.getValueAsDag("results");
return TypeConstraint(cast<DefInit>(results->getArg(index)));
}
StringRef Operator::getResultName(int index) const {
DagInit *results = def.getValueAsDag("results");
return results->getArgNameStr(index);
}
auto Operator::getResultDecorators(int index) const -> var_decorator_range {
Record *result =
cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
if (!result->isSubClassOf("OpVariable"))
return var_decorator_range(nullptr, nullptr);
return *result->getValueAsListInit("decorators");
}
unsigned Operator::getNumVariableLengthResults() const {
return llvm::count_if(results, [](const NamedTypeConstraint &c) {
return c.constraint.isVariableLength();
});
}
unsigned Operator::getNumVariableLengthOperands() const {
return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
return c.constraint.isVariableLength();
});
}
bool Operator::hasSingleVariadicArg() const {
return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() &&
getOperand(0).isVariadic();
}
Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
Operator::arg_range Operator::getArgs() const {
return {arg_begin(), arg_end()};
}
StringRef Operator::getArgName(int index) const {
DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgNameStr(index);
}
auto Operator::getArgDecorators(int index) const -> var_decorator_range {
Record *arg =
cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
if (!arg->isSubClassOf("OpVariable"))
return var_decorator_range(nullptr, nullptr);
return *arg->getValueAsListInit("decorators");
}
const Trait *Operator::getTrait(StringRef trait) const {
for (const auto &t : traits) {
if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
if (traitDef->getFullyQualifiedTraitName() == trait)
return traitDef;
} else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
if (traitDef->getFullyQualifiedTraitName() == trait)
return traitDef;
} else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
if (traitDef->getFullyQualifiedTraitName() == trait)
return traitDef;
}
}
return nullptr;
}
auto Operator::region_begin() const -> const_region_iterator {
return regions.begin();
}
auto Operator::region_end() const -> const_region_iterator {
return regions.end();
}
auto Operator::getRegions() const
-> llvm::iterator_range<const_region_iterator> {
return {region_begin(), region_end()};
}
unsigned Operator::getNumRegions() const { return regions.size(); }
const NamedRegion &Operator::getRegion(unsigned index) const {
return regions[index];
}
unsigned Operator::getNumVariadicRegions() const {
return llvm::count_if(regions,
[](const NamedRegion &c) { return c.isVariadic(); });
}
auto Operator::successor_begin() const -> const_successor_iterator {
return successors.begin();
}
auto Operator::successor_end() const -> const_successor_iterator {
return successors.end();
}
auto Operator::getSuccessors() const
-> llvm::iterator_range<const_successor_iterator> {
return {successor_begin(), successor_end()};
}
unsigned Operator::getNumSuccessors() const { return successors.size(); }
const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
return successors[index];
}
unsigned Operator::getNumVariadicSuccessors() const {
return llvm::count_if(successors,
[](const NamedSuccessor &c) { return c.isVariadic(); });
}
auto Operator::trait_begin() const -> const_trait_iterator {
return traits.begin();
}
auto Operator::trait_end() const -> const_trait_iterator {
return traits.end();
}
auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
return {trait_begin(), trait_end()};
}
auto Operator::attribute_begin() const -> attribute_iterator {
return attributes.begin();
}
auto Operator::attribute_end() const -> attribute_iterator {
return attributes.end();
}
auto Operator::getAttributes() const
-> llvm::iterator_range<attribute_iterator> {
return {attribute_begin(), attribute_end()};
}
auto Operator::operand_begin() -> value_iterator { return operands.begin(); }
auto Operator::operand_end() -> value_iterator { return operands.end(); }
auto Operator::getOperands() -> value_range {
return {operand_begin(), operand_end()};
}
auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
// Mapping from result index to combined argument and result index. Arguments
// are indexed to match getArg index, while the result indexes are mapped to
// avoid overlap.
static int resultIndex(int i) { return -1 - i; }
bool Operator::isVariadic() const {
return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
[](const NamedTypeConstraint &op) { return op.isVariadic(); });
}
void Operator::populateTypeInferenceInfo(
const llvm::StringMap<int> &argumentsAndResultsIndex) {
// If the type inference op interface is not registered, then do not attempt
// to determine if the result types an be inferred.
auto &recordKeeper = def.getRecords();
auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface);
allResultsHaveKnownTypes = false;
if (!inferTrait)
return;
// If there are no results, the skip this else the build method generated
// overlaps with another autogenerated builder.
if (getNumResults() == 0)
return;
// Skip for ops with variadic operands/results.
// TODO: This can be relaxed.
if (isVariadic())
return;
// Skip cases currently being custom generated.
// TODO: Remove special cases.
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType"))
return;
// We create equivalence classes of argument/result types where arguments
// and results are mapped into the same index space and indices corresponding
// to the same type are in the same equivalence class.
llvm::EquivalenceClasses<int> ecs;
resultTypeMapping.resize(getNumResults());
// Captures the argument whose type matches a given result type. Preference
// towards capturing operands first before attributes.
auto captureMapping = [&](int i) {
bool found = false;
ecs.insert(resultIndex(i));
auto mi = ecs.findLeader(resultIndex(i));
for (auto me = ecs.member_end(); mi != me; ++mi) {
if (*mi < 0) {
auto tc = getResultTypeConstraint(i);
if (tc.getBuilderCall().hasValue()) {
resultTypeMapping[i].emplace_back(tc);
found = true;
}
continue;
}
if (getArg(*mi).is<NamedAttribute *>()) {
// TODO: Handle attributes.
continue;
} else {
resultTypeMapping[i].emplace_back(*mi);
found = true;
}
}
return found;
};
for (const Trait &trait : traits) {
const llvm::Record &def = trait.getDef();
// If the infer type op interface was manually added, then treat it as
// intention that the op needs special handling.
// TODO: Reconsider whether to always generate, this is more conservative
// and keeps existing behavior so starting that way for now.
if (def.isSubClassOf(
llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
return;
if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
if (&traitDef->getDef() == inferTrait)
return;
if (!def.isSubClassOf("AllTypesMatch"))
continue;
auto values = def.getValueAsListOfStrings("values");
auto root = argumentsAndResultsIndex.lookup(values.front());
for (StringRef str : values)
ecs.unionSets(argumentsAndResultsIndex.lookup(str), root);
}
// Verifies that all output types have a corresponding known input type
// and chooses matching operand or attribute (in that order) that
// matches it.
allResultsHaveKnownTypes =
all_of(llvm::seq<int>(0, getNumResults()), captureMapping);
// If the types could be computed, then add type inference trait.
if (allResultsHaveKnownTypes)
traits.push_back(Trait::create(inferTrait->getDefInit()));
}
void Operator::populateOpStructure() {
auto &recordKeeper = def.getRecords();
auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
auto *attrClass = recordKeeper.getClass("Attr");
auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr");
auto *opVarClass = recordKeeper.getClass("OpVariable");
numNativeAttributes = 0;
DagInit *argumentValues = def.getValueAsDag("arguments");
unsigned numArgs = argumentValues->getNumArgs();
// Mapping from name of to argument or result index. Arguments are indexed
// to match getArg index, while the results are negatively indexed.
llvm::StringMap<int> argumentsAndResultsIndex;
// Handle operands and native attributes.
for (unsigned i = 0; i != numArgs; ++i) {
auto *arg = argumentValues->getArg(i);
auto givenName = argumentValues->getArgNameStr(i);
auto *argDefInit = dyn_cast<DefInit>(arg);
if (!argDefInit)
PrintFatalError(def.getLoc(),
Twine("undefined type for argument #") + Twine(i));
Record *argDef = argDefInit->getDef();
if (argDef->isSubClassOf(opVarClass))
argDef = argDef->getValueAsDef("constraint");
if (argDef->isSubClassOf(typeConstraintClass)) {
operands.push_back(
NamedTypeConstraint{givenName, TypeConstraint(argDef)});
} else if (argDef->isSubClassOf(attrClass)) {
if (givenName.empty())
PrintFatalError(argDef->getLoc(), "attributes must be named");
if (argDef->isSubClassOf(derivedAttrClass))
PrintFatalError(argDef->getLoc(),
"derived attributes not allowed in argument list");
attributes.push_back({givenName, Attribute(argDef)});
++numNativeAttributes;
} else {
PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
"from TypeConstraint or Attr are allowed");
}
if (!givenName.empty())
argumentsAndResultsIndex[givenName] = i;
}
// Handle derived attributes.
for (const auto &val : def.getValues()) {
if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
if (!record->isSubClassOf(attrClass))
continue;
if (!record->isSubClassOf(derivedAttrClass))
PrintFatalError(def.getLoc(),
"unexpected Attr where only DerivedAttr is allowed");
if (record->getClasses().size() != 1) {
PrintFatalError(
def.getLoc(),
"unsupported attribute modelling, only single class expected");
}
attributes.push_back(
{cast<llvm::StringInit>(val.getNameInit())->getValue(),
Attribute(cast<DefInit>(val.getValue()))});
}
}
// Populate `arguments`. This must happen after we've finalized `operands` and
// `attributes` because we will put their elements' pointers in `arguments`.
// SmallVector may perform re-allocation under the hood when adding new
// elements.
int operandIndex = 0, attrIndex = 0;
for (unsigned i = 0; i != numArgs; ++i) {
Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
if (argDef->isSubClassOf(opVarClass))
argDef = argDef->getValueAsDef("constraint");
if (argDef->isSubClassOf(typeConstraintClass)) {
attrOrOperandMapping.push_back(
{OperandOrAttribute::Kind::Operand, operandIndex});
arguments.emplace_back(&operands[operandIndex++]);
} else {
assert(argDef->isSubClassOf(attrClass));
attrOrOperandMapping.push_back(
{OperandOrAttribute::Kind::Attribute, attrIndex});
arguments.emplace_back(&attributes[attrIndex++]);
}
}
auto *resultsDag = def.getValueAsDag("results");
auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
if (!outsOp || outsOp->getDef()->getName() != "outs") {
PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
}
// Handle results.
for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
auto name = resultsDag->getArgNameStr(i);
auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
if (!resultInit) {
PrintFatalError(def.getLoc(),
Twine("undefined type for result #") + Twine(i));
}
auto *resultDef = resultInit->getDef();
if (resultDef->isSubClassOf(opVarClass))
resultDef = resultDef->getValueAsDef("constraint");
results.push_back({name, TypeConstraint(resultDef)});
if (!name.empty())
argumentsAndResultsIndex[name] = resultIndex(i);
}
// Handle successors
auto *successorsDag = def.getValueAsDag("successors");
auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
PrintFatalError(def.getLoc(),
"'successors' must have 'successor' directive");
}
for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
auto name = successorsDag->getArgNameStr(i);
auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
if (!successorInit) {
PrintFatalError(def.getLoc(),
Twine("undefined kind for successor #") + Twine(i));
}
Successor successor(successorInit->getDef());
// Only support variadic successors if it is the last one for now.
if (i != e - 1 && successor.isVariadic())
PrintFatalError(def.getLoc(), "only the last successor can be variadic");
successors.push_back({name, successor});
}
// Create list of traits, skipping over duplicates: appending to lists in
// tablegen is easy, making them unique less so, so dedupe here.
if (auto *traitList = def.getValueAsListInit("traits")) {
// This is uniquing based on pointers of the trait.
SmallPtrSet<const llvm::Init *, 32> traitSet;
traits.reserve(traitSet.size());
for (auto *traitInit : *traitList) {
// Keep traits in the same order while skipping over duplicates.
if (traitSet.insert(traitInit).second)
traits.push_back(Trait::create(traitInit));
}
}
populateTypeInferenceInfo(argumentsAndResultsIndex);
// Handle regions
auto *regionsDag = def.getValueAsDag("regions");
auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
if (!regionsOp || regionsOp->getDef()->getName() != "region") {
PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
}
for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
auto name = regionsDag->getArgNameStr(i);
auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
if (!regionInit) {
PrintFatalError(def.getLoc(),
Twine("undefined kind for region #") + Twine(i));
}
Region region(regionInit->getDef());
if (region.isVariadic()) {
// Only support variadic regions if it is the last one for now.
if (i != e - 1)
PrintFatalError(def.getLoc(), "only the last region can be variadic");
if (name.empty())
PrintFatalError(def.getLoc(), "variadic regions must be named");
}
regions.push_back({name, region});
}
// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues())
builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
} else if (skipDefaultBuilders()) {
PrintFatalError(
def.getLoc(),
"default builders are skipped and no custom builders provided");
}
LLVM_DEBUG(print(llvm::dbgs()));
}
auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> {
assert(allResultTypesKnown());
return resultTypeMapping[index];
}
ArrayRef<llvm::SMLoc> Operator::getLoc() const { return def.getLoc(); }
bool Operator::hasDescription() const {
return def.getValue("description") != nullptr;
}
StringRef Operator::getDescription() const {
return def.getValueAsString("description");
}
bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; }
StringRef Operator::getSummary() const {
return def.getValueAsString("summary");
}
bool Operator::hasAssemblyFormat() const {
auto *valueInit = def.getValueInit("assemblyFormat");
return isa<llvm::StringInit>(valueInit);
}
StringRef Operator::getAssemblyFormat() const {
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
.Case<llvm::StringInit>(
[&](auto *init) { return init->getValue(); });
}
void Operator::print(llvm::raw_ostream &os) const {
os << "op '" << getOperationName() << "'\n";
for (Argument arg : arguments) {
if (auto *attr = arg.dyn_cast<NamedAttribute *>())
os << "[attribute] " << attr->name << '\n';
else
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
}
}
auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
-> VariableDecorator {
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
}
auto Operator::getArgToOperandOrAttribute(int index) const
-> OperandOrAttribute {
return attrOrOperandMapping[index];
}

View File

@ -0,0 +1,360 @@
//===- Operator.h - Operator class ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_OPERATOR_H_
#define MLIR_TABLEGEN_OPERATOR_H_
#include "mlir/Support/LLVM.h"
#include "Argument.h"
#include "Attribute.h"
#include "Builder.h"
#include "Dialect.h"
#include "Region.h"
#include "Successor.h"
#include "Trait.h"
#include "Type.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/SMLoc.h"
namespace llvm {
class DefInit;
class Record;
class StringInit;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Wrapper class that contains a MLIR op's information (e.g., operands,
// attributes) defined in TableGen and provides helper methods for
// accessing them.
class Operator {
public:
explicit Operator(const llvm::Record &def);
explicit Operator(const llvm::Record *def) : Operator(*def) {}
// Returns this op's dialect name.
StringRef getDialectName() const;
// Returns the operation name. The name will follow the "<dialect>.<op-name>"
// format if its dialect name is not empty.
std::string getOperationName() const;
// Returns this op's C++ class name.
StringRef getCppClassName() const;
// Returns this op's C++ class name prefixed with namespaces.
std::string getQualCppClassName() const;
// Returns this op's C++ namespace.
StringRef getCppNamespace() const;
// Returns the name of op's adaptor C++ class.
std::string getAdaptorName() const;
/// A class used to represent the decorators of an operator variable, i.e.
/// argument or result.
struct VariableDecorator {
public:
explicit VariableDecorator(const llvm::Record *def) : def(def) {}
const llvm::Record &getDef() const { return *def; }
protected:
// The TableGen definition of this decorator.
const llvm::Record *def;
};
// A utility iterator over a list of variable decorators.
struct VariableDecoratorIterator
: public llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)> {
using reference = VariableDecorator;
/// Initializes the iterator to the specified iterator.
VariableDecoratorIterator(llvm::Init *const *it)
: llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)>(it,
&unwrap) {}
static VariableDecorator unwrap(llvm::Init *init);
};
using var_decorator_iterator = VariableDecoratorIterator;
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
using value_iterator = NamedTypeConstraint *;
using value_range = llvm::iterator_range<value_iterator>;
// Returns true if this op has variable length operands or results.
bool isVariadic() const;
// Returns true if default builders should not be generated.
bool skipDefaultBuilders() const;
// Op result iterators.
value_iterator result_begin();
value_iterator result_end();
value_range getResults();
// Returns the number of results this op produces.
int getNumResults() const;
// Returns the op result at the given `index`.
NamedTypeConstraint &getResult(int index) { return results[index]; }
const NamedTypeConstraint &getResult(int index) const {
return results[index];
}
// Returns the `index`-th result's type constraint.
TypeConstraint getResultTypeConstraint(int index) const;
// Returns the `index`-th result's name.
StringRef getResultName(int index) const;
// Returns the `index`-th result's decorators.
var_decorator_range getResultDecorators(int index) const;
// Returns the number of variable length results in this operation.
unsigned getNumVariableLengthResults() const;
// Op attribute iterators.
using attribute_iterator = const NamedAttribute *;
attribute_iterator attribute_begin() const;
attribute_iterator attribute_end() const;
llvm::iterator_range<attribute_iterator> getAttributes() const;
int getNumAttributes() const { return attributes.size(); }
int getNumNativeAttributes() const { return numNativeAttributes; }
// Op attribute accessors.
NamedAttribute &getAttribute(int index) { return attributes[index]; }
// Op operand iterators.
value_iterator operand_begin();
value_iterator operand_end();
value_range getOperands();
int getNumOperands() const { return operands.size(); }
NamedTypeConstraint &getOperand(int index) { return operands[index]; }
const NamedTypeConstraint &getOperand(int index) const {
return operands[index];
}
// Returns the number of variadic operands in this operation.
unsigned getNumVariableLengthOperands() const;
// Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); }
// Returns true of the operation has a single variadic arg.
bool hasSingleVariadicArg() const;
// Returns true if the operation has a single variadic result.
bool hasSingleVariadicResult() const {
return getNumResults() == 1 && getResult(0).isVariadic();
}
// Returns true of the operation has no variadic regions.
bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
using arg_iterator = const Argument *;
using arg_range = llvm::iterator_range<arg_iterator>;
// Op argument (attribute or operand) iterators.
arg_iterator arg_begin() const;
arg_iterator arg_end() const;
arg_range getArgs() const;
// Op argument (attribute or operand) accessors.
Argument getArg(int index) const;
StringRef getArgName(int index) const;
var_decorator_range getArgDecorators(int index) const;
// Returns the trait wrapper for the given MLIR C++ `trait`.
const Trait *getTrait(llvm::StringRef trait) const;
// Regions.
using const_region_iterator = const NamedRegion *;
const_region_iterator region_begin() const;
const_region_iterator region_end() const;
llvm::iterator_range<const_region_iterator> getRegions() const;
// Returns the number of regions.
unsigned getNumRegions() const;
// Returns the `index`-th region.
const NamedRegion &getRegion(unsigned index) const;
// Returns the number of variadic regions in this operation.
unsigned getNumVariadicRegions() const;
// Successors.
using const_successor_iterator = const NamedSuccessor *;
const_successor_iterator successor_begin() const;
const_successor_iterator successor_end() const;
llvm::iterator_range<const_successor_iterator> getSuccessors() const;
// Returns the number of successors.
unsigned getNumSuccessors() const;
// Returns the `index`-th successor.
const NamedSuccessor &getSuccessor(unsigned index) const;
// Returns the number of variadic successors in this operation.
unsigned getNumVariadicSuccessors() const;
// Trait.
using const_trait_iterator = const Trait *;
const_trait_iterator trait_begin() const;
const_trait_iterator trait_end() const;
llvm::iterator_range<const_trait_iterator> getTraits() const;
ArrayRef<llvm::SMLoc> getLoc() const;
// Query functions for the documentation of the operator.
bool hasDescription() const;
StringRef getDescription() const;
bool hasSummary() const;
StringRef getSummary() const;
// Query functions for the assembly format of the operator.
bool hasAssemblyFormat() const;
StringRef getAssemblyFormat() const;
// Returns this op's extra class declaration code.
StringRef getExtraClassDeclaration() const;
// Returns the Tablegen definition this operator was constructed from.
// TODO: do not expose the TableGen record, this is a temporary solution to
// OpEmitter requiring a Record because Operator does not provide enough
// methods.
const llvm::Record &getDef() const;
// Returns the dialect of the op.
const Dialect &getDialect() const { return dialect; }
// Prints the contents in this operator to the given `os`. This is used for
// debugging purposes.
void print(llvm::raw_ostream &os) const;
// Return whether all the result types are known.
bool allResultTypesKnown() const { return allResultsHaveKnownTypes; };
// Pair representing either a index to an argument or a type constraint. Only
// one of these entries should have the non-default value.
struct ArgOrType {
explicit ArgOrType(int index) : index(index), constraint(None) {}
explicit ArgOrType(TypeConstraint constraint)
: index(None), constraint(constraint) {}
bool isArg() const {
assert(constraint.hasValue() ^ index.hasValue());
return index.hasValue();
}
bool isType() const {
assert(constraint.hasValue() ^ index.hasValue());
return constraint.hasValue();
}
int getArg() const { return *index; }
TypeConstraint getType() const { return *constraint; }
private:
Optional<int> index;
Optional<TypeConstraint> constraint;
};
// Return all arguments or type constraints with same type as result[index].
// Requires: all result types are known.
ArrayRef<ArgOrType> getSameTypeAsResult(int index) const;
// Pair consisting kind of argument and index into operands or attributes.
struct OperandOrAttribute {
enum class Kind { Operand, Attribute };
OperandOrAttribute(Kind kind, int index) {
packed = (index << 1) & (kind == Kind::Attribute);
}
int operandOrAttributeIndex() const { return (packed >> 1); }
Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
private:
int packed;
};
// Returns the OperandOrAttribute corresponding to the index.
OperandOrAttribute getArgToOperandOrAttribute(int index) const;
// Returns the builders of this operation.
ArrayRef<Builder> getBuilders() const { return builders; }
private:
// Populates the vectors containing operands, attributes, results and traits.
void populateOpStructure();
// Populates type inference info (mostly equality) with input a mapping from
// names to indices for arguments and results.
void populateTypeInferenceInfo(
const llvm::StringMap<int> &argumentsAndResultsIndex);
// The dialect of this op.
Dialect dialect;
// The unqualified C++ class name of the op.
StringRef cppClassName;
// The C++ namespace for this op.
StringRef cppNamespace;
// The operands of the op.
SmallVector<NamedTypeConstraint, 4> operands;
// The attributes of the op. Contains native attributes (corresponding to the
// actual stored attributed of the operation) followed by derived attributes
// (corresponding to dynamic properties of the operation that are computed
// upon request).
SmallVector<NamedAttribute, 4> attributes;
// The arguments of the op (operands and native attributes).
SmallVector<Argument, 4> arguments;
// The results of the op.
SmallVector<NamedTypeConstraint, 4> results;
// The successors of this op.
SmallVector<NamedSuccessor, 0> successors;
// The traits of the op.
SmallVector<Trait, 4> traits;
// The regions of this op.
SmallVector<NamedRegion, 1> regions;
// The argument with the same type as the result.
SmallVector<SmallVector<ArgOrType, 2>, 4> resultTypeMapping;
// Map from argument to attribute or operand number.
SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
// The builders of this operator.
SmallVector<Builder> builders;
// The number of native attributes stored in the leading positions of
// `attributes`.
int numNativeAttributes;
// The TableGen definition of this op.
const llvm::Record &def;
// Whether the type of all results are known.
bool allResultsHaveKnownTypes;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_OPERATOR_H_

View File

@ -0,0 +1,99 @@
//===- Pass.cpp - Pass related classes ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Pass.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// PassOption
//===----------------------------------------------------------------------===//
StringRef PassOption::getCppVariableName() const {
return def->getValueAsString("cppName");
}
StringRef PassOption::getArgument() const {
return def->getValueAsString("argument");
}
StringRef PassOption::getType() const { return def->getValueAsString("type"); }
Optional<StringRef> PassOption::getDefaultValue() const {
StringRef defaultVal = def->getValueAsString("defaultValue");
return defaultVal.empty() ? Optional<StringRef>() : defaultVal;
}
StringRef PassOption::getDescription() const {
return def->getValueAsString("description");
}
Optional<StringRef> PassOption::getAdditionalFlags() const {
StringRef additionalFlags = def->getValueAsString("additionalOptFlags");
return additionalFlags.empty() ? Optional<StringRef>() : additionalFlags;
}
bool PassOption::isListOption() const {
return def->isSubClassOf("ListOption");
}
//===----------------------------------------------------------------------===//
// PassStatistic
//===----------------------------------------------------------------------===//
StringRef PassStatistic::getCppVariableName() const {
return def->getValueAsString("cppName");
}
StringRef PassStatistic::getName() const {
return def->getValueAsString("name");
}
StringRef PassStatistic::getDescription() const {
return def->getValueAsString("description");
}
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
Pass::Pass(const llvm::Record *def) : def(def) {
for (auto *init : def->getValueAsListOfDefs("options"))
options.push_back(PassOption(init));
for (auto *init : def->getValueAsListOfDefs("statistics"))
statistics.push_back(PassStatistic(init));
for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
dependentDialects.push_back(dialect);
}
StringRef Pass::getArgument() const {
return def->getValueAsString("argument");
}
StringRef Pass::getBaseClass() const {
return def->getValueAsString("baseClass");
}
StringRef Pass::getSummary() const { return def->getValueAsString("summary"); }
StringRef Pass::getDescription() const {
return def->getValueAsString("description");
}
StringRef Pass::getConstructor() const {
return def->getValueAsString("constructor");
}
ArrayRef<StringRef> Pass::getDependentDialects() const {
return dependentDialects;
}
ArrayRef<PassOption> Pass::getOptions() const { return options; }
ArrayRef<PassStatistic> Pass::getStatistics() const { return statistics; }

View File

@ -0,0 +1,118 @@
//===- Pass.h - TableGen pass definitions -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_PASS_H_
#define MLIR_TABLEGEN_PASS_H_
#include "mlir/Support/LLVM.h"
#include <vector>
namespace llvm {
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
//===----------------------------------------------------------------------===//
// PassOption
//===----------------------------------------------------------------------===//
class PassOption {
public:
explicit PassOption(const llvm::Record *def) : def(def) {}
/// Return the name for the C++ option variable.
StringRef getCppVariableName() const;
/// Return the command line argument to use for this option.
StringRef getArgument() const;
/// Return the C++ type of the option.
StringRef getType() const;
/// Return the default value of the option.
Optional<StringRef> getDefaultValue() const;
/// Return the description for this option.
StringRef getDescription() const;
/// Return the additional flags passed to the option constructor.
Optional<StringRef> getAdditionalFlags() const;
/// Flag indicating if this is a list option.
bool isListOption() const;
private:
const llvm::Record *def;
};
//===----------------------------------------------------------------------===//
// PassStatistic
//===----------------------------------------------------------------------===//
class PassStatistic {
public:
explicit PassStatistic(const llvm::Record *def) : def(def) {}
/// Return the name for the C++ statistic variable.
StringRef getCppVariableName() const;
/// Return the name of the statistic.
StringRef getName() const;
/// Return the description for this statistic.
StringRef getDescription() const;
private:
const llvm::Record *def;
};
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
/// Wrapper class providing helper methods for Passes defined in TableGen.
class Pass {
public:
explicit Pass(const llvm::Record *def);
/// Return the command line argument of the pass.
StringRef getArgument() const;
/// Return the name for the C++ base class.
StringRef getBaseClass() const;
/// Return the short 1-line summary of the pass.
StringRef getSummary() const;
/// Return the description of the pass.
StringRef getDescription() const;
/// Return the C++ constructor call to create an instance of this pass.
StringRef getConstructor() const;
/// Return the dialects this pass needs to be registered.
ArrayRef<StringRef> getDependentDialects() const;
/// Return the options provided by this pass.
ArrayRef<PassOption> getOptions() const;
/// Return the statistics provided by this pass.
ArrayRef<PassStatistic> getStatistics() const;
const llvm::Record *getDef() const { return def; }
private:
const llvm::Record *def;
std::vector<StringRef> dependentDialects;
std::vector<PassOption> options;
std::vector<PassStatistic> statistics;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_PASS_H_

View File

@ -0,0 +1,739 @@
//===- Pattern.cpp - Pattern wrapper class --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
// Pattern.
//
//===----------------------------------------------------------------------===//
#include "Pattern.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#define DEBUG_TYPE "mlir-tblgen-pattern"
using namespace mlir;
using namespace tblgen;
using llvm::formatv;
//===----------------------------------------------------------------------===//
// DagLeaf
//===----------------------------------------------------------------------===//
bool DagLeaf::isUnspecified() const {
return dyn_cast_or_null<llvm::UnsetInit>(def);
}
bool DagLeaf::isOperandMatcher() const {
// Operand matchers specify a type constraint.
return isSubClassOf("TypeConstraint");
}
bool DagLeaf::isAttrMatcher() const {
// Attribute matchers specify an attribute constraint.
return isSubClassOf("AttrConstraint");
}
bool DagLeaf::isNativeCodeCall() const {
return isSubClassOf("NativeCodeCall");
}
bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
bool DagLeaf::isEnumAttrCase() const {
return isSubClassOf("EnumAttrCaseInfo");
}
bool DagLeaf::isStringAttr() const {
return isa<llvm::StringInit>(def);
}
Constraint DagLeaf::getAsConstraint() const {
assert((isOperandMatcher() || isAttrMatcher()) &&
"the DAG leaf must be operand or attribute");
return Constraint(cast<llvm::DefInit>(def)->getDef());
}
ConstantAttr DagLeaf::getAsConstantAttr() const {
assert(isConstantAttr() && "the DAG leaf must be constant attribute");
return ConstantAttr(cast<llvm::DefInit>(def));
}
EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
return EnumAttrCase(cast<llvm::DefInit>(def));
}
std::string DagLeaf::getConditionTemplate() const {
return getAsConstraint().getConditionTemplate();
}
llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
}
std::string DagLeaf::getStringAttr() const {
assert(isStringAttr() && "the DAG leaf must be string attribute");
return def->getAsUnquotedString();
}
bool DagLeaf::isSubClassOf(StringRef superclass) const {
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
return defInit->getDef()->isSubClassOf(superclass);
return false;
}
void DagLeaf::print(raw_ostream &os) const {
if (def)
def->print(os);
}
//===----------------------------------------------------------------------===//
// DagNode
//===----------------------------------------------------------------------===//
bool DagNode::isNativeCodeCall() const {
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
return defInit->getDef()->isSubClassOf("NativeCodeCall");
return false;
}
bool DagNode::isOperation() const {
return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
}
llvm::StringRef DagNode::getNativeCodeTemplate() const {
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
return cast<llvm::DefInit>(node->getOperator())
->getDef()
->getValueAsString("expression");
}
llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
auto it = mapper->find(opDef);
if (it != mapper->end())
return *it->second;
return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
.first->second;
}
int DagNode::getNumOps() const {
int count = isReplaceWithValue() ? 0 : 1;
for (int i = 0, e = getNumArgs(); i != e; ++i) {
if (auto child = getArgAsNestedDag(i))
count += child.getNumOps();
}
return count;
}
int DagNode::getNumArgs() const { return node->getNumArgs(); }
bool DagNode::isNestedDagArg(unsigned index) const {
return isa<llvm::DagInit>(node->getArg(index));
}
DagNode DagNode::getArgAsNestedDag(unsigned index) const {
return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
}
DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
assert(!isNestedDagArg(index));
return DagLeaf(node->getArg(index));
}
StringRef DagNode::getArgName(unsigned index) const {
return node->getArgNameStr(index);
}
bool DagNode::isReplaceWithValue() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "replaceWithValue";
}
bool DagNode::isLocationDirective() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "location";
}
void DagNode::print(raw_ostream &os) const {
if (node)
node->print(os);
}
//===----------------------------------------------------------------------===//
// SymbolInfoMap
//===----------------------------------------------------------------------===//
StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
StringRef name, indexStr;
int idx = -1;
std::tie(name, indexStr) = symbol.rsplit("__");
if (indexStr.consumeInteger(10, idx)) {
// The second part is not an index; we return the whole symbol as-is.
return symbol;
}
if (index) {
*index = idx;
}
return name;
}
SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
Optional<int> index)
: op(op), kind(kind), argIndex(index) {}
int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
switch (kind) {
case Kind::Attr:
case Kind::Operand:
case Kind::Value:
return 1;
case Kind::Result:
return op->getNumResults();
}
llvm_unreachable("unknown kind");
}
std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
}
std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
switch (kind) {
case Kind::Attr: {
if (op) {
auto type =
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
return std::string(formatv("{0} {1};\n", type, name));
}
// TODO(suderman): Use a more exact type when available.
return std::string(formatv("Attribute {0};\n", name));
}
case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic
// operands).
return std::string(
formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
getVarName(name)));
}
case Kind::Value: {
return std::string(formatv("::mlir::Value {0};\n", name));
}
case Kind::Result: {
// Use the op itself for captured results.
return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
}
}
llvm_unreachable("unknown kind");
}
std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const {
LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
switch (kind) {
case Kind::Attr: {
assert(index < 0);
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
return std::string(repl);
}
case Kind::Operand: {
assert(index < 0);
auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
if (operand->isVariableLength()) {
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
return std::string(repl);
}
auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
return std::string(repl);
}
case Kind::Result: {
// If `index` is greater than zero, then we are referencing a specific
// result of a multi-result op. The result can still be variadic.
if (index >= 0) {
std::string v =
std::string(formatv("{0}.getODSResults({1})", name, index));
if (!op->getResult(index).isVariadic())
v = std::string(formatv("(*{0}.begin())", v));
auto repl = formatv(fmt, v);
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
return std::string(repl);
}
// If this op has no result at all but still we bind a symbol to it, it
// means we want to capture the op itself.
if (op->getNumResults() == 0) {
LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
return std::string(name);
}
// We are referencing all results of the multi-result op. A specific result
// can either be a value or a range. Then join them with `separator`.
SmallVector<std::string, 4> values;
values.reserve(op->getNumResults());
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
if (!op->getResult(i).isVariadic()) {
v = std::string(formatv("(*{0}.begin())", v));
}
values.push_back(std::string(formatv(fmt, v)));
}
auto repl = llvm::join(values, separator);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
return repl;
}
case Kind::Value: {
assert(index < 0);
assert(op == nullptr);
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
return std::string(repl);
}
}
llvm_unreachable("unknown kind");
}
std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const {
LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
switch (kind) {
case Kind::Attr:
case Kind::Operand: {
assert(index < 0 && "only allowed for symbol bound to result");
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
return std::string(repl);
}
case Kind::Result: {
if (index >= 0) {
auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
return std::string(repl);
}
// We are referencing all results of the multi-result op. Each result should
// have a value range, and then join them with `separator`.
SmallVector<std::string, 4> values;
values.reserve(op->getNumResults());
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
values.push_back(std::string(
formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
}
auto repl = llvm::join(values, separator);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
return repl;
}
case Kind::Value: {
assert(index < 0 && "only allowed for symbol bound to result");
assert(op == nullptr);
auto repl = formatv(fmt, formatv("{{{0}}", name));
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
return std::string(repl);
}
}
llvm_unreachable("unknown kind");
}
bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
int argIndex) {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
auto error = formatv(
"symbol '{0}' with trailing index cannot bind to op argument", symbol);
PrintFatalError(loc, error);
}
auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
? SymbolInfo::getAttr(&op, argIndex)
: SymbolInfo::getOperand(&op, argIndex);
std::string key = symbol.str();
if (symbolInfoMap.count(key)) {
// Only non unique name for the operand is supported.
if (symInfo.kind != SymbolInfo::Kind::Operand) {
return false;
}
// Cannot add new operand if there is already non operand with the same
// name.
if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
return false;
}
}
symbolInfoMap.emplace(key, symInfo);
return true;
}
bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
std::string name = getValuePackName(symbol).str();
auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
return symbolInfoMap.count(inserted->first) == 1;
}
bool SymbolInfoMap::bindValue(StringRef symbol) {
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
return symbolInfoMap.count(inserted->first) == 1;
}
bool SymbolInfoMap::bindAttr(StringRef symbol) {
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
return symbolInfoMap.count(inserted->first) == 1;
}
bool SymbolInfoMap::contains(StringRef symbol) const {
return find(symbol) != symbolInfoMap.end();
}
SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
std::string name = getValuePackName(key).str();
return symbolInfoMap.find(name);
}
SymbolInfoMap::const_iterator
SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
int argIndex) const {
std::string name = getValuePackName(key).str();
auto range = symbolInfoMap.equal_range(name);
for (auto it = range.first; it != range.second; ++it) {
if (it->second.op == &op && it->second.argIndex == argIndex) {
return it;
}
}
return symbolInfoMap.end();
}
std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
std::string name = getValuePackName(key).str();
return symbolInfoMap.equal_range(name);
}
int SymbolInfoMap::count(StringRef key) const {
std::string name = getValuePackName(key).str();
return symbolInfoMap.count(name);
}
int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
// If there is a trailing index inside symbol, it references just one
// static value.
return 1;
}
// Otherwise, find how many it represents by querying the symbol's info.
return find(name)->second.getStaticValueCount();
}
std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
const char *fmt,
const char *separator) const {
int index = -1;
StringRef name = getValuePackName(symbol, &index);
auto it = symbolInfoMap.find(name.str());
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
return it->second.getValueAndRangeUse(name, index, fmt, separator);
}
std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
const char *separator) const {
int index = -1;
StringRef name = getValuePackName(symbol, &index);
auto it = symbolInfoMap.find(name.str());
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
return it->second.getAllRangeUse(name, index, fmt, separator);
}
void SymbolInfoMap::assignUniqueAlternativeNames() {
llvm::StringSet<> usedNames;
for (auto symbolInfoIt = symbolInfoMap.begin();
symbolInfoIt != symbolInfoMap.end();) {
auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
auto startRange = range.first;
auto endRange = range.second;
auto operandName = symbolInfoIt->first;
int startSearchIndex = 0;
for (++startRange; startRange != endRange; ++startRange) {
// Current operand name is not unique, find a unique one
// and set the alternative name.
for (int i = startSearchIndex;; ++i) {
std::string alternativeName = operandName + std::to_string(i);
if (!usedNames.contains(alternativeName) &&
symbolInfoMap.count(alternativeName) == 0) {
usedNames.insert(alternativeName);
startRange->second.alternativeName = alternativeName;
startSearchIndex = i + 1;
break;
}
}
}
symbolInfoIt = endRange;
}
}
//===----------------------------------------------------------------------===//
// Pattern
//==----------------------------------------------------------------------===//
Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) {}
DagNode Pattern::getSourcePattern() const {
return DagNode(def.getValueAsDag("sourcePattern"));
}
int Pattern::getNumResultPatterns() const {
auto *results = def.getValueAsListInit("resultPatterns");
return results->size();
}
DagNode Pattern::getResultPattern(unsigned index) const {
auto *results = def.getValueAsListInit("resultPatterns");
return DagNode(cast<llvm::DagInit>(results->getElement(index)));
}
void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
infoMap.assignUniqueAlternativeNames();
LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
}
void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
auto pattern = getResultPattern(i);
collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
}
LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
}
const Operator &Pattern::getSourceRootOp() {
return getSourcePattern().getDialectOp(recordOpMap);
}
Operator &Pattern::getDialectOp(DagNode node) {
return node.getDialectOp(recordOpMap);
}
std::vector<AppliedConstraint> Pattern::getConstraints() const {
auto *listInit = def.getValueAsListInit("constraints");
std::vector<AppliedConstraint> ret;
ret.reserve(listInit->size());
for (auto it : *listInit) {
auto *dagInit = dyn_cast<llvm::DagInit>(it);
if (!dagInit)
PrintFatalError(&def, "all elements in Pattern multi-entity "
"constraints should be DAG nodes");
std::vector<std::string> entities;
entities.reserve(dagInit->arg_size());
for (auto *argName : dagInit->getArgNames()) {
if (!argName) {
PrintFatalError(
&def,
"operands to additional constraints can only be symbol references");
}
entities.push_back(std::string(argName->getValue()));
}
ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
dagInit->getNameStr(), std::move(entities));
}
return ret;
}
int Pattern::getBenefit() const {
// The initial benefit value is a heuristic with number of ops in the source
// pattern.
int initBenefit = getSourcePattern().getNumOps();
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
PrintFatalError(&def,
"The 'addBenefit' takes and only takes one integer value");
}
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
}
std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
std::vector<std::pair<StringRef, unsigned>> result;
result.reserve(def.getLoc().size());
for (auto loc : def.getLoc()) {
unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
assert(buf && "invalid source location");
result.emplace_back(
llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
llvm::SrcMgr.getLineAndColumn(loc, buf).first);
}
return result;
}
void Pattern::verifyBind(bool result, StringRef symbolName) {
if (!result) {
auto err = formatv("symbol '{0}' bound more than once", symbolName);
PrintFatalError(&def, err);
}
}
void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern) {
auto treeName = tree.getSymbol();
auto numTreeArgs = tree.getNumArgs();
if (tree.isNativeCodeCall()) {
if (!treeName.empty()) {
if (!isSrcPattern) {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
<< treeName << '\n');
verifyBind(infoMap.bindValue(treeName), treeName);
} else {
PrintFatalError(&def,
formatv("binding symbol '{0}' to NativecodeCall in "
"MatchPattern is not supported",
treeName));
}
}
for (int i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
continue;
}
if (!isSrcPattern)
continue;
// We can only bind symbols to arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
// `$_` is a special symbol meaning ignore the current argument.
if (!treeArgName.empty() && treeArgName != "_") {
DagLeaf leaf = tree.getArgAsLeaf(i);
// In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
if (leaf.isUnspecified()) {
// This is case of $c, a Value without any constraints.
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
} else {
auto constraint = leaf.getAsConstraint();
bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
leaf.isConstantAttr() ||
constraint.getKind() == Constraint::Kind::CK_Attr;
if (isAttr) {
// This is case of $a, a binding to a certain attribute.
verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
continue;
}
// This is case of $b, a binding to a certain type.
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
}
}
}
return;
}
if (tree.isOperation()) {
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();
// The pattern might have the last argument specifying the location.
bool hasLocDirective = false;
if (numTreeArgs != 0) {
if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
hasLocDirective = lastArg.isLocationDirective();
}
if (numOpArgs != numTreeArgs - hasLocDirective) {
auto err = formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs, numOpArgs);
PrintFatalError(&def, err);
}
// The name attached to the DAG node's operator is for representing the
// results generated from this op. It should be remembered as bound results.
if (!treeName.empty()) {
LLVM_DEBUG(llvm::dbgs()
<< "found symbol bound to op result: " << treeName << '\n');
verifyBind(infoMap.bindOpResult(treeName, op), treeName);
}
for (int i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
continue;
}
if (isSrcPattern) {
// We can only bind symbols to op arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
// `$_` is a special symbol meaning ignore the current argument.
if (!treeArgName.empty() && treeArgName != "_") {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
<< treeArgName << '\n');
verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
}
}
}
return;
}
if (!treeName.empty()) {
PrintFatalError(
&def, formatv("binding symbol '{0}' to non-operation/native code call "
"unsupported right now",
treeName));
}
return;
}

View File

@ -0,0 +1,451 @@
//===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
// Pattern.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_PATTERN_H_
#define MLIR_TABLEGEN_PATTERN_H_
#include "mlir/Support/LLVM.h"
#include "Argument.h"
#include "Operator.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include <unordered_map>
namespace llvm {
class DagInit;
class Init;
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Mapping from TableGen Record to Operator wrapper object.
//
// We allocate each wrapper object in heap to make sure the pointer to it is
// valid throughout the lifetime of this map. This is important because this map
// is shared among multiple patterns to avoid creating the wrapper object for
// the same op again and again. But this map will continuously grow.
using RecordOperatorMap =
DenseMap<const llvm::Record *, std::unique_ptr<Operator>>;
class Pattern;
// Wrapper class providing helper methods for accessing TableGen DAG leaves
// used inside Patterns. This class is lightweight and designed to be used like
// values.
//
// A TableGen DAG construct is of the syntax
// `(operator, arg0, arg1, ...)`.
//
// This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
// for handy helper methods. It only works on `arg*`s that are not nested DAG
// constructs.
class DagLeaf {
public:
explicit DagLeaf(const llvm::Init *def) : def(def) {}
// Returns true if this DAG leaf is not specified in the pattern. That is, it
// places no further constraints/transforms and just carries over the original
// value.
bool isUnspecified() const;
// Returns true if this DAG leaf is matching an operand. That is, it specifies
// a type constraint.
bool isOperandMatcher() const;
// Returns true if this DAG leaf is matching an attribute. That is, it
// specifies an attribute constraint.
bool isAttrMatcher() const;
// Returns true if this DAG leaf is wrapping native code call.
bool isNativeCodeCall() const;
// Returns true if this DAG leaf is specifying a constant attribute.
bool isConstantAttr() const;
// Returns true if this DAG leaf is specifying an enum attribute case.
bool isEnumAttrCase() const;
// Returns true if this DAG leaf is specifying a string attribute.
bool isStringAttr() const;
// Returns this DAG leaf as a constraint. Asserts if fails.
Constraint getAsConstraint() const;
// Returns this DAG leaf as an constant attribute. Asserts if fails.
ConstantAttr getAsConstantAttr() const;
// Returns this DAG leaf as an enum attribute case.
// Precondition: isEnumAttrCase()
EnumAttrCase getAsEnumAttrCase() const;
// Returns the matching condition template inside this DAG leaf. Assumes the
// leaf is an operand/attribute matcher and asserts otherwise.
std::string getConditionTemplate() const;
// Returns the native code call template inside this DAG leaf.
// Precondition: isNativeCodeCall()
StringRef getNativeCodeTemplate() const;
// Returns the string associated with the leaf.
// Precondition: isStringAttr()
std::string getStringAttr() const;
void print(raw_ostream &os) const;
private:
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
// also a subclass of the given `superclass`.
bool isSubClassOf(StringRef superclass) const;
const llvm::Init *def;
};
// Wrapper class providing helper methods for accessing TableGen DAG constructs
// used inside Patterns. This class is lightweight and designed to be used like
// values.
//
// A TableGen DAG construct is of the syntax
// `(operator, arg0, arg1, ...)`.
//
// When used inside Patterns, `operator` corresponds to some dialect op, or
// a known list of verbs that defines special transformation actions. This
// `arg*` can be a nested DAG construct. This class provides getters to
// retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
// methods.
//
// A null DagNode contains a nullptr and converts to false implicitly.
class DagNode {
public:
explicit DagNode(const llvm::DagInit *node) : node(node) {}
// Implicit bool converter that returns true if this DagNode is not a null
// DagNode.
operator bool() const { return node != nullptr; }
// Returns the symbol bound to this DAG node.
StringRef getSymbol() const;
// Returns the operator wrapper object corresponding to the dialect op matched
// by this DAG. The operator wrapper will be queried from the given `mapper`
// and created in it if not existing.
Operator &getDialectOp(RecordOperatorMap *mapper) const;
// Returns the number of operations recursively involved in the DAG tree
// rooted from this node.
int getNumOps() const;
// Returns the number of immediate arguments to this DAG node.
int getNumArgs() const;
// Returns true if the `index`-th argument is a nested DAG construct.
bool isNestedDagArg(unsigned index) const;
// Gets the `index`-th argument as a nested DAG construct if possible. Returns
// null DagNode otherwise.
DagNode getArgAsNestedDag(unsigned index) const;
// Gets the `index`-th argument as a DAG leaf.
DagLeaf getArgAsLeaf(unsigned index) const;
// Returns the specified name of the `index`-th argument.
StringRef getArgName(unsigned index) const;
// Returns true if this DAG construct means to replace with an existing SSA
// value.
bool isReplaceWithValue() const;
// Returns whether this DAG represents the location of an op creation.
bool isLocationDirective() const;
// Returns true if this DAG node is wrapping native code call.
bool isNativeCodeCall() const;
// Returns true if this DAG node is an operation.
bool isOperation() const;
// Returns the native code call template inside this DAG node.
// Precondition: isNativeCodeCall()
StringRef getNativeCodeTemplate() const;
void print(raw_ostream &os) const;
private:
const llvm::DagInit *node; // nullptr means null DagNode
};
// A class for maintaining information for symbols bound in patterns and
// provides methods for resolving them according to specific use cases.
//
// Symbols can be bound to
//
// * Op arguments and op results in the source pattern and
// * Op results in result patterns.
//
// Symbols can be referenced in result patterns and additional constraints to
// the pattern.
//
// For example, in
//
// ```
// def : Pattern<
// (SrcOp:$results1 $arg0, %arg1),
// [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
// ```
//
// `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
// `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
// `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
//
// If a symbol binds to a multi-result op and it does not have the `__N`
// suffix, the symbol is expanded to represent all results generated by the
// multi-result op. If the symbol has a `__N` suffix, then it will expand to
// only the N-th *static* result as declared in ODS, and that can still
// corresponds to multiple *dynamic* values if the N-th *static* result is
// variadic.
//
// This class keeps track of such symbols and resolves them into their bound
// values in a suitable way.
class SymbolInfoMap {
public:
explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {}
// Class for information regarding a symbol.
class SymbolInfo {
public:
// Returns a string for defining a variable named as `name` to store the
// value bound by this symbol.
std::string getVarDecl(StringRef name) const;
// Returns a variable name for the symbol named as `name`.
std::string getVarName(StringRef name) const;
private:
// Allow SymbolInfoMap to access private methods.
friend class SymbolInfoMap;
// What kind of entity this symbol represents:
// * Attr: op attribute
// * Operand: op operand
// * Result: op result
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
enum class Kind : uint8_t { Attr, Operand, Result, Value };
// Creates a SymbolInfo instance. `index` is only used for `Attr` and
// `Operand` so should be negative for `Result` and `Value` kind.
SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
// Static methods for creating SymbolInfo.
static SymbolInfo getAttr(const Operator *op, int index) {
return SymbolInfo(op, Kind::Attr, index);
}
static SymbolInfo getAttr() {
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
}
static SymbolInfo getOperand(const Operator *op, int index) {
return SymbolInfo(op, Kind::Operand, index);
}
static SymbolInfo getResult(const Operator *op) {
return SymbolInfo(op, Kind::Result, llvm::None);
}
static SymbolInfo getValue() {
return SymbolInfo(nullptr, Kind::Value, llvm::None);
}
// Returns the number of static values this symbol corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol
// only represents one static value, but symbols bound to op results can
// represent more than one if the op is a multi-result op.
int getStaticValueCount() const;
// Returns a string containing the C++ expression for referencing this
// symbol as a value (if this symbol represents one static value) or a value
// range (if this symbol represents multiple static values). `name` is the
// name of the C++ variable that this symbol bounds to. `index` should only
// be used for indexing results. `fmt` is used to format each value.
// `separator` is used to separate values if this is a value range.
std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
const char *separator) const;
// Returns a string containing the C++ expression for referencing this
// symbol as a value range regardless of how many static values this symbol
// represents. `name` is the name of the C++ variable that this symbol
// bounds to. `index` should only be used for indexing results. `fmt` is
// used to format each value. `separator` is used to separate values in the
// range.
std::string getAllRangeUse(StringRef name, int index, const char *fmt,
const char *separator) const;
const Operator *op; // The op where the bound entity belongs
Kind kind; // The kind of the bound entity
// The argument index (for `Attr` and `Operand` only)
Optional<int> argIndex;
// Alternative name for the symbol. It is used in case the name
// is not unique. Applicable for `Operand` only.
Optional<std::string> alternativeName;
};
using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
// Iterators for accessing all symbols.
using iterator = BaseT::iterator;
iterator begin() { return symbolInfoMap.begin(); }
iterator end() { return symbolInfoMap.end(); }
// Const iterators for accessing all symbols.
using const_iterator = BaseT::const_iterator;
const_iterator begin() const { return symbolInfoMap.begin(); }
const_iterator end() const { return symbolInfoMap.end(); }
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
// Returns false if `symbol` is already bound and symbols are not operands.
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
// Binds the given `symbol` to the results the given `op`. Returns false if
// `symbol` is already bound.
bool bindOpResult(StringRef symbol, const Operator &op);
// Registers the given `symbol` as bound to a value. Returns false if `symbol`
// is already bound.
bool bindValue(StringRef symbol);
// Registers the given `symbol` as bound to an attr. Returns false if `symbol`
// is already bound.
bool bindAttr(StringRef symbol);
// Returns true if the given `symbol` is bound.
bool contains(StringRef symbol) const;
// Returns an iterator to the information of the given symbol named as `key`.
const_iterator find(StringRef key) const;
// Returns an iterator to the information of the given symbol named as `key`,
// with index `argIndex` for operator `op`.
const_iterator findBoundSymbol(StringRef key, const Operator &op,
int argIndex) const;
// Returns the bounds of a range that includes all the elements which
// bind to the `key`.
std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
// Returns number of times symbol named as `key` was used.
int count(StringRef key) const;
// Returns the number of static values of the given `symbol` corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol only
// represents one static value, but symbols bound to op results can represent
// more than one if the op is a multi-result op.
int getStaticValueCount(StringRef symbol) const;
// Returns a string containing the C++ expression for referencing this
// symbol as a value (if this symbol represents one static value) or a value
// range (if this symbol represents multiple static values). `fmt` is used to
// format each value. `separator` is used to separate values if `symbol`
// represents a value range.
std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;
// Returns a string containing the C++ expression for referencing this
// symbol as a value range regardless of how many static values this symbol
// represents. `fmt` is used to format each value. `separator` is used to
// separate values in the range.
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;
// Assign alternative unique names to Operands that have equal names.
void assignUniqueAlternativeNames();
// Splits the given `symbol` into a value pack name and an index. Returns the
// value pack name and writes the index to `index` on success. Returns
// `symbol` itself if it does not contain an index.
//
// We can use `name__N` to access the `N`-th value in the value pack bound to
// `name`. `name` is typically the results of an multi-result op.
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
private:
BaseT symbolInfoMap;
// Pattern instantiation location. This is intended to be used as parameter
// to PrintFatalError() to report errors.
ArrayRef<llvm::SMLoc> loc;
};
// Wrapper class providing helper methods for accessing MLIR Pattern defined
// in TableGen. This class should closely reflect what is defined as class
// `Pattern` in TableGen. This class contains maps so it is not intended to be
// used as values.
class Pattern {
public:
explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
// Returns the source pattern to match.
DagNode getSourcePattern() const;
// Returns the number of result patterns generated by applying this rewrite
// rule.
int getNumResultPatterns() const;
// Returns the DAG tree root node of the `index`-th result pattern.
DagNode getResultPattern(unsigned index) const;
// Collects all symbols bound in the source pattern into `infoMap`.
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
// Collects all symbols bound in result patterns into `infoMap`.
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
// Returns the op that the root node of the source pattern matches.
const Operator &getSourceRootOp();
// Returns the operator wrapper object corresponding to the given `node`'s DAG
// operator.
Operator &getDialectOp(DagNode node);
// Returns the constraints.
std::vector<AppliedConstraint> getConstraints() const;
// Returns the benefit score of the pattern.
int getBenefit() const;
using IdentifierLine = std::pair<StringRef, unsigned>;
// Returns the file location of the pattern (buffer identifier + line number
// pair).
std::vector<IdentifierLine> getLocation() const;
private:
// Helper function to verify variabld binding.
void verifyBind(bool result, StringRef symbolName);
// Recursively collects all bound symbols inside the DAG tree rooted
// at `tree` and updates the given `infoMap`.
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern);
// The TableGen definition of this pattern.
const llvm::Record &def;
// All operators.
// TODO: we need a proper context manager, like MLIRContext, for managing the
// lifetime of shared entities.
RecordOperatorMap *recordOpMap;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_PATTERN_H_

View File

@ -0,0 +1,376 @@
//===- Predicate.cpp - Predicate class ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Wrapper around predicates defined in TableGen.
//
//===----------------------------------------------------------------------===//
#include "Predicate.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace tblgen;
// Construct a Predicate from a record.
Pred::Pred(const llvm::Record *record) : def(record) {
assert(def->isSubClassOf("Pred") &&
"must be a subclass of TableGen 'Pred' class");
}
// Construct a Predicate from an initializer.
Pred::Pred(const llvm::Init *init) : def(nullptr) {
if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
def = defInit->getDef();
}
std::string Pred::getCondition() const {
// Static dispatch to subclasses.
if (def->isSubClassOf("CombinedPred"))
return static_cast<const CombinedPred *>(this)->getConditionImpl();
if (def->isSubClassOf("CPred"))
return static_cast<const CPred *>(this)->getConditionImpl();
llvm_unreachable("Pred::getCondition must be overridden in subclasses");
}
bool Pred::isCombined() const {
return def && def->isSubClassOf("CombinedPred");
}
ArrayRef<llvm::SMLoc> Pred::getLoc() const { return def->getLoc(); }
CPred::CPred(const llvm::Record *record) : Pred(record) {
assert(def->isSubClassOf("CPred") &&
"must be a subclass of Tablegen 'CPred' class");
}
CPred::CPred(const llvm::Init *init) : Pred(init) {
assert((!def || def->isSubClassOf("CPred")) &&
"must be a subclass of Tablegen 'CPred' class");
}
// Get condition of the C Predicate.
std::string CPred::getConditionImpl() const {
assert(!isNull() && "null predicate does not have a condition");
return std::string(def->getValueAsString("predExpr"));
}
CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
assert(def->isSubClassOf("CombinedPred") &&
"must be a subclass of Tablegen 'CombinedPred' class");
}
CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
assert((!def || def->isSubClassOf("CombinedPred")) &&
"must be a subclass of Tablegen 'CombinedPred' class");
}
const llvm::Record *CombinedPred::getCombinerDef() const {
assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
return def->getValueAsDef("kind");
}
const std::vector<llvm::Record *> CombinedPred::getChildren() const {
assert(def->getValue("children") &&
"CombinedPred must have a value 'children'");
return def->getValueAsListOfDefs("children");
}
namespace {
// Kinds of nodes in a logical predicate tree.
enum class PredCombinerKind {
Leaf,
And,
Or,
Not,
SubstLeaves,
Concat,
// Special kinds that are used in simplification.
False,
True
};
// A node in a logical predicate tree.
struct PredNode {
PredCombinerKind kind;
const Pred *predicate;
SmallVector<PredNode *, 4> children;
std::string expr;
// Prefix and suffix are used by ConcatPred.
std::string prefix;
std::string suffix;
};
} // end anonymous namespace
// Get a predicate tree node kind based on the kind used in the predicate
// TableGen record.
static PredCombinerKind getPredCombinerKind(const Pred &pred) {
if (!pred.isCombined())
return PredCombinerKind::Leaf;
const auto &combinedPred = static_cast<const CombinedPred &>(pred);
return StringSwitch<PredCombinerKind>(
combinedPred.getCombinerDef()->getName())
.Case("PredCombinerAnd", PredCombinerKind::And)
.Case("PredCombinerOr", PredCombinerKind::Or)
.Case("PredCombinerNot", PredCombinerKind::Not)
.Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
.Case("PredCombinerConcat", PredCombinerKind::Concat);
}
namespace {
// Substitution<pattern, replacement>.
using Subst = std::pair<StringRef, StringRef>;
} // end anonymous namespace
/// Perform the given substitutions on 'str' in-place.
static void performSubstitutions(std::string &str,
ArrayRef<Subst> substitutions) {
// Apply all parent substitutions from innermost to outermost.
for (const auto &subst : llvm::reverse(substitutions)) {
auto pos = str.find(std::string(subst.first));
while (pos != std::string::npos) {
str.replace(pos, subst.first.size(), std::string(subst.second));
// Skip the newly inserted substring, which itself may consider the
// pattern to match.
pos += subst.second.size();
// Find the next possible match position.
pos = str.find(std::string(subst.first), pos);
}
}
}
// Build the predicate tree starting from the top-level predicate, which may
// have children, and perform leaf substitutions inplace. Note that after
// substitution, nodes are still pointing to the original TableGen record.
// All nodes are created within "allocator".
static PredNode *
buildPredicateTree(const Pred &root,
llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
ArrayRef<Subst> substitutions) {
auto *rootNode = allocator.Allocate();
new (rootNode) PredNode;
rootNode->kind = getPredCombinerKind(root);
rootNode->predicate = &root;
if (!root.isCombined()) {
rootNode->expr = root.getCondition();
performSubstitutions(rootNode->expr, substitutions);
return rootNode;
}
// If the current combined predicate is a leaf substitution, append it to the
// list before continuing.
auto allSubstitutions = llvm::to_vector<4>(substitutions);
if (rootNode->kind == PredCombinerKind::SubstLeaves) {
const auto &substPred = static_cast<const SubstLeavesPred &>(root);
allSubstitutions.push_back(
{substPred.getPattern(), substPred.getReplacement()});
// If the current predicate is a ConcatPred, record the prefix and suffix.
} else if (rootNode->kind == PredCombinerKind::Concat) {
const auto &concatPred = static_cast<const ConcatPred &>(root);
rootNode->prefix = std::string(concatPred.getPrefix());
performSubstitutions(rootNode->prefix, substitutions);
rootNode->suffix = std::string(concatPred.getSuffix());
performSubstitutions(rootNode->suffix, substitutions);
}
// Build child subtrees.
auto combined = static_cast<const CombinedPred &>(root);
for (const auto *record : combined.getChildren()) {
auto childTree =
buildPredicateTree(Pred(record), allocator, allSubstitutions);
rootNode->children.push_back(childTree);
}
return rootNode;
}
// Simplify a predicate tree rooted at "node" using the predicates that are
// known to be true(false). For AND(OR) combined predicates, if any of the
// children is known to be false(true), the result is also false(true).
// Furthermore, for AND(OR) combined predicates, children that are known to be
// true(false) don't have to be checked dynamically.
static PredNode *
propagateGroundTruth(PredNode *node,
const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
// If the current predicate is known to be true or false, change the kind of
// the node and return immediately.
if (knownTruePreds.count(node->predicate) != 0) {
node->kind = PredCombinerKind::True;
node->children.clear();
return node;
}
if (knownFalsePreds.count(node->predicate) != 0) {
node->kind = PredCombinerKind::False;
node->children.clear();
return node;
}
// If the current node is a substitution, stop recursion now.
// The expressions in the leaves below this node were rewritten, but the nodes
// still point to the original predicate records. While the original
// predicate may be known to be true or false, it is not necessarily the case
// after rewriting.
// TODO: we can support ground truth for rewritten
// predicates by either (a) having our own unique'ing of the predicates
// instead of relying on TableGen record pointers or (b) taking ground truth
// values optionally prefixed with a list of substitutions to apply, e.g.
// "predX is true by itself as well as predSubY leaf substitution had been
// applied to it".
if (node->kind == PredCombinerKind::SubstLeaves) {
return node;
}
// Otherwise, look at child nodes.
// Move child nodes into some local variable so that they can be optimized
// separately and re-added if necessary.
llvm::SmallVector<PredNode *, 4> children;
std::swap(node->children, children);
for (auto &child : children) {
// First, simplify the child. This maintains the predicate as it was.
auto simplifiedChild =
propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
// Just add the child if we don't know how to simplify the current node.
if (node->kind != PredCombinerKind::And &&
node->kind != PredCombinerKind::Or) {
node->children.push_back(simplifiedChild);
continue;
}
// Second, based on the type define which known values of child predicates
// immediately collapse this predicate to a known value, and which others
// may be safely ignored.
// OR(..., True, ...) = True
// OR(..., False, ...) = OR(..., ...)
// AND(..., False, ...) = False
// AND(..., True, ...) = AND(..., ...)
auto collapseKind = node->kind == PredCombinerKind::And
? PredCombinerKind::False
: PredCombinerKind::True;
auto eraseKind = node->kind == PredCombinerKind::And
? PredCombinerKind::True
: PredCombinerKind::False;
const auto &collapseList =
node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
const auto &eraseList =
node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
if (simplifiedChild->kind == collapseKind ||
collapseList.count(simplifiedChild->predicate) != 0) {
node->kind = collapseKind;
node->children.clear();
return node;
} else if (simplifiedChild->kind == eraseKind ||
eraseList.count(simplifiedChild->predicate) != 0) {
continue;
}
node->children.push_back(simplifiedChild);
}
return node;
}
// Combine a list of predicate expressions using a binary combiner. If a list
// is empty, return "init".
static std::string combineBinary(ArrayRef<std::string> children,
std::string combiner, std::string init) {
if (children.empty())
return init;
auto size = children.size();
if (size == 1)
return children.front();
std::string str;
llvm::raw_string_ostream os(str);
os << '(' << children.front() << ')';
for (unsigned i = 1; i < size; ++i) {
os << ' ' << combiner << " (" << children[i] << ')';
}
return os.str();
}
// Prepend negation to the only condition in the predicate expression list.
static std::string combineNot(ArrayRef<std::string> children) {
assert(children.size() == 1 && "expected exactly one child predicate of Neg");
return (Twine("!(") + children.front() + Twine(')')).str();
}
// Recursively traverse the predicate tree in depth-first post-order and build
// the final expression.
static std::string getCombinedCondition(const PredNode &root) {
// Immediately return for non-combiner predicates that don't have children.
if (root.kind == PredCombinerKind::Leaf)
return root.expr;
if (root.kind == PredCombinerKind::True)
return "true";
if (root.kind == PredCombinerKind::False)
return "false";
// Recurse into children.
llvm::SmallVector<std::string, 4> childExpressions;
childExpressions.reserve(root.children.size());
for (const auto &child : root.children)
childExpressions.push_back(getCombinedCondition(*child));
// Combine the expressions based on the predicate node kind.
if (root.kind == PredCombinerKind::And)
return combineBinary(childExpressions, "&&", "true");
if (root.kind == PredCombinerKind::Or)
return combineBinary(childExpressions, "||", "false");
if (root.kind == PredCombinerKind::Not)
return combineNot(childExpressions);
if (root.kind == PredCombinerKind::Concat) {
assert(childExpressions.size() == 1 &&
"ConcatPred should only have one child");
return root.prefix + childExpressions.front() + root.suffix;
}
// Substitutions were applied before so just ignore them.
if (root.kind == PredCombinerKind::SubstLeaves) {
assert(childExpressions.size() == 1 &&
"substitution predicate must have one child");
return childExpressions[0];
}
llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
}
std::string CombinedPred::getConditionImpl() const {
llvm::SpecificBumpPtrAllocator<PredNode> allocator;
auto predicateTree = buildPredicateTree(*this, allocator, {});
predicateTree =
propagateGroundTruth(predicateTree,
/*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
/*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
return getCombinedCondition(*predicateTree);
}
StringRef SubstLeavesPred::getPattern() const {
return def->getValueAsString("pattern");
}
StringRef SubstLeavesPred::getReplacement() const {
return def->getValueAsString("replacement");
}
StringRef ConcatPred::getPrefix() const {
return def->getValueAsString("prefix");
}
StringRef ConcatPred::getSuffix() const {
return def->getValueAsString("suffix");
}

View File

@ -0,0 +1,119 @@
//===- Predicate.h - Predicate class ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Wrapper around predicates defined in TableGen.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_PREDICATE_H_
#define MLIR_TABLEGEN_PREDICATE_H_
#include "mlir/Support/LLVM.h"
#include <string>
#include <vector>
namespace llvm {
class Init;
class ListInit;
class Record;
class SMLoc;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// A logical predicate. This class must closely follow the definition of
// TableGen class 'Pred'.
class Pred {
public:
// Constructs the null Predicate (e.g., always true).
explicit Pred() : def(nullptr) {}
// Construct a Predicate from a record.
explicit Pred(const llvm::Record *record);
// Construct a Predicate from an initializer.
explicit Pred(const llvm::Init *init);
// Check if the predicate is defined. Callers may use this to interpret the
// missing predicate as either true (e.g. in filters) or false (e.g. in
// precondition verification).
bool isNull() const { return def == nullptr; }
// Get the predicate condition. This may dispatch to getConditionImpl() of
// the underlying predicate type.
std::string getCondition() const;
// Whether the predicate is a combination of other predicates, i.e. an
// record of type CombinedPred.
bool isCombined() const;
// Records are pointer-comparable.
bool operator==(const Pred &other) const { return def == other.def; }
// Get the location of the predicate.
ArrayRef<llvm::SMLoc> getLoc() const;
protected:
// The TableGen definition of this predicate.
const llvm::Record *def;
};
// A logical predicate wrapping a C expression. This class must closely follow
// the definition of TableGen class 'CPred'.
class CPred : public Pred {
public:
// Construct a CPred from a record.
explicit CPred(const llvm::Record *record);
// Construct a CPred an initializer.
explicit CPred(const llvm::Init *init);
// Get the predicate condition.
std::string getConditionImpl() const;
};
// A logical predicate that is a combination of other predicates. This class
// must closely follow the definition of TableGen class 'CombinedPred'.
class CombinedPred : public Pred {
public:
// Construct a CombinedPred from a record.
explicit CombinedPred(const llvm::Record *record);
// Construct a CombinedPred from an initializer.
explicit CombinedPred(const llvm::Init *init);
// Get the predicate condition.
std::string getConditionImpl() const;
// Get the definition of the combiner used in this predicate.
const llvm::Record *getCombinerDef() const;
// Get the predicates that are combined by this predicate.
const std::vector<llvm::Record *> getChildren() const;
};
// A combined predicate that requires all child predicates of 'CPred' type to
// have their expression rewritten with a simple string substitution rule.
class SubstLeavesPred : public CombinedPred {
public:
// Get the replacement pattern.
StringRef getPattern() const;
// Get the string used to replace the pattern.
StringRef getReplacement() const;
};
// A combined predicate that prepends a prefix and appends a suffix to the
// predicate string composed from a child predicate.
class ConcatPred : public CombinedPred {
public:
StringRef getPrefix() const;
StringRef getSuffix() const;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_PREDICATE_H_

View File

@ -0,0 +1,20 @@
//===- Region.cpp - Region class ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Region wrapper to simplify using TableGen Record defining a MLIR Region.
//
//===----------------------------------------------------------------------===//
#include "Region.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
// Returns true if this region is variadic.
bool Region::isVariadic() const { return def->isSubClassOf("VariadicRegion"); }

View File

@ -0,0 +1,42 @@
//===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_REGION_H_
#define MLIR_TABLEGEN_REGION_H_
#include "mlir/Support/LLVM.h"
#include "Constraint.h"
namespace mlir {
namespace tblgen {
// Wrapper class providing helper methods for accessing Region defined in
// TableGen.
class Region : public Constraint {
public:
using Constraint::Constraint;
static bool classof(const Constraint *c) { return c->getKind() == CK_Region; }
// Returns true if this region is variadic.
bool isVariadic() const;
};
// A struct bundling a region's constraint and its name.
struct NamedRegion {
// Returns true if this region is variadic.
bool isVariadic() const { return constraint.isVariadic(); }
StringRef name;
Region constraint;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_REGION_H_

View File

@ -0,0 +1,58 @@
//===- SideEffects.cpp - SideEffect classes -------------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "SideEffects.h"
#include "llvm/ADT/Twine.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// SideEffect
//===----------------------------------------------------------------------===//
StringRef SideEffect::getName() const {
return def->getValueAsString("effect");
}
StringRef SideEffect::getBaseEffectName() const {
return def->getValueAsString("baseEffectName");
}
std::string SideEffect::getInterfaceTrait() const {
StringRef trait = def->getValueAsString("interfaceTrait");
StringRef cppNamespace = def->getValueAsString("cppNamespace");
return cppNamespace.empty() ? trait.str()
: (cppNamespace + "::" + trait).str();
}
StringRef SideEffect::getResource() const {
return def->getValueAsString("resource");
}
bool SideEffect::classof(const Operator::VariableDecorator *var) {
return var->getDef().isSubClassOf("SideEffect");
}
//===----------------------------------------------------------------------===//
// SideEffectsTrait
//===----------------------------------------------------------------------===//
Operator::var_decorator_range SideEffectTrait::getEffects() const {
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("effects"));
return {listInit->begin(), listInit->end()};
}
StringRef SideEffectTrait::getBaseEffectName() const {
return def->getValueAsString("baseEffectName");
}
bool SideEffectTrait::classof(const Trait *t) {
return t->getDef().isSubClassOf("SideEffectsTraitBase");
}

View File

@ -0,0 +1,58 @@
//===- SideEffects.h - Side Effects classes ---------------------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Wrapper around side effect related classes defined in TableGen.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_SIDEEFFECTS_H_
#define MLIR_TABLEGEN_SIDEEFFECTS_H_
#include "mlir/Support/LLVM.h"
#include "Operator.h"
namespace mlir {
namespace tblgen {
// This class represents a specific instance of an effect that is being
// exhibited.
class SideEffect : public Operator::VariableDecorator {
public:
// Return the name of the C++ effect.
StringRef getName() const;
// Return the name of the base C++ effect.
StringRef getBaseEffectName() const;
// Return the name of the Interface that the effect belongs to.
std::string getInterfaceTrait() const;
// Return the name of the resource class.
StringRef getResource() const;
static bool classof(const Operator::VariableDecorator *var);
};
// This class represents an instance of a side effect interface applied to an
// operation. This is a wrapper around an OpInterfaceTrait that also includes
// the effects that are applied.
class SideEffectTrait : public InterfaceTrait {
public:
// Return the effects that are attached to the side effect interface.
Operator::var_decorator_range getEffects() const;
// Return the name of the base C++ effect.
StringRef getBaseEffectName() const;
static bool classof(const Trait *t);
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_SIDEEFFECTS_H_

View File

@ -0,0 +1,24 @@
//===- Successor.cpp - Successor class ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Successor wrapper to simplify using TableGen Record defining a MLIR
// Successor.
//
//===----------------------------------------------------------------------===//
#include "Successor.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
// Returns true if this successor is variadic.
bool Successor::isVariadic() const {
return def->isSubClassOf("VariadicSuccessor");
}

View File

@ -0,0 +1,44 @@
//===- Successor.h - TableGen successor definitions -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_SUCCESSOR_H_
#define MLIR_TABLEGEN_SUCCESSOR_H_
#include "mlir/Support/LLVM.h"
#include "Constraint.h"
namespace mlir {
namespace tblgen {
// Wrapper class providing helper methods for accessing Successor defined in
// TableGen.
class Successor : public Constraint {
public:
using Constraint::Constraint;
static bool classof(const Constraint *c) {
return c->getKind() == CK_Successor;
}
// Returns true if this successor is variadic.
bool isVariadic() const;
};
// A struct bundling a successor's constraint and its name.
struct NamedSuccessor {
// Returns true if this successor is variadic.
bool isVariadic() const { return constraint.isVariadic(); }
StringRef name;
Successor constraint;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_SUCCESSOR_H_

View File

@ -0,0 +1,93 @@
//===- Trait.cpp ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Trait wrapper to simplify using TableGen Record defining a MLIR Trait.
//
//===----------------------------------------------------------------------===//
#include "Trait.h"
#include "Interfaces.h"
#include "Predicate.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// Trait
//===----------------------------------------------------------------------===//
Trait Trait::create(const llvm::Init *init) {
auto def = cast<llvm::DefInit>(init)->getDef();
if (def->isSubClassOf("PredTrait"))
return Trait(Kind::Pred, def);
if (def->isSubClassOf("GenInternalTrait"))
return Trait(Kind::Internal, def);
if (def->isSubClassOf("InterfaceTrait"))
return Trait(Kind::Interface, def);
assert(def->isSubClassOf("NativeTrait"));
return Trait(Kind::Native, def);
}
Trait::Trait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
//===----------------------------------------------------------------------===//
// NativeTrait
//===----------------------------------------------------------------------===//
std::string NativeTrait::getFullyQualifiedTraitName() const {
llvm::StringRef trait = def->getValueAsString("trait");
llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
return cppNamespace.empty() ? trait.str()
: (cppNamespace + "::" + trait).str();
}
//===----------------------------------------------------------------------===//
// InternalTrait
//===----------------------------------------------------------------------===//
llvm::StringRef InternalTrait::getFullyQualifiedTraitName() const {
return def->getValueAsString("trait");
}
//===----------------------------------------------------------------------===//
// PredTrait
//===----------------------------------------------------------------------===//
std::string PredTrait::getPredTemplate() const {
auto pred = Pred(def->getValueInit("predicate"));
return pred.getCondition();
}
llvm::StringRef PredTrait::getSummary() const {
return def->getValueAsString("summary");
}
//===----------------------------------------------------------------------===//
// InterfaceTrait
//===----------------------------------------------------------------------===//
Interface InterfaceTrait::getInterface() const { return Interface(def); }
std::string InterfaceTrait::getFullyQualifiedTraitName() const {
llvm::StringRef trait = def->getValueAsString("trait");
llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
return cppNamespace.empty() ? trait.str()
: (cppNamespace + "::" + trait).str();
}
bool InterfaceTrait::shouldDeclareMethods() const {
return def->isSubClassOf("DeclareInterfaceMethods");
}
std::vector<StringRef> InterfaceTrait::getAlwaysDeclaredMethods() const {
return def->getValueAsListOfStrings("alwaysOverriddenMethods");
}

View File

@ -0,0 +1,116 @@
//===- Trait.h - Trait wrapper class ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Trait wrapper to simplify using TableGen Record defining an MLIR Trait.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_TRAIT_H_
#define MLIR_TABLEGEN_TRAIT_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include <vector>
namespace llvm {
class Init;
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
class Interface;
// Wrapper class with helper methods for accessing Trait constraints defined in
// TableGen.
class Trait {
public:
// Discriminator for kinds of traits.
enum class Kind {
// Trait corresponding to C++ class.
Native,
// Trait corresponding to a predicate.
Pred,
// Trait controlling definition generator internals.
Internal,
// Trait corresponding to an Interface.
Interface
};
explicit Trait(Kind kind, const llvm::Record *def);
// Returns an Trait corresponding to the init provided.
static Trait create(const llvm::Init *init);
Kind getKind() const { return kind; }
// Returns the Tablegen definition this operator was constructed from.
const llvm::Record &getDef() const { return *def; }
protected:
// The TableGen definition of this trait.
const llvm::Record *def;
Kind kind;
};
// Trait corresponding to a native C++ Trait.
class NativeTrait : public Trait {
public:
// Returns the trait corresponding to a C++ trait class.
std::string getFullyQualifiedTraitName() const;
static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
};
// Trait corresponding to a predicate on the operation.
class PredTrait : public Trait {
public:
// Returns the template for constructing the predicate.
std::string getPredTemplate() const;
// Returns the description of what the predicate is verifying.
StringRef getSummary() const;
static bool classof(const Trait *t) { return t->getKind() == Kind::Pred; }
};
// Trait controlling op definition generator internals.
class InternalTrait : public Trait {
public:
// Returns the trait controlling op definition generator internals.
StringRef getFullyQualifiedTraitName() const;
static bool classof(const Trait *t) { return t->getKind() == Kind::Internal; }
};
// Trait corresponding to an OpInterface on the operation.
class InterfaceTrait : public Trait {
public:
// Returns interface corresponding to the trait.
Interface getInterface() const;
// Returns the trait corresponding to a C++ trait class.
std::string getFullyQualifiedTraitName() const;
static bool classof(const Trait *t) {
return t->getKind() == Kind::Interface;
}
// Whether the declaration of methods for this trait should be emitted.
bool shouldDeclareMethods() const;
// Returns the methods that should always be declared if this interface is
// emitting declarations.
std::vector<StringRef> getAlwaysDeclaredMethods() const;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_TRAIT_H_

View File

@ -0,0 +1,82 @@
//===- Type.cpp - Type class ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Type wrapper to simplify using TableGen Record defining a MLIR Type.
//
//===----------------------------------------------------------------------===//
#include "Type.h"
#include "Dialect.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
TypeConstraint::TypeConstraint(const llvm::Record *record)
: Constraint(Constraint::CK_Type, record) {
assert(def->isSubClassOf("TypeConstraint") &&
"must be subclass of TableGen 'TypeConstraint' class");
}
TypeConstraint::TypeConstraint(const llvm::DefInit *init)
: TypeConstraint(init->getDef()) {}
bool TypeConstraint::isOptional() const {
return def->isSubClassOf("Optional");
}
bool TypeConstraint::isVariadic() const {
return def->isSubClassOf("Variadic");
}
// Returns the builder call for this constraint if this is a buildable type,
// returns None otherwise.
Optional<StringRef> TypeConstraint::getBuilderCall() const {
const llvm::Record *baseType = def;
if (isVariableLength())
baseType = baseType->getValueAsDef("baseType");
// Check to see if this type constraint has a builder call.
const llvm::RecordVal *builderCall = baseType->getValue("builderCall");
if (!builderCall || !builderCall->getValue())
return llvm::None;
return TypeSwitch<llvm::Init *, Optional<StringRef>>(builderCall->getValue())
.Case<llvm::StringInit>([&](auto *init) {
StringRef value = init->getValue();
return value.empty() ? Optional<StringRef>() : value;
})
.Default([](auto *) { return llvm::None; });
}
// Return the C++ class name for this type (which may just be ::mlir::Type).
std::string TypeConstraint::getCPPClassName() const {
StringRef className = def->getValueAsString("cppClassName");
// If the class name is already namespace resolved, use it.
if (className.contains("::"))
return className.str();
// Otherwise, check to see if there is a namespace from a dialect to prepend.
if (const llvm::RecordVal *value = def->getValue("dialect")) {
Dialect dialect(cast<const llvm::DefInit>(value->getValue())->getDef());
return (dialect.getCppNamespace() + "::" + className).str();
}
return className.str();
}
Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
StringRef Type::getDescription() const {
return def->getValueAsString("description");
}
Dialect Type::getDialect() const {
return Dialect(def->getValueAsDef("dialect"));
}

View File

@ -0,0 +1,70 @@
//===- Type.h - Type class --------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Type wrapper to simplify using TableGen Record defining a MLIR Type.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_TYPE_H_
#define MLIR_TABLEGEN_TYPE_H_
#include "mlir/Support/LLVM.h"
#include "Constraint.h"
#include "Dialect.h"
namespace llvm {
class DefInit;
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Wrapper class with helper methods for accessing Type constraints defined in
// TableGen.
class TypeConstraint : public Constraint {
public:
explicit TypeConstraint(const llvm::Record *record);
explicit TypeConstraint(const llvm::DefInit *init);
static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }
// Returns true if this is an optional type constraint.
bool isOptional() const;
// Returns true if this is a variadic type constraint.
bool isVariadic() const;
// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }
// Returns the builder call for this constraint if this is a buildable type,
// returns None otherwise.
Optional<StringRef> getBuilderCall() const;
// Return the C++ class name for this type (which may just be ::mlir::Type).
std::string getCPPClassName() const;
};
// Wrapper class with helper methods for accessing Types defined in TableGen.
class Type : public TypeConstraint {
public:
explicit Type(const llvm::Record *record);
// Returns the description of the type.
StringRef getDescription() const;
// Returns the dialect for the type if defined.
Dialect getDialect() const;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_TYPE_H_

View File

@ -0,0 +1,83 @@
//===- mlir-tblgen.cpp - Top-Level TableGen implementation for MLIR -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the main function for MLIR's TableGen.
//
//===----------------------------------------------------------------------===//
#include "TableGen/GenInfo.h"
#include "TableGen/GenNameParser.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
static llvm::ManagedStatic<std::vector<GenInfo>> generatorRegistry;
mlir::GenRegistration::GenRegistration(StringRef arg, StringRef description,
GenFunction function) {
generatorRegistry->emplace_back(arg, description, function);
}
GenNameParser::GenNameParser(llvm::cl::Option &opt)
: llvm::cl::parser<const GenInfo *>(opt) {
for (const auto &kv : *generatorRegistry) {
addLiteralOption(kv.getGenArgument(), &kv, kv.getGenDescription());
}
}
void GenNameParser::printOptionInfo(const llvm::cl::Option &O,
size_t GlobalWidth) const {
GenNameParser *TP = const_cast<GenNameParser *>(this);
llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
[](const GenNameParser::OptionInfo *VT1,
const GenNameParser::OptionInfo *VT2) {
return VT1->Name.compare(VT2->Name);
});
using llvm::cl::parser;
parser<const GenInfo *>::printOptionInfo(O, GlobalWidth);
}
// Generator that prints records.
GenRegistration printRecords("print-records", "Print all records to stdout",
[](const RecordKeeper &records, raw_ostream &os) {
os << records;
return false;
});
// Generator to invoke.
const mlir::GenInfo *generator;
// TableGenMain requires a function pointer so this function is passed in which
// simply wraps the call to the generator.
static bool MlirTableGenMain(raw_ostream &os, RecordKeeper &records) {
if (!generator) {
os << records;
return false;
}
return generator->invoke(records, os);
}
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::opt<const mlir::GenInfo *, false, mlir::GenNameParser> generator(
"", llvm::cl::desc("Generator to run"));
cl::ParseCommandLineOptions(argc, argv);
::generator = generator.getValue();
return TableGenMain(argv[0], &MlirTableGenMain);
}