add mlir tblgen builder
This commit is contained in:
parent
b0dd7a7518
commit
898eb732de
|
@ -2,3 +2,5 @@ build
|
||||||
llvm-project
|
llvm-project
|
||||||
llvm-build
|
llvm-build
|
||||||
bazel-*
|
bazel-*
|
||||||
|
bazel-bin
|
||||||
|
.vscode
|
||||||
|
|
41
BUILD
41
BUILD
|
@ -122,6 +122,46 @@ gentbl_cc_library(
|
||||||
deps = [":hlo_ops_td_files"],
|
deps = [":hlo_ops_td_files"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "mlir-tblgen-builder",
|
||||||
|
srcs = glob([
|
||||||
|
"tools/mlir-tblgen-builder/*.h",
|
||||||
|
"tools/mlir-tblgen-builder/*.cpp",
|
||||||
|
"tools/mlir-tblgen-builder/TableGen/*.h",
|
||||||
|
"tools/mlir-tblgen-builder/TableGen/*.cpp",
|
||||||
|
]),
|
||||||
|
deps = [
|
||||||
|
"@llvm-project//mlir:MlirTableGenMain",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
# "@llvm-project//mlir:TableGen",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//llvm:TableGen",
|
||||||
|
"@llvm-project//llvm:config",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
gentbl_cc_library(
|
||||||
|
name = "hlo_ops_builder_gen",
|
||||||
|
strip_include_prefix = "include",
|
||||||
|
tbl_outs = [
|
||||||
|
(
|
||||||
|
["-gen-builder-decls"],
|
||||||
|
"include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.h.inc",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["-gen-builder-defs"],
|
||||||
|
"include/mlir-hlo/Dialect/mhlo/IR/hlo_builder.cc.inc",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tblgen = ":mlir-tblgen-builder",
|
||||||
|
td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td",
|
||||||
|
td_includes = [
|
||||||
|
"external/mlir-hlo/include",
|
||||||
|
"include",
|
||||||
|
],
|
||||||
|
deps = [":hlo_ops_td_files"],
|
||||||
|
)
|
||||||
|
|
||||||
gentbl_cc_library(
|
gentbl_cc_library(
|
||||||
name = "hlo_ops_base_inc_gen",
|
name = "hlo_ops_base_inc_gen",
|
||||||
strip_include_prefix = "include",
|
strip_include_prefix = "include",
|
||||||
|
@ -519,6 +559,7 @@ cc_library(
|
||||||
":hlo_ops_base_structs",
|
":hlo_ops_base_structs",
|
||||||
":hlo_ops_common",
|
":hlo_ops_common",
|
||||||
":hlo_ops_inc_gen",
|
":hlo_ops_inc_gen",
|
||||||
|
":hlo_ops_builder_gen",
|
||||||
":hlo_ops_pattern_gen",
|
":hlo_ops_pattern_gen",
|
||||||
":infer_fusibility_op_interface",
|
":infer_fusibility_op_interface",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,28 @@
|
||||||
|
//===- OpFormatGen.h - MLIR operation format generator ----------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file defines the interface for generating parsers and printers from the
|
||||||
|
// declarative format.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_
|
||||||
|
#define MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
class OpClass;
|
||||||
|
class Operator;
|
||||||
|
|
||||||
|
// Generate the assembly format for the given operator.
|
||||||
|
void generateOpFormat(const Operator &constOp, OpClass &opClass);
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TOOLS_MLIRTBLGEN_OPFORMATGEN_H_
|
|
@ -0,0 +1,65 @@
|
||||||
|
//===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file defines helpers used in the op generators.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "OpGenHelpers.h"
|
||||||
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/Regex.h"
|
||||||
|
#include "llvm/TableGen/Error.h"
|
||||||
|
|
||||||
|
using namespace llvm;
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
cl::OptionCategory opDefGenCat("Options for op definition generators");
|
||||||
|
|
||||||
|
static cl::opt<std::string> opIncFilter(
|
||||||
|
"op-include-regex",
|
||||||
|
cl::desc("Regex of name of op's to include (no filter if empty)"),
|
||||||
|
cl::cat(opDefGenCat));
|
||||||
|
static cl::opt<std::string> opExcFilter(
|
||||||
|
"op-exclude-regex",
|
||||||
|
cl::desc("Regex of name of op's to exclude (no filter if empty)"),
|
||||||
|
cl::cat(opDefGenCat));
|
||||||
|
|
||||||
|
static std::string getOperationName(const Record &def) {
|
||||||
|
auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
|
||||||
|
auto opName = def.getValueAsString("opName");
|
||||||
|
if (prefix.empty())
|
||||||
|
return std::string(opName);
|
||||||
|
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Record *>
|
||||||
|
mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
|
||||||
|
Record *classDef = recordKeeper.getClass("Op");
|
||||||
|
if (!classDef)
|
||||||
|
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
|
||||||
|
|
||||||
|
llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
|
||||||
|
std::vector<Record *> defs;
|
||||||
|
for (const auto &def : recordKeeper.getDefs()) {
|
||||||
|
if (!def.second->isSubClassOf(classDef))
|
||||||
|
continue;
|
||||||
|
// Include if no include filter or include filter matches.
|
||||||
|
if (!opIncFilter.empty() &&
|
||||||
|
!includeRegex.match(getOperationName(*def.second)))
|
||||||
|
continue;
|
||||||
|
// Unless there is an exclude filter and it matches.
|
||||||
|
if (!opExcFilter.empty() &&
|
||||||
|
excludeRegex.match(getOperationName(*def.second)))
|
||||||
|
continue;
|
||||||
|
defs.push_back(def.second.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
return defs;
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
//===- OpGenHelpers.h - MLIR operation generator helpers --------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file defines helpers used in the op generators.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
|
||||||
|
#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
|
||||||
|
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
/// Returns all the op definitions filtered by the user. The filtering is via
|
||||||
|
/// command-line option "op-include-regex" and "op-exclude-regex".
|
||||||
|
std::vector<llvm::Record *>
|
||||||
|
getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===- Argument.cpp - Argument definitions --------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Argument.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
bool NamedTypeConstraint::hasPredicate() const {
|
||||||
|
return !constraint.getPredicate().isNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); }
|
||||||
|
|
||||||
|
bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); }
|
|
@ -0,0 +1,65 @@
|
||||||
|
//===- Argument.h - Argument definitions ------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This header file contains definitions for TableGen operation's arguments.
|
||||||
|
// Operation arguments fall into two categories:
|
||||||
|
//
|
||||||
|
// 1. Operands: SSA values operated on by the operation
|
||||||
|
// 2. Attributes: compile-time known properties that have influence over
|
||||||
|
// the operation's behavior
|
||||||
|
//
|
||||||
|
// These two categories are modelled with the unified argument concept in
|
||||||
|
// TableGen because we need similar pattern matching mechanisms for them.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_ARGUMENT_H_
|
||||||
|
#define MLIR_TABLEGEN_ARGUMENT_H_
|
||||||
|
|
||||||
|
#include "Attribute.h"
|
||||||
|
#include "Type.h"
|
||||||
|
#include "llvm/ADT/PointerUnion.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class StringRef;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// A struct wrapping an op attribute and its name together
|
||||||
|
struct NamedAttribute {
|
||||||
|
llvm::StringRef name;
|
||||||
|
Attribute attr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A struct wrapping an op operand/result's constraint and its name together
|
||||||
|
struct NamedTypeConstraint {
|
||||||
|
// Returns true if this operand/result has constraint to be satisfied.
|
||||||
|
bool hasPredicate() const;
|
||||||
|
// Returns true if this is an optional type constraint. This is a special case
|
||||||
|
// of variadic for 0 or 1 type.
|
||||||
|
bool isOptional() const;
|
||||||
|
// Returns true if this operand/result is variadic.
|
||||||
|
bool isVariadic() const;
|
||||||
|
// Returns true if this is a variable length type constraint. This is either
|
||||||
|
// variadic or optional.
|
||||||
|
bool isVariableLength() const { return isOptional() || isVariadic(); }
|
||||||
|
|
||||||
|
llvm::StringRef name;
|
||||||
|
TypeConstraint constraint;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Operation argument: either attribute or operand
|
||||||
|
using Argument = llvm::PointerUnion<NamedAttribute *, NamedTypeConstraint *>;
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_ARGUMENT_H_
|
|
@ -0,0 +1,250 @@
|
||||||
|
//===- AttrOrTypeDef.cpp - AttrOrTypeDef wrapper classes ------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "AttrOrTypeDef.h"
|
||||||
|
#include "Dialect.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/TableGen/Error.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttrOrTypeBuilder
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Returns true if this builder is able to infer the MLIRContext parameter.
|
||||||
|
bool AttrOrTypeBuilder::hasInferredContextParameter() const {
|
||||||
|
return def->getValueAsBit("hasInferredContextParam");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttrOrTypeDef
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
|
||||||
|
// Populate the builders.
|
||||||
|
auto *builderList =
|
||||||
|
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
|
||||||
|
if (builderList && !builderList->empty()) {
|
||||||
|
for (llvm::Init *init : builderList->getValues()) {
|
||||||
|
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
|
||||||
|
def->getLoc());
|
||||||
|
|
||||||
|
// Ensure that all parameters have names.
|
||||||
|
for (const AttrOrTypeBuilder::Parameter ¶m :
|
||||||
|
builder.getParameters()) {
|
||||||
|
if (!param.getName())
|
||||||
|
PrintFatalError(def->getLoc(), "builder parameters must have a name");
|
||||||
|
}
|
||||||
|
builders.emplace_back(builder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the traits.
|
||||||
|
if (auto *traitList = def->getValueAsListInit("traits")) {
|
||||||
|
SmallPtrSet<const llvm::Init *, 32> traitSet;
|
||||||
|
traits.reserve(traitSet.size());
|
||||||
|
for (auto *traitInit : *traitList)
|
||||||
|
if (traitSet.insert(traitInit).second)
|
||||||
|
traits.push_back(Trait::create(traitInit));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Dialect AttrOrTypeDef::getDialect() const {
|
||||||
|
auto *dialect = dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
|
||||||
|
return Dialect(dialect ? dialect->getDef() : nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef AttrOrTypeDef::getName() const { return def->getName(); }
|
||||||
|
|
||||||
|
StringRef AttrOrTypeDef::getCppClassName() const {
|
||||||
|
return def->getValueAsString("cppClassName");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef AttrOrTypeDef::getCppBaseClassName() const {
|
||||||
|
return def->getValueAsString("cppBaseClassName");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrOrTypeDef::hasDescription() const {
|
||||||
|
const llvm::RecordVal *desc = def->getValue("description");
|
||||||
|
return desc && isa<llvm::StringInit>(desc->getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef AttrOrTypeDef::getDescription() const {
|
||||||
|
return def->getValueAsString("description");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrOrTypeDef::hasSummary() const {
|
||||||
|
const llvm::RecordVal *summary = def->getValue("summary");
|
||||||
|
return summary && isa<llvm::StringInit>(summary->getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef AttrOrTypeDef::getSummary() const {
|
||||||
|
return def->getValueAsString("summary");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef AttrOrTypeDef::getStorageClassName() const {
|
||||||
|
return def->getValueAsString("storageClass");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef AttrOrTypeDef::getStorageNamespace() const {
|
||||||
|
return def->getValueAsString("storageNamespace");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrOrTypeDef::genStorageClass() const {
|
||||||
|
return def->getValueAsBit("genStorageClass");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrOrTypeDef::hasStorageCustomConstructor() const {
|
||||||
|
return def->getValueAsBit("hasStorageCustomConstructor");
|
||||||
|
}
|
||||||
|
|
||||||
|
void AttrOrTypeDef::getParameters(
|
||||||
|
SmallVectorImpl<AttrOrTypeParameter> ¶meters) const {
|
||||||
|
if (auto *parametersDag = def->getValueAsDag("parameters")) {
|
||||||
|
for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i)
|
||||||
|
parameters.push_back(AttrOrTypeParameter(parametersDag, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned AttrOrTypeDef::getNumParameters() const {
|
||||||
|
auto *parametersDag = def->getValueAsDag("parameters");
|
||||||
|
return parametersDag ? parametersDag->getNumArgs() : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef> AttrOrTypeDef::getMnemonic() const {
|
||||||
|
return def->getValueAsOptionalString("mnemonic");
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef> AttrOrTypeDef::getPrinterCode() const {
|
||||||
|
return def->getValueAsOptionalString("printer");
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef> AttrOrTypeDef::getParserCode() const {
|
||||||
|
return def->getValueAsOptionalString("parser");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrOrTypeDef::genAccessors() const {
|
||||||
|
return def->getValueAsBit("genAccessors");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrOrTypeDef::genVerifyDecl() const {
|
||||||
|
return def->getValueAsBit("genVerifyDecl");
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
|
||||||
|
auto value = def->getValueAsString("extraClassDeclaration");
|
||||||
|
return value.empty() ? Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<llvm::SMLoc> AttrOrTypeDef::getLoc() const { return def->getLoc(); }
|
||||||
|
|
||||||
|
bool AttrOrTypeDef::skipDefaultBuilders() const {
|
||||||
|
return def->getValueAsBit("skipDefaultBuilders");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrOrTypeDef::operator==(const AttrOrTypeDef &other) const {
|
||||||
|
return def == other.def;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrOrTypeDef::operator<(const AttrOrTypeDef &other) const {
|
||||||
|
return getName() < other.getName();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttrDef
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Optional<StringRef> AttrDef::getTypeBuilder() const {
|
||||||
|
return def->getValueAsOptionalString("typeBuilder");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrDef::classof(const AttrOrTypeDef *def) {
|
||||||
|
return def->getDef()->isSubClassOf("AttrDef");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttrOrTypeParameter
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
StringRef AttrOrTypeParameter::getName() const {
|
||||||
|
return def->getArgName(index)->getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
|
||||||
|
llvm::Init *parameterType = def->getArg(index);
|
||||||
|
if (isa<llvm::StringInit>(parameterType))
|
||||||
|
return Optional<StringRef>();
|
||||||
|
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
|
||||||
|
return param->getDef()->getValueAsOptionalString("allocator");
|
||||||
|
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
|
||||||
|
"defs which inherit from AttrOrTypeParameter\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef> AttrOrTypeParameter::getComparator() const {
|
||||||
|
llvm::Init *parameterType = def->getArg(index);
|
||||||
|
if (isa<llvm::StringInit>(parameterType))
|
||||||
|
return Optional<StringRef>();
|
||||||
|
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
|
||||||
|
return param->getDef()->getValueAsOptionalString("comparator");
|
||||||
|
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
|
||||||
|
"defs which inherit from AttrOrTypeParameter\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef AttrOrTypeParameter::getCppType() const {
|
||||||
|
auto *parameterType = def->getArg(index);
|
||||||
|
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
|
||||||
|
return stringType->getValue();
|
||||||
|
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
|
||||||
|
return param->getDef()->getValueAsString("cppType");
|
||||||
|
llvm::PrintFatalError(
|
||||||
|
"Parameters DAG arguments must be either strings or defs "
|
||||||
|
"which inherit from AttrOrTypeParameter\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef> AttrOrTypeParameter::getSummary() const {
|
||||||
|
auto *parameterType = def->getArg(index);
|
||||||
|
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
|
||||||
|
const auto *desc = param->getDef()->getValue("summary");
|
||||||
|
if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
|
||||||
|
return ci->getValue();
|
||||||
|
}
|
||||||
|
return Optional<StringRef>();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef AttrOrTypeParameter::getSyntax() const {
|
||||||
|
auto *parameterType = def->getArg(index);
|
||||||
|
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
|
||||||
|
return stringType->getValue();
|
||||||
|
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
|
||||||
|
const auto *syntax = param->getDef()->getValue("syntax");
|
||||||
|
if (syntax && isa<llvm::StringInit>(syntax->getValue()))
|
||||||
|
return cast<llvm::StringInit>(syntax->getValue())->getValue();
|
||||||
|
return getCppType();
|
||||||
|
}
|
||||||
|
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
|
||||||
|
"defs which inherit from AttrOrTypeParameter");
|
||||||
|
}
|
||||||
|
|
||||||
|
const llvm::Init *AttrOrTypeParameter::getDef() const {
|
||||||
|
return def->getArg(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttributeSelfTypeParameter
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
|
||||||
|
const llvm::Init *paramDef = param->getDef();
|
||||||
|
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
|
||||||
|
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
|
||||||
|
return false;
|
||||||
|
}
|
|
@ -0,0 +1,229 @@
|
||||||
|
//===-- AttrOrTypeDef.h - Wrapper for attr and type definitions -*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// AttrOrTypeDef, AttrDef, and TypeDef wrappers to simplify using TableGen
|
||||||
|
// Record defining a MLIR attributes and types.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_ATTRORTYPEDEF_H
|
||||||
|
#define MLIR_TABLEGEN_ATTRORTYPEDEF_H
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "Builder.h"
|
||||||
|
#include "Trait.h"
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class DagInit;
|
||||||
|
class Record;
|
||||||
|
class SMLoc;
|
||||||
|
} // namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
class Dialect;
|
||||||
|
class AttrOrTypeParameter;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttrOrTypeBuilder
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Wrapper class that represents a Tablegen AttrOrTypeBuilder.
|
||||||
|
class AttrOrTypeBuilder : public Builder {
|
||||||
|
public:
|
||||||
|
using Builder::Builder;
|
||||||
|
|
||||||
|
/// Returns true if this builder is able to infer the MLIRContext parameter.
|
||||||
|
bool hasInferredContextParameter() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttrOrTypeDef
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Wrapper class that contains a TableGen AttrOrTypeDef's record and provides
|
||||||
|
/// helper methods for accessing them.
|
||||||
|
class AttrOrTypeDef {
|
||||||
|
public:
|
||||||
|
explicit AttrOrTypeDef(const llvm::Record *def);
|
||||||
|
|
||||||
|
// Get the dialect for which this def belongs.
|
||||||
|
Dialect getDialect() const;
|
||||||
|
|
||||||
|
// Returns the name of this AttrOrTypeDef record.
|
||||||
|
StringRef getName() const;
|
||||||
|
|
||||||
|
// Query functions for the documentation of the def.
|
||||||
|
bool hasDescription() const;
|
||||||
|
StringRef getDescription() const;
|
||||||
|
bool hasSummary() const;
|
||||||
|
StringRef getSummary() const;
|
||||||
|
|
||||||
|
// Returns the name of the C++ class to generate.
|
||||||
|
StringRef getCppClassName() const;
|
||||||
|
|
||||||
|
// Returns the name of the C++ base class to use when generating this def.
|
||||||
|
StringRef getCppBaseClassName() const;
|
||||||
|
|
||||||
|
// Returns the name of the storage class for this def.
|
||||||
|
StringRef getStorageClassName() const;
|
||||||
|
|
||||||
|
// Returns the C++ namespace for this def's storage class.
|
||||||
|
StringRef getStorageNamespace() const;
|
||||||
|
|
||||||
|
// Returns true if we should generate the storage class.
|
||||||
|
bool genStorageClass() const;
|
||||||
|
|
||||||
|
// Indicates whether or not to generate the storage class constructor.
|
||||||
|
bool hasStorageCustomConstructor() const;
|
||||||
|
|
||||||
|
// Fill a list with this def's parameters. See AttrOrTypeDef in OpBase.td for
|
||||||
|
// documentation of parameter usage.
|
||||||
|
void getParameters(SmallVectorImpl<AttrOrTypeParameter> &) const;
|
||||||
|
|
||||||
|
// Return the number of parameters
|
||||||
|
unsigned getNumParameters() const;
|
||||||
|
|
||||||
|
// Return the keyword/mnemonic to use in the printer/parser methods if we are
|
||||||
|
// supposed to auto-generate them.
|
||||||
|
Optional<StringRef> getMnemonic() const;
|
||||||
|
|
||||||
|
// Returns the code to use as the types printer method. If not specified,
|
||||||
|
// return a non-value. Otherwise, return the contents of that code block.
|
||||||
|
Optional<StringRef> getPrinterCode() const;
|
||||||
|
|
||||||
|
// Returns the code to use as the parser method. If not specified, returns
|
||||||
|
// None. Otherwise, returns the contents of that code block.
|
||||||
|
Optional<StringRef> getParserCode() const;
|
||||||
|
|
||||||
|
// Returns true if the accessors based on the parameters should be generated.
|
||||||
|
bool genAccessors() const;
|
||||||
|
|
||||||
|
// Return true if we need to generate the verify declaration and getChecked
|
||||||
|
// method.
|
||||||
|
bool genVerifyDecl() const;
|
||||||
|
|
||||||
|
// Returns the def's extra class declaration code.
|
||||||
|
Optional<StringRef> getExtraDecls() const;
|
||||||
|
|
||||||
|
// Get the code location (for error printing).
|
||||||
|
ArrayRef<llvm::SMLoc> getLoc() const;
|
||||||
|
|
||||||
|
// Returns true if the default get/getChecked methods should be skipped during
|
||||||
|
// generation.
|
||||||
|
bool skipDefaultBuilders() const;
|
||||||
|
|
||||||
|
// Returns the builders of this def.
|
||||||
|
ArrayRef<AttrOrTypeBuilder> getBuilders() const { return builders; }
|
||||||
|
|
||||||
|
// Returns the traits of this def.
|
||||||
|
ArrayRef<Trait> getTraits() const { return traits; }
|
||||||
|
|
||||||
|
// Returns whether two AttrOrTypeDefs are equal by checking the equality of
|
||||||
|
// the underlying record.
|
||||||
|
bool operator==(const AttrOrTypeDef &other) const;
|
||||||
|
|
||||||
|
// Compares two AttrOrTypeDefs by comparing the names of the dialects.
|
||||||
|
bool operator<(const AttrOrTypeDef &other) const;
|
||||||
|
|
||||||
|
// Returns whether the AttrOrTypeDef is defined.
|
||||||
|
operator bool() const { return def != nullptr; }
|
||||||
|
|
||||||
|
// Return the underlying def.
|
||||||
|
const llvm::Record *getDef() const { return def; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const llvm::Record *def;
|
||||||
|
|
||||||
|
// The builders of this definition.
|
||||||
|
SmallVector<AttrOrTypeBuilder> builders;
|
||||||
|
|
||||||
|
// The traits of this definition.
|
||||||
|
SmallVector<Trait> traits;
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttrDef
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// This class represents a wrapper around a tablegen AttrDef record.
|
||||||
|
class AttrDef : public AttrOrTypeDef {
|
||||||
|
public:
|
||||||
|
using AttrOrTypeDef::AttrOrTypeDef;
|
||||||
|
|
||||||
|
// Returns the attributes value type builder code block, or None if it doesn't
|
||||||
|
// have one.
|
||||||
|
Optional<StringRef> getTypeBuilder() const;
|
||||||
|
|
||||||
|
static bool classof(const AttrOrTypeDef *def);
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TypeDef
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// This class represents a wrapper around a tablegen TypeDef record.
|
||||||
|
class TypeDef : public AttrOrTypeDef {
|
||||||
|
public:
|
||||||
|
using AttrOrTypeDef::AttrOrTypeDef;
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttrOrTypeParameter
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
|
||||||
|
// AttrOrTypeDefs to parameterize them.
|
||||||
|
class AttrOrTypeParameter {
|
||||||
|
public:
|
||||||
|
explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
|
||||||
|
: def(def), index(index) {}
|
||||||
|
|
||||||
|
// Get the parameter name.
|
||||||
|
StringRef getName() const;
|
||||||
|
|
||||||
|
// If specified, get the custom allocator code for this parameter.
|
||||||
|
Optional<StringRef> getAllocator() const;
|
||||||
|
|
||||||
|
// If specified, get the custom comparator code for this parameter.
|
||||||
|
Optional<StringRef> getComparator() const;
|
||||||
|
|
||||||
|
// Get the C++ type of this parameter.
|
||||||
|
StringRef getCppType() const;
|
||||||
|
|
||||||
|
// Get a description of this parameter for documentation purposes.
|
||||||
|
Optional<StringRef> getSummary() const;
|
||||||
|
|
||||||
|
// Get the assembly syntax documentation.
|
||||||
|
StringRef getSyntax() const;
|
||||||
|
|
||||||
|
// Return the underlying def of this parameter.
|
||||||
|
const llvm::Init *getDef() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// The underlying tablegen parameter list this parameter is a part of.
|
||||||
|
const llvm::DagInit *def;
|
||||||
|
/// The index of the parameter within the parameter list (`def`).
|
||||||
|
unsigned index;
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttributeSelfTypeParameter
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// A wrapper class for the AttributeSelfTypeParameter tblgen class. This
|
||||||
|
// represents a parameter of mlir::Type that is the value type of an AttrDef.
|
||||||
|
class AttributeSelfTypeParameter : public AttrOrTypeParameter {
|
||||||
|
public:
|
||||||
|
static bool classof(const AttrOrTypeParameter *param);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_ATTRORTYPEDEF_H
|
|
@ -0,0 +1,296 @@
|
||||||
|
//===- Attribute.cpp - Attribute wrapper class ----------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Attribute wrapper to simplify using TableGen Record defining a MLIR
|
||||||
|
// Attribute.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Format.h"
|
||||||
|
#include "Operator.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
using llvm::DefInit;
|
||||||
|
using llvm::Init;
|
||||||
|
using llvm::Record;
|
||||||
|
using llvm::StringInit;
|
||||||
|
|
||||||
|
// Returns the initializer's value as string if the given TableGen initializer
|
||||||
|
// is a code or string initializer. Returns the empty StringRef otherwise.
|
||||||
|
static StringRef getValueAsString(const Init *init) {
|
||||||
|
if (const auto *str = dyn_cast<StringInit>(init))
|
||||||
|
return str->getValue().trim();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
AttrConstraint::AttrConstraint(const Record *record)
|
||||||
|
: Constraint(Constraint::CK_Attr, record) {
|
||||||
|
assert(isSubClassOf("AttrConstraint") &&
|
||||||
|
"must be subclass of TableGen 'AttrConstraint' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttrConstraint::isSubClassOf(StringRef className) const {
|
||||||
|
return def->isSubClassOf(className);
|
||||||
|
}
|
||||||
|
|
||||||
|
Attribute::Attribute(const Record *record) : AttrConstraint(record) {
|
||||||
|
assert(record->isSubClassOf("Attr") &&
|
||||||
|
"must be subclass of TableGen 'Attr' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
|
||||||
|
|
||||||
|
bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
|
||||||
|
|
||||||
|
bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
|
||||||
|
|
||||||
|
bool Attribute::isSymbolRefAttr() const {
|
||||||
|
StringRef defName = def->getName();
|
||||||
|
if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
|
||||||
|
return true;
|
||||||
|
return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
|
||||||
|
|
||||||
|
StringRef Attribute::getStorageType() const {
|
||||||
|
const auto *init = def->getValueInit("storageType");
|
||||||
|
auto type = getValueAsString(init);
|
||||||
|
if (type.empty())
|
||||||
|
return "Attribute";
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Attribute::getReturnType() const {
|
||||||
|
const auto *init = def->getValueInit("returnType");
|
||||||
|
return getValueAsString(init);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the type constraint corresponding to the type of this attribute, or
|
||||||
|
// None if this is not a TypedAttr.
|
||||||
|
llvm::Optional<Type> Attribute::getValueType() const {
|
||||||
|
if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
|
||||||
|
return Type(defInit->getDef());
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Attribute::getConvertFromStorageCall() const {
|
||||||
|
const auto *init = def->getValueInit("convertFromStorage");
|
||||||
|
return getValueAsString(init);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Attribute::isConstBuildable() const {
|
||||||
|
const auto *init = def->getValueInit("constBuilderCall");
|
||||||
|
return !getValueAsString(init).empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Attribute::getConstBuilderTemplate() const {
|
||||||
|
const auto *init = def->getValueInit("constBuilderCall");
|
||||||
|
return getValueAsString(init);
|
||||||
|
}
|
||||||
|
|
||||||
|
Attribute Attribute::getBaseAttr() const {
|
||||||
|
if (const auto *defInit =
|
||||||
|
llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
|
||||||
|
return Attribute(defInit).getBaseAttr();
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Attribute::hasDefaultValue() const {
|
||||||
|
const auto *init = def->getValueInit("defaultValue");
|
||||||
|
return !getValueAsString(init).empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Attribute::getDefaultValue() const {
|
||||||
|
const auto *init = def->getValueInit("defaultValue");
|
||||||
|
return getValueAsString(init);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); }
|
||||||
|
|
||||||
|
StringRef Attribute::getAttrDefName() const {
|
||||||
|
if (def->isAnonymous()) {
|
||||||
|
return getBaseAttr().def->getName();
|
||||||
|
}
|
||||||
|
return def->getName();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Attribute::getDerivedCodeBody() const {
|
||||||
|
assert(isDerivedAttr() && "only derived attribute has 'body' field");
|
||||||
|
return def->getValueAsString("body");
|
||||||
|
}
|
||||||
|
|
||||||
|
Dialect Attribute::getDialect() const {
|
||||||
|
const llvm::RecordVal *record = def->getValue("dialect");
|
||||||
|
if (record && record->getValue()) {
|
||||||
|
if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
|
||||||
|
return Dialect(init->getDef());
|
||||||
|
}
|
||||||
|
return Dialect(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
|
||||||
|
assert(def->isSubClassOf("ConstantAttr") &&
|
||||||
|
"must be subclass of TableGen 'ConstantAttr' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
Attribute ConstantAttr::getAttribute() const {
|
||||||
|
return Attribute(def->getValueAsDef("attr"));
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef ConstantAttr::getConstantValue() const {
|
||||||
|
return def->getValueAsString("value");
|
||||||
|
}
|
||||||
|
|
||||||
|
EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
|
||||||
|
assert(isSubClassOf("EnumAttrCaseInfo") &&
|
||||||
|
"must be subclass of TableGen 'EnumAttrInfo' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
|
||||||
|
: EnumAttrCase(init->getDef()) {}
|
||||||
|
|
||||||
|
bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
|
||||||
|
|
||||||
|
StringRef EnumAttrCase::getSymbol() const {
|
||||||
|
return def->getValueAsString("symbol");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
|
||||||
|
|
||||||
|
int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
|
||||||
|
|
||||||
|
const llvm::Record &EnumAttrCase::getDef() const { return *def; }
|
||||||
|
|
||||||
|
EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
|
||||||
|
assert(isSubClassOf("EnumAttrInfo") &&
|
||||||
|
"must be subclass of TableGen 'EnumAttr' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
|
||||||
|
|
||||||
|
EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
|
||||||
|
|
||||||
|
bool EnumAttr::classof(const Attribute *attr) {
|
||||||
|
return attr->isSubClassOf("EnumAttrInfo");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
|
||||||
|
|
||||||
|
StringRef EnumAttr::getEnumClassName() const {
|
||||||
|
return def->getValueAsString("className");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef EnumAttr::getCppNamespace() const {
|
||||||
|
return def->getValueAsString("cppNamespace");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef EnumAttr::getUnderlyingType() const {
|
||||||
|
return def->getValueAsString("underlyingType");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
|
||||||
|
return def->getValueAsString("underlyingToSymbolFnName");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef EnumAttr::getStringToSymbolFnName() const {
|
||||||
|
return def->getValueAsString("stringToSymbolFnName");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef EnumAttr::getSymbolToStringFnName() const {
|
||||||
|
return def->getValueAsString("symbolToStringFnName");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef EnumAttr::getSymbolToStringFnRetType() const {
|
||||||
|
return def->getValueAsString("symbolToStringFnRetType");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef EnumAttr::getMaxEnumValFnName() const {
|
||||||
|
return def->getValueAsString("maxEnumValFnName");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
|
||||||
|
const auto *inits = def->getValueAsListInit("enumerants");
|
||||||
|
|
||||||
|
std::vector<EnumAttrCase> cases;
|
||||||
|
cases.reserve(inits->size());
|
||||||
|
|
||||||
|
for (const llvm::Init *init : *inits) {
|
||||||
|
cases.push_back(EnumAttrCase(cast<llvm::DefInit>(init)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return cases;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EnumAttr::genSpecializedAttr() const {
|
||||||
|
return def->getValueAsBit("genSpecializedAttr");
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Record *EnumAttr::getBaseAttrClass() const {
|
||||||
|
return def->getValueAsDef("baseAttrClass");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef EnumAttr::getSpecializedAttrClassName() const {
|
||||||
|
return def->getValueAsString("specializedAttrClassName");
|
||||||
|
}
|
||||||
|
|
||||||
|
StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
|
||||||
|
assert(def->isSubClassOf("StructFieldAttr") &&
|
||||||
|
"must be subclass of TableGen 'StructFieldAttr' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
StructFieldAttr::StructFieldAttr(const llvm::Record &record)
|
||||||
|
: StructFieldAttr(&record) {}
|
||||||
|
|
||||||
|
StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
|
||||||
|
: StructFieldAttr(init->getDef()) {}
|
||||||
|
|
||||||
|
StringRef StructFieldAttr::getName() const {
|
||||||
|
return def->getValueAsString("name");
|
||||||
|
}
|
||||||
|
|
||||||
|
Attribute StructFieldAttr::getType() const {
|
||||||
|
auto init = def->getValueInit("type");
|
||||||
|
return Attribute(cast<llvm::DefInit>(init));
|
||||||
|
}
|
||||||
|
|
||||||
|
StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
|
||||||
|
assert(isSubClassOf("StructAttr") &&
|
||||||
|
"must be subclass of TableGen 'StructAttr' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
StructAttr::StructAttr(const llvm::DefInit *init)
|
||||||
|
: StructAttr(init->getDef()) {}
|
||||||
|
|
||||||
|
StringRef StructAttr::getStructClassName() const {
|
||||||
|
return def->getValueAsString("className");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef StructAttr::getCppNamespace() const {
|
||||||
|
Dialect dialect(def->getValueAsDef("dialect"));
|
||||||
|
return dialect.getCppNamespace();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<StructFieldAttr> StructAttr::getAllFields() const {
|
||||||
|
std::vector<StructFieldAttr> attributes;
|
||||||
|
|
||||||
|
const auto *inits = def->getValueAsListInit("fields");
|
||||||
|
attributes.reserve(inits->size());
|
||||||
|
|
||||||
|
for (const llvm::Init *init : *inits) {
|
||||||
|
attributes.emplace_back(cast<llvm::DefInit>(init));
|
||||||
|
}
|
||||||
|
|
||||||
|
return attributes;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
|
|
@ -0,0 +1,247 @@
|
||||||
|
//===- Attribute.h - Attribute wrapper class --------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Attribute wrapper to simplify using TableGen Record defining a MLIR
|
||||||
|
// Attribute.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_ATTRIBUTE_H_
|
||||||
|
#define MLIR_TABLEGEN_ATTRIBUTE_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "Constraint.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class DefInit;
|
||||||
|
class Record;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
class Dialect;
|
||||||
|
class Type;
|
||||||
|
|
||||||
|
// Wrapper class with helper methods for accessing attribute constraints defined
|
||||||
|
// in TableGen.
|
||||||
|
class AttrConstraint : public Constraint {
|
||||||
|
public:
|
||||||
|
explicit AttrConstraint(const llvm::Record *record);
|
||||||
|
|
||||||
|
static bool classof(const Constraint *c) { return c->getKind() == CK_Attr; }
|
||||||
|
|
||||||
|
// Returns true if this constraint is a subclass of the given `className`
|
||||||
|
// class defined in TableGen.
|
||||||
|
bool isSubClassOf(StringRef className) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing MLIR Attribute defined
|
||||||
|
// in TableGen. This class should closely reflect what is defined as class
|
||||||
|
// `Attr` in TableGen.
|
||||||
|
class Attribute : public AttrConstraint {
|
||||||
|
public:
|
||||||
|
explicit Attribute(const llvm::Record *record);
|
||||||
|
explicit Attribute(const llvm::DefInit *init);
|
||||||
|
|
||||||
|
// Returns the storage type if set. Returns the default storage type
|
||||||
|
// ("Attribute") otherwise.
|
||||||
|
StringRef getStorageType() const;
|
||||||
|
|
||||||
|
// Returns the return type for this attribute.
|
||||||
|
StringRef getReturnType() const;
|
||||||
|
|
||||||
|
// Return the type constraint corresponding to the type of this attribute, or
|
||||||
|
// None if this is not a TypedAttr.
|
||||||
|
llvm::Optional<Type> getValueType() const;
|
||||||
|
|
||||||
|
// Returns the template getter method call which reads this attribute's
|
||||||
|
// storage and returns the value as of the desired return type.
|
||||||
|
// The call will contain a `{0}` which will be expanded to this attribute.
|
||||||
|
StringRef getConvertFromStorageCall() const;
|
||||||
|
|
||||||
|
// Returns true if this attribute can be built from a constant value.
|
||||||
|
bool isConstBuildable() const;
|
||||||
|
|
||||||
|
// Returns the template that can be used to produce an instance of the
|
||||||
|
// attribute.
|
||||||
|
// Syntax: `$builder` should be replaced with a builder, `$0` should be
|
||||||
|
// replaced with the constant value.
|
||||||
|
StringRef getConstBuilderTemplate() const;
|
||||||
|
|
||||||
|
// Returns the base-level attribute that this attribute constraint is
|
||||||
|
// built upon.
|
||||||
|
Attribute getBaseAttr() const;
|
||||||
|
|
||||||
|
// Returns whether this attribute has a default value.
|
||||||
|
bool hasDefaultValue() const;
|
||||||
|
// Returns the default value for this attribute.
|
||||||
|
StringRef getDefaultValue() const;
|
||||||
|
|
||||||
|
// Returns whether this attribute is optional.
|
||||||
|
bool isOptional() const;
|
||||||
|
|
||||||
|
// Returns true if this attribute is a derived attribute (i.e., a subclass
|
||||||
|
// of `DerivedAttr`).
|
||||||
|
bool isDerivedAttr() const;
|
||||||
|
|
||||||
|
// Returns true if this attribute is a type attribute (i.e., a subclass
|
||||||
|
// of `TypeAttrBase`).
|
||||||
|
bool isTypeAttr() const;
|
||||||
|
|
||||||
|
// Returns true if this attribute is a symbol reference attribute (i.e., a
|
||||||
|
// subclass of `SymbolRefAttr` or `FlatSymbolRefAttr`).
|
||||||
|
bool isSymbolRefAttr() const;
|
||||||
|
|
||||||
|
// Returns true if this attribute is an enum attribute (i.e., a subclass of
|
||||||
|
// `EnumAttrInfo`)
|
||||||
|
bool isEnumAttr() const;
|
||||||
|
|
||||||
|
// Returns this attribute's TableGen def name. If this is an `OptionalAttr`
|
||||||
|
// or `DefaultValuedAttr` without explicit name, returns the base attribute's
|
||||||
|
// name.
|
||||||
|
StringRef getAttrDefName() const;
|
||||||
|
|
||||||
|
// Returns the code body for derived attribute. Aborts if this is not a
|
||||||
|
// derived attribute.
|
||||||
|
StringRef getDerivedCodeBody() const;
|
||||||
|
|
||||||
|
// Returns the dialect for the attribute if defined.
|
||||||
|
Dialect getDialect() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing MLIR constant attribute
|
||||||
|
// defined in TableGen. This class should closely reflect what is defined as
|
||||||
|
// class `ConstantAttr` in TableGen.
|
||||||
|
class ConstantAttr {
|
||||||
|
public:
|
||||||
|
explicit ConstantAttr(const llvm::DefInit *init);
|
||||||
|
|
||||||
|
// Returns the attribute kind.
|
||||||
|
Attribute getAttribute() const;
|
||||||
|
|
||||||
|
// Returns the constant value.
|
||||||
|
StringRef getConstantValue() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The TableGen definition of this constant attribute.
|
||||||
|
const llvm::Record *def;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing enum attribute cases
|
||||||
|
// defined in TableGen. This is used for enum attribute case backed by both
|
||||||
|
// StringAttr and IntegerAttr.
|
||||||
|
class EnumAttrCase : public Attribute {
|
||||||
|
public:
|
||||||
|
explicit EnumAttrCase(const llvm::Record *record);
|
||||||
|
explicit EnumAttrCase(const llvm::DefInit *init);
|
||||||
|
|
||||||
|
// Returns true if this EnumAttrCase is backed by a StringAttr.
|
||||||
|
bool isStrCase() const;
|
||||||
|
|
||||||
|
// Returns the symbol of this enum attribute case.
|
||||||
|
StringRef getSymbol() const;
|
||||||
|
|
||||||
|
// Returns the textual representation of this enum attribute case.
|
||||||
|
StringRef getStr() const;
|
||||||
|
|
||||||
|
// Returns the value of this enum attribute case.
|
||||||
|
int64_t getValue() const;
|
||||||
|
|
||||||
|
// Returns the TableGen definition this EnumAttrCase was constructed from.
|
||||||
|
const llvm::Record &getDef() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing enum attributes defined
|
||||||
|
// in TableGen.This is used for enum attribute case backed by both StringAttr
|
||||||
|
// and IntegerAttr.
|
||||||
|
class EnumAttr : public Attribute {
|
||||||
|
public:
|
||||||
|
explicit EnumAttr(const llvm::Record *record);
|
||||||
|
explicit EnumAttr(const llvm::Record &record);
|
||||||
|
explicit EnumAttr(const llvm::DefInit *init);
|
||||||
|
|
||||||
|
static bool classof(const Attribute *attr);
|
||||||
|
|
||||||
|
// Returns true if this is a bit enum attribute.
|
||||||
|
bool isBitEnum() const;
|
||||||
|
|
||||||
|
// Returns the enum class name.
|
||||||
|
StringRef getEnumClassName() const;
|
||||||
|
|
||||||
|
// Returns the C++ namespaces this enum class should be placed in.
|
||||||
|
StringRef getCppNamespace() const;
|
||||||
|
|
||||||
|
// Returns the underlying type.
|
||||||
|
StringRef getUnderlyingType() const;
|
||||||
|
|
||||||
|
// Returns the name of the utility function that converts a value of the
|
||||||
|
// underlying type to the corresponding symbol.
|
||||||
|
StringRef getUnderlyingToSymbolFnName() const;
|
||||||
|
|
||||||
|
// Returns the name of the utility function that converts a string to the
|
||||||
|
// corresponding symbol.
|
||||||
|
StringRef getStringToSymbolFnName() const;
|
||||||
|
|
||||||
|
// Returns the name of the utility function that converts a symbol to the
|
||||||
|
// corresponding string.
|
||||||
|
StringRef getSymbolToStringFnName() const;
|
||||||
|
|
||||||
|
// Returns the return type of the utility function that converts a symbol to
|
||||||
|
// the corresponding string.
|
||||||
|
StringRef getSymbolToStringFnRetType() const;
|
||||||
|
|
||||||
|
// Returns the name of the utilit function that returns the max enum value
|
||||||
|
// used within the enum class.
|
||||||
|
StringRef getMaxEnumValFnName() const;
|
||||||
|
|
||||||
|
// Returns all allowed cases for this enum attribute.
|
||||||
|
std::vector<EnumAttrCase> getAllCases() const;
|
||||||
|
|
||||||
|
bool genSpecializedAttr() const;
|
||||||
|
llvm::Record *getBaseAttrClass() const;
|
||||||
|
StringRef getSpecializedAttrClassName() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class StructFieldAttr {
|
||||||
|
public:
|
||||||
|
explicit StructFieldAttr(const llvm::Record *record);
|
||||||
|
explicit StructFieldAttr(const llvm::Record &record);
|
||||||
|
explicit StructFieldAttr(const llvm::DefInit *init);
|
||||||
|
|
||||||
|
StringRef getName() const;
|
||||||
|
Attribute getType() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llvm::Record *def;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing struct attributes
|
||||||
|
// defined in TableGen.
|
||||||
|
class StructAttr : public Attribute {
|
||||||
|
public:
|
||||||
|
explicit StructAttr(const llvm::Record *record);
|
||||||
|
explicit StructAttr(const llvm::Record &record) : StructAttr(&record){};
|
||||||
|
explicit StructAttr(const llvm::DefInit *init);
|
||||||
|
|
||||||
|
// Returns the struct class name.
|
||||||
|
StringRef getStructClassName() const;
|
||||||
|
|
||||||
|
// Returns the C++ namespaces this struct class should be placed in.
|
||||||
|
StringRef getCppNamespace() const;
|
||||||
|
|
||||||
|
std::vector<StructFieldAttr> getAllFields() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Name of infer type op interface.
|
||||||
|
extern const char *inferTypeOpInterface;
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_ATTRIBUTE_H_
|
|
@ -0,0 +1,74 @@
|
||||||
|
//===- Builder.cpp - Builder definitions ----------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Builder.h"
|
||||||
|
#include "llvm/TableGen/Error.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Builder::Parameter
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Return a string containing the C++ type of this parameter.
|
||||||
|
StringRef Builder::Parameter::getCppType() const {
|
||||||
|
if (const auto *stringInit = dyn_cast<llvm::StringInit>(def))
|
||||||
|
return stringInit->getValue();
|
||||||
|
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
|
||||||
|
return record->getValueAsString("type");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return an optional string containing the default value to use for this
|
||||||
|
/// parameter.
|
||||||
|
Optional<StringRef> Builder::Parameter::getDefaultValue() const {
|
||||||
|
if (isa<llvm::StringInit>(def))
|
||||||
|
return llvm::None;
|
||||||
|
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
|
||||||
|
Optional<StringRef> value = record->getValueAsOptionalString("defaultValue");
|
||||||
|
return value && !value->empty() ? value : llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Builder
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Builder::Builder(const llvm::Record *record, ArrayRef<llvm::SMLoc> loc)
|
||||||
|
: def(record) {
|
||||||
|
// Initialize the parameters of the builder.
|
||||||
|
const llvm::DagInit *dag = def->getValueAsDag("dagParams");
|
||||||
|
auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
|
||||||
|
if (!defInit || !defInit->getDef()->getName().equals("ins"))
|
||||||
|
PrintFatalError(def->getLoc(), "expected 'ins' in builders");
|
||||||
|
|
||||||
|
bool seenDefaultValue = false;
|
||||||
|
for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) {
|
||||||
|
const llvm::StringInit *paramName = dag->getArgName(i);
|
||||||
|
const llvm::Init *paramValue = dag->getArg(i);
|
||||||
|
Parameter param(paramName ? paramName->getValue() : Optional<StringRef>(),
|
||||||
|
paramValue);
|
||||||
|
|
||||||
|
// Similarly to C++, once an argument with a default value is detected, the
|
||||||
|
// following arguments must have default values as well.
|
||||||
|
if (param.getDefaultValue()) {
|
||||||
|
seenDefaultValue = true;
|
||||||
|
} else if (seenDefaultValue) {
|
||||||
|
PrintFatalError(loc,
|
||||||
|
"expected an argument with default value after other "
|
||||||
|
"arguments with default values");
|
||||||
|
}
|
||||||
|
parameters.emplace_back(param);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return an optional string containing the body of the builder.
|
||||||
|
Optional<StringRef> Builder::getBody() const {
|
||||||
|
Optional<StringRef> body = def->getValueAsOptionalString("body");
|
||||||
|
return body && !body->empty() ? body : llvm::None;
|
||||||
|
}
|
|
@ -0,0 +1,85 @@
|
||||||
|
//===- Builder.h - Builder classes ------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Builder wrapper to simplify using TableGen Record for building
|
||||||
|
// operations/types/etc.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_BUILDER_H_
|
||||||
|
#define MLIR_TABLEGEN_BUILDER_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class Init;
|
||||||
|
class Record;
|
||||||
|
class SMLoc;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
/// Wrapper class with helper methods for accessing Builders defined in
|
||||||
|
/// TableGen.
|
||||||
|
class Builder {
|
||||||
|
public:
|
||||||
|
/// This class represents a single parameter to a builder method.
|
||||||
|
class Parameter {
|
||||||
|
public:
|
||||||
|
/// Return a string containing the C++ type of this parameter.
|
||||||
|
StringRef getCppType() const;
|
||||||
|
|
||||||
|
/// Return an optional string containing the name of this parameter. If
|
||||||
|
/// None, no name was specified for this parameter by the user.
|
||||||
|
Optional<StringRef> getName() const { return name; }
|
||||||
|
|
||||||
|
/// Return an optional string containing the default value to use for this
|
||||||
|
/// parameter.
|
||||||
|
Optional<StringRef> getDefaultValue() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Parameter(Optional<StringRef> name, const llvm::Init *def)
|
||||||
|
: name(name), def(def) {}
|
||||||
|
|
||||||
|
/// The optional name of the parameter.
|
||||||
|
Optional<StringRef> name;
|
||||||
|
|
||||||
|
/// The tablegen definition of the parameter. This is either a StringInit,
|
||||||
|
/// or a CArg DefInit.
|
||||||
|
const llvm::Init *def;
|
||||||
|
|
||||||
|
// Allow access to the constructor.
|
||||||
|
friend Builder;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Construct a builder from the given Record instance.
|
||||||
|
Builder(const llvm::Record *record, ArrayRef<llvm::SMLoc> loc);
|
||||||
|
|
||||||
|
/// Return a list of parameters used in this build method.
|
||||||
|
ArrayRef<Parameter> getParameters() const { return parameters; }
|
||||||
|
|
||||||
|
/// Return an optional string containing the body of the builder.
|
||||||
|
Optional<StringRef> getBody() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// The TableGen definition of this builder.
|
||||||
|
const llvm::Record *def;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// A collection of parameters to the builder.
|
||||||
|
SmallVector<Parameter> parameters;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_BUILDER_H_
|
|
@ -0,0 +1,68 @@
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file defines common utilities for generating C++ from tablegen
|
||||||
|
// structures.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_CODEGENHELPERS_H
|
||||||
|
#define MLIR_TABLEGEN_CODEGENHELPERS_H
|
||||||
|
|
||||||
|
#include "Dialect.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// Simple RAII helper for defining ifdef-undef-endif scopes.
|
||||||
|
class IfDefScope {
|
||||||
|
public:
|
||||||
|
IfDefScope(llvm::StringRef name, llvm::raw_ostream &os)
|
||||||
|
: name(name.str()), os(os) {
|
||||||
|
os << "#ifdef " << name << "\n"
|
||||||
|
<< "#undef " << name << "\n\n";
|
||||||
|
}
|
||||||
|
~IfDefScope() { os << "\n#endif // " << name << "\n\n"; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name;
|
||||||
|
llvm::raw_ostream &os;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A helper RAII class to emit nested namespaces for this op.
|
||||||
|
class NamespaceEmitter {
|
||||||
|
public:
|
||||||
|
NamespaceEmitter(raw_ostream &os, const Dialect &dialect) : os(os) {
|
||||||
|
if (!dialect)
|
||||||
|
return;
|
||||||
|
emitNamespaceStarts(os, dialect.getCppNamespace());
|
||||||
|
}
|
||||||
|
NamespaceEmitter(raw_ostream &os, StringRef cppNamespace) : os(os) {
|
||||||
|
emitNamespaceStarts(os, cppNamespace);
|
||||||
|
}
|
||||||
|
|
||||||
|
~NamespaceEmitter() {
|
||||||
|
for (StringRef ns : llvm::reverse(namespaces))
|
||||||
|
os << "} // namespace " << ns << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void emitNamespaceStarts(raw_ostream &os, StringRef cppNamespace) {
|
||||||
|
llvm::SplitString(cppNamespace, namespaces, "::");
|
||||||
|
for (StringRef ns : namespaces)
|
||||||
|
os << "namespace " << ns << " {\n";
|
||||||
|
}
|
||||||
|
raw_ostream &os;
|
||||||
|
SmallVector<StringRef, 2> namespaces;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tblgen
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_CODEGENHELPERS_H
|
|
@ -0,0 +1,70 @@
|
||||||
|
//===- Constraint.cpp - Constraint class ----------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Constraint wrapper to simplify using TableGen Record for constraints.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Constraint.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
Constraint::Constraint(const llvm::Record *record)
|
||||||
|
: def(record), kind(CK_Uncategorized) {
|
||||||
|
// Look through OpVariable's to their constraint.
|
||||||
|
if (def->isSubClassOf("OpVariable"))
|
||||||
|
def = def->getValueAsDef("constraint");
|
||||||
|
if (def->isSubClassOf("TypeConstraint")) {
|
||||||
|
kind = CK_Type;
|
||||||
|
} else if (def->isSubClassOf("AttrConstraint")) {
|
||||||
|
kind = CK_Attr;
|
||||||
|
} else if (def->isSubClassOf("RegionConstraint")) {
|
||||||
|
kind = CK_Region;
|
||||||
|
} else if (def->isSubClassOf("SuccessorConstraint")) {
|
||||||
|
kind = CK_Successor;
|
||||||
|
} else {
|
||||||
|
assert(def->isSubClassOf("Constraint"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Constraint::Constraint(Kind kind, const llvm::Record *record)
|
||||||
|
: def(record), kind(kind) {
|
||||||
|
// Look through OpVariable's to their constraint.
|
||||||
|
if (def->isSubClassOf("OpVariable"))
|
||||||
|
def = def->getValueAsDef("constraint");
|
||||||
|
}
|
||||||
|
|
||||||
|
Pred Constraint::getPredicate() const {
|
||||||
|
auto *val = def->getValue("predicate");
|
||||||
|
|
||||||
|
// If no predicate is specified, then return the null predicate (which
|
||||||
|
// corresponds to true).
|
||||||
|
if (!val)
|
||||||
|
return Pred();
|
||||||
|
|
||||||
|
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
|
||||||
|
return Pred(pred);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Constraint::getConditionTemplate() const {
|
||||||
|
return getPredicate().getCondition();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Constraint::getSummary() const {
|
||||||
|
if (Optional<StringRef> summary = def->getValueAsOptionalString("summary"))
|
||||||
|
return *summary;
|
||||||
|
return def->getName();
|
||||||
|
}
|
||||||
|
|
||||||
|
AppliedConstraint::AppliedConstraint(Constraint &&constraint,
|
||||||
|
llvm::StringRef self,
|
||||||
|
std::vector<std::string> &&entities)
|
||||||
|
: constraint(constraint), self(std::string(self)),
|
||||||
|
entities(std::move(entities)) {}
|
|
@ -0,0 +1,88 @@
|
||||||
|
//===- Constraint.h - Constraint class --------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Constraint wrapper to simplify using TableGen Record for constraints.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_CONSTRAINT_H_
|
||||||
|
#define MLIR_TABLEGEN_CONSTRAINT_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "Predicate.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class Record;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// Wrapper class with helper methods for accessing Constraint defined in
|
||||||
|
// TableGen.
|
||||||
|
class Constraint {
|
||||||
|
public:
|
||||||
|
Constraint(const llvm::Record *record);
|
||||||
|
|
||||||
|
bool operator==(const Constraint &that) { return def == that.def; }
|
||||||
|
bool operator!=(const Constraint &that) { return def != that.def; }
|
||||||
|
|
||||||
|
// Returns the predicate for this constraint.
|
||||||
|
Pred getPredicate() const;
|
||||||
|
|
||||||
|
// Returns the condition template that can be used to check if a type or
|
||||||
|
// attribute satisfies this constraint. The template may contain "{0}" that
|
||||||
|
// must be substituted with an expression returning an mlir::Type or
|
||||||
|
// mlir::Attribute.
|
||||||
|
std::string getConditionTemplate() const;
|
||||||
|
|
||||||
|
// Returns the user-readable description of this constraint. If the
|
||||||
|
// description is not provided, returns the TableGen def name.
|
||||||
|
StringRef getSummary() const;
|
||||||
|
|
||||||
|
// Constraint kind
|
||||||
|
enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized };
|
||||||
|
|
||||||
|
Kind getKind() const { return kind; }
|
||||||
|
|
||||||
|
/// Get an opaque pointer to the constraint.
|
||||||
|
const void *getAsOpaquePointer() const { return def; }
|
||||||
|
/// Construct a constraint from the opaque pointer representation.
|
||||||
|
static Constraint getFromOpaquePointer(const void *ptr) {
|
||||||
|
return Constraint(reinterpret_cast<const llvm::Record *>(ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Constraint(Kind kind, const llvm::Record *record);
|
||||||
|
|
||||||
|
// The TableGen definition of this constraint.
|
||||||
|
const llvm::Record *def;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// What kind of constraint this is.
|
||||||
|
Kind kind;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An constraint and the concrete entities to place the constraint on.
|
||||||
|
struct AppliedConstraint {
|
||||||
|
AppliedConstraint(Constraint &&constraint, StringRef self,
|
||||||
|
std::vector<std::string> &&entities);
|
||||||
|
|
||||||
|
Constraint constraint;
|
||||||
|
// The symbol to replace `$_self` special placeholder in the constraint.
|
||||||
|
std::string self;
|
||||||
|
// The symbols to replace `$N` positional placeholders in the constraint.
|
||||||
|
std::vector<std::string> entities;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_CONSTRAINT_H_
|
|
@ -0,0 +1,94 @@
|
||||||
|
//===- Dialect.cpp - Dialect wrapper class --------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Dialect.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
Dialect::Dialect(const llvm::Record *def) : def(def) {
|
||||||
|
if (def == nullptr)
|
||||||
|
return;
|
||||||
|
for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
|
||||||
|
dependentDialects.push_back(dialect);
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Dialect::getName() const { return def->getValueAsString("name"); }
|
||||||
|
|
||||||
|
StringRef Dialect::getCppNamespace() const {
|
||||||
|
return def->getValueAsString("cppNamespace");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Dialect::getCppClassName() const {
|
||||||
|
// Simply use the name and remove any '_' tokens.
|
||||||
|
std::string cppName = def->getName().str();
|
||||||
|
llvm::erase_if(cppName, [](char c) { return c == '_'; });
|
||||||
|
return cppName;
|
||||||
|
}
|
||||||
|
|
||||||
|
static StringRef getAsStringOrEmpty(const llvm::Record &record,
|
||||||
|
StringRef fieldName) {
|
||||||
|
if (auto valueInit = record.getValueInit(fieldName)) {
|
||||||
|
if (llvm::isa<llvm::StringInit>(valueInit))
|
||||||
|
return record.getValueAsString(fieldName);
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Dialect::getSummary() const {
|
||||||
|
return getAsStringOrEmpty(*def, "summary");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Dialect::getDescription() const {
|
||||||
|
return getAsStringOrEmpty(*def, "description");
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<StringRef> Dialect::getDependentDialects() const {
|
||||||
|
return dependentDialects;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
|
||||||
|
auto value = def->getValueAsString("extraClassDeclaration");
|
||||||
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Dialect::hasCanonicalizer() const {
|
||||||
|
return def->getValueAsBit("hasCanonicalizer");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Dialect::hasConstantMaterializer() const {
|
||||||
|
return def->getValueAsBit("hasConstantMaterializer");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Dialect::hasOperationAttrVerify() const {
|
||||||
|
return def->getValueAsBit("hasOperationAttrVerify");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Dialect::hasRegionArgAttrVerify() const {
|
||||||
|
return def->getValueAsBit("hasRegionArgAttrVerify");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Dialect::hasRegionResultAttrVerify() const {
|
||||||
|
return def->getValueAsBit("hasRegionResultAttrVerify");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Dialect::hasOperationInterfaceFallback() const {
|
||||||
|
return def->getValueAsBit("hasOperationInterfaceFallback");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Dialect::operator==(const Dialect &other) const {
|
||||||
|
return def == other.def;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Dialect::operator<(const Dialect &other) const {
|
||||||
|
return getName() < other.getName();
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_DIALECT_H_
|
||||||
|
#define MLIR_TABLEGEN_DIALECT_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class Record;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
// Wrapper class that contains a MLIR dialect's information defined in TableGen
|
||||||
|
// and provides helper methods for accessing them.
|
||||||
|
class Dialect {
|
||||||
|
public:
|
||||||
|
explicit Dialect(const llvm::Record *def);
|
||||||
|
|
||||||
|
// Returns the name of this dialect.
|
||||||
|
StringRef getName() const;
|
||||||
|
|
||||||
|
// Returns the C++ namespaces that ops of this dialect should be placed into.
|
||||||
|
StringRef getCppNamespace() const;
|
||||||
|
|
||||||
|
// Returns this dialect's C++ class name.
|
||||||
|
std::string getCppClassName() const;
|
||||||
|
|
||||||
|
// Returns the summary description of the dialect. Returns empty string if
|
||||||
|
// none.
|
||||||
|
StringRef getSummary() const;
|
||||||
|
|
||||||
|
// Returns the description of the dialect. Returns empty string if none.
|
||||||
|
StringRef getDescription() const;
|
||||||
|
|
||||||
|
// Returns the list of dialect (class names) that this dialect depends on.
|
||||||
|
// These are dialects that will be loaded on construction of this dialect.
|
||||||
|
ArrayRef<StringRef> getDependentDialects() const;
|
||||||
|
|
||||||
|
// Returns the dialects extra class declaration code.
|
||||||
|
llvm::Optional<StringRef> getExtraClassDeclaration() const;
|
||||||
|
|
||||||
|
/// Returns true if this dialect has a canonicalizer.
|
||||||
|
bool hasCanonicalizer() const;
|
||||||
|
|
||||||
|
// Returns true if this dialect has a constant materializer.
|
||||||
|
bool hasConstantMaterializer() const;
|
||||||
|
|
||||||
|
/// Returns true if this dialect has an operation attribute verifier.
|
||||||
|
bool hasOperationAttrVerify() const;
|
||||||
|
|
||||||
|
/// Returns true if this dialect has a region argument attribute verifier.
|
||||||
|
bool hasRegionArgAttrVerify() const;
|
||||||
|
|
||||||
|
/// Returns true if this dialect has a region result attribute verifier.
|
||||||
|
bool hasRegionResultAttrVerify() const;
|
||||||
|
|
||||||
|
/// Returns true if this dialect has fallback interfaces for its operations.
|
||||||
|
bool hasOperationInterfaceFallback() const;
|
||||||
|
|
||||||
|
// Returns whether two dialects are equal by checking the equality of the
|
||||||
|
// underlying record.
|
||||||
|
bool operator==(const Dialect &other) const;
|
||||||
|
|
||||||
|
bool operator!=(const Dialect &other) const { return !(*this == other); }
|
||||||
|
|
||||||
|
// Compares two dialects by comparing the names of the dialects.
|
||||||
|
bool operator<(const Dialect &other) const;
|
||||||
|
|
||||||
|
// Returns whether the dialect is defined.
|
||||||
|
explicit operator bool() const { return def != nullptr; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llvm::Record *def;
|
||||||
|
std::vector<StringRef> dependentDialects;
|
||||||
|
};
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_DIALECT_H_
|
|
@ -0,0 +1,194 @@
|
||||||
|
//===- Format.cpp - Utilities for String Format ---------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file defines utilities for formatting strings. They are specially
|
||||||
|
// tailored to the needs of TableGen'ing op definitions and rewrite rules,
|
||||||
|
// so they are not expected to be used as widely applicable utilities.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Format.h"
|
||||||
|
#include <cctype>
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
// Marker to indicate an error happened when replacing a placeholder.
|
||||||
|
const char *const kMarkerForNoSubst = "<no-subst-found>";
|
||||||
|
|
||||||
|
FmtContext &FmtContext::addSubst(StringRef placeholder, Twine subst) {
|
||||||
|
customSubstMap[placeholder] = subst.str();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
FmtContext &FmtContext::withBuilder(Twine subst) {
|
||||||
|
builtinSubstMap[PHKind::Builder] = subst.str();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
FmtContext &FmtContext::withOp(Twine subst) {
|
||||||
|
builtinSubstMap[PHKind::Op] = subst.str();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
FmtContext &FmtContext::withSelf(Twine subst) {
|
||||||
|
builtinSubstMap[PHKind::Self] = subst.str();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef>
|
||||||
|
FmtContext::getSubstFor(FmtContext::PHKind placeholder) const {
|
||||||
|
if (placeholder == FmtContext::PHKind::None ||
|
||||||
|
placeholder == FmtContext::PHKind::Custom)
|
||||||
|
return {};
|
||||||
|
auto it = builtinSubstMap.find(placeholder);
|
||||||
|
if (it == builtinSubstMap.end())
|
||||||
|
return {};
|
||||||
|
return StringRef(it->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef> FmtContext::getSubstFor(StringRef placeholder) const {
|
||||||
|
auto it = customSubstMap.find(placeholder);
|
||||||
|
if (it == customSubstMap.end())
|
||||||
|
return {};
|
||||||
|
return StringRef(it->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) {
|
||||||
|
return StringSwitch<FmtContext::PHKind>(str)
|
||||||
|
.Case("_builder", FmtContext::PHKind::Builder)
|
||||||
|
.Case("_op", FmtContext::PHKind::Op)
|
||||||
|
.Case("_self", FmtContext::PHKind::Self)
|
||||||
|
.Case("", FmtContext::PHKind::None)
|
||||||
|
.Default(FmtContext::PHKind::Custom);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<FmtReplacement, StringRef>
|
||||||
|
FmtObjectBase::splitFmtSegment(StringRef fmt) {
|
||||||
|
size_t begin = fmt.find_first_of('$');
|
||||||
|
if (begin == StringRef::npos) {
|
||||||
|
// No placeholders: the whole format string should be returned as a
|
||||||
|
// literal string.
|
||||||
|
return {FmtReplacement{fmt}, StringRef()};
|
||||||
|
}
|
||||||
|
if (begin != 0) {
|
||||||
|
// The first placeholder is not at the beginning: we can split the format
|
||||||
|
// string into a literal string and the rest.
|
||||||
|
return {FmtReplacement{fmt.substr(0, begin)}, fmt.substr(begin)};
|
||||||
|
}
|
||||||
|
|
||||||
|
// The first placeholder is at the beginning
|
||||||
|
|
||||||
|
if (fmt.size() == 1) {
|
||||||
|
// The whole format string just contains '$': treat as literal.
|
||||||
|
return {FmtReplacement{fmt}, StringRef()};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow escaping dollar with '$$'
|
||||||
|
if (fmt[1] == '$') {
|
||||||
|
return {FmtReplacement{fmt.substr(0, 1)}, fmt.substr(2)};
|
||||||
|
}
|
||||||
|
|
||||||
|
// First try to see if it's a positional placeholder, and then handle special
|
||||||
|
// placeholders.
|
||||||
|
|
||||||
|
size_t end = fmt.find_if_not([](char c) { return std::isdigit(c); }, 1);
|
||||||
|
if (end != 1) {
|
||||||
|
// We have a positional placeholder. Parse the index.
|
||||||
|
size_t index = 0;
|
||||||
|
if (fmt.substr(1, end - 1).consumeInteger(0, index)) {
|
||||||
|
llvm_unreachable("invalid replacement sequence index");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (end == StringRef::npos) {
|
||||||
|
// All the remaining characters are part of the positional placeholder.
|
||||||
|
return {FmtReplacement{fmt, index}, StringRef()};
|
||||||
|
}
|
||||||
|
return {FmtReplacement{fmt.substr(0, end), index}, fmt.substr(end)};
|
||||||
|
}
|
||||||
|
|
||||||
|
end = fmt.find_if_not([](char c) { return std::isalnum(c) || c == '_'; }, 1);
|
||||||
|
auto placeholder = FmtContext::getPlaceHolderKind(fmt.substr(1, end - 1));
|
||||||
|
if (end == StringRef::npos) {
|
||||||
|
// All the remaining characters are part of the special placeholder.
|
||||||
|
return {FmtReplacement{fmt, placeholder}, StringRef()};
|
||||||
|
}
|
||||||
|
return {FmtReplacement{fmt.substr(0, end), placeholder}, fmt.substr(end)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<FmtReplacement> FmtObjectBase::parseFormatString(StringRef fmt) {
|
||||||
|
std::vector<FmtReplacement> replacements;
|
||||||
|
FmtReplacement repl;
|
||||||
|
while (!fmt.empty()) {
|
||||||
|
std::tie(repl, fmt) = splitFmtSegment(fmt);
|
||||||
|
if (repl.type != FmtReplacement::Type::Empty)
|
||||||
|
replacements.push_back(repl);
|
||||||
|
}
|
||||||
|
return replacements;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FmtObjectBase::format(raw_ostream &s) const {
|
||||||
|
for (auto &repl : replacements) {
|
||||||
|
if (repl.type == FmtReplacement::Type::Empty)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (repl.type == FmtReplacement::Type::Literal) {
|
||||||
|
s << repl.spec;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (repl.type == FmtReplacement::Type::SpecialPH) {
|
||||||
|
if (repl.placeholder == FmtContext::PHKind::None) {
|
||||||
|
s << repl.spec;
|
||||||
|
} else if (!context) {
|
||||||
|
// We need the context to replace special placeholders.
|
||||||
|
s << repl.spec << kMarkerForNoSubst;
|
||||||
|
} else {
|
||||||
|
Optional<StringRef> subst;
|
||||||
|
if (repl.placeholder == FmtContext::PHKind::Custom) {
|
||||||
|
// Skip the leading '$' sign for the custom placeholder
|
||||||
|
subst = context->getSubstFor(repl.spec.substr(1));
|
||||||
|
} else {
|
||||||
|
subst = context->getSubstFor(repl.placeholder);
|
||||||
|
}
|
||||||
|
if (subst)
|
||||||
|
s << *subst;
|
||||||
|
else
|
||||||
|
s << repl.spec << kMarkerForNoSubst;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(repl.type == FmtReplacement::Type::PositionalPH);
|
||||||
|
|
||||||
|
if (repl.index >= adapters.size()) {
|
||||||
|
s << repl.spec << kMarkerForNoSubst;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
adapters[repl.index]->format(s, /*Options=*/"");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FmtStrVecObject::FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
|
||||||
|
ArrayRef<std::string> params)
|
||||||
|
: FmtObjectBase(fmt, ctx, params.size()) {
|
||||||
|
parameters.reserve(params.size());
|
||||||
|
for (std::string p : params)
|
||||||
|
parameters.push_back(llvm::detail::build_format_adapter(std::move(p)));
|
||||||
|
|
||||||
|
adapters.reserve(parameters.size());
|
||||||
|
for (auto &p : parameters)
|
||||||
|
adapters.push_back(&p);
|
||||||
|
}
|
||||||
|
|
||||||
|
FmtStrVecObject::FmtStrVecObject(FmtStrVecObject &&that)
|
||||||
|
: FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) {
|
||||||
|
adapters.reserve(parameters.size());
|
||||||
|
for (auto &p : parameters)
|
||||||
|
adapters.push_back(&p);
|
||||||
|
}
|
|
@ -0,0 +1,259 @@
|
||||||
|
//===- Format.h - Utilities for String Format -------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file declares utilities for formatting strings. They are specially
|
||||||
|
// tailored to the needs of TableGen'ing op definitions and rewrite rules,
|
||||||
|
// so they are not expected to be used as widely applicable utilities.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_FORMAT_H_
|
||||||
|
#define MLIR_TABLEGEN_FORMAT_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/StringMap.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
/// Format context containing substitutions for special placeholders.
|
||||||
|
///
|
||||||
|
/// This context divides special placeholders into two categories: builtin ones
|
||||||
|
/// and custom ones.
|
||||||
|
///
|
||||||
|
/// Builtin placeholders are baked into `FmtContext` and each one of them has a
|
||||||
|
/// dedicated setter. They can be used in all dialects. Their names follow the
|
||||||
|
/// convention of `$_<name>`. The rationale of the leading underscore is to
|
||||||
|
/// avoid confusion and name collision: op arguments/attributes/results are
|
||||||
|
/// named as $<name>, and we can potentially support referencing those entities
|
||||||
|
/// directly in the format template in the future.
|
||||||
|
//
|
||||||
|
/// Custom ones are registered by dialect-specific TableGen backends and use the
|
||||||
|
/// same unified setter.
|
||||||
|
class FmtContext {
|
||||||
|
public:
|
||||||
|
// Placeholder kinds
|
||||||
|
enum class PHKind : char {
|
||||||
|
None,
|
||||||
|
Custom, // For custom placeholders
|
||||||
|
Builder, // For the $_builder placeholder
|
||||||
|
Op, // For the $_op placeholder
|
||||||
|
Self, // For the $_self placeholder
|
||||||
|
};
|
||||||
|
|
||||||
|
FmtContext() = default;
|
||||||
|
|
||||||
|
// Setter for custom placeholders
|
||||||
|
FmtContext &addSubst(StringRef placeholder, Twine subst);
|
||||||
|
|
||||||
|
// Setters for builtin placeholders
|
||||||
|
FmtContext &withBuilder(Twine subst);
|
||||||
|
FmtContext &withOp(Twine subst);
|
||||||
|
FmtContext &withSelf(Twine subst);
|
||||||
|
|
||||||
|
Optional<StringRef> getSubstFor(PHKind placeholder) const;
|
||||||
|
Optional<StringRef> getSubstFor(StringRef placeholder) const;
|
||||||
|
|
||||||
|
static PHKind getPlaceHolderKind(StringRef str);
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct PHKindInfo : DenseMapInfo<PHKind> {
|
||||||
|
using CharInfo = DenseMapInfo<char>;
|
||||||
|
|
||||||
|
static inline PHKind getEmptyKey() {
|
||||||
|
return static_cast<PHKind>(CharInfo::getEmptyKey());
|
||||||
|
}
|
||||||
|
static inline PHKind getTombstoneKey() {
|
||||||
|
return static_cast<PHKind>(CharInfo::getTombstoneKey());
|
||||||
|
}
|
||||||
|
static unsigned getHashValue(const PHKind &val) {
|
||||||
|
return CharInfo::getHashValue(static_cast<char>(val));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const PHKind &lhs, const PHKind &rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
llvm::SmallDenseMap<PHKind, std::string, 4, PHKindInfo> builtinSubstMap;
|
||||||
|
llvm::StringMap<std::string> customSubstMap;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Struct representing a replacement segment for the formatted string. It can
|
||||||
|
/// be a segment of the formatting template (for `Literal`) or a replacement
|
||||||
|
/// parameter (for `PositionalPH` and `SpecialPH`).
|
||||||
|
struct FmtReplacement {
|
||||||
|
enum class Type { Empty, Literal, PositionalPH, SpecialPH };
|
||||||
|
|
||||||
|
FmtReplacement() = default;
|
||||||
|
explicit FmtReplacement(StringRef literal)
|
||||||
|
: type(Type::Literal), spec(literal) {}
|
||||||
|
FmtReplacement(StringRef spec, size_t index)
|
||||||
|
: type(Type::PositionalPH), spec(spec), index(index) {}
|
||||||
|
FmtReplacement(StringRef spec, FmtContext::PHKind placeholder)
|
||||||
|
: type(Type::SpecialPH), spec(spec), placeholder(placeholder) {}
|
||||||
|
|
||||||
|
Type type = Type::Empty;
|
||||||
|
StringRef spec;
|
||||||
|
size_t index = 0;
|
||||||
|
FmtContext::PHKind placeholder = FmtContext::PHKind::None;
|
||||||
|
};
|
||||||
|
|
||||||
|
class FmtObjectBase {
|
||||||
|
private:
|
||||||
|
static std::pair<FmtReplacement, StringRef> splitFmtSegment(StringRef fmt);
|
||||||
|
static std::vector<FmtReplacement> parseFormatString(StringRef fmt);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// The parameters are stored in a std::tuple, which does not provide runtime
|
||||||
|
// indexing capabilities. In order to enable runtime indexing, we use this
|
||||||
|
// structure to put the parameters into a std::vector. Since the parameters
|
||||||
|
// are not all the same type, we use some type-erasure by wrapping the
|
||||||
|
// parameters in a template class that derives from a non-template superclass.
|
||||||
|
// Essentially, we are converting a std::tuple<Derived<Ts...>> to a
|
||||||
|
// std::vector<Base*>.
|
||||||
|
struct CreateAdapters {
|
||||||
|
template <typename... Ts>
|
||||||
|
std::vector<llvm::detail::format_adapter *> operator()(Ts &... items) {
|
||||||
|
return std::vector<llvm::detail::format_adapter *>{&items...};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
StringRef fmt;
|
||||||
|
const FmtContext *context;
|
||||||
|
std::vector<llvm::detail::format_adapter *> adapters;
|
||||||
|
std::vector<FmtReplacement> replacements;
|
||||||
|
|
||||||
|
public:
|
||||||
|
FmtObjectBase(StringRef fmt, const FmtContext *ctx, size_t numParams)
|
||||||
|
: fmt(fmt), context(ctx), replacements(parseFormatString(fmt)) {}
|
||||||
|
|
||||||
|
FmtObjectBase(const FmtObjectBase &that) = delete;
|
||||||
|
|
||||||
|
FmtObjectBase(FmtObjectBase &&that)
|
||||||
|
: fmt(std::move(that.fmt)), context(that.context),
|
||||||
|
adapters(), // adapters are initialized by FmtObject
|
||||||
|
replacements(std::move(that.replacements)) {}
|
||||||
|
|
||||||
|
void format(llvm::raw_ostream &s) const;
|
||||||
|
|
||||||
|
std::string str() const {
|
||||||
|
std::string result;
|
||||||
|
llvm::raw_string_ostream s(result);
|
||||||
|
format(s);
|
||||||
|
return s.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <unsigned N> SmallString<N> sstr() const {
|
||||||
|
SmallString<N> result;
|
||||||
|
llvm::raw_svector_ostream s(result);
|
||||||
|
format(s);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <unsigned N> operator SmallString<N>() const { return sstr<N>(); }
|
||||||
|
|
||||||
|
operator std::string() const { return str(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Tuple> class FmtObject : public FmtObjectBase {
|
||||||
|
// Storage for the parameter adapters. Since the base class erases the type
|
||||||
|
// of the parameters, we have to own the storage for the parameters here, and
|
||||||
|
// have the base class store type-erased pointers into this tuple.
|
||||||
|
Tuple parameters;
|
||||||
|
|
||||||
|
public:
|
||||||
|
FmtObject(StringRef fmt, const FmtContext *ctx, Tuple &¶ms)
|
||||||
|
: FmtObjectBase(fmt, ctx, std::tuple_size<Tuple>::value),
|
||||||
|
parameters(std::move(params)) {
|
||||||
|
adapters.reserve(std::tuple_size<Tuple>::value);
|
||||||
|
adapters = llvm::apply_tuple(CreateAdapters(), parameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
FmtObject(FmtObject const &that) = delete;
|
||||||
|
|
||||||
|
FmtObject(FmtObject &&that)
|
||||||
|
: FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) {
|
||||||
|
adapters.reserve(that.adapters.size());
|
||||||
|
adapters = llvm::apply_tuple(CreateAdapters(), parameters);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class FmtStrVecObject : public FmtObjectBase {
|
||||||
|
public:
|
||||||
|
using StrFormatAdapter =
|
||||||
|
decltype(llvm::detail::build_format_adapter(std::declval<std::string>()));
|
||||||
|
|
||||||
|
FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
|
||||||
|
ArrayRef<std::string> params);
|
||||||
|
FmtStrVecObject(FmtStrVecObject const &that) = delete;
|
||||||
|
FmtStrVecObject(FmtStrVecObject &&that);
|
||||||
|
|
||||||
|
private:
|
||||||
|
SmallVector<StrFormatAdapter, 16> parameters;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Formats text by substituting placeholders in format string with replacement
|
||||||
|
/// parameters.
|
||||||
|
///
|
||||||
|
/// There are two categories of placeholders accepted, both led by a '$' sign:
|
||||||
|
///
|
||||||
|
/// 1. Positional placeholder: $[0-9]+
|
||||||
|
/// 2. Special placeholder: $[a-zA-Z_][a-zA-Z0-9_]*
|
||||||
|
///
|
||||||
|
/// Replacement parameters for positional placeholders are supplied as the
|
||||||
|
/// `vals` parameter pack with 1:1 mapping. That is, $0 will be replaced by the
|
||||||
|
/// first parameter in `vals`, $1 by the second one, and so on. Note that you
|
||||||
|
/// can use the positional placeholders in any order and repeat any times, for
|
||||||
|
/// example, "$2 $1 $1 $0" is accepted.
|
||||||
|
///
|
||||||
|
/// Replacement parameters for special placeholders are supplied using the `ctx`
|
||||||
|
/// format context.
|
||||||
|
///
|
||||||
|
/// The `fmt` is recorded as a `StringRef` inside the returned `FmtObject`.
|
||||||
|
/// The caller needs to make sure the underlying data is available when the
|
||||||
|
/// `FmtObject` is used.
|
||||||
|
///
|
||||||
|
/// `ctx` accepts a nullptr if there is no special placeholder is used.
|
||||||
|
///
|
||||||
|
/// If no substitution is provided for a placeholder or any error happens during
|
||||||
|
/// format string parsing or replacement, the placeholder will be outputted
|
||||||
|
/// as-is with an additional marker '<no-subst-found>', to aid debugging.
|
||||||
|
///
|
||||||
|
/// To print a '$' literally, escape it with '$$'.
|
||||||
|
///
|
||||||
|
/// This utility function is inspired by LLVM formatv(), with modifications
|
||||||
|
/// specially tailored for TableGen C++ generation usage:
|
||||||
|
///
|
||||||
|
/// 1. This utility use '$' instead of '{' and '}' for denoting the placeholder
|
||||||
|
/// because '{' and '}' are frequently used in C++ code.
|
||||||
|
/// 2. This utility does not support format layout because it is rarely needed
|
||||||
|
/// in C++ code generation.
|
||||||
|
template <typename... Ts>
|
||||||
|
inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
|
||||||
|
-> FmtObject<decltype(std::make_tuple(
|
||||||
|
llvm::detail::build_format_adapter(std::forward<Ts>(vals))...))> {
|
||||||
|
using ParamTuple = decltype(std::make_tuple(
|
||||||
|
llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
|
||||||
|
return FmtObject<ParamTuple>(
|
||||||
|
fmt, ctx,
|
||||||
|
std::make_tuple(
|
||||||
|
llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline FmtStrVecObject tgfmt(StringRef fmt, const FmtContext *ctx,
|
||||||
|
ArrayRef<std::string> params) {
|
||||||
|
return FmtStrVecObject(fmt, ctx, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_FORMAT_H_
|
|
@ -0,0 +1,72 @@
|
||||||
|
//===- GenInfo.h - Generator info -------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_GENINFO_H_
|
||||||
|
#define MLIR_TABLEGEN_GENINFO_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class RecordKeeper;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
/// Generator function to invoke.
|
||||||
|
using GenFunction = std::function<bool(const llvm::RecordKeeper &recordKeeper,
|
||||||
|
raw_ostream &os)>;
|
||||||
|
|
||||||
|
/// Structure to group information about a generator (argument to invoke via
|
||||||
|
/// mlir-tblgen, description, and generator function).
|
||||||
|
class GenInfo {
|
||||||
|
public:
|
||||||
|
/// GenInfo constructor should not be invoked directly, instead use
|
||||||
|
/// GenRegistration or registerGen.
|
||||||
|
GenInfo(StringRef arg, StringRef description, GenFunction generator)
|
||||||
|
: arg(arg), description(description), generator(generator) {}
|
||||||
|
|
||||||
|
/// Invokes the generator and returns whether the generator failed.
|
||||||
|
bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
|
||||||
|
assert(generator && "Cannot call generator with null generator");
|
||||||
|
return generator(recordKeeper, os);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the command line option that may be passed to 'mlir-tblgen' to
|
||||||
|
/// invoke this generator.
|
||||||
|
StringRef getGenArgument() const { return arg; }
|
||||||
|
|
||||||
|
/// Returns a description for the generator.
|
||||||
|
StringRef getGenDescription() const { return description; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The argument with which to invoke the generator via mlir-tblgen.
|
||||||
|
StringRef arg;
|
||||||
|
|
||||||
|
// Description of the generator.
|
||||||
|
StringRef description;
|
||||||
|
|
||||||
|
// Generator function.
|
||||||
|
GenFunction generator;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// GenRegistration provides a global initializer that registers a generator
|
||||||
|
/// function.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
///
|
||||||
|
/// // At namespace scope.
|
||||||
|
/// static GenRegistration Print("print", "Print records", [](...){...});
|
||||||
|
struct GenRegistration {
|
||||||
|
GenRegistration(StringRef arg, StringRef description, GenFunction function);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_GENINFO_H_
|
|
@ -0,0 +1,31 @@
|
||||||
|
//===- GenNameParser.h - Command line parser for generators -----*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// The GenNameParser class adds all passes linked in to the system that are
|
||||||
|
// creatable to the tool.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_GENNAMEPARSER_H_
|
||||||
|
#define MLIR_TABLEGEN_GENNAMEPARSER_H_
|
||||||
|
|
||||||
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class GenInfo;
|
||||||
|
|
||||||
|
/// Adds command line option for each registered generator.
|
||||||
|
struct GenNameParser : public llvm::cl::parser<const GenInfo *> {
|
||||||
|
GenNameParser(llvm::cl::Option &opt);
|
||||||
|
|
||||||
|
void printOptionInfo(const llvm::cl::Option &O,
|
||||||
|
size_t GlobalWidth) const override;
|
||||||
|
};
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_GENNAMEPARSER_H_
|
|
@ -0,0 +1,144 @@
|
||||||
|
//===- Interfaces.cpp - Interface classes ---------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Interfaces.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/TableGen/Error.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// InterfaceMethod
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
|
||||||
|
llvm::DagInit *args = def->getValueAsDag("arguments");
|
||||||
|
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
|
||||||
|
arguments.push_back(
|
||||||
|
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
|
||||||
|
args->getArgNameStr(i)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef InterfaceMethod::getReturnType() const {
|
||||||
|
return def->getValueAsString("returnType");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the name of this method.
|
||||||
|
StringRef InterfaceMethod::getName() const {
|
||||||
|
return def->getValueAsString("name");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return if this method is static.
|
||||||
|
bool InterfaceMethod::isStatic() const {
|
||||||
|
return def->isSubClassOf("StaticInterfaceMethod");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the body for this method if it has one.
|
||||||
|
llvm::Optional<StringRef> InterfaceMethod::getBody() const {
|
||||||
|
auto value = def->getValueAsString("body");
|
||||||
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the default implementation for this method if it has one.
|
||||||
|
llvm::Optional<StringRef> InterfaceMethod::getDefaultImplementation() const {
|
||||||
|
auto value = def->getValueAsString("defaultBody");
|
||||||
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the description of this method if it has one.
|
||||||
|
llvm::Optional<StringRef> InterfaceMethod::getDescription() const {
|
||||||
|
auto value = def->getValueAsString("description");
|
||||||
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<InterfaceMethod::Argument> InterfaceMethod::getArguments() const {
|
||||||
|
return arguments;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool InterfaceMethod::arg_empty() const { return arguments.empty(); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Interface
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Interface::Interface(const llvm::Record *def) : def(def) {
|
||||||
|
assert(def->isSubClassOf("Interface") &&
|
||||||
|
"must be subclass of TableGen 'Interface' class");
|
||||||
|
|
||||||
|
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
|
||||||
|
for (llvm::Init *init : listInit->getValues())
|
||||||
|
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the name of this interface.
|
||||||
|
StringRef Interface::getName() const {
|
||||||
|
return def->getValueAsString("cppClassName");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the C++ namespace of this interface.
|
||||||
|
StringRef Interface::getCppNamespace() const {
|
||||||
|
return def->getValueAsString("cppNamespace");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the methods of this interface.
|
||||||
|
ArrayRef<InterfaceMethod> Interface::getMethods() const { return methods; }
|
||||||
|
|
||||||
|
// Return the description of this method if it has one.
|
||||||
|
llvm::Optional<StringRef> Interface::getDescription() const {
|
||||||
|
auto value = def->getValueAsString("description");
|
||||||
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the interfaces extra class declaration code.
|
||||||
|
llvm::Optional<StringRef> Interface::getExtraClassDeclaration() const {
|
||||||
|
auto value = def->getValueAsString("extraClassDeclaration");
|
||||||
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the traits extra class declaration code.
|
||||||
|
llvm::Optional<StringRef> Interface::getExtraTraitClassDeclaration() const {
|
||||||
|
auto value = def->getValueAsString("extraTraitClassDeclaration");
|
||||||
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the body for this method if it has one.
|
||||||
|
llvm::Optional<StringRef> Interface::getVerify() const {
|
||||||
|
// Only OpInterface supports the verify method.
|
||||||
|
if (!isa<OpInterface>(this))
|
||||||
|
return llvm::None;
|
||||||
|
auto value = def->getValueAsString("verify");
|
||||||
|
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttrInterface
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool AttrInterface::classof(const Interface *interface) {
|
||||||
|
return interface->getDef().isSubClassOf("AttrInterface");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpInterface
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool OpInterface::classof(const Interface *interface) {
|
||||||
|
return interface->getDef().isSubClassOf("OpInterface");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TypeInterface
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool TypeInterface::classof(const Interface *interface) {
|
||||||
|
return interface->getDef().isSubClassOf("TypeInterface");
|
||||||
|
}
|
|
@ -0,0 +1,129 @@
|
||||||
|
//===- Interfaces.h - Interface wrapper classes -----------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_INTERFACES_H_
|
||||||
|
#define MLIR_TABLEGEN_INTERFACES_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class Init;
|
||||||
|
class Record;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// Wrapper class with helper methods for accessing InterfaceMethod defined
|
||||||
|
// in TableGen.
|
||||||
|
class InterfaceMethod {
|
||||||
|
public:
|
||||||
|
// This struct represents a single method argument.
|
||||||
|
struct Argument {
|
||||||
|
StringRef type;
|
||||||
|
StringRef name;
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit InterfaceMethod(const llvm::Record *def);
|
||||||
|
|
||||||
|
// Return the return type of this method.
|
||||||
|
StringRef getReturnType() const;
|
||||||
|
|
||||||
|
// Return the name of this method.
|
||||||
|
StringRef getName() const;
|
||||||
|
|
||||||
|
// Return if this method is static.
|
||||||
|
bool isStatic() const;
|
||||||
|
|
||||||
|
// Return the body for this method if it has one.
|
||||||
|
llvm::Optional<StringRef> getBody() const;
|
||||||
|
|
||||||
|
// Return the default implementation for this method if it has one.
|
||||||
|
llvm::Optional<StringRef> getDefaultImplementation() const;
|
||||||
|
|
||||||
|
// Return the description of this method if it has one.
|
||||||
|
llvm::Optional<StringRef> getDescription() const;
|
||||||
|
|
||||||
|
// Arguments.
|
||||||
|
ArrayRef<Argument> getArguments() const;
|
||||||
|
bool arg_empty() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The TableGen definition of this method.
|
||||||
|
const llvm::Record *def;
|
||||||
|
|
||||||
|
// The arguments of this method.
|
||||||
|
SmallVector<Argument, 2> arguments;
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Interface
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Wrapper class with helper methods for accessing Interfaces defined in
|
||||||
|
// TableGen.
|
||||||
|
class Interface {
|
||||||
|
public:
|
||||||
|
explicit Interface(const llvm::Record *def);
|
||||||
|
|
||||||
|
// Return the name of this interface.
|
||||||
|
StringRef getName() const;
|
||||||
|
|
||||||
|
// Return the C++ namespace of this interface.
|
||||||
|
StringRef getCppNamespace() const;
|
||||||
|
|
||||||
|
// Return the methods of this interface.
|
||||||
|
ArrayRef<InterfaceMethod> getMethods() const;
|
||||||
|
|
||||||
|
// Return the description of this method if it has one.
|
||||||
|
llvm::Optional<StringRef> getDescription() const;
|
||||||
|
|
||||||
|
// Return the interfaces extra class declaration code.
|
||||||
|
llvm::Optional<StringRef> getExtraClassDeclaration() const;
|
||||||
|
|
||||||
|
// Return the traits extra class declaration code.
|
||||||
|
llvm::Optional<StringRef> getExtraTraitClassDeclaration() const;
|
||||||
|
|
||||||
|
// Return the verify method body if it has one.
|
||||||
|
llvm::Optional<StringRef> getVerify() const;
|
||||||
|
|
||||||
|
// Returns the Tablegen definition this interface was constructed from.
|
||||||
|
const llvm::Record &getDef() const { return *def; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The TableGen definition of this interface.
|
||||||
|
const llvm::Record *def;
|
||||||
|
|
||||||
|
// The methods of this interface.
|
||||||
|
SmallVector<InterfaceMethod, 8> methods;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An interface that is registered to an Attribute.
|
||||||
|
struct AttrInterface : public Interface {
|
||||||
|
using Interface::Interface;
|
||||||
|
|
||||||
|
static bool classof(const Interface *interface);
|
||||||
|
};
|
||||||
|
// An interface that is registered to an Operation.
|
||||||
|
struct OpInterface : public Interface {
|
||||||
|
using Interface::Interface;
|
||||||
|
|
||||||
|
static bool classof(const Interface *interface);
|
||||||
|
};
|
||||||
|
// An interface that is registered to a Type.
|
||||||
|
struct TypeInterface : public Interface {
|
||||||
|
using Interface::Interface;
|
||||||
|
|
||||||
|
static bool classof(const Interface *interface);
|
||||||
|
};
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_INTERFACES_H_
|
|
@ -0,0 +1,347 @@
|
||||||
|
//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "OpClass.h"
|
||||||
|
|
||||||
|
#include "Format.h"
|
||||||
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
#include "llvm/ADT/Twine.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "mlir-tblgen-opclass"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Returns space to be emitted after the given C++ `type`. return "" if the
|
||||||
|
// ends with '&' or '*', or is empty, else returns " ".
|
||||||
|
StringRef getSpaceAfterType(StringRef type) {
|
||||||
|
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpMethodParameter definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
|
||||||
|
if (properties & PP_Optional)
|
||||||
|
os << "/*optional*/";
|
||||||
|
os << type << getSpaceAfterType(type) << name;
|
||||||
|
if (emitDefault && !defaultValue.empty())
|
||||||
|
os << " = " << defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpMethodParameters definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Factory methods to construct the correct type of `OpMethodParameters`
|
||||||
|
// object based on the arguments.
|
||||||
|
std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
|
||||||
|
return std::make_unique<OpMethodResolvedParameters>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OpMethodParameters>
|
||||||
|
OpMethodParameters::create(StringRef params) {
|
||||||
|
return std::make_unique<OpMethodUnresolvedParameters>(params);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OpMethodParameters>
|
||||||
|
OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms) {
|
||||||
|
return std::make_unique<OpMethodResolvedParameters>(std::move(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OpMethodParameters>
|
||||||
|
OpMethodParameters::create(StringRef type, StringRef name,
|
||||||
|
StringRef defaultValue) {
|
||||||
|
return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpMethodUnresolvedParameters definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
|
||||||
|
os << parameters;
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
|
||||||
|
// We need to remove the default values for parameters in method definition.
|
||||||
|
// TODO: We are using '=' and ',' as delimiters for parameter
|
||||||
|
// initializers. This is incorrect for initializer list with more than one
|
||||||
|
// element. Change to a more robust approach.
|
||||||
|
llvm::SmallVector<StringRef, 4> tokens;
|
||||||
|
StringRef params = parameters;
|
||||||
|
while (!params.empty()) {
|
||||||
|
std::pair<StringRef, StringRef> parts = params.split("=");
|
||||||
|
tokens.push_back(parts.first);
|
||||||
|
params = parts.second.split(',').second;
|
||||||
|
}
|
||||||
|
llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpMethodResolvedParameters definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Returns true if a method with these parameters makes a method with parameters
|
||||||
|
// `other` redundant. This should return true only if all possible calls to the
|
||||||
|
// other method can be replaced by calls to this method.
|
||||||
|
bool OpMethodResolvedParameters::makesRedundant(
|
||||||
|
const OpMethodResolvedParameters &other) const {
|
||||||
|
const size_t otherNumParams = other.getNumParameters();
|
||||||
|
const size_t thisNumParams = getNumParameters();
|
||||||
|
|
||||||
|
// All calls to the other method can be replaced this method only if this
|
||||||
|
// method has the same or more arguments number of arguments as the other, and
|
||||||
|
// the common arguments have the same type.
|
||||||
|
if (thisNumParams < otherNumParams)
|
||||||
|
return false;
|
||||||
|
for (int idx : llvm::seq<int>(0, otherNumParams))
|
||||||
|
if (parameters[idx].getType() != other.parameters[idx].getType())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// If all the common arguments have the same type, we can elide the other
|
||||||
|
// method if this method has the same number of arguments as other or the
|
||||||
|
// first argument after the common ones has a default value (and by C++
|
||||||
|
// requirement, all the later ones will also have a default value).
|
||||||
|
return thisNumParams == otherNumParams ||
|
||||||
|
parameters[otherNumParams].hasDefaultValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
|
||||||
|
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
|
||||||
|
param.writeDeclTo(os);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
|
||||||
|
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
|
||||||
|
param.writeDefTo(os);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpMethodSignature definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Returns if a method with this signature makes a method with `other` signature
|
||||||
|
// redundant. Only supports resolved parameters.
|
||||||
|
bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
|
||||||
|
if (methodName != other.methodName)
|
||||||
|
return false;
|
||||||
|
auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
|
||||||
|
auto *resolvedOther =
|
||||||
|
dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
|
||||||
|
if (resolvedThis && resolvedOther)
|
||||||
|
return resolvedThis->makesRedundant(*resolvedOther);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
|
||||||
|
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
|
||||||
|
parameters->writeDeclTo(os);
|
||||||
|
os << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpMethodSignature::writeDefTo(raw_ostream &os,
|
||||||
|
StringRef namePrefix) const {
|
||||||
|
os << returnType << getSpaceAfterType(returnType) << namePrefix
|
||||||
|
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
|
||||||
|
parameters->writeDefTo(os);
|
||||||
|
os << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpMethodBody definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
|
||||||
|
|
||||||
|
OpMethodBody &OpMethodBody::operator<<(Twine content) {
|
||||||
|
if (isEffective)
|
||||||
|
body.append(content.str());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpMethodBody &OpMethodBody::operator<<(int content) {
|
||||||
|
if (isEffective)
|
||||||
|
body.append(std::to_string(content));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
|
||||||
|
if (isEffective)
|
||||||
|
body.append(content.str());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpMethodBody::writeTo(raw_ostream &os) const {
|
||||||
|
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
|
||||||
|
os << bodyRef;
|
||||||
|
if (bodyRef.empty() || bodyRef.back() != '\n')
|
||||||
|
os << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpMethod definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void OpMethod::writeDeclTo(raw_ostream &os) const {
|
||||||
|
os.indent(2);
|
||||||
|
if (isStatic())
|
||||||
|
os << "static ";
|
||||||
|
if (properties & MP_Constexpr)
|
||||||
|
os << "constexpr ";
|
||||||
|
methodSignature.writeDeclTo(os);
|
||||||
|
if (!isInline())
|
||||||
|
os << ";";
|
||||||
|
else {
|
||||||
|
os << " {\n";
|
||||||
|
methodBody.writeTo(os);
|
||||||
|
os << "}";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
||||||
|
// Do not write definition if the method is decl only.
|
||||||
|
if (properties & MP_Declaration)
|
||||||
|
return;
|
||||||
|
// Do not generate separate definition for inline method
|
||||||
|
if (isInline())
|
||||||
|
return;
|
||||||
|
methodSignature.writeDefTo(os, namePrefix);
|
||||||
|
os << " {\n";
|
||||||
|
methodBody.writeTo(os);
|
||||||
|
os << "}";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpConstructor definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
|
||||||
|
memberInitializers.append(std::string(llvm::formatv(
|
||||||
|
"{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
||||||
|
// Do not write definition if the method is decl only.
|
||||||
|
if (properties & MP_Declaration)
|
||||||
|
return;
|
||||||
|
|
||||||
|
methodSignature.writeDefTo(os, namePrefix);
|
||||||
|
os << " " << memberInitializers << " {\n";
|
||||||
|
methodBody.writeTo(os);
|
||||||
|
os << "}";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Class definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Class::Class(StringRef name) : className(name) {}
|
||||||
|
|
||||||
|
void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
|
||||||
|
std::string varName = formatv("{0} {1}", type, name).str();
|
||||||
|
std::string field = defaultValue.empty()
|
||||||
|
? varName
|
||||||
|
: formatv("{0} = {1}", varName, defaultValue).str();
|
||||||
|
fields.push_back(std::move(field));
|
||||||
|
}
|
||||||
|
void Class::writeDeclTo(raw_ostream &os) const {
|
||||||
|
bool hasPrivateMethod = false;
|
||||||
|
os << "class " << className << " {\n";
|
||||||
|
os << "public:\n";
|
||||||
|
|
||||||
|
forAllMethods([&](const OpMethod &method) {
|
||||||
|
if (!method.isPrivate()) {
|
||||||
|
method.writeDeclTo(os);
|
||||||
|
os << '\n';
|
||||||
|
} else {
|
||||||
|
hasPrivateMethod = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
os << '\n';
|
||||||
|
os << "private:\n";
|
||||||
|
if (hasPrivateMethod) {
|
||||||
|
forAllMethods([&](const OpMethod &method) {
|
||||||
|
if (method.isPrivate()) {
|
||||||
|
method.writeDeclTo(os);
|
||||||
|
os << '\n';
|
||||||
|
}
|
||||||
|
});
|
||||||
|
os << '\n';
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &field : fields)
|
||||||
|
os.indent(2) << field << ";\n";
|
||||||
|
os << "};\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void Class::writeDefTo(raw_ostream &os) const {
|
||||||
|
forAllMethods([&](const OpMethod &method) {
|
||||||
|
method.writeDefTo(os, className);
|
||||||
|
os << "\n\n";
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpClass definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
|
||||||
|
: Class(name), extraClassDeclaration(extraClassDeclaration) {}
|
||||||
|
|
||||||
|
void OpClass::addTrait(Twine trait) {
|
||||||
|
auto traitStr = trait.str();
|
||||||
|
if (traitsSet.insert(traitStr).second)
|
||||||
|
traitsVec.push_back(std::move(traitStr));
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpClass::writeDeclTo(raw_ostream &os) const {
|
||||||
|
os << "class " << className << " : public ::mlir::Op<" << className;
|
||||||
|
for (const auto &trait : traitsVec)
|
||||||
|
os << ", " << trait;
|
||||||
|
os << "> {\npublic:\n";
|
||||||
|
// << " using Op::Op;\n"
|
||||||
|
// << " using Op::print;\n"
|
||||||
|
// << " using Adaptor = " << className << "Adaptor;\n";
|
||||||
|
|
||||||
|
bool hasPrivateMethod = false;
|
||||||
|
forAllMethods([&](const OpMethod &method) {
|
||||||
|
if (!method.isPrivate()) {
|
||||||
|
method.writeDeclTo(os);
|
||||||
|
os << "\n";
|
||||||
|
} else {
|
||||||
|
hasPrivateMethod = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// TODO: Add line control markers to make errors easier to debug.
|
||||||
|
if (!extraClassDeclaration.empty())
|
||||||
|
os << extraClassDeclaration << "\n";
|
||||||
|
|
||||||
|
if (hasPrivateMethod) {
|
||||||
|
os << "\nprivate:\n";
|
||||||
|
forAllMethods([&](const OpMethod &method) {
|
||||||
|
if (method.isPrivate()) {
|
||||||
|
method.writeDeclTo(os);
|
||||||
|
os << "\n";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
os << "};\n";
|
||||||
|
}
|
|
@ -0,0 +1,442 @@
|
||||||
|
//===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file defines several classes for Op C++ code emission. They are only
|
||||||
|
// expected to be used by MLIR TableGen backends.
|
||||||
|
//
|
||||||
|
// We emit the op declaration and definition into separate files: *Ops.h.inc
|
||||||
|
// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
|
||||||
|
// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
|
||||||
|
//
|
||||||
|
// In order to do this split, we need to track method signature and
|
||||||
|
// implementation logic separately. Signature information is used for both
|
||||||
|
// declaration and definition, while implementation logic is only for
|
||||||
|
// definition. So we have the following classes for C++ code emission.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_OPCLASS_H_
|
||||||
|
#define MLIR_TABLEGEN_OPCLASS_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
class FmtObjectBase;
|
||||||
|
|
||||||
|
// Class for holding a single parameter of an op's method for C++ code emission.
|
||||||
|
class OpMethodParameter {
|
||||||
|
public:
|
||||||
|
// Properties (qualifiers) for the parameter.
|
||||||
|
enum Property {
|
||||||
|
PP_None = 0x0,
|
||||||
|
PP_Optional = 0x1,
|
||||||
|
};
|
||||||
|
|
||||||
|
OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "",
|
||||||
|
Property properties = PP_None)
|
||||||
|
: type(type), name(name), defaultValue(defaultValue),
|
||||||
|
properties(properties) {}
|
||||||
|
|
||||||
|
OpMethodParameter(StringRef type, StringRef name, Property property)
|
||||||
|
: OpMethodParameter(type, name, "", property) {}
|
||||||
|
|
||||||
|
// Writes the parameter as a part of a method declaration to `os`.
|
||||||
|
void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
|
||||||
|
|
||||||
|
// Writes the parameter as a part of a method definition to `os`
|
||||||
|
void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
|
||||||
|
|
||||||
|
const std::string &getType() const { return type; }
|
||||||
|
bool hasDefaultValue() const { return !defaultValue.empty(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void writeTo(raw_ostream &os, bool emitDefault) const;
|
||||||
|
|
||||||
|
std::string type;
|
||||||
|
std::string name;
|
||||||
|
std::string defaultValue;
|
||||||
|
Property properties;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Base class for holding parameters of an op's method for C++ code emission.
|
||||||
|
class OpMethodParameters {
|
||||||
|
public:
|
||||||
|
// Discriminator for LLVM-style RTTI.
|
||||||
|
enum ParamsKind {
|
||||||
|
// Separate type and name for each parameter is not known.
|
||||||
|
PK_Unresolved,
|
||||||
|
// Each parameter is resolved to a type and name.
|
||||||
|
PK_Resolved,
|
||||||
|
};
|
||||||
|
|
||||||
|
OpMethodParameters(ParamsKind kind) : kind(kind) {}
|
||||||
|
virtual ~OpMethodParameters() {}
|
||||||
|
|
||||||
|
// LLVM-style RTTI support.
|
||||||
|
ParamsKind getKind() const { return kind; }
|
||||||
|
|
||||||
|
// Writes the parameters as a part of a method declaration to `os`.
|
||||||
|
virtual void writeDeclTo(raw_ostream &os) const = 0;
|
||||||
|
|
||||||
|
// Writes the parameters as a part of a method definition to `os`
|
||||||
|
virtual void writeDefTo(raw_ostream &os) const = 0;
|
||||||
|
|
||||||
|
// Factory methods to create the correct type of `OpMethodParameters`
|
||||||
|
// object based on the arguments.
|
||||||
|
static std::unique_ptr<OpMethodParameters> create();
|
||||||
|
|
||||||
|
static std::unique_ptr<OpMethodParameters> create(StringRef params);
|
||||||
|
|
||||||
|
static std::unique_ptr<OpMethodParameters>
|
||||||
|
create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms);
|
||||||
|
|
||||||
|
static std::unique_ptr<OpMethodParameters>
|
||||||
|
create(StringRef type, StringRef name, StringRef defaultValue = "");
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ParamsKind kind;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class for holding unresolved parameters.
|
||||||
|
class OpMethodUnresolvedParameters : public OpMethodParameters {
|
||||||
|
public:
|
||||||
|
OpMethodUnresolvedParameters(StringRef params)
|
||||||
|
: OpMethodParameters(PK_Unresolved), parameters(params) {}
|
||||||
|
|
||||||
|
// write the parameters as a part of a method declaration to the given `os`.
|
||||||
|
void writeDeclTo(raw_ostream &os) const override;
|
||||||
|
|
||||||
|
// write the parameters as a part of a method definition to the given `os`
|
||||||
|
void writeDefTo(raw_ostream &os) const override;
|
||||||
|
|
||||||
|
// LLVM-style RTTI support.
|
||||||
|
static bool classof(const OpMethodParameters *params) {
|
||||||
|
return params->getKind() == PK_Unresolved;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string parameters;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class for holding resolved parameters.
|
||||||
|
class OpMethodResolvedParameters : public OpMethodParameters {
|
||||||
|
public:
|
||||||
|
OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {}
|
||||||
|
|
||||||
|
OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> &¶ms)
|
||||||
|
: OpMethodParameters(PK_Resolved) {
|
||||||
|
for (OpMethodParameter ¶m : params)
|
||||||
|
parameters.emplace_back(std::move(param));
|
||||||
|
}
|
||||||
|
|
||||||
|
OpMethodResolvedParameters(StringRef type, StringRef name,
|
||||||
|
StringRef defaultValue)
|
||||||
|
: OpMethodParameters(PK_Resolved) {
|
||||||
|
parameters.emplace_back(type, name, defaultValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of parameters.
|
||||||
|
size_t getNumParameters() const { return parameters.size(); }
|
||||||
|
|
||||||
|
// Returns if this method makes the `other` method redundant. Note that this
|
||||||
|
// is more than just finding conflicting methods. This method determines if
|
||||||
|
// the 2 set of parameters are conflicting and if so, returns true if this
|
||||||
|
// method has a more general set of parameters that can replace all possible
|
||||||
|
// calls to the `other` method.
|
||||||
|
bool makesRedundant(const OpMethodResolvedParameters &other) const;
|
||||||
|
|
||||||
|
// write the parameters as a part of a method declaration to the given `os`.
|
||||||
|
void writeDeclTo(raw_ostream &os) const override;
|
||||||
|
|
||||||
|
// write the parameters as a part of a method definition to the given `os`
|
||||||
|
void writeDefTo(raw_ostream &os) const override;
|
||||||
|
|
||||||
|
// LLVM-style RTTI support.
|
||||||
|
static bool classof(const OpMethodParameters *params) {
|
||||||
|
return params->getKind() == PK_Resolved;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
llvm::SmallVector<OpMethodParameter, 4> parameters;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class for holding the signature of an op's method for C++ code emission
|
||||||
|
class OpMethodSignature {
|
||||||
|
public:
|
||||||
|
template <typename... Args>
|
||||||
|
OpMethodSignature(StringRef retType, StringRef name, Args &&...args)
|
||||||
|
: returnType(retType), methodName(name),
|
||||||
|
parameters(OpMethodParameters::create(std::forward<Args>(args)...)) {}
|
||||||
|
OpMethodSignature(OpMethodSignature &&) = default;
|
||||||
|
|
||||||
|
// Returns if a method with this signature makes a method with `other`
|
||||||
|
// signature redundant. Only supports resolved parameters.
|
||||||
|
bool makesRedundant(const OpMethodSignature &other) const;
|
||||||
|
|
||||||
|
// Returns the number of parameters (for resolved parameters).
|
||||||
|
size_t getNumParameters() const {
|
||||||
|
return cast<OpMethodResolvedParameters>(parameters.get())
|
||||||
|
->getNumParameters();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the name of the method.
|
||||||
|
StringRef getName() const { return methodName; }
|
||||||
|
|
||||||
|
// Writes the signature as a method declaration to the given `os`.
|
||||||
|
void writeDeclTo(raw_ostream &os) const;
|
||||||
|
|
||||||
|
// Writes the signature as the start of a method definition to the given `os`.
|
||||||
|
// `namePrefix` is the prefix to be prepended to the method name (typically
|
||||||
|
// namespaces for qualifying the method definition).
|
||||||
|
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string returnType;
|
||||||
|
std::string methodName;
|
||||||
|
std::unique_ptr<OpMethodParameters> parameters;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class for holding the body of an op's method for C++ code emission
|
||||||
|
class OpMethodBody {
|
||||||
|
public:
|
||||||
|
explicit OpMethodBody(bool declOnly);
|
||||||
|
|
||||||
|
OpMethodBody &operator<<(Twine content);
|
||||||
|
OpMethodBody &operator<<(int content);
|
||||||
|
OpMethodBody &operator<<(const FmtObjectBase &content);
|
||||||
|
|
||||||
|
void writeTo(raw_ostream &os) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Whether this class should record method body.
|
||||||
|
bool isEffective;
|
||||||
|
std::string body;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class for holding an op's method for C++ code emission
|
||||||
|
class OpMethod {
|
||||||
|
public:
|
||||||
|
// Properties (qualifiers) of class methods. Bitfield is used here to help
|
||||||
|
// querying properties.
|
||||||
|
enum Property {
|
||||||
|
MP_None = 0x0,
|
||||||
|
MP_Static = 0x1,
|
||||||
|
MP_Constructor = 0x2,
|
||||||
|
MP_Private = 0x4,
|
||||||
|
MP_Declaration = 0x8,
|
||||||
|
MP_Inline = 0x10,
|
||||||
|
MP_Constexpr = 0x20 | MP_Inline,
|
||||||
|
MP_StaticDeclaration = MP_Static | MP_Declaration,
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
OpMethod(StringRef retType, StringRef name, Property property, unsigned id,
|
||||||
|
Args &&...args)
|
||||||
|
: properties(property),
|
||||||
|
methodSignature(retType, name, std::forward<Args>(args)...),
|
||||||
|
methodBody(properties & MP_Declaration), id(id) {}
|
||||||
|
|
||||||
|
OpMethod(OpMethod &&) = default;
|
||||||
|
|
||||||
|
virtual ~OpMethod() = default;
|
||||||
|
|
||||||
|
OpMethodBody &body() { return methodBody; }
|
||||||
|
|
||||||
|
// Returns true if this is a static method.
|
||||||
|
bool isStatic() const { return properties & MP_Static; }
|
||||||
|
|
||||||
|
// Returns true if this is a private method.
|
||||||
|
bool isPrivate() const { return properties & MP_Private; }
|
||||||
|
|
||||||
|
// Returns true if this is an inline method.
|
||||||
|
bool isInline() const { return properties & MP_Inline; }
|
||||||
|
|
||||||
|
// Returns the name of this method.
|
||||||
|
StringRef getName() const { return methodSignature.getName(); }
|
||||||
|
|
||||||
|
// Returns the ID for this method
|
||||||
|
unsigned getID() const { return id; }
|
||||||
|
|
||||||
|
// Returns if this method makes the `other` method redundant.
|
||||||
|
bool makesRedundant(const OpMethod &other) const {
|
||||||
|
return methodSignature.makesRedundant(other.methodSignature);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writes the method as a declaration to the given `os`.
|
||||||
|
virtual void writeDeclTo(raw_ostream &os) const;
|
||||||
|
|
||||||
|
// Writes the method as a definition to the given `os`. `namePrefix` is the
|
||||||
|
// prefix to be prepended to the method name (typically namespaces for
|
||||||
|
// qualifying the method definition).
|
||||||
|
virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Property properties;
|
||||||
|
OpMethodSignature methodSignature;
|
||||||
|
OpMethodBody methodBody;
|
||||||
|
const unsigned id;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class for holding an op's constructor method for C++ code emission.
|
||||||
|
class OpConstructor : public OpMethod {
|
||||||
|
public:
|
||||||
|
template <typename... Args>
|
||||||
|
OpConstructor(StringRef className, Property property, unsigned id,
|
||||||
|
Args &&...args)
|
||||||
|
: OpMethod("", className, property, id, std::forward<Args>(args)...) {}
|
||||||
|
|
||||||
|
// Add member initializer to constructor initializing `name` with `value`.
|
||||||
|
void addMemberInitializer(StringRef name, StringRef value);
|
||||||
|
|
||||||
|
// Writes the method as a definition to the given `os`. `namePrefix` is the
|
||||||
|
// prefix to be prepended to the method name (typically namespaces for
|
||||||
|
// qualifying the method definition).
|
||||||
|
void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Member initializers.
|
||||||
|
std::string memberInitializers;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A class used to emit C++ classes from Tablegen. Contains a list of public
|
||||||
|
// methods and a list of private fields to be emitted.
|
||||||
|
class Class {
|
||||||
|
public:
|
||||||
|
explicit Class(StringRef name);
|
||||||
|
|
||||||
|
// Adds a new method to this class and prune redundant methods. Returns null
|
||||||
|
// if the method was not added (because an existing method would make it
|
||||||
|
// redundant), else returns a pointer to the added method. Note that this call
|
||||||
|
// may also delete existing methods that are made redundant by a method to the
|
||||||
|
// class.
|
||||||
|
template <typename... Args>
|
||||||
|
OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
|
||||||
|
OpMethod::Property properties, Args &&...args) {
|
||||||
|
auto newMethod = std::make_unique<OpMethod>(
|
||||||
|
retType, name, properties, nextMethodID++, std::forward<Args>(args)...);
|
||||||
|
return addMethodAndPrune(methods, std::move(newMethod));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
|
||||||
|
Args &&...args) {
|
||||||
|
return addMethodAndPrune(retType, name, OpMethod::MP_None,
|
||||||
|
std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
OpConstructor *addConstructorAndPrune(Args &&...args) {
|
||||||
|
auto newConstructor = std::make_unique<OpConstructor>(
|
||||||
|
getClassName(), OpMethod::MP_Constructor, nextMethodID++,
|
||||||
|
std::forward<Args>(args)...);
|
||||||
|
return addMethodAndPrune(constructors, std::move(newConstructor));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new field in this class.
|
||||||
|
void newField(StringRef type, StringRef name, StringRef defaultValue = "");
|
||||||
|
|
||||||
|
// Writes this op's class as a declaration to the given `os`.
|
||||||
|
void writeDeclTo(raw_ostream &os) const;
|
||||||
|
// Writes the method definitions in this op's class to the given `os`.
|
||||||
|
void writeDefTo(raw_ostream &os) const;
|
||||||
|
|
||||||
|
// Returns the C++ class name of the op.
|
||||||
|
StringRef getClassName() const { return className; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Get a list of all the methods to emit, filtering out hidden ones.
|
||||||
|
void forAllMethods(llvm::function_ref<void(const OpMethod &)> func) const {
|
||||||
|
using ConsRef = const std::unique_ptr<OpConstructor> &;
|
||||||
|
using MethodRef = const std::unique_ptr<OpMethod> &;
|
||||||
|
llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); });
|
||||||
|
llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); });
|
||||||
|
}
|
||||||
|
|
||||||
|
// For deterministic code generation, keep methods sorted in the order in
|
||||||
|
// which they were generated.
|
||||||
|
template <typename MethodTy>
|
||||||
|
struct MethodCompare {
|
||||||
|
bool operator()(const std::unique_ptr<MethodTy> &x,
|
||||||
|
const std::unique_ptr<MethodTy> &y) const {
|
||||||
|
return x->getID() < y->getID();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename MethodTy>
|
||||||
|
using MethodSet =
|
||||||
|
std::set<std::unique_ptr<MethodTy>, MethodCompare<MethodTy>>;
|
||||||
|
|
||||||
|
template <typename MethodTy>
|
||||||
|
MethodTy *addMethodAndPrune(MethodSet<MethodTy> &set,
|
||||||
|
std::unique_ptr<MethodTy> &&newMethod) {
|
||||||
|
// Check if the new method will be made redundant by existing methods.
|
||||||
|
for (auto &method : set)
|
||||||
|
if (method->makesRedundant(*newMethod))
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// We can add this a method to the set. Prune any existing methods that will
|
||||||
|
// be made redundant by adding this new method. Note that the redundant
|
||||||
|
// check between two methods is more than a conflict check. makesRedundant()
|
||||||
|
// below will check if the new method conflicts with an existing method and
|
||||||
|
// if so, returns true if the new method makes the existing method redundant
|
||||||
|
// because all calls to the existing method can be subsumed by the new
|
||||||
|
// method. So makesRedundant() does a combined job of finding conflicts and
|
||||||
|
// deciding which of the 2 conflicting methods survive.
|
||||||
|
//
|
||||||
|
// Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing
|
||||||
|
// it manually here.
|
||||||
|
for (auto it = set.begin(), end = set.end(); it != end;) {
|
||||||
|
if (newMethod->makesRedundant(*(it->get())))
|
||||||
|
it = set.erase(it);
|
||||||
|
else
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
|
||||||
|
MethodTy *ret = newMethod.get();
|
||||||
|
set.insert(std::move(newMethod));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string className;
|
||||||
|
MethodSet<OpConstructor> constructors;
|
||||||
|
MethodSet<OpMethod> methods;
|
||||||
|
unsigned nextMethodID = 0;
|
||||||
|
SmallVector<std::string, 4> fields;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class for holding an op for C++ code emission
|
||||||
|
class OpClass : public Class {
|
||||||
|
public:
|
||||||
|
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
|
||||||
|
|
||||||
|
// Adds an op trait.
|
||||||
|
void addTrait(Twine trait);
|
||||||
|
|
||||||
|
// Writes this op's class as a declaration to the given `os`. Redefines
|
||||||
|
// Class::writeDeclTo to also emit traits and extra class declarations.
|
||||||
|
void writeDeclTo(raw_ostream &os) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
StringRef extraClassDeclaration;
|
||||||
|
SmallVector<std::string, 4> traitsVec;
|
||||||
|
StringSet<> traitsSet;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tblgen
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_OPCLASS_H_
|
|
@ -0,0 +1,592 @@
|
||||||
|
//===- Operator.cpp - Operator class --------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Operator.h"
|
||||||
|
#include "Predicate.h"
|
||||||
|
#include "Trait.h"
|
||||||
|
#include "Type.h"
|
||||||
|
#include "llvm/ADT/EquivalenceClasses.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/TableGen/Error.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "mlir-tblgen-operator"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
using llvm::DagInit;
|
||||||
|
using llvm::DefInit;
|
||||||
|
using llvm::Record;
|
||||||
|
|
||||||
|
Operator::Operator(const llvm::Record &def)
|
||||||
|
: dialect(def.getValueAsDef("opDialect")), def(def) {
|
||||||
|
// The first `_` in the op's TableGen def name is treated as separating the
|
||||||
|
// dialect prefix and the op class name. The dialect prefix will be ignored if
|
||||||
|
// not empty. Otherwise, if def name starts with a `_`, the `_` is considered
|
||||||
|
// as part of the class name.
|
||||||
|
StringRef prefix;
|
||||||
|
std::tie(prefix, cppClassName) = def.getName().split('_');
|
||||||
|
if (prefix.empty()) {
|
||||||
|
// Class name with a leading underscore and without dialect prefix
|
||||||
|
cppClassName = def.getName();
|
||||||
|
} else if (cppClassName.empty()) {
|
||||||
|
// Class name without dialect prefix
|
||||||
|
cppClassName = prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
cppNamespace = def.getValueAsString("cppNamespace");
|
||||||
|
|
||||||
|
populateOpStructure();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Operator::getOperationName() const {
|
||||||
|
auto prefix = dialect.getName();
|
||||||
|
auto opName = def.getValueAsString("opName");
|
||||||
|
if (prefix.empty())
|
||||||
|
return std::string(opName);
|
||||||
|
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Operator::getAdaptorName() const {
|
||||||
|
return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Operator::getDialectName() const { return dialect.getName(); }
|
||||||
|
|
||||||
|
StringRef Operator::getCppClassName() const { return cppClassName; }
|
||||||
|
|
||||||
|
std::string Operator::getQualCppClassName() const {
|
||||||
|
if (cppNamespace.empty())
|
||||||
|
return std::string(cppClassName);
|
||||||
|
return std::string(llvm::formatv("{0}::{1}", cppNamespace, cppClassName));
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Operator::getCppNamespace() const { return cppNamespace; }
|
||||||
|
|
||||||
|
int Operator::getNumResults() const {
|
||||||
|
DagInit *results = def.getValueAsDag("results");
|
||||||
|
return results->getNumArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Operator::getExtraClassDeclaration() const {
|
||||||
|
constexpr auto attr = "extraClassDeclaration";
|
||||||
|
if (def.isValueUnset(attr))
|
||||||
|
return {};
|
||||||
|
return def.getValueAsString(attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
const llvm::Record &Operator::getDef() const { return def; }
|
||||||
|
|
||||||
|
bool Operator::skipDefaultBuilders() const {
|
||||||
|
return def.getValueAsBit("skipDefaultBuilders");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::result_begin() -> value_iterator { return results.begin(); }
|
||||||
|
|
||||||
|
auto Operator::result_end() -> value_iterator { return results.end(); }
|
||||||
|
|
||||||
|
auto Operator::getResults() -> value_range {
|
||||||
|
return {result_begin(), result_end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeConstraint Operator::getResultTypeConstraint(int index) const {
|
||||||
|
DagInit *results = def.getValueAsDag("results");
|
||||||
|
return TypeConstraint(cast<DefInit>(results->getArg(index)));
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Operator::getResultName(int index) const {
|
||||||
|
DagInit *results = def.getValueAsDag("results");
|
||||||
|
return results->getArgNameStr(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::getResultDecorators(int index) const -> var_decorator_range {
|
||||||
|
Record *result =
|
||||||
|
cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
|
||||||
|
if (!result->isSubClassOf("OpVariable"))
|
||||||
|
return var_decorator_range(nullptr, nullptr);
|
||||||
|
return *result->getValueAsListInit("decorators");
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Operator::getNumVariableLengthResults() const {
|
||||||
|
return llvm::count_if(results, [](const NamedTypeConstraint &c) {
|
||||||
|
return c.constraint.isVariableLength();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Operator::getNumVariableLengthOperands() const {
|
||||||
|
return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
|
||||||
|
return c.constraint.isVariableLength();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Operator::hasSingleVariadicArg() const {
|
||||||
|
return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() &&
|
||||||
|
getOperand(0).isVariadic();
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
|
||||||
|
|
||||||
|
Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
|
||||||
|
|
||||||
|
Operator::arg_range Operator::getArgs() const {
|
||||||
|
return {arg_begin(), arg_end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Operator::getArgName(int index) const {
|
||||||
|
DagInit *argumentValues = def.getValueAsDag("arguments");
|
||||||
|
return argumentValues->getArgNameStr(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::getArgDecorators(int index) const -> var_decorator_range {
|
||||||
|
Record *arg =
|
||||||
|
cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
|
||||||
|
if (!arg->isSubClassOf("OpVariable"))
|
||||||
|
return var_decorator_range(nullptr, nullptr);
|
||||||
|
return *arg->getValueAsListInit("decorators");
|
||||||
|
}
|
||||||
|
|
||||||
|
const Trait *Operator::getTrait(StringRef trait) const {
|
||||||
|
for (const auto &t : traits) {
|
||||||
|
if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
|
||||||
|
if (traitDef->getFullyQualifiedTraitName() == trait)
|
||||||
|
return traitDef;
|
||||||
|
} else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
|
||||||
|
if (traitDef->getFullyQualifiedTraitName() == trait)
|
||||||
|
return traitDef;
|
||||||
|
} else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
|
||||||
|
if (traitDef->getFullyQualifiedTraitName() == trait)
|
||||||
|
return traitDef;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::region_begin() const -> const_region_iterator {
|
||||||
|
return regions.begin();
|
||||||
|
}
|
||||||
|
auto Operator::region_end() const -> const_region_iterator {
|
||||||
|
return regions.end();
|
||||||
|
}
|
||||||
|
auto Operator::getRegions() const
|
||||||
|
-> llvm::iterator_range<const_region_iterator> {
|
||||||
|
return {region_begin(), region_end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Operator::getNumRegions() const { return regions.size(); }
|
||||||
|
|
||||||
|
const NamedRegion &Operator::getRegion(unsigned index) const {
|
||||||
|
return regions[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Operator::getNumVariadicRegions() const {
|
||||||
|
return llvm::count_if(regions,
|
||||||
|
[](const NamedRegion &c) { return c.isVariadic(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::successor_begin() const -> const_successor_iterator {
|
||||||
|
return successors.begin();
|
||||||
|
}
|
||||||
|
auto Operator::successor_end() const -> const_successor_iterator {
|
||||||
|
return successors.end();
|
||||||
|
}
|
||||||
|
auto Operator::getSuccessors() const
|
||||||
|
-> llvm::iterator_range<const_successor_iterator> {
|
||||||
|
return {successor_begin(), successor_end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Operator::getNumSuccessors() const { return successors.size(); }
|
||||||
|
|
||||||
|
const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
|
||||||
|
return successors[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned Operator::getNumVariadicSuccessors() const {
|
||||||
|
return llvm::count_if(successors,
|
||||||
|
[](const NamedSuccessor &c) { return c.isVariadic(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::trait_begin() const -> const_trait_iterator {
|
||||||
|
return traits.begin();
|
||||||
|
}
|
||||||
|
auto Operator::trait_end() const -> const_trait_iterator {
|
||||||
|
return traits.end();
|
||||||
|
}
|
||||||
|
auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
|
||||||
|
return {trait_begin(), trait_end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::attribute_begin() const -> attribute_iterator {
|
||||||
|
return attributes.begin();
|
||||||
|
}
|
||||||
|
auto Operator::attribute_end() const -> attribute_iterator {
|
||||||
|
return attributes.end();
|
||||||
|
}
|
||||||
|
auto Operator::getAttributes() const
|
||||||
|
-> llvm::iterator_range<attribute_iterator> {
|
||||||
|
return {attribute_begin(), attribute_end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::operand_begin() -> value_iterator { return operands.begin(); }
|
||||||
|
auto Operator::operand_end() -> value_iterator { return operands.end(); }
|
||||||
|
auto Operator::getOperands() -> value_range {
|
||||||
|
return {operand_begin(), operand_end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
|
||||||
|
|
||||||
|
// Mapping from result index to combined argument and result index. Arguments
|
||||||
|
// are indexed to match getArg index, while the result indexes are mapped to
|
||||||
|
// avoid overlap.
|
||||||
|
static int resultIndex(int i) { return -1 - i; }
|
||||||
|
|
||||||
|
bool Operator::isVariadic() const {
|
||||||
|
return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
|
||||||
|
[](const NamedTypeConstraint &op) { return op.isVariadic(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Operator::populateTypeInferenceInfo(
|
||||||
|
const llvm::StringMap<int> &argumentsAndResultsIndex) {
|
||||||
|
// If the type inference op interface is not registered, then do not attempt
|
||||||
|
// to determine if the result types an be inferred.
|
||||||
|
auto &recordKeeper = def.getRecords();
|
||||||
|
auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface);
|
||||||
|
allResultsHaveKnownTypes = false;
|
||||||
|
if (!inferTrait)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// If there are no results, the skip this else the build method generated
|
||||||
|
// overlaps with another autogenerated builder.
|
||||||
|
if (getNumResults() == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Skip for ops with variadic operands/results.
|
||||||
|
// TODO: This can be relaxed.
|
||||||
|
if (isVariadic())
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Skip cases currently being custom generated.
|
||||||
|
// TODO: Remove special cases.
|
||||||
|
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType"))
|
||||||
|
return;
|
||||||
|
|
||||||
|
// We create equivalence classes of argument/result types where arguments
|
||||||
|
// and results are mapped into the same index space and indices corresponding
|
||||||
|
// to the same type are in the same equivalence class.
|
||||||
|
llvm::EquivalenceClasses<int> ecs;
|
||||||
|
resultTypeMapping.resize(getNumResults());
|
||||||
|
// Captures the argument whose type matches a given result type. Preference
|
||||||
|
// towards capturing operands first before attributes.
|
||||||
|
auto captureMapping = [&](int i) {
|
||||||
|
bool found = false;
|
||||||
|
ecs.insert(resultIndex(i));
|
||||||
|
auto mi = ecs.findLeader(resultIndex(i));
|
||||||
|
for (auto me = ecs.member_end(); mi != me; ++mi) {
|
||||||
|
if (*mi < 0) {
|
||||||
|
auto tc = getResultTypeConstraint(i);
|
||||||
|
if (tc.getBuilderCall().hasValue()) {
|
||||||
|
resultTypeMapping[i].emplace_back(tc);
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getArg(*mi).is<NamedAttribute *>()) {
|
||||||
|
// TODO: Handle attributes.
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
resultTypeMapping[i].emplace_back(*mi);
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return found;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const Trait &trait : traits) {
|
||||||
|
const llvm::Record &def = trait.getDef();
|
||||||
|
// If the infer type op interface was manually added, then treat it as
|
||||||
|
// intention that the op needs special handling.
|
||||||
|
// TODO: Reconsider whether to always generate, this is more conservative
|
||||||
|
// and keeps existing behavior so starting that way for now.
|
||||||
|
if (def.isSubClassOf(
|
||||||
|
llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
|
||||||
|
return;
|
||||||
|
if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
|
||||||
|
if (&traitDef->getDef() == inferTrait)
|
||||||
|
return;
|
||||||
|
|
||||||
|
if (!def.isSubClassOf("AllTypesMatch"))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto values = def.getValueAsListOfStrings("values");
|
||||||
|
auto root = argumentsAndResultsIndex.lookup(values.front());
|
||||||
|
for (StringRef str : values)
|
||||||
|
ecs.unionSets(argumentsAndResultsIndex.lookup(str), root);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verifies that all output types have a corresponding known input type
|
||||||
|
// and chooses matching operand or attribute (in that order) that
|
||||||
|
// matches it.
|
||||||
|
allResultsHaveKnownTypes =
|
||||||
|
all_of(llvm::seq<int>(0, getNumResults()), captureMapping);
|
||||||
|
|
||||||
|
// If the types could be computed, then add type inference trait.
|
||||||
|
if (allResultsHaveKnownTypes)
|
||||||
|
traits.push_back(Trait::create(inferTrait->getDefInit()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Operator::populateOpStructure() {
|
||||||
|
auto &recordKeeper = def.getRecords();
|
||||||
|
auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
|
||||||
|
auto *attrClass = recordKeeper.getClass("Attr");
|
||||||
|
auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr");
|
||||||
|
auto *opVarClass = recordKeeper.getClass("OpVariable");
|
||||||
|
numNativeAttributes = 0;
|
||||||
|
|
||||||
|
DagInit *argumentValues = def.getValueAsDag("arguments");
|
||||||
|
unsigned numArgs = argumentValues->getNumArgs();
|
||||||
|
|
||||||
|
// Mapping from name of to argument or result index. Arguments are indexed
|
||||||
|
// to match getArg index, while the results are negatively indexed.
|
||||||
|
llvm::StringMap<int> argumentsAndResultsIndex;
|
||||||
|
|
||||||
|
// Handle operands and native attributes.
|
||||||
|
for (unsigned i = 0; i != numArgs; ++i) {
|
||||||
|
auto *arg = argumentValues->getArg(i);
|
||||||
|
auto givenName = argumentValues->getArgNameStr(i);
|
||||||
|
auto *argDefInit = dyn_cast<DefInit>(arg);
|
||||||
|
if (!argDefInit)
|
||||||
|
PrintFatalError(def.getLoc(),
|
||||||
|
Twine("undefined type for argument #") + Twine(i));
|
||||||
|
Record *argDef = argDefInit->getDef();
|
||||||
|
if (argDef->isSubClassOf(opVarClass))
|
||||||
|
argDef = argDef->getValueAsDef("constraint");
|
||||||
|
|
||||||
|
if (argDef->isSubClassOf(typeConstraintClass)) {
|
||||||
|
operands.push_back(
|
||||||
|
NamedTypeConstraint{givenName, TypeConstraint(argDef)});
|
||||||
|
} else if (argDef->isSubClassOf(attrClass)) {
|
||||||
|
if (givenName.empty())
|
||||||
|
PrintFatalError(argDef->getLoc(), "attributes must be named");
|
||||||
|
if (argDef->isSubClassOf(derivedAttrClass))
|
||||||
|
PrintFatalError(argDef->getLoc(),
|
||||||
|
"derived attributes not allowed in argument list");
|
||||||
|
attributes.push_back({givenName, Attribute(argDef)});
|
||||||
|
++numNativeAttributes;
|
||||||
|
} else {
|
||||||
|
PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
|
||||||
|
"from TypeConstraint or Attr are allowed");
|
||||||
|
}
|
||||||
|
if (!givenName.empty())
|
||||||
|
argumentsAndResultsIndex[givenName] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle derived attributes.
|
||||||
|
for (const auto &val : def.getValues()) {
|
||||||
|
if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
|
||||||
|
if (!record->isSubClassOf(attrClass))
|
||||||
|
continue;
|
||||||
|
if (!record->isSubClassOf(derivedAttrClass))
|
||||||
|
PrintFatalError(def.getLoc(),
|
||||||
|
"unexpected Attr where only DerivedAttr is allowed");
|
||||||
|
|
||||||
|
if (record->getClasses().size() != 1) {
|
||||||
|
PrintFatalError(
|
||||||
|
def.getLoc(),
|
||||||
|
"unsupported attribute modelling, only single class expected");
|
||||||
|
}
|
||||||
|
attributes.push_back(
|
||||||
|
{cast<llvm::StringInit>(val.getNameInit())->getValue(),
|
||||||
|
Attribute(cast<DefInit>(val.getValue()))});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate `arguments`. This must happen after we've finalized `operands` and
|
||||||
|
// `attributes` because we will put their elements' pointers in `arguments`.
|
||||||
|
// SmallVector may perform re-allocation under the hood when adding new
|
||||||
|
// elements.
|
||||||
|
int operandIndex = 0, attrIndex = 0;
|
||||||
|
for (unsigned i = 0; i != numArgs; ++i) {
|
||||||
|
Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
|
||||||
|
if (argDef->isSubClassOf(opVarClass))
|
||||||
|
argDef = argDef->getValueAsDef("constraint");
|
||||||
|
|
||||||
|
if (argDef->isSubClassOf(typeConstraintClass)) {
|
||||||
|
attrOrOperandMapping.push_back(
|
||||||
|
{OperandOrAttribute::Kind::Operand, operandIndex});
|
||||||
|
arguments.emplace_back(&operands[operandIndex++]);
|
||||||
|
} else {
|
||||||
|
assert(argDef->isSubClassOf(attrClass));
|
||||||
|
attrOrOperandMapping.push_back(
|
||||||
|
{OperandOrAttribute::Kind::Attribute, attrIndex});
|
||||||
|
arguments.emplace_back(&attributes[attrIndex++]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto *resultsDag = def.getValueAsDag("results");
|
||||||
|
auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
|
||||||
|
if (!outsOp || outsOp->getDef()->getName() != "outs") {
|
||||||
|
PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle results.
|
||||||
|
for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
|
||||||
|
auto name = resultsDag->getArgNameStr(i);
|
||||||
|
auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
|
||||||
|
if (!resultInit) {
|
||||||
|
PrintFatalError(def.getLoc(),
|
||||||
|
Twine("undefined type for result #") + Twine(i));
|
||||||
|
}
|
||||||
|
auto *resultDef = resultInit->getDef();
|
||||||
|
if (resultDef->isSubClassOf(opVarClass))
|
||||||
|
resultDef = resultDef->getValueAsDef("constraint");
|
||||||
|
results.push_back({name, TypeConstraint(resultDef)});
|
||||||
|
if (!name.empty())
|
||||||
|
argumentsAndResultsIndex[name] = resultIndex(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle successors
|
||||||
|
auto *successorsDag = def.getValueAsDag("successors");
|
||||||
|
auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
|
||||||
|
if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
|
||||||
|
PrintFatalError(def.getLoc(),
|
||||||
|
"'successors' must have 'successor' directive");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
|
||||||
|
auto name = successorsDag->getArgNameStr(i);
|
||||||
|
auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
|
||||||
|
if (!successorInit) {
|
||||||
|
PrintFatalError(def.getLoc(),
|
||||||
|
Twine("undefined kind for successor #") + Twine(i));
|
||||||
|
}
|
||||||
|
Successor successor(successorInit->getDef());
|
||||||
|
|
||||||
|
// Only support variadic successors if it is the last one for now.
|
||||||
|
if (i != e - 1 && successor.isVariadic())
|
||||||
|
PrintFatalError(def.getLoc(), "only the last successor can be variadic");
|
||||||
|
successors.push_back({name, successor});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create list of traits, skipping over duplicates: appending to lists in
|
||||||
|
// tablegen is easy, making them unique less so, so dedupe here.
|
||||||
|
if (auto *traitList = def.getValueAsListInit("traits")) {
|
||||||
|
// This is uniquing based on pointers of the trait.
|
||||||
|
SmallPtrSet<const llvm::Init *, 32> traitSet;
|
||||||
|
traits.reserve(traitSet.size());
|
||||||
|
for (auto *traitInit : *traitList) {
|
||||||
|
// Keep traits in the same order while skipping over duplicates.
|
||||||
|
if (traitSet.insert(traitInit).second)
|
||||||
|
traits.push_back(Trait::create(traitInit));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
populateTypeInferenceInfo(argumentsAndResultsIndex);
|
||||||
|
|
||||||
|
// Handle regions
|
||||||
|
auto *regionsDag = def.getValueAsDag("regions");
|
||||||
|
auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
|
||||||
|
if (!regionsOp || regionsOp->getDef()->getName() != "region") {
|
||||||
|
PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
|
||||||
|
auto name = regionsDag->getArgNameStr(i);
|
||||||
|
auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
|
||||||
|
if (!regionInit) {
|
||||||
|
PrintFatalError(def.getLoc(),
|
||||||
|
Twine("undefined kind for region #") + Twine(i));
|
||||||
|
}
|
||||||
|
Region region(regionInit->getDef());
|
||||||
|
if (region.isVariadic()) {
|
||||||
|
// Only support variadic regions if it is the last one for now.
|
||||||
|
if (i != e - 1)
|
||||||
|
PrintFatalError(def.getLoc(), "only the last region can be variadic");
|
||||||
|
if (name.empty())
|
||||||
|
PrintFatalError(def.getLoc(), "variadic regions must be named");
|
||||||
|
}
|
||||||
|
|
||||||
|
regions.push_back({name, region});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the builders.
|
||||||
|
auto *builderList =
|
||||||
|
dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
|
||||||
|
if (builderList && !builderList->empty()) {
|
||||||
|
for (llvm::Init *init : builderList->getValues())
|
||||||
|
builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
|
||||||
|
} else if (skipDefaultBuilders()) {
|
||||||
|
PrintFatalError(
|
||||||
|
def.getLoc(),
|
||||||
|
"default builders are skipped and no custom builders provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
LLVM_DEBUG(print(llvm::dbgs()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> {
|
||||||
|
assert(allResultTypesKnown());
|
||||||
|
return resultTypeMapping[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<llvm::SMLoc> Operator::getLoc() const { return def.getLoc(); }
|
||||||
|
|
||||||
|
bool Operator::hasDescription() const {
|
||||||
|
return def.getValue("description") != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Operator::getDescription() const {
|
||||||
|
return def.getValueAsString("description");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; }
|
||||||
|
|
||||||
|
StringRef Operator::getSummary() const {
|
||||||
|
return def.getValueAsString("summary");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Operator::hasAssemblyFormat() const {
|
||||||
|
auto *valueInit = def.getValueInit("assemblyFormat");
|
||||||
|
return isa<llvm::StringInit>(valueInit);
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Operator::getAssemblyFormat() const {
|
||||||
|
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
|
||||||
|
.Case<llvm::StringInit>(
|
||||||
|
[&](auto *init) { return init->getValue(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Operator::print(llvm::raw_ostream &os) const {
|
||||||
|
os << "op '" << getOperationName() << "'\n";
|
||||||
|
for (Argument arg : arguments) {
|
||||||
|
if (auto *attr = arg.dyn_cast<NamedAttribute *>())
|
||||||
|
os << "[attribute] " << attr->name << '\n';
|
||||||
|
else
|
||||||
|
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
|
||||||
|
-> VariableDecorator {
|
||||||
|
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto Operator::getArgToOperandOrAttribute(int index) const
|
||||||
|
-> OperandOrAttribute {
|
||||||
|
return attrOrOperandMapping[index];
|
||||||
|
}
|
|
@ -0,0 +1,360 @@
|
||||||
|
//===- Operator.h - Operator class ------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_OPERATOR_H_
|
||||||
|
#define MLIR_TABLEGEN_OPERATOR_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "Argument.h"
|
||||||
|
#include "Attribute.h"
|
||||||
|
#include "Builder.h"
|
||||||
|
#include "Dialect.h"
|
||||||
|
#include "Region.h"
|
||||||
|
#include "Successor.h"
|
||||||
|
#include "Trait.h"
|
||||||
|
#include "Type.h"
|
||||||
|
#include "llvm/ADT/PointerUnion.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringMap.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/Support/SMLoc.h"
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class DefInit;
|
||||||
|
class Record;
|
||||||
|
class StringInit;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// Wrapper class that contains a MLIR op's information (e.g., operands,
|
||||||
|
// attributes) defined in TableGen and provides helper methods for
|
||||||
|
// accessing them.
|
||||||
|
class Operator {
|
||||||
|
public:
|
||||||
|
explicit Operator(const llvm::Record &def);
|
||||||
|
explicit Operator(const llvm::Record *def) : Operator(*def) {}
|
||||||
|
|
||||||
|
// Returns this op's dialect name.
|
||||||
|
StringRef getDialectName() const;
|
||||||
|
|
||||||
|
// Returns the operation name. The name will follow the "<dialect>.<op-name>"
|
||||||
|
// format if its dialect name is not empty.
|
||||||
|
std::string getOperationName() const;
|
||||||
|
|
||||||
|
// Returns this op's C++ class name.
|
||||||
|
StringRef getCppClassName() const;
|
||||||
|
|
||||||
|
// Returns this op's C++ class name prefixed with namespaces.
|
||||||
|
std::string getQualCppClassName() const;
|
||||||
|
|
||||||
|
// Returns this op's C++ namespace.
|
||||||
|
StringRef getCppNamespace() const;
|
||||||
|
|
||||||
|
// Returns the name of op's adaptor C++ class.
|
||||||
|
std::string getAdaptorName() const;
|
||||||
|
|
||||||
|
/// A class used to represent the decorators of an operator variable, i.e.
|
||||||
|
/// argument or result.
|
||||||
|
struct VariableDecorator {
|
||||||
|
public:
|
||||||
|
explicit VariableDecorator(const llvm::Record *def) : def(def) {}
|
||||||
|
const llvm::Record &getDef() const { return *def; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// The TableGen definition of this decorator.
|
||||||
|
const llvm::Record *def;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A utility iterator over a list of variable decorators.
|
||||||
|
struct VariableDecoratorIterator
|
||||||
|
: public llvm::mapped_iterator<llvm::Init *const *,
|
||||||
|
VariableDecorator (*)(llvm::Init *)> {
|
||||||
|
using reference = VariableDecorator;
|
||||||
|
|
||||||
|
/// Initializes the iterator to the specified iterator.
|
||||||
|
VariableDecoratorIterator(llvm::Init *const *it)
|
||||||
|
: llvm::mapped_iterator<llvm::Init *const *,
|
||||||
|
VariableDecorator (*)(llvm::Init *)>(it,
|
||||||
|
&unwrap) {}
|
||||||
|
static VariableDecorator unwrap(llvm::Init *init);
|
||||||
|
};
|
||||||
|
using var_decorator_iterator = VariableDecoratorIterator;
|
||||||
|
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
|
||||||
|
|
||||||
|
using value_iterator = NamedTypeConstraint *;
|
||||||
|
using value_range = llvm::iterator_range<value_iterator>;
|
||||||
|
|
||||||
|
// Returns true if this op has variable length operands or results.
|
||||||
|
bool isVariadic() const;
|
||||||
|
|
||||||
|
// Returns true if default builders should not be generated.
|
||||||
|
bool skipDefaultBuilders() const;
|
||||||
|
|
||||||
|
// Op result iterators.
|
||||||
|
value_iterator result_begin();
|
||||||
|
value_iterator result_end();
|
||||||
|
value_range getResults();
|
||||||
|
|
||||||
|
// Returns the number of results this op produces.
|
||||||
|
int getNumResults() const;
|
||||||
|
|
||||||
|
// Returns the op result at the given `index`.
|
||||||
|
NamedTypeConstraint &getResult(int index) { return results[index]; }
|
||||||
|
const NamedTypeConstraint &getResult(int index) const {
|
||||||
|
return results[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the `index`-th result's type constraint.
|
||||||
|
TypeConstraint getResultTypeConstraint(int index) const;
|
||||||
|
// Returns the `index`-th result's name.
|
||||||
|
StringRef getResultName(int index) const;
|
||||||
|
// Returns the `index`-th result's decorators.
|
||||||
|
var_decorator_range getResultDecorators(int index) const;
|
||||||
|
|
||||||
|
// Returns the number of variable length results in this operation.
|
||||||
|
unsigned getNumVariableLengthResults() const;
|
||||||
|
|
||||||
|
// Op attribute iterators.
|
||||||
|
using attribute_iterator = const NamedAttribute *;
|
||||||
|
attribute_iterator attribute_begin() const;
|
||||||
|
attribute_iterator attribute_end() const;
|
||||||
|
llvm::iterator_range<attribute_iterator> getAttributes() const;
|
||||||
|
|
||||||
|
int getNumAttributes() const { return attributes.size(); }
|
||||||
|
int getNumNativeAttributes() const { return numNativeAttributes; }
|
||||||
|
|
||||||
|
// Op attribute accessors.
|
||||||
|
NamedAttribute &getAttribute(int index) { return attributes[index]; }
|
||||||
|
|
||||||
|
// Op operand iterators.
|
||||||
|
value_iterator operand_begin();
|
||||||
|
value_iterator operand_end();
|
||||||
|
value_range getOperands();
|
||||||
|
|
||||||
|
int getNumOperands() const { return operands.size(); }
|
||||||
|
NamedTypeConstraint &getOperand(int index) { return operands[index]; }
|
||||||
|
const NamedTypeConstraint &getOperand(int index) const {
|
||||||
|
return operands[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of variadic operands in this operation.
|
||||||
|
unsigned getNumVariableLengthOperands() const;
|
||||||
|
|
||||||
|
// Returns the total number of arguments.
|
||||||
|
int getNumArgs() const { return arguments.size(); }
|
||||||
|
|
||||||
|
// Returns true of the operation has a single variadic arg.
|
||||||
|
bool hasSingleVariadicArg() const;
|
||||||
|
|
||||||
|
// Returns true if the operation has a single variadic result.
|
||||||
|
bool hasSingleVariadicResult() const {
|
||||||
|
return getNumResults() == 1 && getResult(0).isVariadic();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true of the operation has no variadic regions.
|
||||||
|
bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
|
||||||
|
|
||||||
|
using arg_iterator = const Argument *;
|
||||||
|
using arg_range = llvm::iterator_range<arg_iterator>;
|
||||||
|
|
||||||
|
// Op argument (attribute or operand) iterators.
|
||||||
|
arg_iterator arg_begin() const;
|
||||||
|
arg_iterator arg_end() const;
|
||||||
|
arg_range getArgs() const;
|
||||||
|
|
||||||
|
// Op argument (attribute or operand) accessors.
|
||||||
|
Argument getArg(int index) const;
|
||||||
|
StringRef getArgName(int index) const;
|
||||||
|
var_decorator_range getArgDecorators(int index) const;
|
||||||
|
|
||||||
|
// Returns the trait wrapper for the given MLIR C++ `trait`.
|
||||||
|
const Trait *getTrait(llvm::StringRef trait) const;
|
||||||
|
|
||||||
|
// Regions.
|
||||||
|
using const_region_iterator = const NamedRegion *;
|
||||||
|
const_region_iterator region_begin() const;
|
||||||
|
const_region_iterator region_end() const;
|
||||||
|
llvm::iterator_range<const_region_iterator> getRegions() const;
|
||||||
|
|
||||||
|
// Returns the number of regions.
|
||||||
|
unsigned getNumRegions() const;
|
||||||
|
// Returns the `index`-th region.
|
||||||
|
const NamedRegion &getRegion(unsigned index) const;
|
||||||
|
|
||||||
|
// Returns the number of variadic regions in this operation.
|
||||||
|
unsigned getNumVariadicRegions() const;
|
||||||
|
|
||||||
|
// Successors.
|
||||||
|
using const_successor_iterator = const NamedSuccessor *;
|
||||||
|
const_successor_iterator successor_begin() const;
|
||||||
|
const_successor_iterator successor_end() const;
|
||||||
|
llvm::iterator_range<const_successor_iterator> getSuccessors() const;
|
||||||
|
|
||||||
|
// Returns the number of successors.
|
||||||
|
unsigned getNumSuccessors() const;
|
||||||
|
// Returns the `index`-th successor.
|
||||||
|
const NamedSuccessor &getSuccessor(unsigned index) const;
|
||||||
|
|
||||||
|
// Returns the number of variadic successors in this operation.
|
||||||
|
unsigned getNumVariadicSuccessors() const;
|
||||||
|
|
||||||
|
// Trait.
|
||||||
|
using const_trait_iterator = const Trait *;
|
||||||
|
const_trait_iterator trait_begin() const;
|
||||||
|
const_trait_iterator trait_end() const;
|
||||||
|
llvm::iterator_range<const_trait_iterator> getTraits() const;
|
||||||
|
|
||||||
|
ArrayRef<llvm::SMLoc> getLoc() const;
|
||||||
|
|
||||||
|
// Query functions for the documentation of the operator.
|
||||||
|
bool hasDescription() const;
|
||||||
|
StringRef getDescription() const;
|
||||||
|
bool hasSummary() const;
|
||||||
|
StringRef getSummary() const;
|
||||||
|
|
||||||
|
// Query functions for the assembly format of the operator.
|
||||||
|
bool hasAssemblyFormat() const;
|
||||||
|
StringRef getAssemblyFormat() const;
|
||||||
|
|
||||||
|
// Returns this op's extra class declaration code.
|
||||||
|
StringRef getExtraClassDeclaration() const;
|
||||||
|
|
||||||
|
// Returns the Tablegen definition this operator was constructed from.
|
||||||
|
// TODO: do not expose the TableGen record, this is a temporary solution to
|
||||||
|
// OpEmitter requiring a Record because Operator does not provide enough
|
||||||
|
// methods.
|
||||||
|
const llvm::Record &getDef() const;
|
||||||
|
|
||||||
|
// Returns the dialect of the op.
|
||||||
|
const Dialect &getDialect() const { return dialect; }
|
||||||
|
|
||||||
|
// Prints the contents in this operator to the given `os`. This is used for
|
||||||
|
// debugging purposes.
|
||||||
|
void print(llvm::raw_ostream &os) const;
|
||||||
|
|
||||||
|
// Return whether all the result types are known.
|
||||||
|
bool allResultTypesKnown() const { return allResultsHaveKnownTypes; };
|
||||||
|
|
||||||
|
// Pair representing either a index to an argument or a type constraint. Only
|
||||||
|
// one of these entries should have the non-default value.
|
||||||
|
struct ArgOrType {
|
||||||
|
explicit ArgOrType(int index) : index(index), constraint(None) {}
|
||||||
|
explicit ArgOrType(TypeConstraint constraint)
|
||||||
|
: index(None), constraint(constraint) {}
|
||||||
|
bool isArg() const {
|
||||||
|
assert(constraint.hasValue() ^ index.hasValue());
|
||||||
|
return index.hasValue();
|
||||||
|
}
|
||||||
|
bool isType() const {
|
||||||
|
assert(constraint.hasValue() ^ index.hasValue());
|
||||||
|
return constraint.hasValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
int getArg() const { return *index; }
|
||||||
|
TypeConstraint getType() const { return *constraint; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Optional<int> index;
|
||||||
|
Optional<TypeConstraint> constraint;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return all arguments or type constraints with same type as result[index].
|
||||||
|
// Requires: all result types are known.
|
||||||
|
ArrayRef<ArgOrType> getSameTypeAsResult(int index) const;
|
||||||
|
|
||||||
|
// Pair consisting kind of argument and index into operands or attributes.
|
||||||
|
struct OperandOrAttribute {
|
||||||
|
enum class Kind { Operand, Attribute };
|
||||||
|
OperandOrAttribute(Kind kind, int index) {
|
||||||
|
packed = (index << 1) & (kind == Kind::Attribute);
|
||||||
|
}
|
||||||
|
int operandOrAttributeIndex() const { return (packed >> 1); }
|
||||||
|
Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int packed;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns the OperandOrAttribute corresponding to the index.
|
||||||
|
OperandOrAttribute getArgToOperandOrAttribute(int index) const;
|
||||||
|
|
||||||
|
// Returns the builders of this operation.
|
||||||
|
ArrayRef<Builder> getBuilders() const { return builders; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Populates the vectors containing operands, attributes, results and traits.
|
||||||
|
void populateOpStructure();
|
||||||
|
|
||||||
|
// Populates type inference info (mostly equality) with input a mapping from
|
||||||
|
// names to indices for arguments and results.
|
||||||
|
void populateTypeInferenceInfo(
|
||||||
|
const llvm::StringMap<int> &argumentsAndResultsIndex);
|
||||||
|
|
||||||
|
// The dialect of this op.
|
||||||
|
Dialect dialect;
|
||||||
|
|
||||||
|
// The unqualified C++ class name of the op.
|
||||||
|
StringRef cppClassName;
|
||||||
|
|
||||||
|
// The C++ namespace for this op.
|
||||||
|
StringRef cppNamespace;
|
||||||
|
|
||||||
|
// The operands of the op.
|
||||||
|
SmallVector<NamedTypeConstraint, 4> operands;
|
||||||
|
|
||||||
|
// The attributes of the op. Contains native attributes (corresponding to the
|
||||||
|
// actual stored attributed of the operation) followed by derived attributes
|
||||||
|
// (corresponding to dynamic properties of the operation that are computed
|
||||||
|
// upon request).
|
||||||
|
SmallVector<NamedAttribute, 4> attributes;
|
||||||
|
|
||||||
|
// The arguments of the op (operands and native attributes).
|
||||||
|
SmallVector<Argument, 4> arguments;
|
||||||
|
|
||||||
|
// The results of the op.
|
||||||
|
SmallVector<NamedTypeConstraint, 4> results;
|
||||||
|
|
||||||
|
// The successors of this op.
|
||||||
|
SmallVector<NamedSuccessor, 0> successors;
|
||||||
|
|
||||||
|
// The traits of the op.
|
||||||
|
SmallVector<Trait, 4> traits;
|
||||||
|
|
||||||
|
// The regions of this op.
|
||||||
|
SmallVector<NamedRegion, 1> regions;
|
||||||
|
|
||||||
|
// The argument with the same type as the result.
|
||||||
|
SmallVector<SmallVector<ArgOrType, 2>, 4> resultTypeMapping;
|
||||||
|
|
||||||
|
// Map from argument to attribute or operand number.
|
||||||
|
SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
|
||||||
|
|
||||||
|
// The builders of this operator.
|
||||||
|
SmallVector<Builder> builders;
|
||||||
|
|
||||||
|
// The number of native attributes stored in the leading positions of
|
||||||
|
// `attributes`.
|
||||||
|
int numNativeAttributes;
|
||||||
|
|
||||||
|
// The TableGen definition of this op.
|
||||||
|
const llvm::Record &def;
|
||||||
|
|
||||||
|
// Whether the type of all results are known.
|
||||||
|
bool allResultsHaveKnownTypes;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_OPERATOR_H_
|
|
@ -0,0 +1,99 @@
|
||||||
|
//===- Pass.cpp - Pass related classes ------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Pass.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PassOption
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
StringRef PassOption::getCppVariableName() const {
|
||||||
|
return def->getValueAsString("cppName");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef PassOption::getArgument() const {
|
||||||
|
return def->getValueAsString("argument");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef PassOption::getType() const { return def->getValueAsString("type"); }
|
||||||
|
|
||||||
|
Optional<StringRef> PassOption::getDefaultValue() const {
|
||||||
|
StringRef defaultVal = def->getValueAsString("defaultValue");
|
||||||
|
return defaultVal.empty() ? Optional<StringRef>() : defaultVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef PassOption::getDescription() const {
|
||||||
|
return def->getValueAsString("description");
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<StringRef> PassOption::getAdditionalFlags() const {
|
||||||
|
StringRef additionalFlags = def->getValueAsString("additionalOptFlags");
|
||||||
|
return additionalFlags.empty() ? Optional<StringRef>() : additionalFlags;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PassOption::isListOption() const {
|
||||||
|
return def->isSubClassOf("ListOption");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PassStatistic
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
StringRef PassStatistic::getCppVariableName() const {
|
||||||
|
return def->getValueAsString("cppName");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef PassStatistic::getName() const {
|
||||||
|
return def->getValueAsString("name");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef PassStatistic::getDescription() const {
|
||||||
|
return def->getValueAsString("description");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Pass::Pass(const llvm::Record *def) : def(def) {
|
||||||
|
for (auto *init : def->getValueAsListOfDefs("options"))
|
||||||
|
options.push_back(PassOption(init));
|
||||||
|
for (auto *init : def->getValueAsListOfDefs("statistics"))
|
||||||
|
statistics.push_back(PassStatistic(init));
|
||||||
|
for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
|
||||||
|
dependentDialects.push_back(dialect);
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Pass::getArgument() const {
|
||||||
|
return def->getValueAsString("argument");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Pass::getBaseClass() const {
|
||||||
|
return def->getValueAsString("baseClass");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Pass::getSummary() const { return def->getValueAsString("summary"); }
|
||||||
|
|
||||||
|
StringRef Pass::getDescription() const {
|
||||||
|
return def->getValueAsString("description");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef Pass::getConstructor() const {
|
||||||
|
return def->getValueAsString("constructor");
|
||||||
|
}
|
||||||
|
ArrayRef<StringRef> Pass::getDependentDialects() const {
|
||||||
|
return dependentDialects;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<PassOption> Pass::getOptions() const { return options; }
|
||||||
|
|
||||||
|
ArrayRef<PassStatistic> Pass::getStatistics() const { return statistics; }
|
|
@ -0,0 +1,118 @@
|
||||||
|
//===- Pass.h - TableGen pass definitions -----------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_PASS_H_
|
||||||
|
#define MLIR_TABLEGEN_PASS_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class Record;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PassOption
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
class PassOption {
|
||||||
|
public:
|
||||||
|
explicit PassOption(const llvm::Record *def) : def(def) {}
|
||||||
|
|
||||||
|
/// Return the name for the C++ option variable.
|
||||||
|
StringRef getCppVariableName() const;
|
||||||
|
|
||||||
|
/// Return the command line argument to use for this option.
|
||||||
|
StringRef getArgument() const;
|
||||||
|
|
||||||
|
/// Return the C++ type of the option.
|
||||||
|
StringRef getType() const;
|
||||||
|
|
||||||
|
/// Return the default value of the option.
|
||||||
|
Optional<StringRef> getDefaultValue() const;
|
||||||
|
|
||||||
|
/// Return the description for this option.
|
||||||
|
StringRef getDescription() const;
|
||||||
|
|
||||||
|
/// Return the additional flags passed to the option constructor.
|
||||||
|
Optional<StringRef> getAdditionalFlags() const;
|
||||||
|
|
||||||
|
/// Flag indicating if this is a list option.
|
||||||
|
bool isListOption() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llvm::Record *def;
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PassStatistic
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
class PassStatistic {
|
||||||
|
public:
|
||||||
|
explicit PassStatistic(const llvm::Record *def) : def(def) {}
|
||||||
|
|
||||||
|
/// Return the name for the C++ statistic variable.
|
||||||
|
StringRef getCppVariableName() const;
|
||||||
|
|
||||||
|
/// Return the name of the statistic.
|
||||||
|
StringRef getName() const;
|
||||||
|
|
||||||
|
/// Return the description for this statistic.
|
||||||
|
StringRef getDescription() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llvm::Record *def;
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Wrapper class providing helper methods for Passes defined in TableGen.
|
||||||
|
class Pass {
|
||||||
|
public:
|
||||||
|
explicit Pass(const llvm::Record *def);
|
||||||
|
|
||||||
|
/// Return the command line argument of the pass.
|
||||||
|
StringRef getArgument() const;
|
||||||
|
|
||||||
|
/// Return the name for the C++ base class.
|
||||||
|
StringRef getBaseClass() const;
|
||||||
|
|
||||||
|
/// Return the short 1-line summary of the pass.
|
||||||
|
StringRef getSummary() const;
|
||||||
|
|
||||||
|
/// Return the description of the pass.
|
||||||
|
StringRef getDescription() const;
|
||||||
|
|
||||||
|
/// Return the C++ constructor call to create an instance of this pass.
|
||||||
|
StringRef getConstructor() const;
|
||||||
|
|
||||||
|
/// Return the dialects this pass needs to be registered.
|
||||||
|
ArrayRef<StringRef> getDependentDialects() const;
|
||||||
|
|
||||||
|
/// Return the options provided by this pass.
|
||||||
|
ArrayRef<PassOption> getOptions() const;
|
||||||
|
|
||||||
|
/// Return the statistics provided by this pass.
|
||||||
|
ArrayRef<PassStatistic> getStatistics() const;
|
||||||
|
|
||||||
|
const llvm::Record *getDef() const { return def; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llvm::Record *def;
|
||||||
|
std::vector<StringRef> dependentDialects;
|
||||||
|
std::vector<PassOption> options;
|
||||||
|
std::vector<PassStatistic> statistics;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_PASS_H_
|
|
@ -0,0 +1,739 @@
|
||||||
|
//===- Pattern.cpp - Pattern wrapper class --------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
|
||||||
|
// Pattern.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Pattern.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/ADT/Twine.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/TableGen/Error.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "mlir-tblgen-pattern"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace tblgen;
|
||||||
|
|
||||||
|
using llvm::formatv;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DagLeaf
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool DagLeaf::isUnspecified() const {
|
||||||
|
return dyn_cast_or_null<llvm::UnsetInit>(def);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DagLeaf::isOperandMatcher() const {
|
||||||
|
// Operand matchers specify a type constraint.
|
||||||
|
return isSubClassOf("TypeConstraint");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DagLeaf::isAttrMatcher() const {
|
||||||
|
// Attribute matchers specify an attribute constraint.
|
||||||
|
return isSubClassOf("AttrConstraint");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DagLeaf::isNativeCodeCall() const {
|
||||||
|
return isSubClassOf("NativeCodeCall");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
|
||||||
|
|
||||||
|
bool DagLeaf::isEnumAttrCase() const {
|
||||||
|
return isSubClassOf("EnumAttrCaseInfo");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DagLeaf::isStringAttr() const {
|
||||||
|
return isa<llvm::StringInit>(def);
|
||||||
|
}
|
||||||
|
|
||||||
|
Constraint DagLeaf::getAsConstraint() const {
|
||||||
|
assert((isOperandMatcher() || isAttrMatcher()) &&
|
||||||
|
"the DAG leaf must be operand or attribute");
|
||||||
|
return Constraint(cast<llvm::DefInit>(def)->getDef());
|
||||||
|
}
|
||||||
|
|
||||||
|
ConstantAttr DagLeaf::getAsConstantAttr() const {
|
||||||
|
assert(isConstantAttr() && "the DAG leaf must be constant attribute");
|
||||||
|
return ConstantAttr(cast<llvm::DefInit>(def));
|
||||||
|
}
|
||||||
|
|
||||||
|
EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
|
||||||
|
assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
|
||||||
|
return EnumAttrCase(cast<llvm::DefInit>(def));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string DagLeaf::getConditionTemplate() const {
|
||||||
|
return getAsConstraint().getConditionTemplate();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
|
||||||
|
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
|
||||||
|
return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string DagLeaf::getStringAttr() const {
|
||||||
|
assert(isStringAttr() && "the DAG leaf must be string attribute");
|
||||||
|
return def->getAsUnquotedString();
|
||||||
|
}
|
||||||
|
bool DagLeaf::isSubClassOf(StringRef superclass) const {
|
||||||
|
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
|
||||||
|
return defInit->getDef()->isSubClassOf(superclass);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DagLeaf::print(raw_ostream &os) const {
|
||||||
|
if (def)
|
||||||
|
def->print(os);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DagNode
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool DagNode::isNativeCodeCall() const {
|
||||||
|
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
|
||||||
|
return defInit->getDef()->isSubClassOf("NativeCodeCall");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DagNode::isOperation() const {
|
||||||
|
return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef DagNode::getNativeCodeTemplate() const {
|
||||||
|
assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
|
||||||
|
return cast<llvm::DefInit>(node->getOperator())
|
||||||
|
->getDef()
|
||||||
|
->getValueAsString("expression");
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
|
||||||
|
|
||||||
|
Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
|
||||||
|
llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||||
|
auto it = mapper->find(opDef);
|
||||||
|
if (it != mapper->end())
|
||||||
|
return *it->second;
|
||||||
|
return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
|
||||||
|
.first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
int DagNode::getNumOps() const {
|
||||||
|
int count = isReplaceWithValue() ? 0 : 1;
|
||||||
|
for (int i = 0, e = getNumArgs(); i != e; ++i) {
|
||||||
|
if (auto child = getArgAsNestedDag(i))
|
||||||
|
count += child.getNumOps();
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
int DagNode::getNumArgs() const { return node->getNumArgs(); }
|
||||||
|
|
||||||
|
bool DagNode::isNestedDagArg(unsigned index) const {
|
||||||
|
return isa<llvm::DagInit>(node->getArg(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
DagNode DagNode::getArgAsNestedDag(unsigned index) const {
|
||||||
|
return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
|
||||||
|
}
|
||||||
|
|
||||||
|
DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
|
||||||
|
assert(!isNestedDagArg(index));
|
||||||
|
return DagLeaf(node->getArg(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef DagNode::getArgName(unsigned index) const {
|
||||||
|
return node->getArgNameStr(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DagNode::isReplaceWithValue() const {
|
||||||
|
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||||
|
return dagOpDef->getName() == "replaceWithValue";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DagNode::isLocationDirective() const {
|
||||||
|
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
|
||||||
|
return dagOpDef->getName() == "location";
|
||||||
|
}
|
||||||
|
|
||||||
|
void DagNode::print(raw_ostream &os) const {
|
||||||
|
if (node)
|
||||||
|
node->print(os);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SymbolInfoMap
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
|
||||||
|
StringRef name, indexStr;
|
||||||
|
int idx = -1;
|
||||||
|
std::tie(name, indexStr) = symbol.rsplit("__");
|
||||||
|
|
||||||
|
if (indexStr.consumeInteger(10, idx)) {
|
||||||
|
// The second part is not an index; we return the whole symbol as-is.
|
||||||
|
return symbol;
|
||||||
|
}
|
||||||
|
if (index) {
|
||||||
|
*index = idx;
|
||||||
|
}
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
|
||||||
|
Optional<int> index)
|
||||||
|
: op(op), kind(kind), argIndex(index) {}
|
||||||
|
|
||||||
|
int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
|
||||||
|
switch (kind) {
|
||||||
|
case Kind::Attr:
|
||||||
|
case Kind::Operand:
|
||||||
|
case Kind::Value:
|
||||||
|
return 1;
|
||||||
|
case Kind::Result:
|
||||||
|
return op->getNumResults();
|
||||||
|
}
|
||||||
|
llvm_unreachable("unknown kind");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
|
||||||
|
return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
|
||||||
|
switch (kind) {
|
||||||
|
case Kind::Attr: {
|
||||||
|
if (op) {
|
||||||
|
auto type =
|
||||||
|
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
|
||||||
|
return std::string(formatv("{0} {1};\n", type, name));
|
||||||
|
}
|
||||||
|
// TODO(suderman): Use a more exact type when available.
|
||||||
|
return std::string(formatv("Attribute {0};\n", name));
|
||||||
|
}
|
||||||
|
case Kind::Operand: {
|
||||||
|
// Use operand range for captured operands (to support potential variadic
|
||||||
|
// operands).
|
||||||
|
return std::string(
|
||||||
|
formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
|
||||||
|
getVarName(name)));
|
||||||
|
}
|
||||||
|
case Kind::Value: {
|
||||||
|
return std::string(formatv("::mlir::Value {0};\n", name));
|
||||||
|
}
|
||||||
|
case Kind::Result: {
|
||||||
|
// Use the op itself for captured results.
|
||||||
|
return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llvm_unreachable("unknown kind");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||||
|
StringRef name, int index, const char *fmt, const char *separator) const {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
|
||||||
|
switch (kind) {
|
||||||
|
case Kind::Attr: {
|
||||||
|
assert(index < 0);
|
||||||
|
auto repl = formatv(fmt, name);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
|
||||||
|
return std::string(repl);
|
||||||
|
}
|
||||||
|
case Kind::Operand: {
|
||||||
|
assert(index < 0);
|
||||||
|
auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
|
||||||
|
// If this operand is variadic, then return a range. Otherwise, return the
|
||||||
|
// value itself.
|
||||||
|
if (operand->isVariableLength()) {
|
||||||
|
auto repl = formatv(fmt, name);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
|
||||||
|
return std::string(repl);
|
||||||
|
}
|
||||||
|
auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
|
||||||
|
return std::string(repl);
|
||||||
|
}
|
||||||
|
case Kind::Result: {
|
||||||
|
// If `index` is greater than zero, then we are referencing a specific
|
||||||
|
// result of a multi-result op. The result can still be variadic.
|
||||||
|
if (index >= 0) {
|
||||||
|
std::string v =
|
||||||
|
std::string(formatv("{0}.getODSResults({1})", name, index));
|
||||||
|
if (!op->getResult(index).isVariadic())
|
||||||
|
v = std::string(formatv("(*{0}.begin())", v));
|
||||||
|
auto repl = formatv(fmt, v);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
|
||||||
|
return std::string(repl);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this op has no result at all but still we bind a symbol to it, it
|
||||||
|
// means we want to capture the op itself.
|
||||||
|
if (op->getNumResults() == 0) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
|
||||||
|
return std::string(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// We are referencing all results of the multi-result op. A specific result
|
||||||
|
// can either be a value or a range. Then join them with `separator`.
|
||||||
|
SmallVector<std::string, 4> values;
|
||||||
|
values.reserve(op->getNumResults());
|
||||||
|
|
||||||
|
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
|
||||||
|
std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
|
||||||
|
if (!op->getResult(i).isVariadic()) {
|
||||||
|
v = std::string(formatv("(*{0}.begin())", v));
|
||||||
|
}
|
||||||
|
values.push_back(std::string(formatv(fmt, v)));
|
||||||
|
}
|
||||||
|
auto repl = llvm::join(values, separator);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
|
||||||
|
return repl;
|
||||||
|
}
|
||||||
|
case Kind::Value: {
|
||||||
|
assert(index < 0);
|
||||||
|
assert(op == nullptr);
|
||||||
|
auto repl = formatv(fmt, name);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
|
||||||
|
return std::string(repl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llvm_unreachable("unknown kind");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
|
||||||
|
StringRef name, int index, const char *fmt, const char *separator) const {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
|
||||||
|
switch (kind) {
|
||||||
|
case Kind::Attr:
|
||||||
|
case Kind::Operand: {
|
||||||
|
assert(index < 0 && "only allowed for symbol bound to result");
|
||||||
|
auto repl = formatv(fmt, name);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
|
||||||
|
return std::string(repl);
|
||||||
|
}
|
||||||
|
case Kind::Result: {
|
||||||
|
if (index >= 0) {
|
||||||
|
auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
|
||||||
|
return std::string(repl);
|
||||||
|
}
|
||||||
|
|
||||||
|
// We are referencing all results of the multi-result op. Each result should
|
||||||
|
// have a value range, and then join them with `separator`.
|
||||||
|
SmallVector<std::string, 4> values;
|
||||||
|
values.reserve(op->getNumResults());
|
||||||
|
|
||||||
|
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
|
||||||
|
values.push_back(std::string(
|
||||||
|
formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
|
||||||
|
}
|
||||||
|
auto repl = llvm::join(values, separator);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
|
||||||
|
return repl;
|
||||||
|
}
|
||||||
|
case Kind::Value: {
|
||||||
|
assert(index < 0 && "only allowed for symbol bound to result");
|
||||||
|
assert(op == nullptr);
|
||||||
|
auto repl = formatv(fmt, formatv("{{{0}}", name));
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
|
||||||
|
return std::string(repl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llvm_unreachable("unknown kind");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
|
||||||
|
int argIndex) {
|
||||||
|
StringRef name = getValuePackName(symbol);
|
||||||
|
if (name != symbol) {
|
||||||
|
auto error = formatv(
|
||||||
|
"symbol '{0}' with trailing index cannot bind to op argument", symbol);
|
||||||
|
PrintFatalError(loc, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
|
||||||
|
? SymbolInfo::getAttr(&op, argIndex)
|
||||||
|
: SymbolInfo::getOperand(&op, argIndex);
|
||||||
|
|
||||||
|
std::string key = symbol.str();
|
||||||
|
if (symbolInfoMap.count(key)) {
|
||||||
|
// Only non unique name for the operand is supported.
|
||||||
|
if (symInfo.kind != SymbolInfo::Kind::Operand) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot add new operand if there is already non operand with the same
|
||||||
|
// name.
|
||||||
|
if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
symbolInfoMap.emplace(key, symInfo);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
|
||||||
|
std::string name = getValuePackName(symbol).str();
|
||||||
|
auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
|
||||||
|
|
||||||
|
return symbolInfoMap.count(inserted->first) == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SymbolInfoMap::bindValue(StringRef symbol) {
|
||||||
|
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
|
||||||
|
return symbolInfoMap.count(inserted->first) == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SymbolInfoMap::bindAttr(StringRef symbol) {
|
||||||
|
auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
|
||||||
|
return symbolInfoMap.count(inserted->first) == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SymbolInfoMap::contains(StringRef symbol) const {
|
||||||
|
return find(symbol) != symbolInfoMap.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
|
||||||
|
std::string name = getValuePackName(key).str();
|
||||||
|
|
||||||
|
return symbolInfoMap.find(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
SymbolInfoMap::const_iterator
|
||||||
|
SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
|
||||||
|
int argIndex) const {
|
||||||
|
std::string name = getValuePackName(key).str();
|
||||||
|
auto range = symbolInfoMap.equal_range(name);
|
||||||
|
|
||||||
|
for (auto it = range.first; it != range.second; ++it) {
|
||||||
|
if (it->second.op == &op && it->second.argIndex == argIndex) {
|
||||||
|
return it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return symbolInfoMap.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
|
||||||
|
SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
|
||||||
|
std::string name = getValuePackName(key).str();
|
||||||
|
|
||||||
|
return symbolInfoMap.equal_range(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
int SymbolInfoMap::count(StringRef key) const {
|
||||||
|
std::string name = getValuePackName(key).str();
|
||||||
|
return symbolInfoMap.count(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
|
||||||
|
StringRef name = getValuePackName(symbol);
|
||||||
|
if (name != symbol) {
|
||||||
|
// If there is a trailing index inside symbol, it references just one
|
||||||
|
// static value.
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
// Otherwise, find how many it represents by querying the symbol's info.
|
||||||
|
return find(name)->second.getStaticValueCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
|
||||||
|
const char *fmt,
|
||||||
|
const char *separator) const {
|
||||||
|
int index = -1;
|
||||||
|
StringRef name = getValuePackName(symbol, &index);
|
||||||
|
|
||||||
|
auto it = symbolInfoMap.find(name.str());
|
||||||
|
if (it == symbolInfoMap.end()) {
|
||||||
|
auto error = formatv("referencing unbound symbol '{0}'", symbol);
|
||||||
|
PrintFatalError(loc, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
return it->second.getValueAndRangeUse(name, index, fmt, separator);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
|
||||||
|
const char *separator) const {
|
||||||
|
int index = -1;
|
||||||
|
StringRef name = getValuePackName(symbol, &index);
|
||||||
|
|
||||||
|
auto it = symbolInfoMap.find(name.str());
|
||||||
|
if (it == symbolInfoMap.end()) {
|
||||||
|
auto error = formatv("referencing unbound symbol '{0}'", symbol);
|
||||||
|
PrintFatalError(loc, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
return it->second.getAllRangeUse(name, index, fmt, separator);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SymbolInfoMap::assignUniqueAlternativeNames() {
|
||||||
|
llvm::StringSet<> usedNames;
|
||||||
|
|
||||||
|
for (auto symbolInfoIt = symbolInfoMap.begin();
|
||||||
|
symbolInfoIt != symbolInfoMap.end();) {
|
||||||
|
auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
|
||||||
|
auto startRange = range.first;
|
||||||
|
auto endRange = range.second;
|
||||||
|
|
||||||
|
auto operandName = symbolInfoIt->first;
|
||||||
|
int startSearchIndex = 0;
|
||||||
|
for (++startRange; startRange != endRange; ++startRange) {
|
||||||
|
// Current operand name is not unique, find a unique one
|
||||||
|
// and set the alternative name.
|
||||||
|
for (int i = startSearchIndex;; ++i) {
|
||||||
|
std::string alternativeName = operandName + std::to_string(i);
|
||||||
|
if (!usedNames.contains(alternativeName) &&
|
||||||
|
symbolInfoMap.count(alternativeName) == 0) {
|
||||||
|
usedNames.insert(alternativeName);
|
||||||
|
startRange->second.alternativeName = alternativeName;
|
||||||
|
startSearchIndex = i + 1;
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
symbolInfoIt = endRange;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pattern
|
||||||
|
//==----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
|
||||||
|
: def(*def), recordOpMap(mapper) {}
|
||||||
|
|
||||||
|
DagNode Pattern::getSourcePattern() const {
|
||||||
|
return DagNode(def.getValueAsDag("sourcePattern"));
|
||||||
|
}
|
||||||
|
|
||||||
|
int Pattern::getNumResultPatterns() const {
|
||||||
|
auto *results = def.getValueAsListInit("resultPatterns");
|
||||||
|
return results->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
DagNode Pattern::getResultPattern(unsigned index) const {
|
||||||
|
auto *results = def.getValueAsListInit("resultPatterns");
|
||||||
|
return DagNode(cast<llvm::DagInit>(results->getElement(index)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
|
||||||
|
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
|
||||||
|
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
|
||||||
|
infoMap.assignUniqueAlternativeNames();
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
|
||||||
|
for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
|
||||||
|
auto pattern = getResultPattern(i);
|
||||||
|
collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
|
||||||
|
}
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
const Operator &Pattern::getSourceRootOp() {
|
||||||
|
return getSourcePattern().getDialectOp(recordOpMap);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator &Pattern::getDialectOp(DagNode node) {
|
||||||
|
return node.getDialectOp(recordOpMap);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AppliedConstraint> Pattern::getConstraints() const {
|
||||||
|
auto *listInit = def.getValueAsListInit("constraints");
|
||||||
|
std::vector<AppliedConstraint> ret;
|
||||||
|
ret.reserve(listInit->size());
|
||||||
|
|
||||||
|
for (auto it : *listInit) {
|
||||||
|
auto *dagInit = dyn_cast<llvm::DagInit>(it);
|
||||||
|
if (!dagInit)
|
||||||
|
PrintFatalError(&def, "all elements in Pattern multi-entity "
|
||||||
|
"constraints should be DAG nodes");
|
||||||
|
|
||||||
|
std::vector<std::string> entities;
|
||||||
|
entities.reserve(dagInit->arg_size());
|
||||||
|
for (auto *argName : dagInit->getArgNames()) {
|
||||||
|
if (!argName) {
|
||||||
|
PrintFatalError(
|
||||||
|
&def,
|
||||||
|
"operands to additional constraints can only be symbol references");
|
||||||
|
}
|
||||||
|
entities.push_back(std::string(argName->getValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
|
||||||
|
dagInit->getNameStr(), std::move(entities));
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Pattern::getBenefit() const {
|
||||||
|
// The initial benefit value is a heuristic with number of ops in the source
|
||||||
|
// pattern.
|
||||||
|
int initBenefit = getSourcePattern().getNumOps();
|
||||||
|
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
|
||||||
|
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
|
||||||
|
PrintFatalError(&def,
|
||||||
|
"The 'addBenefit' takes and only takes one integer value");
|
||||||
|
}
|
||||||
|
return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
|
||||||
|
std::vector<std::pair<StringRef, unsigned>> result;
|
||||||
|
result.reserve(def.getLoc().size());
|
||||||
|
for (auto loc : def.getLoc()) {
|
||||||
|
unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
|
||||||
|
assert(buf && "invalid source location");
|
||||||
|
result.emplace_back(
|
||||||
|
llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
|
||||||
|
llvm::SrcMgr.getLineAndColumn(loc, buf).first);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Pattern::verifyBind(bool result, StringRef symbolName) {
|
||||||
|
if (!result) {
|
||||||
|
auto err = formatv("symbol '{0}' bound more than once", symbolName);
|
||||||
|
PrintFatalError(&def, err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||||
|
bool isSrcPattern) {
|
||||||
|
auto treeName = tree.getSymbol();
|
||||||
|
auto numTreeArgs = tree.getNumArgs();
|
||||||
|
|
||||||
|
if (tree.isNativeCodeCall()) {
|
||||||
|
if (!treeName.empty()) {
|
||||||
|
if (!isSrcPattern) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
|
||||||
|
<< treeName << '\n');
|
||||||
|
verifyBind(infoMap.bindValue(treeName), treeName);
|
||||||
|
} else {
|
||||||
|
PrintFatalError(&def,
|
||||||
|
formatv("binding symbol '{0}' to NativecodeCall in "
|
||||||
|
"MatchPattern is not supported",
|
||||||
|
treeName));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i != numTreeArgs; ++i) {
|
||||||
|
if (auto treeArg = tree.getArgAsNestedDag(i)) {
|
||||||
|
// This DAG node argument is a DAG node itself. Go inside recursively.
|
||||||
|
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isSrcPattern)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// We can only bind symbols to arguments in source pattern. Those
|
||||||
|
// symbols are referenced in result patterns.
|
||||||
|
auto treeArgName = tree.getArgName(i);
|
||||||
|
|
||||||
|
// `$_` is a special symbol meaning ignore the current argument.
|
||||||
|
if (!treeArgName.empty() && treeArgName != "_") {
|
||||||
|
DagLeaf leaf = tree.getArgAsLeaf(i);
|
||||||
|
|
||||||
|
// In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
|
||||||
|
if (leaf.isUnspecified()) {
|
||||||
|
// This is case of $c, a Value without any constraints.
|
||||||
|
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
|
||||||
|
} else {
|
||||||
|
auto constraint = leaf.getAsConstraint();
|
||||||
|
bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
|
||||||
|
leaf.isConstantAttr() ||
|
||||||
|
constraint.getKind() == Constraint::Kind::CK_Attr;
|
||||||
|
|
||||||
|
if (isAttr) {
|
||||||
|
// This is case of $a, a binding to a certain attribute.
|
||||||
|
verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is case of $b, a binding to a certain type.
|
||||||
|
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tree.isOperation()) {
|
||||||
|
auto &op = getDialectOp(tree);
|
||||||
|
auto numOpArgs = op.getNumArgs();
|
||||||
|
|
||||||
|
// The pattern might have the last argument specifying the location.
|
||||||
|
bool hasLocDirective = false;
|
||||||
|
if (numTreeArgs != 0) {
|
||||||
|
if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
|
||||||
|
hasLocDirective = lastArg.isLocationDirective();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (numOpArgs != numTreeArgs - hasLocDirective) {
|
||||||
|
auto err = formatv("op '{0}' argument number mismatch: "
|
||||||
|
"{1} in pattern vs. {2} in definition",
|
||||||
|
op.getOperationName(), numTreeArgs, numOpArgs);
|
||||||
|
PrintFatalError(&def, err);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The name attached to the DAG node's operator is for representing the
|
||||||
|
// results generated from this op. It should be remembered as bound results.
|
||||||
|
if (!treeName.empty()) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
|
<< "found symbol bound to op result: " << treeName << '\n');
|
||||||
|
verifyBind(infoMap.bindOpResult(treeName, op), treeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i != numTreeArgs; ++i) {
|
||||||
|
if (auto treeArg = tree.getArgAsNestedDag(i)) {
|
||||||
|
// This DAG node argument is a DAG node itself. Go inside recursively.
|
||||||
|
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSrcPattern) {
|
||||||
|
// We can only bind symbols to op arguments in source pattern. Those
|
||||||
|
// symbols are referenced in result patterns.
|
||||||
|
auto treeArgName = tree.getArgName(i);
|
||||||
|
// `$_` is a special symbol meaning ignore the current argument.
|
||||||
|
if (!treeArgName.empty() && treeArgName != "_") {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
|
||||||
|
<< treeArgName << '\n');
|
||||||
|
verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!treeName.empty()) {
|
||||||
|
PrintFatalError(
|
||||||
|
&def, formatv("binding symbol '{0}' to non-operation/native code call "
|
||||||
|
"unsupported right now",
|
||||||
|
treeName));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
|
@ -0,0 +1,451 @@
|
||||||
|
//===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
|
||||||
|
// Pattern.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_PATTERN_H_
|
||||||
|
#define MLIR_TABLEGEN_PATTERN_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "Argument.h"
|
||||||
|
#include "Operator.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/StringMap.h"
|
||||||
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class DagInit;
|
||||||
|
class Init;
|
||||||
|
class Record;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// Mapping from TableGen Record to Operator wrapper object.
|
||||||
|
//
|
||||||
|
// We allocate each wrapper object in heap to make sure the pointer to it is
|
||||||
|
// valid throughout the lifetime of this map. This is important because this map
|
||||||
|
// is shared among multiple patterns to avoid creating the wrapper object for
|
||||||
|
// the same op again and again. But this map will continuously grow.
|
||||||
|
using RecordOperatorMap =
|
||||||
|
DenseMap<const llvm::Record *, std::unique_ptr<Operator>>;
|
||||||
|
|
||||||
|
class Pattern;
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing TableGen DAG leaves
|
||||||
|
// used inside Patterns. This class is lightweight and designed to be used like
|
||||||
|
// values.
|
||||||
|
//
|
||||||
|
// A TableGen DAG construct is of the syntax
|
||||||
|
// `(operator, arg0, arg1, ...)`.
|
||||||
|
//
|
||||||
|
// This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
|
||||||
|
// for handy helper methods. It only works on `arg*`s that are not nested DAG
|
||||||
|
// constructs.
|
||||||
|
class DagLeaf {
|
||||||
|
public:
|
||||||
|
explicit DagLeaf(const llvm::Init *def) : def(def) {}
|
||||||
|
|
||||||
|
// Returns true if this DAG leaf is not specified in the pattern. That is, it
|
||||||
|
// places no further constraints/transforms and just carries over the original
|
||||||
|
// value.
|
||||||
|
bool isUnspecified() const;
|
||||||
|
|
||||||
|
// Returns true if this DAG leaf is matching an operand. That is, it specifies
|
||||||
|
// a type constraint.
|
||||||
|
bool isOperandMatcher() const;
|
||||||
|
|
||||||
|
// Returns true if this DAG leaf is matching an attribute. That is, it
|
||||||
|
// specifies an attribute constraint.
|
||||||
|
bool isAttrMatcher() const;
|
||||||
|
|
||||||
|
// Returns true if this DAG leaf is wrapping native code call.
|
||||||
|
bool isNativeCodeCall() const;
|
||||||
|
|
||||||
|
// Returns true if this DAG leaf is specifying a constant attribute.
|
||||||
|
bool isConstantAttr() const;
|
||||||
|
|
||||||
|
// Returns true if this DAG leaf is specifying an enum attribute case.
|
||||||
|
bool isEnumAttrCase() const;
|
||||||
|
|
||||||
|
// Returns true if this DAG leaf is specifying a string attribute.
|
||||||
|
bool isStringAttr() const;
|
||||||
|
|
||||||
|
// Returns this DAG leaf as a constraint. Asserts if fails.
|
||||||
|
Constraint getAsConstraint() const;
|
||||||
|
|
||||||
|
// Returns this DAG leaf as an constant attribute. Asserts if fails.
|
||||||
|
ConstantAttr getAsConstantAttr() const;
|
||||||
|
|
||||||
|
// Returns this DAG leaf as an enum attribute case.
|
||||||
|
// Precondition: isEnumAttrCase()
|
||||||
|
EnumAttrCase getAsEnumAttrCase() const;
|
||||||
|
|
||||||
|
// Returns the matching condition template inside this DAG leaf. Assumes the
|
||||||
|
// leaf is an operand/attribute matcher and asserts otherwise.
|
||||||
|
std::string getConditionTemplate() const;
|
||||||
|
|
||||||
|
// Returns the native code call template inside this DAG leaf.
|
||||||
|
// Precondition: isNativeCodeCall()
|
||||||
|
StringRef getNativeCodeTemplate() const;
|
||||||
|
|
||||||
|
// Returns the string associated with the leaf.
|
||||||
|
// Precondition: isStringAttr()
|
||||||
|
std::string getStringAttr() const;
|
||||||
|
|
||||||
|
void print(raw_ostream &os) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
|
||||||
|
// also a subclass of the given `superclass`.
|
||||||
|
bool isSubClassOf(StringRef superclass) const;
|
||||||
|
|
||||||
|
const llvm::Init *def;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing TableGen DAG constructs
|
||||||
|
// used inside Patterns. This class is lightweight and designed to be used like
|
||||||
|
// values.
|
||||||
|
//
|
||||||
|
// A TableGen DAG construct is of the syntax
|
||||||
|
// `(operator, arg0, arg1, ...)`.
|
||||||
|
//
|
||||||
|
// When used inside Patterns, `operator` corresponds to some dialect op, or
|
||||||
|
// a known list of verbs that defines special transformation actions. This
|
||||||
|
// `arg*` can be a nested DAG construct. This class provides getters to
|
||||||
|
// retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
|
||||||
|
// methods.
|
||||||
|
//
|
||||||
|
// A null DagNode contains a nullptr and converts to false implicitly.
|
||||||
|
class DagNode {
|
||||||
|
public:
|
||||||
|
explicit DagNode(const llvm::DagInit *node) : node(node) {}
|
||||||
|
|
||||||
|
// Implicit bool converter that returns true if this DagNode is not a null
|
||||||
|
// DagNode.
|
||||||
|
operator bool() const { return node != nullptr; }
|
||||||
|
|
||||||
|
// Returns the symbol bound to this DAG node.
|
||||||
|
StringRef getSymbol() const;
|
||||||
|
|
||||||
|
// Returns the operator wrapper object corresponding to the dialect op matched
|
||||||
|
// by this DAG. The operator wrapper will be queried from the given `mapper`
|
||||||
|
// and created in it if not existing.
|
||||||
|
Operator &getDialectOp(RecordOperatorMap *mapper) const;
|
||||||
|
|
||||||
|
// Returns the number of operations recursively involved in the DAG tree
|
||||||
|
// rooted from this node.
|
||||||
|
int getNumOps() const;
|
||||||
|
|
||||||
|
// Returns the number of immediate arguments to this DAG node.
|
||||||
|
int getNumArgs() const;
|
||||||
|
|
||||||
|
// Returns true if the `index`-th argument is a nested DAG construct.
|
||||||
|
bool isNestedDagArg(unsigned index) const;
|
||||||
|
|
||||||
|
// Gets the `index`-th argument as a nested DAG construct if possible. Returns
|
||||||
|
// null DagNode otherwise.
|
||||||
|
DagNode getArgAsNestedDag(unsigned index) const;
|
||||||
|
|
||||||
|
// Gets the `index`-th argument as a DAG leaf.
|
||||||
|
DagLeaf getArgAsLeaf(unsigned index) const;
|
||||||
|
|
||||||
|
// Returns the specified name of the `index`-th argument.
|
||||||
|
StringRef getArgName(unsigned index) const;
|
||||||
|
|
||||||
|
// Returns true if this DAG construct means to replace with an existing SSA
|
||||||
|
// value.
|
||||||
|
bool isReplaceWithValue() const;
|
||||||
|
|
||||||
|
// Returns whether this DAG represents the location of an op creation.
|
||||||
|
bool isLocationDirective() const;
|
||||||
|
|
||||||
|
// Returns true if this DAG node is wrapping native code call.
|
||||||
|
bool isNativeCodeCall() const;
|
||||||
|
|
||||||
|
// Returns true if this DAG node is an operation.
|
||||||
|
bool isOperation() const;
|
||||||
|
|
||||||
|
// Returns the native code call template inside this DAG node.
|
||||||
|
// Precondition: isNativeCodeCall()
|
||||||
|
StringRef getNativeCodeTemplate() const;
|
||||||
|
|
||||||
|
void print(raw_ostream &os) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llvm::DagInit *node; // nullptr means null DagNode
|
||||||
|
};
|
||||||
|
|
||||||
|
// A class for maintaining information for symbols bound in patterns and
|
||||||
|
// provides methods for resolving them according to specific use cases.
|
||||||
|
//
|
||||||
|
// Symbols can be bound to
|
||||||
|
//
|
||||||
|
// * Op arguments and op results in the source pattern and
|
||||||
|
// * Op results in result patterns.
|
||||||
|
//
|
||||||
|
// Symbols can be referenced in result patterns and additional constraints to
|
||||||
|
// the pattern.
|
||||||
|
//
|
||||||
|
// For example, in
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// def : Pattern<
|
||||||
|
// (SrcOp:$results1 $arg0, %arg1),
|
||||||
|
// [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
|
||||||
|
// `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
|
||||||
|
// `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
|
||||||
|
//
|
||||||
|
// If a symbol binds to a multi-result op and it does not have the `__N`
|
||||||
|
// suffix, the symbol is expanded to represent all results generated by the
|
||||||
|
// multi-result op. If the symbol has a `__N` suffix, then it will expand to
|
||||||
|
// only the N-th *static* result as declared in ODS, and that can still
|
||||||
|
// corresponds to multiple *dynamic* values if the N-th *static* result is
|
||||||
|
// variadic.
|
||||||
|
//
|
||||||
|
// This class keeps track of such symbols and resolves them into their bound
|
||||||
|
// values in a suitable way.
|
||||||
|
class SymbolInfoMap {
|
||||||
|
public:
|
||||||
|
explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {}
|
||||||
|
|
||||||
|
// Class for information regarding a symbol.
|
||||||
|
class SymbolInfo {
|
||||||
|
public:
|
||||||
|
// Returns a string for defining a variable named as `name` to store the
|
||||||
|
// value bound by this symbol.
|
||||||
|
std::string getVarDecl(StringRef name) const;
|
||||||
|
|
||||||
|
// Returns a variable name for the symbol named as `name`.
|
||||||
|
std::string getVarName(StringRef name) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Allow SymbolInfoMap to access private methods.
|
||||||
|
friend class SymbolInfoMap;
|
||||||
|
|
||||||
|
// What kind of entity this symbol represents:
|
||||||
|
// * Attr: op attribute
|
||||||
|
// * Operand: op operand
|
||||||
|
// * Result: op result
|
||||||
|
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
|
||||||
|
enum class Kind : uint8_t { Attr, Operand, Result, Value };
|
||||||
|
|
||||||
|
// Creates a SymbolInfo instance. `index` is only used for `Attr` and
|
||||||
|
// `Operand` so should be negative for `Result` and `Value` kind.
|
||||||
|
SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
|
||||||
|
|
||||||
|
// Static methods for creating SymbolInfo.
|
||||||
|
static SymbolInfo getAttr(const Operator *op, int index) {
|
||||||
|
return SymbolInfo(op, Kind::Attr, index);
|
||||||
|
}
|
||||||
|
static SymbolInfo getAttr() {
|
||||||
|
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
|
||||||
|
}
|
||||||
|
static SymbolInfo getOperand(const Operator *op, int index) {
|
||||||
|
return SymbolInfo(op, Kind::Operand, index);
|
||||||
|
}
|
||||||
|
static SymbolInfo getResult(const Operator *op) {
|
||||||
|
return SymbolInfo(op, Kind::Result, llvm::None);
|
||||||
|
}
|
||||||
|
static SymbolInfo getValue() {
|
||||||
|
return SymbolInfo(nullptr, Kind::Value, llvm::None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of static values this symbol corresponds to.
|
||||||
|
// A static value is an operand/result declared in ODS. Normally a symbol
|
||||||
|
// only represents one static value, but symbols bound to op results can
|
||||||
|
// represent more than one if the op is a multi-result op.
|
||||||
|
int getStaticValueCount() const;
|
||||||
|
|
||||||
|
// Returns a string containing the C++ expression for referencing this
|
||||||
|
// symbol as a value (if this symbol represents one static value) or a value
|
||||||
|
// range (if this symbol represents multiple static values). `name` is the
|
||||||
|
// name of the C++ variable that this symbol bounds to. `index` should only
|
||||||
|
// be used for indexing results. `fmt` is used to format each value.
|
||||||
|
// `separator` is used to separate values if this is a value range.
|
||||||
|
std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
|
||||||
|
const char *separator) const;
|
||||||
|
|
||||||
|
// Returns a string containing the C++ expression for referencing this
|
||||||
|
// symbol as a value range regardless of how many static values this symbol
|
||||||
|
// represents. `name` is the name of the C++ variable that this symbol
|
||||||
|
// bounds to. `index` should only be used for indexing results. `fmt` is
|
||||||
|
// used to format each value. `separator` is used to separate values in the
|
||||||
|
// range.
|
||||||
|
std::string getAllRangeUse(StringRef name, int index, const char *fmt,
|
||||||
|
const char *separator) const;
|
||||||
|
|
||||||
|
const Operator *op; // The op where the bound entity belongs
|
||||||
|
Kind kind; // The kind of the bound entity
|
||||||
|
// The argument index (for `Attr` and `Operand` only)
|
||||||
|
Optional<int> argIndex;
|
||||||
|
// Alternative name for the symbol. It is used in case the name
|
||||||
|
// is not unique. Applicable for `Operand` only.
|
||||||
|
Optional<std::string> alternativeName;
|
||||||
|
};
|
||||||
|
|
||||||
|
using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
|
||||||
|
|
||||||
|
// Iterators for accessing all symbols.
|
||||||
|
using iterator = BaseT::iterator;
|
||||||
|
iterator begin() { return symbolInfoMap.begin(); }
|
||||||
|
iterator end() { return symbolInfoMap.end(); }
|
||||||
|
|
||||||
|
// Const iterators for accessing all symbols.
|
||||||
|
using const_iterator = BaseT::const_iterator;
|
||||||
|
const_iterator begin() const { return symbolInfoMap.begin(); }
|
||||||
|
const_iterator end() const { return symbolInfoMap.end(); }
|
||||||
|
|
||||||
|
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
|
||||||
|
// Returns false if `symbol` is already bound and symbols are not operands.
|
||||||
|
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
|
||||||
|
|
||||||
|
// Binds the given `symbol` to the results the given `op`. Returns false if
|
||||||
|
// `symbol` is already bound.
|
||||||
|
bool bindOpResult(StringRef symbol, const Operator &op);
|
||||||
|
|
||||||
|
// Registers the given `symbol` as bound to a value. Returns false if `symbol`
|
||||||
|
// is already bound.
|
||||||
|
bool bindValue(StringRef symbol);
|
||||||
|
|
||||||
|
// Registers the given `symbol` as bound to an attr. Returns false if `symbol`
|
||||||
|
// is already bound.
|
||||||
|
bool bindAttr(StringRef symbol);
|
||||||
|
|
||||||
|
// Returns true if the given `symbol` is bound.
|
||||||
|
bool contains(StringRef symbol) const;
|
||||||
|
|
||||||
|
// Returns an iterator to the information of the given symbol named as `key`.
|
||||||
|
const_iterator find(StringRef key) const;
|
||||||
|
|
||||||
|
// Returns an iterator to the information of the given symbol named as `key`,
|
||||||
|
// with index `argIndex` for operator `op`.
|
||||||
|
const_iterator findBoundSymbol(StringRef key, const Operator &op,
|
||||||
|
int argIndex) const;
|
||||||
|
|
||||||
|
// Returns the bounds of a range that includes all the elements which
|
||||||
|
// bind to the `key`.
|
||||||
|
std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
|
||||||
|
|
||||||
|
// Returns number of times symbol named as `key` was used.
|
||||||
|
int count(StringRef key) const;
|
||||||
|
|
||||||
|
// Returns the number of static values of the given `symbol` corresponds to.
|
||||||
|
// A static value is an operand/result declared in ODS. Normally a symbol only
|
||||||
|
// represents one static value, but symbols bound to op results can represent
|
||||||
|
// more than one if the op is a multi-result op.
|
||||||
|
int getStaticValueCount(StringRef symbol) const;
|
||||||
|
|
||||||
|
// Returns a string containing the C++ expression for referencing this
|
||||||
|
// symbol as a value (if this symbol represents one static value) or a value
|
||||||
|
// range (if this symbol represents multiple static values). `fmt` is used to
|
||||||
|
// format each value. `separator` is used to separate values if `symbol`
|
||||||
|
// represents a value range.
|
||||||
|
std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
|
||||||
|
const char *separator = ", ") const;
|
||||||
|
|
||||||
|
// Returns a string containing the C++ expression for referencing this
|
||||||
|
// symbol as a value range regardless of how many static values this symbol
|
||||||
|
// represents. `fmt` is used to format each value. `separator` is used to
|
||||||
|
// separate values in the range.
|
||||||
|
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
|
||||||
|
const char *separator = ", ") const;
|
||||||
|
|
||||||
|
// Assign alternative unique names to Operands that have equal names.
|
||||||
|
void assignUniqueAlternativeNames();
|
||||||
|
|
||||||
|
// Splits the given `symbol` into a value pack name and an index. Returns the
|
||||||
|
// value pack name and writes the index to `index` on success. Returns
|
||||||
|
// `symbol` itself if it does not contain an index.
|
||||||
|
//
|
||||||
|
// We can use `name__N` to access the `N`-th value in the value pack bound to
|
||||||
|
// `name`. `name` is typically the results of an multi-result op.
|
||||||
|
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
BaseT symbolInfoMap;
|
||||||
|
|
||||||
|
// Pattern instantiation location. This is intended to be used as parameter
|
||||||
|
// to PrintFatalError() to report errors.
|
||||||
|
ArrayRef<llvm::SMLoc> loc;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing MLIR Pattern defined
|
||||||
|
// in TableGen. This class should closely reflect what is defined as class
|
||||||
|
// `Pattern` in TableGen. This class contains maps so it is not intended to be
|
||||||
|
// used as values.
|
||||||
|
class Pattern {
|
||||||
|
public:
|
||||||
|
explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
|
||||||
|
|
||||||
|
// Returns the source pattern to match.
|
||||||
|
DagNode getSourcePattern() const;
|
||||||
|
|
||||||
|
// Returns the number of result patterns generated by applying this rewrite
|
||||||
|
// rule.
|
||||||
|
int getNumResultPatterns() const;
|
||||||
|
|
||||||
|
// Returns the DAG tree root node of the `index`-th result pattern.
|
||||||
|
DagNode getResultPattern(unsigned index) const;
|
||||||
|
|
||||||
|
// Collects all symbols bound in the source pattern into `infoMap`.
|
||||||
|
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
|
||||||
|
|
||||||
|
// Collects all symbols bound in result patterns into `infoMap`.
|
||||||
|
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
|
||||||
|
|
||||||
|
// Returns the op that the root node of the source pattern matches.
|
||||||
|
const Operator &getSourceRootOp();
|
||||||
|
|
||||||
|
// Returns the operator wrapper object corresponding to the given `node`'s DAG
|
||||||
|
// operator.
|
||||||
|
Operator &getDialectOp(DagNode node);
|
||||||
|
|
||||||
|
// Returns the constraints.
|
||||||
|
std::vector<AppliedConstraint> getConstraints() const;
|
||||||
|
|
||||||
|
// Returns the benefit score of the pattern.
|
||||||
|
int getBenefit() const;
|
||||||
|
|
||||||
|
using IdentifierLine = std::pair<StringRef, unsigned>;
|
||||||
|
|
||||||
|
// Returns the file location of the pattern (buffer identifier + line number
|
||||||
|
// pair).
|
||||||
|
std::vector<IdentifierLine> getLocation() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Helper function to verify variabld binding.
|
||||||
|
void verifyBind(bool result, StringRef symbolName);
|
||||||
|
|
||||||
|
// Recursively collects all bound symbols inside the DAG tree rooted
|
||||||
|
// at `tree` and updates the given `infoMap`.
|
||||||
|
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||||
|
bool isSrcPattern);
|
||||||
|
|
||||||
|
// The TableGen definition of this pattern.
|
||||||
|
const llvm::Record &def;
|
||||||
|
|
||||||
|
// All operators.
|
||||||
|
// TODO: we need a proper context manager, like MLIRContext, for managing the
|
||||||
|
// lifetime of shared entities.
|
||||||
|
RecordOperatorMap *recordOpMap;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_PATTERN_H_
|
|
@ -0,0 +1,376 @@
|
||||||
|
//===- Predicate.cpp - Predicate class ------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Wrapper around predicates defined in TableGen.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Predicate.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/TableGen/Error.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace tblgen;
|
||||||
|
|
||||||
|
// Construct a Predicate from a record.
|
||||||
|
Pred::Pred(const llvm::Record *record) : def(record) {
|
||||||
|
assert(def->isSubClassOf("Pred") &&
|
||||||
|
"must be a subclass of TableGen 'Pred' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct a Predicate from an initializer.
|
||||||
|
Pred::Pred(const llvm::Init *init) : def(nullptr) {
|
||||||
|
if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
|
||||||
|
def = defInit->getDef();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Pred::getCondition() const {
|
||||||
|
// Static dispatch to subclasses.
|
||||||
|
if (def->isSubClassOf("CombinedPred"))
|
||||||
|
return static_cast<const CombinedPred *>(this)->getConditionImpl();
|
||||||
|
if (def->isSubClassOf("CPred"))
|
||||||
|
return static_cast<const CPred *>(this)->getConditionImpl();
|
||||||
|
llvm_unreachable("Pred::getCondition must be overridden in subclasses");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Pred::isCombined() const {
|
||||||
|
return def && def->isSubClassOf("CombinedPred");
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<llvm::SMLoc> Pred::getLoc() const { return def->getLoc(); }
|
||||||
|
|
||||||
|
CPred::CPred(const llvm::Record *record) : Pred(record) {
|
||||||
|
assert(def->isSubClassOf("CPred") &&
|
||||||
|
"must be a subclass of Tablegen 'CPred' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
CPred::CPred(const llvm::Init *init) : Pred(init) {
|
||||||
|
assert((!def || def->isSubClassOf("CPred")) &&
|
||||||
|
"must be a subclass of Tablegen 'CPred' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get condition of the C Predicate.
|
||||||
|
std::string CPred::getConditionImpl() const {
|
||||||
|
assert(!isNull() && "null predicate does not have a condition");
|
||||||
|
return std::string(def->getValueAsString("predExpr"));
|
||||||
|
}
|
||||||
|
|
||||||
|
CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
|
||||||
|
assert(def->isSubClassOf("CombinedPred") &&
|
||||||
|
"must be a subclass of Tablegen 'CombinedPred' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
|
||||||
|
assert((!def || def->isSubClassOf("CombinedPred")) &&
|
||||||
|
"must be a subclass of Tablegen 'CombinedPred' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
const llvm::Record *CombinedPred::getCombinerDef() const {
|
||||||
|
assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
|
||||||
|
return def->getValueAsDef("kind");
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<llvm::Record *> CombinedPred::getChildren() const {
|
||||||
|
assert(def->getValue("children") &&
|
||||||
|
"CombinedPred must have a value 'children'");
|
||||||
|
return def->getValueAsListOfDefs("children");
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Kinds of nodes in a logical predicate tree.
|
||||||
|
enum class PredCombinerKind {
|
||||||
|
Leaf,
|
||||||
|
And,
|
||||||
|
Or,
|
||||||
|
Not,
|
||||||
|
SubstLeaves,
|
||||||
|
Concat,
|
||||||
|
// Special kinds that are used in simplification.
|
||||||
|
False,
|
||||||
|
True
|
||||||
|
};
|
||||||
|
|
||||||
|
// A node in a logical predicate tree.
|
||||||
|
struct PredNode {
|
||||||
|
PredCombinerKind kind;
|
||||||
|
const Pred *predicate;
|
||||||
|
SmallVector<PredNode *, 4> children;
|
||||||
|
std::string expr;
|
||||||
|
|
||||||
|
// Prefix and suffix are used by ConcatPred.
|
||||||
|
std::string prefix;
|
||||||
|
std::string suffix;
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
// Get a predicate tree node kind based on the kind used in the predicate
|
||||||
|
// TableGen record.
|
||||||
|
static PredCombinerKind getPredCombinerKind(const Pred &pred) {
|
||||||
|
if (!pred.isCombined())
|
||||||
|
return PredCombinerKind::Leaf;
|
||||||
|
|
||||||
|
const auto &combinedPred = static_cast<const CombinedPred &>(pred);
|
||||||
|
return StringSwitch<PredCombinerKind>(
|
||||||
|
combinedPred.getCombinerDef()->getName())
|
||||||
|
.Case("PredCombinerAnd", PredCombinerKind::And)
|
||||||
|
.Case("PredCombinerOr", PredCombinerKind::Or)
|
||||||
|
.Case("PredCombinerNot", PredCombinerKind::Not)
|
||||||
|
.Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
|
||||||
|
.Case("PredCombinerConcat", PredCombinerKind::Concat);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Substitution<pattern, replacement>.
|
||||||
|
using Subst = std::pair<StringRef, StringRef>;
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/// Perform the given substitutions on 'str' in-place.
|
||||||
|
static void performSubstitutions(std::string &str,
|
||||||
|
ArrayRef<Subst> substitutions) {
|
||||||
|
// Apply all parent substitutions from innermost to outermost.
|
||||||
|
for (const auto &subst : llvm::reverse(substitutions)) {
|
||||||
|
auto pos = str.find(std::string(subst.first));
|
||||||
|
while (pos != std::string::npos) {
|
||||||
|
str.replace(pos, subst.first.size(), std::string(subst.second));
|
||||||
|
// Skip the newly inserted substring, which itself may consider the
|
||||||
|
// pattern to match.
|
||||||
|
pos += subst.second.size();
|
||||||
|
// Find the next possible match position.
|
||||||
|
pos = str.find(std::string(subst.first), pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the predicate tree starting from the top-level predicate, which may
|
||||||
|
// have children, and perform leaf substitutions inplace. Note that after
|
||||||
|
// substitution, nodes are still pointing to the original TableGen record.
|
||||||
|
// All nodes are created within "allocator".
|
||||||
|
static PredNode *
|
||||||
|
buildPredicateTree(const Pred &root,
|
||||||
|
llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
|
||||||
|
ArrayRef<Subst> substitutions) {
|
||||||
|
auto *rootNode = allocator.Allocate();
|
||||||
|
new (rootNode) PredNode;
|
||||||
|
rootNode->kind = getPredCombinerKind(root);
|
||||||
|
rootNode->predicate = &root;
|
||||||
|
if (!root.isCombined()) {
|
||||||
|
rootNode->expr = root.getCondition();
|
||||||
|
performSubstitutions(rootNode->expr, substitutions);
|
||||||
|
return rootNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the current combined predicate is a leaf substitution, append it to the
|
||||||
|
// list before continuing.
|
||||||
|
auto allSubstitutions = llvm::to_vector<4>(substitutions);
|
||||||
|
if (rootNode->kind == PredCombinerKind::SubstLeaves) {
|
||||||
|
const auto &substPred = static_cast<const SubstLeavesPred &>(root);
|
||||||
|
allSubstitutions.push_back(
|
||||||
|
{substPred.getPattern(), substPred.getReplacement()});
|
||||||
|
|
||||||
|
// If the current predicate is a ConcatPred, record the prefix and suffix.
|
||||||
|
} else if (rootNode->kind == PredCombinerKind::Concat) {
|
||||||
|
const auto &concatPred = static_cast<const ConcatPred &>(root);
|
||||||
|
rootNode->prefix = std::string(concatPred.getPrefix());
|
||||||
|
performSubstitutions(rootNode->prefix, substitutions);
|
||||||
|
rootNode->suffix = std::string(concatPred.getSuffix());
|
||||||
|
performSubstitutions(rootNode->suffix, substitutions);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build child subtrees.
|
||||||
|
auto combined = static_cast<const CombinedPred &>(root);
|
||||||
|
for (const auto *record : combined.getChildren()) {
|
||||||
|
auto childTree =
|
||||||
|
buildPredicateTree(Pred(record), allocator, allSubstitutions);
|
||||||
|
rootNode->children.push_back(childTree);
|
||||||
|
}
|
||||||
|
return rootNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simplify a predicate tree rooted at "node" using the predicates that are
|
||||||
|
// known to be true(false). For AND(OR) combined predicates, if any of the
|
||||||
|
// children is known to be false(true), the result is also false(true).
|
||||||
|
// Furthermore, for AND(OR) combined predicates, children that are known to be
|
||||||
|
// true(false) don't have to be checked dynamically.
|
||||||
|
static PredNode *
|
||||||
|
propagateGroundTruth(PredNode *node,
|
||||||
|
const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
|
||||||
|
const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
|
||||||
|
// If the current predicate is known to be true or false, change the kind of
|
||||||
|
// the node and return immediately.
|
||||||
|
if (knownTruePreds.count(node->predicate) != 0) {
|
||||||
|
node->kind = PredCombinerKind::True;
|
||||||
|
node->children.clear();
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
if (knownFalsePreds.count(node->predicate) != 0) {
|
||||||
|
node->kind = PredCombinerKind::False;
|
||||||
|
node->children.clear();
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the current node is a substitution, stop recursion now.
|
||||||
|
// The expressions in the leaves below this node were rewritten, but the nodes
|
||||||
|
// still point to the original predicate records. While the original
|
||||||
|
// predicate may be known to be true or false, it is not necessarily the case
|
||||||
|
// after rewriting.
|
||||||
|
// TODO: we can support ground truth for rewritten
|
||||||
|
// predicates by either (a) having our own unique'ing of the predicates
|
||||||
|
// instead of relying on TableGen record pointers or (b) taking ground truth
|
||||||
|
// values optionally prefixed with a list of substitutions to apply, e.g.
|
||||||
|
// "predX is true by itself as well as predSubY leaf substitution had been
|
||||||
|
// applied to it".
|
||||||
|
if (node->kind == PredCombinerKind::SubstLeaves) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, look at child nodes.
|
||||||
|
|
||||||
|
// Move child nodes into some local variable so that they can be optimized
|
||||||
|
// separately and re-added if necessary.
|
||||||
|
llvm::SmallVector<PredNode *, 4> children;
|
||||||
|
std::swap(node->children, children);
|
||||||
|
|
||||||
|
for (auto &child : children) {
|
||||||
|
// First, simplify the child. This maintains the predicate as it was.
|
||||||
|
auto simplifiedChild =
|
||||||
|
propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
|
||||||
|
|
||||||
|
// Just add the child if we don't know how to simplify the current node.
|
||||||
|
if (node->kind != PredCombinerKind::And &&
|
||||||
|
node->kind != PredCombinerKind::Or) {
|
||||||
|
node->children.push_back(simplifiedChild);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second, based on the type define which known values of child predicates
|
||||||
|
// immediately collapse this predicate to a known value, and which others
|
||||||
|
// may be safely ignored.
|
||||||
|
// OR(..., True, ...) = True
|
||||||
|
// OR(..., False, ...) = OR(..., ...)
|
||||||
|
// AND(..., False, ...) = False
|
||||||
|
// AND(..., True, ...) = AND(..., ...)
|
||||||
|
auto collapseKind = node->kind == PredCombinerKind::And
|
||||||
|
? PredCombinerKind::False
|
||||||
|
: PredCombinerKind::True;
|
||||||
|
auto eraseKind = node->kind == PredCombinerKind::And
|
||||||
|
? PredCombinerKind::True
|
||||||
|
: PredCombinerKind::False;
|
||||||
|
const auto &collapseList =
|
||||||
|
node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
|
||||||
|
const auto &eraseList =
|
||||||
|
node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
|
||||||
|
if (simplifiedChild->kind == collapseKind ||
|
||||||
|
collapseList.count(simplifiedChild->predicate) != 0) {
|
||||||
|
node->kind = collapseKind;
|
||||||
|
node->children.clear();
|
||||||
|
return node;
|
||||||
|
} else if (simplifiedChild->kind == eraseKind ||
|
||||||
|
eraseList.count(simplifiedChild->predicate) != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
node->children.push_back(simplifiedChild);
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine a list of predicate expressions using a binary combiner. If a list
|
||||||
|
// is empty, return "init".
|
||||||
|
static std::string combineBinary(ArrayRef<std::string> children,
|
||||||
|
std::string combiner, std::string init) {
|
||||||
|
if (children.empty())
|
||||||
|
return init;
|
||||||
|
|
||||||
|
auto size = children.size();
|
||||||
|
if (size == 1)
|
||||||
|
return children.front();
|
||||||
|
|
||||||
|
std::string str;
|
||||||
|
llvm::raw_string_ostream os(str);
|
||||||
|
os << '(' << children.front() << ')';
|
||||||
|
for (unsigned i = 1; i < size; ++i) {
|
||||||
|
os << ' ' << combiner << " (" << children[i] << ')';
|
||||||
|
}
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepend negation to the only condition in the predicate expression list.
|
||||||
|
static std::string combineNot(ArrayRef<std::string> children) {
|
||||||
|
assert(children.size() == 1 && "expected exactly one child predicate of Neg");
|
||||||
|
return (Twine("!(") + children.front() + Twine(')')).str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively traverse the predicate tree in depth-first post-order and build
|
||||||
|
// the final expression.
|
||||||
|
static std::string getCombinedCondition(const PredNode &root) {
|
||||||
|
// Immediately return for non-combiner predicates that don't have children.
|
||||||
|
if (root.kind == PredCombinerKind::Leaf)
|
||||||
|
return root.expr;
|
||||||
|
if (root.kind == PredCombinerKind::True)
|
||||||
|
return "true";
|
||||||
|
if (root.kind == PredCombinerKind::False)
|
||||||
|
return "false";
|
||||||
|
|
||||||
|
// Recurse into children.
|
||||||
|
llvm::SmallVector<std::string, 4> childExpressions;
|
||||||
|
childExpressions.reserve(root.children.size());
|
||||||
|
for (const auto &child : root.children)
|
||||||
|
childExpressions.push_back(getCombinedCondition(*child));
|
||||||
|
|
||||||
|
// Combine the expressions based on the predicate node kind.
|
||||||
|
if (root.kind == PredCombinerKind::And)
|
||||||
|
return combineBinary(childExpressions, "&&", "true");
|
||||||
|
if (root.kind == PredCombinerKind::Or)
|
||||||
|
return combineBinary(childExpressions, "||", "false");
|
||||||
|
if (root.kind == PredCombinerKind::Not)
|
||||||
|
return combineNot(childExpressions);
|
||||||
|
if (root.kind == PredCombinerKind::Concat) {
|
||||||
|
assert(childExpressions.size() == 1 &&
|
||||||
|
"ConcatPred should only have one child");
|
||||||
|
return root.prefix + childExpressions.front() + root.suffix;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Substitutions were applied before so just ignore them.
|
||||||
|
if (root.kind == PredCombinerKind::SubstLeaves) {
|
||||||
|
assert(childExpressions.size() == 1 &&
|
||||||
|
"substitution predicate must have one child");
|
||||||
|
return childExpressions[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CombinedPred::getConditionImpl() const {
|
||||||
|
llvm::SpecificBumpPtrAllocator<PredNode> allocator;
|
||||||
|
auto predicateTree = buildPredicateTree(*this, allocator, {});
|
||||||
|
predicateTree =
|
||||||
|
propagateGroundTruth(predicateTree,
|
||||||
|
/*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
|
||||||
|
/*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
|
||||||
|
|
||||||
|
return getCombinedCondition(*predicateTree);
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef SubstLeavesPred::getPattern() const {
|
||||||
|
return def->getValueAsString("pattern");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef SubstLeavesPred::getReplacement() const {
|
||||||
|
return def->getValueAsString("replacement");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef ConcatPred::getPrefix() const {
|
||||||
|
return def->getValueAsString("prefix");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef ConcatPred::getSuffix() const {
|
||||||
|
return def->getValueAsString("suffix");
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
//===- Predicate.h - Predicate class ----------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Wrapper around predicates defined in TableGen.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_PREDICATE_H_
|
||||||
|
#define MLIR_TABLEGEN_PREDICATE_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class Init;
|
||||||
|
class ListInit;
|
||||||
|
class Record;
|
||||||
|
class SMLoc;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// A logical predicate. This class must closely follow the definition of
|
||||||
|
// TableGen class 'Pred'.
|
||||||
|
class Pred {
|
||||||
|
public:
|
||||||
|
// Constructs the null Predicate (e.g., always true).
|
||||||
|
explicit Pred() : def(nullptr) {}
|
||||||
|
// Construct a Predicate from a record.
|
||||||
|
explicit Pred(const llvm::Record *record);
|
||||||
|
// Construct a Predicate from an initializer.
|
||||||
|
explicit Pred(const llvm::Init *init);
|
||||||
|
|
||||||
|
// Check if the predicate is defined. Callers may use this to interpret the
|
||||||
|
// missing predicate as either true (e.g. in filters) or false (e.g. in
|
||||||
|
// precondition verification).
|
||||||
|
bool isNull() const { return def == nullptr; }
|
||||||
|
|
||||||
|
// Get the predicate condition. This may dispatch to getConditionImpl() of
|
||||||
|
// the underlying predicate type.
|
||||||
|
std::string getCondition() const;
|
||||||
|
|
||||||
|
// Whether the predicate is a combination of other predicates, i.e. an
|
||||||
|
// record of type CombinedPred.
|
||||||
|
bool isCombined() const;
|
||||||
|
|
||||||
|
// Records are pointer-comparable.
|
||||||
|
bool operator==(const Pred &other) const { return def == other.def; }
|
||||||
|
|
||||||
|
// Get the location of the predicate.
|
||||||
|
ArrayRef<llvm::SMLoc> getLoc() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// The TableGen definition of this predicate.
|
||||||
|
const llvm::Record *def;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A logical predicate wrapping a C expression. This class must closely follow
|
||||||
|
// the definition of TableGen class 'CPred'.
|
||||||
|
class CPred : public Pred {
|
||||||
|
public:
|
||||||
|
// Construct a CPred from a record.
|
||||||
|
explicit CPred(const llvm::Record *record);
|
||||||
|
// Construct a CPred an initializer.
|
||||||
|
explicit CPred(const llvm::Init *init);
|
||||||
|
|
||||||
|
// Get the predicate condition.
|
||||||
|
std::string getConditionImpl() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A logical predicate that is a combination of other predicates. This class
|
||||||
|
// must closely follow the definition of TableGen class 'CombinedPred'.
|
||||||
|
class CombinedPred : public Pred {
|
||||||
|
public:
|
||||||
|
// Construct a CombinedPred from a record.
|
||||||
|
explicit CombinedPred(const llvm::Record *record);
|
||||||
|
// Construct a CombinedPred from an initializer.
|
||||||
|
explicit CombinedPred(const llvm::Init *init);
|
||||||
|
|
||||||
|
// Get the predicate condition.
|
||||||
|
std::string getConditionImpl() const;
|
||||||
|
|
||||||
|
// Get the definition of the combiner used in this predicate.
|
||||||
|
const llvm::Record *getCombinerDef() const;
|
||||||
|
|
||||||
|
// Get the predicates that are combined by this predicate.
|
||||||
|
const std::vector<llvm::Record *> getChildren() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A combined predicate that requires all child predicates of 'CPred' type to
|
||||||
|
// have their expression rewritten with a simple string substitution rule.
|
||||||
|
class SubstLeavesPred : public CombinedPred {
|
||||||
|
public:
|
||||||
|
// Get the replacement pattern.
|
||||||
|
StringRef getPattern() const;
|
||||||
|
// Get the string used to replace the pattern.
|
||||||
|
StringRef getReplacement() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A combined predicate that prepends a prefix and appends a suffix to the
|
||||||
|
// predicate string composed from a child predicate.
|
||||||
|
class ConcatPred : public CombinedPred {
|
||||||
|
public:
|
||||||
|
StringRef getPrefix() const;
|
||||||
|
StringRef getSuffix() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_PREDICATE_H_
|
|
@ -0,0 +1,20 @@
|
||||||
|
//===- Region.cpp - Region class ------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Region wrapper to simplify using TableGen Record defining a MLIR Region.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Region.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
// Returns true if this region is variadic.
|
||||||
|
bool Region::isVariadic() const { return def->isSubClassOf("VariadicRegion"); }
|
|
@ -0,0 +1,42 @@
|
||||||
|
//===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_REGION_H_
|
||||||
|
#define MLIR_TABLEGEN_REGION_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "Constraint.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing Region defined in
|
||||||
|
// TableGen.
|
||||||
|
class Region : public Constraint {
|
||||||
|
public:
|
||||||
|
using Constraint::Constraint;
|
||||||
|
|
||||||
|
static bool classof(const Constraint *c) { return c->getKind() == CK_Region; }
|
||||||
|
|
||||||
|
// Returns true if this region is variadic.
|
||||||
|
bool isVariadic() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A struct bundling a region's constraint and its name.
|
||||||
|
struct NamedRegion {
|
||||||
|
// Returns true if this region is variadic.
|
||||||
|
bool isVariadic() const { return constraint.isVariadic(); }
|
||||||
|
|
||||||
|
StringRef name;
|
||||||
|
Region constraint;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_REGION_H_
|
|
@ -0,0 +1,58 @@
|
||||||
|
//===- SideEffects.cpp - SideEffect classes -------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "SideEffects.h"
|
||||||
|
#include "llvm/ADT/Twine.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SideEffect
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
StringRef SideEffect::getName() const {
|
||||||
|
return def->getValueAsString("effect");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef SideEffect::getBaseEffectName() const {
|
||||||
|
return def->getValueAsString("baseEffectName");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SideEffect::getInterfaceTrait() const {
|
||||||
|
StringRef trait = def->getValueAsString("interfaceTrait");
|
||||||
|
StringRef cppNamespace = def->getValueAsString("cppNamespace");
|
||||||
|
return cppNamespace.empty() ? trait.str()
|
||||||
|
: (cppNamespace + "::" + trait).str();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef SideEffect::getResource() const {
|
||||||
|
return def->getValueAsString("resource");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SideEffect::classof(const Operator::VariableDecorator *var) {
|
||||||
|
return var->getDef().isSubClassOf("SideEffect");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SideEffectsTrait
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Operator::var_decorator_range SideEffectTrait::getEffects() const {
|
||||||
|
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("effects"));
|
||||||
|
return {listInit->begin(), listInit->end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef SideEffectTrait::getBaseEffectName() const {
|
||||||
|
return def->getValueAsString("baseEffectName");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SideEffectTrait::classof(const Trait *t) {
|
||||||
|
return t->getDef().isSubClassOf("SideEffectsTraitBase");
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
//===- SideEffects.h - Side Effects classes ---------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Wrapper around side effect related classes defined in TableGen.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_SIDEEFFECTS_H_
|
||||||
|
#define MLIR_TABLEGEN_SIDEEFFECTS_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "Operator.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// This class represents a specific instance of an effect that is being
|
||||||
|
// exhibited.
|
||||||
|
class SideEffect : public Operator::VariableDecorator {
|
||||||
|
public:
|
||||||
|
// Return the name of the C++ effect.
|
||||||
|
StringRef getName() const;
|
||||||
|
|
||||||
|
// Return the name of the base C++ effect.
|
||||||
|
StringRef getBaseEffectName() const;
|
||||||
|
|
||||||
|
// Return the name of the Interface that the effect belongs to.
|
||||||
|
std::string getInterfaceTrait() const;
|
||||||
|
|
||||||
|
// Return the name of the resource class.
|
||||||
|
StringRef getResource() const;
|
||||||
|
|
||||||
|
static bool classof(const Operator::VariableDecorator *var);
|
||||||
|
};
|
||||||
|
|
||||||
|
// This class represents an instance of a side effect interface applied to an
|
||||||
|
// operation. This is a wrapper around an OpInterfaceTrait that also includes
|
||||||
|
// the effects that are applied.
|
||||||
|
class SideEffectTrait : public InterfaceTrait {
|
||||||
|
public:
|
||||||
|
// Return the effects that are attached to the side effect interface.
|
||||||
|
Operator::var_decorator_range getEffects() const;
|
||||||
|
|
||||||
|
// Return the name of the base C++ effect.
|
||||||
|
StringRef getBaseEffectName() const;
|
||||||
|
|
||||||
|
static bool classof(const Trait *t);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_SIDEEFFECTS_H_
|
|
@ -0,0 +1,24 @@
|
||||||
|
//===- Successor.cpp - Successor class ------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Successor wrapper to simplify using TableGen Record defining a MLIR
|
||||||
|
// Successor.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Successor.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
// Returns true if this successor is variadic.
|
||||||
|
bool Successor::isVariadic() const {
|
||||||
|
return def->isSubClassOf("VariadicSuccessor");
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
//===- Successor.h - TableGen successor definitions -------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_SUCCESSOR_H_
|
||||||
|
#define MLIR_TABLEGEN_SUCCESSOR_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "Constraint.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// Wrapper class providing helper methods for accessing Successor defined in
|
||||||
|
// TableGen.
|
||||||
|
class Successor : public Constraint {
|
||||||
|
public:
|
||||||
|
using Constraint::Constraint;
|
||||||
|
|
||||||
|
static bool classof(const Constraint *c) {
|
||||||
|
return c->getKind() == CK_Successor;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if this successor is variadic.
|
||||||
|
bool isVariadic() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A struct bundling a successor's constraint and its name.
|
||||||
|
struct NamedSuccessor {
|
||||||
|
// Returns true if this successor is variadic.
|
||||||
|
bool isVariadic() const { return constraint.isVariadic(); }
|
||||||
|
|
||||||
|
StringRef name;
|
||||||
|
Successor constraint;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_SUCCESSOR_H_
|
|
@ -0,0 +1,93 @@
|
||||||
|
//===- Trait.cpp ----------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Trait wrapper to simplify using TableGen Record defining a MLIR Trait.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Trait.h"
|
||||||
|
#include "Interfaces.h"
|
||||||
|
#include "Predicate.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/TableGen/Error.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Trait
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Trait Trait::create(const llvm::Init *init) {
|
||||||
|
auto def = cast<llvm::DefInit>(init)->getDef();
|
||||||
|
if (def->isSubClassOf("PredTrait"))
|
||||||
|
return Trait(Kind::Pred, def);
|
||||||
|
if (def->isSubClassOf("GenInternalTrait"))
|
||||||
|
return Trait(Kind::Internal, def);
|
||||||
|
if (def->isSubClassOf("InterfaceTrait"))
|
||||||
|
return Trait(Kind::Interface, def);
|
||||||
|
assert(def->isSubClassOf("NativeTrait"));
|
||||||
|
return Trait(Kind::Native, def);
|
||||||
|
}
|
||||||
|
|
||||||
|
Trait::Trait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// NativeTrait
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
std::string NativeTrait::getFullyQualifiedTraitName() const {
|
||||||
|
llvm::StringRef trait = def->getValueAsString("trait");
|
||||||
|
llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
|
||||||
|
return cppNamespace.empty() ? trait.str()
|
||||||
|
: (cppNamespace + "::" + trait).str();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// InternalTrait
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
llvm::StringRef InternalTrait::getFullyQualifiedTraitName() const {
|
||||||
|
return def->getValueAsString("trait");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PredTrait
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
std::string PredTrait::getPredTemplate() const {
|
||||||
|
auto pred = Pred(def->getValueInit("predicate"));
|
||||||
|
return pred.getCondition();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef PredTrait::getSummary() const {
|
||||||
|
return def->getValueAsString("summary");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// InterfaceTrait
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Interface InterfaceTrait::getInterface() const { return Interface(def); }
|
||||||
|
|
||||||
|
std::string InterfaceTrait::getFullyQualifiedTraitName() const {
|
||||||
|
llvm::StringRef trait = def->getValueAsString("trait");
|
||||||
|
llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace");
|
||||||
|
return cppNamespace.empty() ? trait.str()
|
||||||
|
: (cppNamespace + "::" + trait).str();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool InterfaceTrait::shouldDeclareMethods() const {
|
||||||
|
return def->isSubClassOf("DeclareInterfaceMethods");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<StringRef> InterfaceTrait::getAlwaysDeclaredMethods() const {
|
||||||
|
return def->getValueAsListOfStrings("alwaysOverriddenMethods");
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
//===- Trait.h - Trait wrapper class ----------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Trait wrapper to simplify using TableGen Record defining an MLIR Trait.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_TRAIT_H_
|
||||||
|
#define MLIR_TABLEGEN_TRAIT_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class Init;
|
||||||
|
class Record;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
class Interface;
|
||||||
|
|
||||||
|
// Wrapper class with helper methods for accessing Trait constraints defined in
|
||||||
|
// TableGen.
|
||||||
|
class Trait {
|
||||||
|
public:
|
||||||
|
// Discriminator for kinds of traits.
|
||||||
|
enum class Kind {
|
||||||
|
// Trait corresponding to C++ class.
|
||||||
|
Native,
|
||||||
|
// Trait corresponding to a predicate.
|
||||||
|
Pred,
|
||||||
|
// Trait controlling definition generator internals.
|
||||||
|
Internal,
|
||||||
|
// Trait corresponding to an Interface.
|
||||||
|
Interface
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit Trait(Kind kind, const llvm::Record *def);
|
||||||
|
|
||||||
|
// Returns an Trait corresponding to the init provided.
|
||||||
|
static Trait create(const llvm::Init *init);
|
||||||
|
|
||||||
|
Kind getKind() const { return kind; }
|
||||||
|
|
||||||
|
// Returns the Tablegen definition this operator was constructed from.
|
||||||
|
const llvm::Record &getDef() const { return *def; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// The TableGen definition of this trait.
|
||||||
|
const llvm::Record *def;
|
||||||
|
Kind kind;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Trait corresponding to a native C++ Trait.
|
||||||
|
class NativeTrait : public Trait {
|
||||||
|
public:
|
||||||
|
// Returns the trait corresponding to a C++ trait class.
|
||||||
|
std::string getFullyQualifiedTraitName() const;
|
||||||
|
|
||||||
|
static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Trait corresponding to a predicate on the operation.
|
||||||
|
class PredTrait : public Trait {
|
||||||
|
public:
|
||||||
|
// Returns the template for constructing the predicate.
|
||||||
|
std::string getPredTemplate() const;
|
||||||
|
|
||||||
|
// Returns the description of what the predicate is verifying.
|
||||||
|
StringRef getSummary() const;
|
||||||
|
|
||||||
|
static bool classof(const Trait *t) { return t->getKind() == Kind::Pred; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Trait controlling op definition generator internals.
|
||||||
|
class InternalTrait : public Trait {
|
||||||
|
public:
|
||||||
|
// Returns the trait controlling op definition generator internals.
|
||||||
|
StringRef getFullyQualifiedTraitName() const;
|
||||||
|
|
||||||
|
static bool classof(const Trait *t) { return t->getKind() == Kind::Internal; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Trait corresponding to an OpInterface on the operation.
|
||||||
|
class InterfaceTrait : public Trait {
|
||||||
|
public:
|
||||||
|
// Returns interface corresponding to the trait.
|
||||||
|
Interface getInterface() const;
|
||||||
|
|
||||||
|
// Returns the trait corresponding to a C++ trait class.
|
||||||
|
std::string getFullyQualifiedTraitName() const;
|
||||||
|
|
||||||
|
static bool classof(const Trait *t) {
|
||||||
|
return t->getKind() == Kind::Interface;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Whether the declaration of methods for this trait should be emitted.
|
||||||
|
bool shouldDeclareMethods() const;
|
||||||
|
|
||||||
|
// Returns the methods that should always be declared if this interface is
|
||||||
|
// emitting declarations.
|
||||||
|
std::vector<StringRef> getAlwaysDeclaredMethods() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_TRAIT_H_
|
|
@ -0,0 +1,82 @@
|
||||||
|
//===- Type.cpp - Type class ----------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Type wrapper to simplify using TableGen Record defining a MLIR Type.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "Type.h"
|
||||||
|
#include "Dialect.h"
|
||||||
|
#include "llvm/ADT/Twine.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::tblgen;
|
||||||
|
|
||||||
|
TypeConstraint::TypeConstraint(const llvm::Record *record)
|
||||||
|
: Constraint(Constraint::CK_Type, record) {
|
||||||
|
assert(def->isSubClassOf("TypeConstraint") &&
|
||||||
|
"must be subclass of TableGen 'TypeConstraint' class");
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeConstraint::TypeConstraint(const llvm::DefInit *init)
|
||||||
|
: TypeConstraint(init->getDef()) {}
|
||||||
|
|
||||||
|
bool TypeConstraint::isOptional() const {
|
||||||
|
return def->isSubClassOf("Optional");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TypeConstraint::isVariadic() const {
|
||||||
|
return def->isSubClassOf("Variadic");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the builder call for this constraint if this is a buildable type,
|
||||||
|
// returns None otherwise.
|
||||||
|
Optional<StringRef> TypeConstraint::getBuilderCall() const {
|
||||||
|
const llvm::Record *baseType = def;
|
||||||
|
if (isVariableLength())
|
||||||
|
baseType = baseType->getValueAsDef("baseType");
|
||||||
|
|
||||||
|
// Check to see if this type constraint has a builder call.
|
||||||
|
const llvm::RecordVal *builderCall = baseType->getValue("builderCall");
|
||||||
|
if (!builderCall || !builderCall->getValue())
|
||||||
|
return llvm::None;
|
||||||
|
return TypeSwitch<llvm::Init *, Optional<StringRef>>(builderCall->getValue())
|
||||||
|
.Case<llvm::StringInit>([&](auto *init) {
|
||||||
|
StringRef value = init->getValue();
|
||||||
|
return value.empty() ? Optional<StringRef>() : value;
|
||||||
|
})
|
||||||
|
.Default([](auto *) { return llvm::None; });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the C++ class name for this type (which may just be ::mlir::Type).
|
||||||
|
std::string TypeConstraint::getCPPClassName() const {
|
||||||
|
StringRef className = def->getValueAsString("cppClassName");
|
||||||
|
|
||||||
|
// If the class name is already namespace resolved, use it.
|
||||||
|
if (className.contains("::"))
|
||||||
|
return className.str();
|
||||||
|
|
||||||
|
// Otherwise, check to see if there is a namespace from a dialect to prepend.
|
||||||
|
if (const llvm::RecordVal *value = def->getValue("dialect")) {
|
||||||
|
Dialect dialect(cast<const llvm::DefInit>(value->getValue())->getDef());
|
||||||
|
return (dialect.getCppNamespace() + "::" + className).str();
|
||||||
|
}
|
||||||
|
return className.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
|
||||||
|
|
||||||
|
StringRef Type::getDescription() const {
|
||||||
|
return def->getValueAsString("description");
|
||||||
|
}
|
||||||
|
|
||||||
|
Dialect Type::getDialect() const {
|
||||||
|
return Dialect(def->getValueAsDef("dialect"));
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
//===- Type.h - Type class --------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Type wrapper to simplify using TableGen Record defining a MLIR Type.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TABLEGEN_TYPE_H_
|
||||||
|
#define MLIR_TABLEGEN_TYPE_H_
|
||||||
|
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "Constraint.h"
|
||||||
|
#include "Dialect.h"
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
class DefInit;
|
||||||
|
class Record;
|
||||||
|
} // end namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace tblgen {
|
||||||
|
|
||||||
|
// Wrapper class with helper methods for accessing Type constraints defined in
|
||||||
|
// TableGen.
|
||||||
|
class TypeConstraint : public Constraint {
|
||||||
|
public:
|
||||||
|
explicit TypeConstraint(const llvm::Record *record);
|
||||||
|
explicit TypeConstraint(const llvm::DefInit *init);
|
||||||
|
|
||||||
|
static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }
|
||||||
|
|
||||||
|
// Returns true if this is an optional type constraint.
|
||||||
|
bool isOptional() const;
|
||||||
|
|
||||||
|
// Returns true if this is a variadic type constraint.
|
||||||
|
bool isVariadic() const;
|
||||||
|
|
||||||
|
// Returns true if this is a variable length type constraint. This is either
|
||||||
|
// variadic or optional.
|
||||||
|
bool isVariableLength() const { return isOptional() || isVariadic(); }
|
||||||
|
|
||||||
|
// Returns the builder call for this constraint if this is a buildable type,
|
||||||
|
// returns None otherwise.
|
||||||
|
Optional<StringRef> getBuilderCall() const;
|
||||||
|
|
||||||
|
// Return the C++ class name for this type (which may just be ::mlir::Type).
|
||||||
|
std::string getCPPClassName() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper class with helper methods for accessing Types defined in TableGen.
|
||||||
|
class Type : public TypeConstraint {
|
||||||
|
public:
|
||||||
|
explicit Type(const llvm::Record *record);
|
||||||
|
|
||||||
|
// Returns the description of the type.
|
||||||
|
StringRef getDescription() const;
|
||||||
|
|
||||||
|
// Returns the dialect for the type if defined.
|
||||||
|
Dialect getDialect() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tblgen
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TABLEGEN_TYPE_H_
|
|
@ -0,0 +1,83 @@
|
||||||
|
//===- mlir-tblgen.cpp - Top-Level TableGen implementation for MLIR -------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file contains the main function for MLIR's TableGen.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "TableGen/GenInfo.h"
|
||||||
|
#include "TableGen/GenNameParser.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/InitLLVM.h"
|
||||||
|
#include "llvm/Support/ManagedStatic.h"
|
||||||
|
#include "llvm/Support/Signals.h"
|
||||||
|
#include "llvm/TableGen/Error.h"
|
||||||
|
#include "llvm/TableGen/Main.h"
|
||||||
|
#include "llvm/TableGen/Record.h"
|
||||||
|
#include "llvm/TableGen/TableGenBackend.h"
|
||||||
|
|
||||||
|
using namespace llvm;
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
static llvm::ManagedStatic<std::vector<GenInfo>> generatorRegistry;
|
||||||
|
|
||||||
|
mlir::GenRegistration::GenRegistration(StringRef arg, StringRef description,
|
||||||
|
GenFunction function) {
|
||||||
|
generatorRegistry->emplace_back(arg, description, function);
|
||||||
|
}
|
||||||
|
|
||||||
|
GenNameParser::GenNameParser(llvm::cl::Option &opt)
|
||||||
|
: llvm::cl::parser<const GenInfo *>(opt) {
|
||||||
|
for (const auto &kv : *generatorRegistry) {
|
||||||
|
addLiteralOption(kv.getGenArgument(), &kv, kv.getGenDescription());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GenNameParser::printOptionInfo(const llvm::cl::Option &O,
|
||||||
|
size_t GlobalWidth) const {
|
||||||
|
GenNameParser *TP = const_cast<GenNameParser *>(this);
|
||||||
|
llvm::array_pod_sort(TP->Values.begin(), TP->Values.end(),
|
||||||
|
[](const GenNameParser::OptionInfo *VT1,
|
||||||
|
const GenNameParser::OptionInfo *VT2) {
|
||||||
|
return VT1->Name.compare(VT2->Name);
|
||||||
|
});
|
||||||
|
using llvm::cl::parser;
|
||||||
|
parser<const GenInfo *>::printOptionInfo(O, GlobalWidth);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generator that prints records.
|
||||||
|
GenRegistration printRecords("print-records", "Print all records to stdout",
|
||||||
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
|
os << records;
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Generator to invoke.
|
||||||
|
const mlir::GenInfo *generator;
|
||||||
|
|
||||||
|
// TableGenMain requires a function pointer so this function is passed in which
|
||||||
|
// simply wraps the call to the generator.
|
||||||
|
static bool MlirTableGenMain(raw_ostream &os, RecordKeeper &records) {
|
||||||
|
if (!generator) {
|
||||||
|
os << records;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return generator->invoke(records, os);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
llvm::cl::opt<const mlir::GenInfo *, false, mlir::GenNameParser> generator(
|
||||||
|
"", llvm::cl::desc("Generator to run"));
|
||||||
|
cl::ParseCommandLineOptions(argc, argv);
|
||||||
|
::generator = generator.getValue();
|
||||||
|
|
||||||
|
return TableGenMain(argv[0], &MlirTableGenMain);
|
||||||
|
}
|
Loading…
Reference in New Issue