452 lines
17 KiB
C
452 lines
17 KiB
C
|
//===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
|
||
|
// Pattern.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#ifndef MLIR_TABLEGEN_PATTERN_H_
|
||
|
#define MLIR_TABLEGEN_PATTERN_H_
|
||
|
|
||
|
#include "mlir/Support/LLVM.h"
|
||
|
#include "Argument.h"
|
||
|
#include "Operator.h"
|
||
|
#include "llvm/ADT/DenseMap.h"
|
||
|
#include "llvm/ADT/StringMap.h"
|
||
|
#include "llvm/ADT/StringSet.h"
|
||
|
|
||
|
#include <unordered_map>
|
||
|
|
||
|
namespace llvm {
|
||
|
class DagInit;
|
||
|
class Init;
|
||
|
class Record;
|
||
|
} // end namespace llvm
|
||
|
|
||
|
namespace mlir {
|
||
|
namespace tblgen {
|
||
|
|
||
|
// Mapping from TableGen Record to Operator wrapper object.
|
||
|
//
|
||
|
// We allocate each wrapper object in heap to make sure the pointer to it is
|
||
|
// valid throughout the lifetime of this map. This is important because this map
|
||
|
// is shared among multiple patterns to avoid creating the wrapper object for
|
||
|
// the same op again and again. But this map will continuously grow.
|
||
|
using RecordOperatorMap =
|
||
|
DenseMap<const llvm::Record *, std::unique_ptr<Operator>>;
|
||
|
|
||
|
class Pattern;
|
||
|
|
||
|
// Wrapper class providing helper methods for accessing TableGen DAG leaves
|
||
|
// used inside Patterns. This class is lightweight and designed to be used like
|
||
|
// values.
|
||
|
//
|
||
|
// A TableGen DAG construct is of the syntax
|
||
|
// `(operator, arg0, arg1, ...)`.
|
||
|
//
|
||
|
// This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
|
||
|
// for handy helper methods. It only works on `arg*`s that are not nested DAG
|
||
|
// constructs.
|
||
|
class DagLeaf {
|
||
|
public:
|
||
|
explicit DagLeaf(const llvm::Init *def) : def(def) {}
|
||
|
|
||
|
// Returns true if this DAG leaf is not specified in the pattern. That is, it
|
||
|
// places no further constraints/transforms and just carries over the original
|
||
|
// value.
|
||
|
bool isUnspecified() const;
|
||
|
|
||
|
// Returns true if this DAG leaf is matching an operand. That is, it specifies
|
||
|
// a type constraint.
|
||
|
bool isOperandMatcher() const;
|
||
|
|
||
|
// Returns true if this DAG leaf is matching an attribute. That is, it
|
||
|
// specifies an attribute constraint.
|
||
|
bool isAttrMatcher() const;
|
||
|
|
||
|
// Returns true if this DAG leaf is wrapping native code call.
|
||
|
bool isNativeCodeCall() const;
|
||
|
|
||
|
// Returns true if this DAG leaf is specifying a constant attribute.
|
||
|
bool isConstantAttr() const;
|
||
|
|
||
|
// Returns true if this DAG leaf is specifying an enum attribute case.
|
||
|
bool isEnumAttrCase() const;
|
||
|
|
||
|
// Returns true if this DAG leaf is specifying a string attribute.
|
||
|
bool isStringAttr() const;
|
||
|
|
||
|
// Returns this DAG leaf as a constraint. Asserts if fails.
|
||
|
Constraint getAsConstraint() const;
|
||
|
|
||
|
// Returns this DAG leaf as an constant attribute. Asserts if fails.
|
||
|
ConstantAttr getAsConstantAttr() const;
|
||
|
|
||
|
// Returns this DAG leaf as an enum attribute case.
|
||
|
// Precondition: isEnumAttrCase()
|
||
|
EnumAttrCase getAsEnumAttrCase() const;
|
||
|
|
||
|
// Returns the matching condition template inside this DAG leaf. Assumes the
|
||
|
// leaf is an operand/attribute matcher and asserts otherwise.
|
||
|
std::string getConditionTemplate() const;
|
||
|
|
||
|
// Returns the native code call template inside this DAG leaf.
|
||
|
// Precondition: isNativeCodeCall()
|
||
|
StringRef getNativeCodeTemplate() const;
|
||
|
|
||
|
// Returns the string associated with the leaf.
|
||
|
// Precondition: isStringAttr()
|
||
|
std::string getStringAttr() const;
|
||
|
|
||
|
void print(raw_ostream &os) const;
|
||
|
|
||
|
private:
|
||
|
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
|
||
|
// also a subclass of the given `superclass`.
|
||
|
bool isSubClassOf(StringRef superclass) const;
|
||
|
|
||
|
const llvm::Init *def;
|
||
|
};
|
||
|
|
||
|
// Wrapper class providing helper methods for accessing TableGen DAG constructs
|
||
|
// used inside Patterns. This class is lightweight and designed to be used like
|
||
|
// values.
|
||
|
//
|
||
|
// A TableGen DAG construct is of the syntax
|
||
|
// `(operator, arg0, arg1, ...)`.
|
||
|
//
|
||
|
// When used inside Patterns, `operator` corresponds to some dialect op, or
|
||
|
// a known list of verbs that defines special transformation actions. This
|
||
|
// `arg*` can be a nested DAG construct. This class provides getters to
|
||
|
// retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
|
||
|
// methods.
|
||
|
//
|
||
|
// A null DagNode contains a nullptr and converts to false implicitly.
|
||
|
class DagNode {
|
||
|
public:
|
||
|
explicit DagNode(const llvm::DagInit *node) : node(node) {}
|
||
|
|
||
|
// Implicit bool converter that returns true if this DagNode is not a null
|
||
|
// DagNode.
|
||
|
operator bool() const { return node != nullptr; }
|
||
|
|
||
|
// Returns the symbol bound to this DAG node.
|
||
|
StringRef getSymbol() const;
|
||
|
|
||
|
// Returns the operator wrapper object corresponding to the dialect op matched
|
||
|
// by this DAG. The operator wrapper will be queried from the given `mapper`
|
||
|
// and created in it if not existing.
|
||
|
Operator &getDialectOp(RecordOperatorMap *mapper) const;
|
||
|
|
||
|
// Returns the number of operations recursively involved in the DAG tree
|
||
|
// rooted from this node.
|
||
|
int getNumOps() const;
|
||
|
|
||
|
// Returns the number of immediate arguments to this DAG node.
|
||
|
int getNumArgs() const;
|
||
|
|
||
|
// Returns true if the `index`-th argument is a nested DAG construct.
|
||
|
bool isNestedDagArg(unsigned index) const;
|
||
|
|
||
|
// Gets the `index`-th argument as a nested DAG construct if possible. Returns
|
||
|
// null DagNode otherwise.
|
||
|
DagNode getArgAsNestedDag(unsigned index) const;
|
||
|
|
||
|
// Gets the `index`-th argument as a DAG leaf.
|
||
|
DagLeaf getArgAsLeaf(unsigned index) const;
|
||
|
|
||
|
// Returns the specified name of the `index`-th argument.
|
||
|
StringRef getArgName(unsigned index) const;
|
||
|
|
||
|
// Returns true if this DAG construct means to replace with an existing SSA
|
||
|
// value.
|
||
|
bool isReplaceWithValue() const;
|
||
|
|
||
|
// Returns whether this DAG represents the location of an op creation.
|
||
|
bool isLocationDirective() const;
|
||
|
|
||
|
// Returns true if this DAG node is wrapping native code call.
|
||
|
bool isNativeCodeCall() const;
|
||
|
|
||
|
// Returns true if this DAG node is an operation.
|
||
|
bool isOperation() const;
|
||
|
|
||
|
// Returns the native code call template inside this DAG node.
|
||
|
// Precondition: isNativeCodeCall()
|
||
|
StringRef getNativeCodeTemplate() const;
|
||
|
|
||
|
void print(raw_ostream &os) const;
|
||
|
|
||
|
private:
|
||
|
const llvm::DagInit *node; // nullptr means null DagNode
|
||
|
};
|
||
|
|
||
|
// A class for maintaining information for symbols bound in patterns and
|
||
|
// provides methods for resolving them according to specific use cases.
|
||
|
//
|
||
|
// Symbols can be bound to
|
||
|
//
|
||
|
// * Op arguments and op results in the source pattern and
|
||
|
// * Op results in result patterns.
|
||
|
//
|
||
|
// Symbols can be referenced in result patterns and additional constraints to
|
||
|
// the pattern.
|
||
|
//
|
||
|
// For example, in
|
||
|
//
|
||
|
// ```
|
||
|
// def : Pattern<
|
||
|
// (SrcOp:$results1 $arg0, %arg1),
|
||
|
// [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
|
||
|
// ```
|
||
|
//
|
||
|
// `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
|
||
|
// `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
|
||
|
// `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
|
||
|
//
|
||
|
// If a symbol binds to a multi-result op and it does not have the `__N`
|
||
|
// suffix, the symbol is expanded to represent all results generated by the
|
||
|
// multi-result op. If the symbol has a `__N` suffix, then it will expand to
|
||
|
// only the N-th *static* result as declared in ODS, and that can still
|
||
|
// corresponds to multiple *dynamic* values if the N-th *static* result is
|
||
|
// variadic.
|
||
|
//
|
||
|
// This class keeps track of such symbols and resolves them into their bound
|
||
|
// values in a suitable way.
|
||
|
class SymbolInfoMap {
|
||
|
public:
|
||
|
explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {}
|
||
|
|
||
|
// Class for information regarding a symbol.
|
||
|
class SymbolInfo {
|
||
|
public:
|
||
|
// Returns a string for defining a variable named as `name` to store the
|
||
|
// value bound by this symbol.
|
||
|
std::string getVarDecl(StringRef name) const;
|
||
|
|
||
|
// Returns a variable name for the symbol named as `name`.
|
||
|
std::string getVarName(StringRef name) const;
|
||
|
|
||
|
private:
|
||
|
// Allow SymbolInfoMap to access private methods.
|
||
|
friend class SymbolInfoMap;
|
||
|
|
||
|
// What kind of entity this symbol represents:
|
||
|
// * Attr: op attribute
|
||
|
// * Operand: op operand
|
||
|
// * Result: op result
|
||
|
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
|
||
|
enum class Kind : uint8_t { Attr, Operand, Result, Value };
|
||
|
|
||
|
// Creates a SymbolInfo instance. `index` is only used for `Attr` and
|
||
|
// `Operand` so should be negative for `Result` and `Value` kind.
|
||
|
SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
|
||
|
|
||
|
// Static methods for creating SymbolInfo.
|
||
|
static SymbolInfo getAttr(const Operator *op, int index) {
|
||
|
return SymbolInfo(op, Kind::Attr, index);
|
||
|
}
|
||
|
static SymbolInfo getAttr() {
|
||
|
return SymbolInfo(nullptr, Kind::Attr, llvm::None);
|
||
|
}
|
||
|
static SymbolInfo getOperand(const Operator *op, int index) {
|
||
|
return SymbolInfo(op, Kind::Operand, index);
|
||
|
}
|
||
|
static SymbolInfo getResult(const Operator *op) {
|
||
|
return SymbolInfo(op, Kind::Result, llvm::None);
|
||
|
}
|
||
|
static SymbolInfo getValue() {
|
||
|
return SymbolInfo(nullptr, Kind::Value, llvm::None);
|
||
|
}
|
||
|
|
||
|
// Returns the number of static values this symbol corresponds to.
|
||
|
// A static value is an operand/result declared in ODS. Normally a symbol
|
||
|
// only represents one static value, but symbols bound to op results can
|
||
|
// represent more than one if the op is a multi-result op.
|
||
|
int getStaticValueCount() const;
|
||
|
|
||
|
// Returns a string containing the C++ expression for referencing this
|
||
|
// symbol as a value (if this symbol represents one static value) or a value
|
||
|
// range (if this symbol represents multiple static values). `name` is the
|
||
|
// name of the C++ variable that this symbol bounds to. `index` should only
|
||
|
// be used for indexing results. `fmt` is used to format each value.
|
||
|
// `separator` is used to separate values if this is a value range.
|
||
|
std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
|
||
|
const char *separator) const;
|
||
|
|
||
|
// Returns a string containing the C++ expression for referencing this
|
||
|
// symbol as a value range regardless of how many static values this symbol
|
||
|
// represents. `name` is the name of the C++ variable that this symbol
|
||
|
// bounds to. `index` should only be used for indexing results. `fmt` is
|
||
|
// used to format each value. `separator` is used to separate values in the
|
||
|
// range.
|
||
|
std::string getAllRangeUse(StringRef name, int index, const char *fmt,
|
||
|
const char *separator) const;
|
||
|
|
||
|
const Operator *op; // The op where the bound entity belongs
|
||
|
Kind kind; // The kind of the bound entity
|
||
|
// The argument index (for `Attr` and `Operand` only)
|
||
|
Optional<int> argIndex;
|
||
|
// Alternative name for the symbol. It is used in case the name
|
||
|
// is not unique. Applicable for `Operand` only.
|
||
|
Optional<std::string> alternativeName;
|
||
|
};
|
||
|
|
||
|
using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
|
||
|
|
||
|
// Iterators for accessing all symbols.
|
||
|
using iterator = BaseT::iterator;
|
||
|
iterator begin() { return symbolInfoMap.begin(); }
|
||
|
iterator end() { return symbolInfoMap.end(); }
|
||
|
|
||
|
// Const iterators for accessing all symbols.
|
||
|
using const_iterator = BaseT::const_iterator;
|
||
|
const_iterator begin() const { return symbolInfoMap.begin(); }
|
||
|
const_iterator end() const { return symbolInfoMap.end(); }
|
||
|
|
||
|
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
|
||
|
// Returns false if `symbol` is already bound and symbols are not operands.
|
||
|
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
|
||
|
|
||
|
// Binds the given `symbol` to the results the given `op`. Returns false if
|
||
|
// `symbol` is already bound.
|
||
|
bool bindOpResult(StringRef symbol, const Operator &op);
|
||
|
|
||
|
// Registers the given `symbol` as bound to a value. Returns false if `symbol`
|
||
|
// is already bound.
|
||
|
bool bindValue(StringRef symbol);
|
||
|
|
||
|
// Registers the given `symbol` as bound to an attr. Returns false if `symbol`
|
||
|
// is already bound.
|
||
|
bool bindAttr(StringRef symbol);
|
||
|
|
||
|
// Returns true if the given `symbol` is bound.
|
||
|
bool contains(StringRef symbol) const;
|
||
|
|
||
|
// Returns an iterator to the information of the given symbol named as `key`.
|
||
|
const_iterator find(StringRef key) const;
|
||
|
|
||
|
// Returns an iterator to the information of the given symbol named as `key`,
|
||
|
// with index `argIndex` for operator `op`.
|
||
|
const_iterator findBoundSymbol(StringRef key, const Operator &op,
|
||
|
int argIndex) const;
|
||
|
|
||
|
// Returns the bounds of a range that includes all the elements which
|
||
|
// bind to the `key`.
|
||
|
std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
|
||
|
|
||
|
// Returns number of times symbol named as `key` was used.
|
||
|
int count(StringRef key) const;
|
||
|
|
||
|
// Returns the number of static values of the given `symbol` corresponds to.
|
||
|
// A static value is an operand/result declared in ODS. Normally a symbol only
|
||
|
// represents one static value, but symbols bound to op results can represent
|
||
|
// more than one if the op is a multi-result op.
|
||
|
int getStaticValueCount(StringRef symbol) const;
|
||
|
|
||
|
// Returns a string containing the C++ expression for referencing this
|
||
|
// symbol as a value (if this symbol represents one static value) or a value
|
||
|
// range (if this symbol represents multiple static values). `fmt` is used to
|
||
|
// format each value. `separator` is used to separate values if `symbol`
|
||
|
// represents a value range.
|
||
|
std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
|
||
|
const char *separator = ", ") const;
|
||
|
|
||
|
// Returns a string containing the C++ expression for referencing this
|
||
|
// symbol as a value range regardless of how many static values this symbol
|
||
|
// represents. `fmt` is used to format each value. `separator` is used to
|
||
|
// separate values in the range.
|
||
|
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
|
||
|
const char *separator = ", ") const;
|
||
|
|
||
|
// Assign alternative unique names to Operands that have equal names.
|
||
|
void assignUniqueAlternativeNames();
|
||
|
|
||
|
// Splits the given `symbol` into a value pack name and an index. Returns the
|
||
|
// value pack name and writes the index to `index` on success. Returns
|
||
|
// `symbol` itself if it does not contain an index.
|
||
|
//
|
||
|
// We can use `name__N` to access the `N`-th value in the value pack bound to
|
||
|
// `name`. `name` is typically the results of an multi-result op.
|
||
|
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
|
||
|
|
||
|
private:
|
||
|
BaseT symbolInfoMap;
|
||
|
|
||
|
// Pattern instantiation location. This is intended to be used as parameter
|
||
|
// to PrintFatalError() to report errors.
|
||
|
ArrayRef<llvm::SMLoc> loc;
|
||
|
};
|
||
|
|
||
|
// Wrapper class providing helper methods for accessing MLIR Pattern defined
|
||
|
// in TableGen. This class should closely reflect what is defined as class
|
||
|
// `Pattern` in TableGen. This class contains maps so it is not intended to be
|
||
|
// used as values.
|
||
|
class Pattern {
|
||
|
public:
|
||
|
explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
|
||
|
|
||
|
// Returns the source pattern to match.
|
||
|
DagNode getSourcePattern() const;
|
||
|
|
||
|
// Returns the number of result patterns generated by applying this rewrite
|
||
|
// rule.
|
||
|
int getNumResultPatterns() const;
|
||
|
|
||
|
// Returns the DAG tree root node of the `index`-th result pattern.
|
||
|
DagNode getResultPattern(unsigned index) const;
|
||
|
|
||
|
// Collects all symbols bound in the source pattern into `infoMap`.
|
||
|
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
|
||
|
|
||
|
// Collects all symbols bound in result patterns into `infoMap`.
|
||
|
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
|
||
|
|
||
|
// Returns the op that the root node of the source pattern matches.
|
||
|
const Operator &getSourceRootOp();
|
||
|
|
||
|
// Returns the operator wrapper object corresponding to the given `node`'s DAG
|
||
|
// operator.
|
||
|
Operator &getDialectOp(DagNode node);
|
||
|
|
||
|
// Returns the constraints.
|
||
|
std::vector<AppliedConstraint> getConstraints() const;
|
||
|
|
||
|
// Returns the benefit score of the pattern.
|
||
|
int getBenefit() const;
|
||
|
|
||
|
using IdentifierLine = std::pair<StringRef, unsigned>;
|
||
|
|
||
|
// Returns the file location of the pattern (buffer identifier + line number
|
||
|
// pair).
|
||
|
std::vector<IdentifierLine> getLocation() const;
|
||
|
|
||
|
private:
|
||
|
// Helper function to verify variabld binding.
|
||
|
void verifyBind(bool result, StringRef symbolName);
|
||
|
|
||
|
// Recursively collects all bound symbols inside the DAG tree rooted
|
||
|
// at `tree` and updates the given `infoMap`.
|
||
|
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||
|
bool isSrcPattern);
|
||
|
|
||
|
// The TableGen definition of this pattern.
|
||
|
const llvm::Record &def;
|
||
|
|
||
|
// All operators.
|
||
|
// TODO: we need a proper context manager, like MLIRContext, for managing the
|
||
|
// lifetime of shared entities.
|
||
|
RecordOperatorMap *recordOpMap;
|
||
|
};
|
||
|
|
||
|
} // end namespace tblgen
|
||
|
} // end namespace mlir
|
||
|
|
||
|
#endif // MLIR_TABLEGEN_PATTERN_H_
|