mlir-hlo/tools/mlir-tblgen-builder/BuilderDefinitionsGen.cpp

2298 lines
87 KiB
C++
Raw Permalink Normal View History

2021-07-23 11:38:34 +08:00
//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpDefinitionsGen uses the description of operations to generate C++
// definitions for ops.
//
//===----------------------------------------------------------------------===//
#include "OpFormatGen.h"
#include "OpGenHelpers.h"
#include "TableGen/CodeGenHelpers.h"
#include "TableGen/Format.h"
#include "TableGen/GenInfo.h"
#include "TableGen/Interfaces.h"
#include "TableGen/OpClass.h"
#include "TableGen/Operator.h"
#include "TableGen/SideEffects.h"
#include "TableGen/Trait.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
2021-08-11 10:46:07 +08:00
#include <iostream>
2021-08-04 20:24:07 +08:00
2021-07-23 11:38:34 +08:00
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "odsArg";
static const char *const odsBuilder = "odsBuilder";
static const char *const builderOpState = "odsState";
2021-08-11 10:46:07 +08:00
namespace {
struct mlirTypeWrap {
std::string Name;
std::string (*ConvertToMlir)(std::string &, OpMethodBody &);
};
static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
{"::mlir::StringAttr",
{"std::string",
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::StringAttr " << var
<< "_mlir = mlir::StringAttr::get(ctx, mlir::Twine(" << var
<< "));\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
{"::mlir::IntegerAttr",
{"builder::Integer",
[](std::string &var, OpMethodBody &body) -> std::string {
body << " mlir::IntegerAttr " << var << "_mlir = " << var
2021-08-13 15:05:10 +08:00
<< ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
2021-08-13 15:05:10 +08:00
{"::mlir::FloatAttr",
{"builder::Float",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::FloatAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
2021-08-13 15:05:10 +08:00
{"::mlir::DenseIntElementsAttr",
{"builder::TensorInt",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::DenseIntElementsAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
2021-08-13 15:05:10 +08:00
{"::mlir::mhlo::ChannelHandle",
{"builder::ChannelHandle",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::mhlo::ChannelHandle " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
{"::mlir::BoolAttr",
{"bool",
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::BoolAttr " << var
<< "_mlir = mlir::BoolAttr::get(ctx, " << var << ");\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
{"::mlir::ElementsAttr",
2021-08-13 15:05:10 +08:00
{"builder::Tensor",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
{"::mlir::DenseElementsAttr",
2021-08-13 15:05:10 +08:00
{"builder::Tensor",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::DenseElementsAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
{"::mlir::ArrayAttr",
2021-08-13 15:05:10 +08:00
{"builder::Array",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::ArrayAttr " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
{"::mlir::mhlo::ConvDimensionNumbers",
2021-08-13 15:05:10 +08:00
{"builder::ConvDimensionNumbers",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::mhlo::ConvDimensionNumbers " << var
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
{"::mlir::mhlo::DotDimensionNumbers",
2021-08-13 15:05:10 +08:00
{"builder::DotDimensionNumbers",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::mhlo::DotDimensionNumbers " << var << "_mlir = " << var
<< ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
2021-08-04 20:24:07 +08:00
{"::mlir::mhlo::GatherDimensionNumbers",
2021-08-13 15:05:10 +08:00
{"builder::GatherDimensionNumbers",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::mhlo::GatherDimensionNumbers " << var
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
2021-08-04 20:24:07 +08:00
{"::mlir::mhlo::ScatterDimensionNumbers",
2021-08-13 15:05:10 +08:00
{"builder::ScatterDimensionNumbers",
2021-08-11 10:46:07 +08:00
[](std::string &var, OpMethodBody &body) -> std::string {
2021-08-13 15:05:10 +08:00
body << " mlir::mhlo::ScatterDimensionNumbers " << var
<< "_mlir = " << var << ".GetImpl()->GetAttr(ctx);\n";
2021-08-11 10:46:07 +08:00
return var + "_mlir";
}}},
2021-08-04 20:24:07 +08:00
};
2021-08-11 10:46:07 +08:00
// static const std::map<std::string, mlirTypeWrap> typeMapMLIR = {
// {"::mlir::StringAttr", "std::string"},
// {"::mlir::IntegerAttr", "int"},
// {"::mlir::DenseIntElementsAttr", "std::vector<int>"},
// {"::mlir::mhlo::ChannelHandle", "ChannelHandle"},
// {"::mlir::FloatAttr", "float"},
// {"::mlir::BoolAttr", "bool"},
2021-08-13 15:05:10 +08:00
// {"::mlir::ElementsAttr", "builder::Array"},
// {"::mlir::DenseElementsAttr", "builder::Tensor"},
2021-08-11 10:46:07 +08:00
// // current only support string array.
// {"::mlir::ArrayAttr", "std::vector<std::string>"},
2021-08-13 15:05:10 +08:00
// {"::mlir::mhlo::ConvDimensionNumbers", "builder::ConvDimensionNumbers"},
// {"::mlir::mhlo::DotDimensionNumbers", "builder::DotDimensionNumbers"},
2021-08-11 10:46:07 +08:00
// {"::mlir::mhlo::GatherDimensionNumbers",
2021-08-13 15:05:10 +08:00
// "builder::GatherDimensionNumbers"},
2021-08-11 10:46:07 +08:00
// {"::mlir::mhlo::ScatterDimensionNumbers",
2021-08-13 15:05:10 +08:00
// "builder::ScatterDimensionNumbers"},
2021-08-11 10:46:07 +08:00
// };
2021-08-04 20:24:07 +08:00
StringRef typeConvertFromMLIR(StringRef type) {
auto re = typeMapMLIR.find(type.str());
2021-08-11 10:46:07 +08:00
if (re != typeMapMLIR.end()) return StringRef(re->second.Name);
2021-08-04 20:24:07 +08:00
return type;
}
StringRef getStorageType(const Attribute &att) {
auto type = att.getStorageType();
return typeConvertFromMLIR(type);
}
StringRef getReturnType(const Attribute &att) {
auto type = att.getStorageType();
return typeConvertFromMLIR(type);
}
2021-08-11 10:46:07 +08:00
} // namespace
2021-08-04 20:24:07 +08:00
2021-07-23 11:38:34 +08:00
// The logic to calculate the actual value range for a declared operand/result
// of an op with variadic operands/results. Note that this logic is not for
// general use; it assumes all variadic operands/results must have the same
// number of values.
//
// {0}: The list of whether each declared operand/result is variadic.
// {1}: The total number of non-variadic operands/results.
// {2}: The total number of variadic operands/results.
// {3}: The total number of actual values.
// {4}: "operand" or "result".
const char *sameVariadicSizeValueRangeCalcCode = R"(
bool isVariadic[] = {{{0}};
int prevVariadicCount = 0;
for (unsigned i = 0; i < index; ++i)
if (isVariadic[i]) ++prevVariadicCount;
// Calculate how many dynamic values a static variadic {4} corresponds to.
// This assumes all static variadic {4}s have the same dynamic value count.
int variadicSize = ({3} - {1}) / {2};
// `index` passed in as the parameter is the static index which counts each
// {4} (variadic or not) as size 1. So here for each previous static variadic
// {4}, we need to offset by (variadicSize - 1) to get where the dynamic
// value pack for this static {4} starts.
int start = index + (variadicSize - 1) * prevVariadicCount;
int size = isVariadic[index] ? variadicSize : 1;
return {{start, size};
)";
// The logic to calculate the actual value range for a declared operand/result
// of an op with variadic operands/results. Note that this logic is assumes
// the op has an attribute specifying the size of each operand/result segment
// (variadic or not).
//
// {0}: The name of the attribute specifying the segment sizes.
const char *adapterSegmentSizeAttrInitCode = R"(
assert(odsAttrs && "missing segment size attribute for op");
auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
)";
const char *opSegmentSizeAttrInitCode = R"(
auto sizeAttr = (*this)->getAttr("{0}").cast<::mlir::DenseIntElementsAttr>();
)";
const char *attrSizedSegmentValueRangeCalcCode = R"(
auto sizeAttrValues = sizeAttr.getValues<uint32_t>();
unsigned start = 0;
for (unsigned i = 0; i < index; ++i)
start += *(sizeAttrValues.begin() + i);
unsigned size = *(sizeAttrValues.begin() + index);
return {start, size};
)";
// The logic to build a range of either operand or result values.
//
// {0}: The begin iterator of the actual values.
// {1}: The call to generate the start and length of the value range.
const char *valueRangeReturnCode = R"(
auto valueRange = {1};
return {{std::next({0}, valueRange.first),
std::next({0}, valueRange.first + valueRange.second)};
)";
static const char *const opCommentHeader = R"(
//===----------------------------------------------------------------------===//
// {0} {1}
//===----------------------------------------------------------------------===//
)";
//===----------------------------------------------------------------------===//
// StaticVerifierFunctionEmitter
//===----------------------------------------------------------------------===//
namespace {
/// This class deduplicates shared operation verification code by emitting
/// static functions alongside the op definitions. These methods are local to
/// the definition file, and are invoked within the operation verify methods.
/// An example is shown below:
///
/// static LogicalResult localVerify(...)
///
/// LogicalResult OpA::verify(...) {
/// if (failed(localVerify(...)))
/// return failure();
/// ...
/// }
///
/// LogicalResult OpB::verify(...) {
/// if (failed(localVerify(...)))
/// return failure();
/// ...
/// }
///
class StaticVerifierFunctionEmitter {
public:
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
ArrayRef<llvm::Record *> opDefs,
raw_ostream &os, bool emitDecl);
/// Get the name of the local function used for the given type constraint.
/// These functions are used for operand and result constraints and have the
/// form:
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
/// unsigned valueGroupStartIndex);
StringRef getTypeConstraintFn(const Constraint &constraint) const {
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
return it->second;
}
private:
/// Returns a unique name to use when generating local methods.
static std::string getUniqueName(const llvm::RecordKeeper &records);
/// Emit local methods for the type constraints used within the provided op
/// definitions.
void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
raw_ostream &os, bool emitDecl);
/// A unique label for the file currently being generated. This is used to
/// ensure that the local functions have a unique name.
std::string uniqueOutputLabel;
/// A set of functions implementing type constraints, used for operand and
/// result verification.
llvm::DenseMap<const void *, std::string> localTypeConstraints;
};
} // namespace
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
raw_ostream &os, bool emitDecl)
: uniqueOutputLabel(getUniqueName(records)) {
llvm::Optional<NamespaceEmitter> namespaceEmitter;
if (!emitDecl) {
2021-08-04 20:24:07 +08:00
// os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
2021-07-23 11:38:34 +08:00
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
}
2021-08-04 20:24:07 +08:00
// emitTypeConstraintMethods(opDefs, os, emitDecl);
2021-07-23 11:38:34 +08:00
}
std::string StaticVerifierFunctionEmitter::getUniqueName(
const llvm::RecordKeeper &records) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();
// Drop all but the base filename.
StringRef nameRef = llvm::sys::path::filename(inputFilename);
nameRef.consume_back(".td");
// Sanitize any invalid characters.
std::string uniqueName;
for (char c : nameRef) {
if (llvm::isAlnum(c) || c == '_')
uniqueName.push_back(c);
else
uniqueName.append(llvm::utohexstr((unsigned char)c));
}
return uniqueName;
}
void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
// Collect a set of all of the used type constraints within the operation
// definitions.
llvm::SetVector<const void *> typeConstraints;
for (Record *def : opDefs) {
Operator op(*def);
for (NamedTypeConstraint &operand : op.getOperands())
if (operand.hasPredicate())
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
for (NamedTypeConstraint &result : op.getResults())
if (result.hasPredicate())
typeConstraints.insert(result.constraint.getAsOpaquePointer());
}
FmtContext fctx;
for (auto it : llvm::enumerate(typeConstraints)) {
// Generate an obscure and unique name for this type constraint.
std::string name = (Twine("__mlir_ods_local_type_constraint_") +
uniqueOutputLabel + Twine(it.index()))
.str();
localTypeConstraints.try_emplace(it.value(), name);
// Only generate the methods if we are generating definitions.
if (emitDecl)
continue;
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
os << "static ::mlir::LogicalResult " << name
<< "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
"valueKind, unsigned valueGroupStartIndex) {\n";
os << " if (!("
<< tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
<< ")) {\n"
<< formatv(
" return op->emitOpError(valueKind) << \" #\" << "
"valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
constraint.getSummary())
<< " }\n"
<< " return ::mlir::success();\n"
<< "}\n\n";
}
}
//===----------------------------------------------------------------------===//
// Utility structs and functions
//===----------------------------------------------------------------------===//
// Replaces all occurrences of `match` in `str` with `substitute`.
static std::string replaceAllSubstrs(std::string str, const std::string &match,
const std::string &substitute) {
std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
str = str.replace(matchLoc, match.size(), substitute);
scanLoc = matchLoc + substitute.size();
}
return str;
}
// Returns whether the record has a value of the given name that can be returned
// via getValueAsString.
static inline bool hasStringAttribute(const Record &record,
StringRef fieldName) {
auto valueInit = record.getValueInit(fieldName);
return isa<StringInit>(valueInit);
}
static std::string getArgumentName(const Operator &op, int index) {
const auto &operand = op.getOperand(index);
if (!operand.name.empty())
return std::string(operand.name);
else
return std::string(formatv("{0}_{1}", generatedArgName, index));
}
//===----------------------------------------------------------------------===//
// Op emitter
//===----------------------------------------------------------------------===//
namespace {
// Helper class to emit a record into the given output stream.
class OpEmitter {
public:
static void
emitDecl(const Operator &op, raw_ostream &os,
const StaticVerifierFunctionEmitter &staticVerifierEmitter);
static void
emitDef(const Operator &op, raw_ostream &os,
const StaticVerifierFunctionEmitter &staticVerifierEmitter);
private:
OpEmitter(const Operator &op,
const StaticVerifierFunctionEmitter &staticVerifierEmitter);
void emitDecl(raw_ostream &os);
void emitDef(raw_ostream &os);
// Generates the OpAsmOpInterface for this operation if possible.
void genOpAsmInterface();
// Generates the `getOperationName` method for this op.
void genOpNameGetter();
// Generates getters for the attributes.
void genAttrGetters();
// Generates setter for the attributes.
void genAttrSetters();
// Generates removers for optional attributes.
void genOptionalAttrRemovers();
// Generates getters for named operands.
void genNamedOperandGetters();
// Generates setters for named operands.
void genNamedOperandSetters();
// Generates getters for named results.
void genNamedResultGetters();
// Generates getters for named regions.
void genNamedRegionGetters();
// Generates getters for named successors.
void genNamedSuccessorGetters();
// Generates builder methods for the operation.
void genBuilder();
// Generates the build() method that takes each operand/attribute
// as a stand-alone parameter.
void genSeparateArgParamBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first operand's
// type as all results' types.
2021-08-04 20:24:07 +08:00
// void genUseOperandAsResultTypeSeparateParamBuilder();
2021-07-23 11:38:34 +08:00
// Generates the build() method that takes all operands/attributes
// collectively as one parameter. The generated build() method uses first
// operand's type as all results' types.
2021-08-04 20:24:07 +08:00
// void genUseOperandAsResultTypeCollectiveParamBuilder();
2021-07-23 11:38:34 +08:00
// Generates the build() method that takes aggregate operands/attributes
// parameters. This build() method uses inferred types as result types.
// Requires: The type needs to be inferable via InferTypeOpInterface.
void genInferredTypeCollectiveParamBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first attribute's
// type as all result's types.
2021-08-04 20:24:07 +08:00
// void genUseAttrAsResultTypeBuilder();
2021-07-23 11:38:34 +08:00
// Generates the build() method that takes all result types collectively as
// one parameter. Similarly for operands and attributes.
2021-08-04 20:24:07 +08:00
// void genCollectiveParamBuilder();
2021-07-23 11:38:34 +08:00
// The kind of parameter to generate for result types in builders.
enum class TypeParamKind {
None, // No result type in parameter list.
Separate, // A separate parameter for each result type.
Collective, // An ArrayRef<Type> for all result types.
};
// The kind of parameter to generate for attributes in builders.
enum class AttrParamKind {
WrappedAttr, // A wrapped MLIR Attribute instance.
UnwrappedValue, // A raw value without MLIR Attribute wrapper.
};
// Builds the parameter list for build() method of this op. This method writes
// to `paramList` the comma-separated parameter list and updates
// `resultTypeNames` with the names for parameters for specifying result
// types. The given `typeParamKind` and `attrParamKind` controls how result
// types and attributes are placed in the parameter list.
void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> &paramList,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
// Adds op arguments and regions into operation state for build() methods.
2021-08-11 10:46:07 +08:00
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
2021-07-23 11:38:34 +08:00
bool isRawValueAttr = false);
// Generates canonicalizer declaration for the operation.
void genCanonicalizerDecls();
// Generates the folder declaration for the operation.
void genFolderDecls();
// Generates the parser for the operation.
void genParser();
// Generates the printer for the operation.
void genPrinter();
// Generates verify method for the operation.
void genVerifier();
// Generates verify statements for operands and results in the operation.
// The generated code will be attached to `body`.
void genOperandResultVerifier(OpMethodBody &body,
Operator::value_range values,
StringRef valueKind);
// Generates verify statements for regions in the operation.
// The generated code will be attached to `body`.
void genRegionVerifier(OpMethodBody &body);
// Generates verify statements for successors in the operation.
// The generated code will be attached to `body`.
void genSuccessorVerifier(OpMethodBody &body);
// Generates the traits used by the object.
void genTraits();
// Generate the OpInterface methods for all interfaces.
void genOpInterfaceMethods();
// Generate op interface methods for the given interface.
void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
// Generate op interface method for the given interface method. If
// 'declaration' is true, generates a declaration, else a definition.
OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
bool declaration = true);
// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();
// Generate the type inference interface methods.
void genTypeInterfaceMethods();
2021-08-11 10:46:07 +08:00
Operator GetOp() { return op; }
private:
2021-07-23 11:38:34 +08:00
// The TableGen record for this op.
// TODO: OpEmitter should not have a Record directly,
// it should rather go through the Operator for better abstraction.
const Record &def;
// The wrapper operator class for querying information from this op.
Operator op;
// The C++ code builder for this op
OpClass opClass;
// The format context for verification code generation.
FmtContext verifyCtx;
// The emitter containing all of the locally emitted verification functions.
const StaticVerifierFunctionEmitter &staticVerifierEmitter;
};
} // end anonymous namespace
// Populate the format context `ctx` with substitutions of attributes, operands
// and results.
// - attrGet corresponds to the name of the function to call to get value of
// attribute (the generated function call returns an Attribute);
// - operandGet corresponds to the name of the function with which to retrieve
// an operand (the generated function call returns an OperandRange);
// - resultGet corresponds to the name of the function to get an result (the
// generated function call returns a ValueRange);
static void populateSubstitutions(const Operator &op, const char *attrGet,
const char *operandGet, const char *resultGet,
FmtContext &ctx) {
// Populate substitutions for attributes and named operands.
for (const auto &namedAttr : op.getAttributes())
ctx.addSubst(namedAttr.name,
formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
auto &value = op.getOperand(i);
if (value.name.empty())
continue;
if (value.isVariadic())
ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
else
ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
}
// Populate substitutions for results.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
auto &value = op.getResult(i);
if (value.name.empty())
continue;
if (value.isVariadic())
ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
else
ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
}
}
// Generate attribute verification. If emitVerificationRequiringOp is set then
// only verification for attributes whose value depend on op being known are
// emitted, else only verification that doesn't depend on the op being known are
// generated.
// - emitErrorPrefix is the prefix for the error emitting call which consists
// of the entire function call up to start of error message fragment;
// - emitVerificationRequiringOp specifies whether verification should be
// emitted for verification that require the op to exist;
static void genAttributeVerifier(const Operator &op, const char *attrGet,
const Twine &emitErrorPrefix,
bool emitVerificationRequiringOp,
FmtContext &ctx, OpMethodBody &body) {
for (const auto &namedAttr : op.getAttributes()) {
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr())
continue;
auto attrName = namedAttr.name;
bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
auto attrPred = attr.getPredicate();
auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
// There is a condition to emit only if the use of $_op and whether to
// emit verifications for op matches.
bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
emitVerificationRequiringOp);
// Prefix with `tblgen_` to avoid hiding the attribute accessor.
auto varName = tblgenNamePrefix + attrName;
// If the attribute is
// 1. Required (not allowed missing) and not in op verification, or
// 2. Has a condition that will get verified
// then the variable will be used.
//
// Therefore, for optional attributes whose verification requires that an
// op already exists for verification/emitVerificationRequiringOp is set
// has nothing that can be verified here.
if ((allowMissingAttr || emitVerificationRequiringOp) &&
!hasConditionToEmit)
continue;
body << formatv(" {\n auto {0} = {1}(\"{2}\");\n", varName, attrGet,
attrName);
if (!emitVerificationRequiringOp && !allowMissingAttr) {
body << " if (!" << varName << ") return " << emitErrorPrefix
<< "\"requires attribute '" << attrName << "'\");\n";
}
if (!hasConditionToEmit) {
body << " }\n";
continue;
}
if (allowMissingAttr) {
// If the attribute has a default value, then only verify the predicate if
// set. This does effectively assume that the default value is valid.
// TODO: verify the debug value is valid (perhaps in debug mode only).
body << " if (" << varName << ") {\n";
}
body << tgfmt(" if (!($0)) return $1\"attribute '$2' "
"failed to satisfy constraint: $3\");\n",
/*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
emitErrorPrefix, attrName, attr.getSummary());
if (allowMissingAttr)
body << " }\n";
body << " }\n";
}
}
OpEmitter::OpEmitter(const Operator &op,
const StaticVerifierFunctionEmitter &staticVerifierEmitter)
: def(op.getDef()), op(op),
opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
staticVerifierEmitter(staticVerifierEmitter) {
verifyCtx.withOp("(*this->getOperation())");
verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
// Dot not need traits in buider
// genTraits();
// Generate C++ code for various op methods. The order here determines the
// methods in the generated file.
// genOpAsmInterface();
2021-08-13 15:05:10 +08:00
//// genOpNameGetter();
2021-08-04 20:24:07 +08:00
// genNamedOperandGetters();
// genNamedOperandSetters();
// genNamedResultGetters();
// genNamedRegionGetters();
2021-08-13 15:05:10 +08:00
//// genNamedSuccessorGetters();
//// genAttrGetters();
//// genAttrSetters();
2021-08-04 20:24:07 +08:00
// genOptionalAttrRemovers();
2021-07-23 11:38:34 +08:00
genBuilder();
2021-08-04 20:24:07 +08:00
// genParser();
// genPrinter();
// genVerifier();
// genCanonicalizerDecls();
// genFolderDecls();
// genTypeInterfaceMethods();
// genOpInterfaceMethods();
// generateOpFormat(op, opClass);
// genSideEffectInterfaceMethods();
2021-07-23 11:38:34 +08:00
}
void OpEmitter::emitDecl(
const Operator &op, raw_ostream &os,
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
OpEmitter(op, staticVerifierEmitter).emitDecl(os);
}
void OpEmitter::emitDef(
const Operator &op, raw_ostream &os,
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
OpEmitter(op, staticVerifierEmitter).emitDef(os);
}
void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
void OpEmitter::genAttrGetters() {
FmtContext fctx;
fctx.withBuilder("::mlir::Builder((*this)->getContext())");
Dialect opDialect = op.getDialect();
// Emit the derived attribute body.
auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
2021-08-04 20:24:07 +08:00
auto *method = opClass.addMethodAndPrune(getReturnType(attr), name);
2021-07-23 11:38:34 +08:00
if (!method)
return;
auto &body = method->body();
body << " " << attr.getDerivedCodeBody() << "\n";
};
// Emit with return type specified.
auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
2021-08-04 20:24:07 +08:00
auto *method = opClass.addMethodAndPrune(getReturnType(attr), name);
2021-07-23 11:38:34 +08:00
auto &body = method->body();
body << " auto attr = " << name << "Attr();\n";
if (attr.hasDefaultValue()) {
// Returns the default value if not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
std::string defaultValue = std::string(
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
body << " if (!attr)\n return "
<< tgfmt(attr.getConvertFromStorageCall(),
&fctx.withSelf(defaultValue))
<< ";\n";
}
body << " return "
<< tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
<< ";\n";
};
// Generate raw named accessor type. This is a wrapper class that allows
// referring to the attributes via accessors instead of having to use
// the string interface for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto *method =
2021-08-04 20:24:07 +08:00
opClass.addMethodAndPrune(getStorageType(attr), (name + "Attr").str());
2021-07-23 11:38:34 +08:00
if (!method)
return;
auto &body = method->body();
body << " return (*this)->getAttr(\"" << name << "\").template ";
if (attr.isOptional() || attr.hasDefaultValue())
body << "dyn_cast_or_null<";
else
body << "cast<";
2021-08-04 20:24:07 +08:00
body << getStorageType(attr) << ">();";
2021-07-23 11:38:34 +08:00
};
for (auto &namedAttr : op.getAttributes()) {
const auto &name = namedAttr.name;
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr()) {
emitDerivedAttr(name, attr);
} else {
emitAttrWithStorageType(name, attr);
emitAttrWithReturnType(name, attr);
}
}
auto derivedAttrs = make_filter_range(op.getAttributes(),
[](const NamedAttribute &namedAttr) {
return namedAttr.attr.isDerivedAttr();
});
if (!derivedAttrs.empty()) {
opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
// Generate helper method to query whether a named attribute is a derived
// attribute. This enables, for example, avoiding adding an attribute that
// overlaps with a derived attribute.
{
auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
OpMethod::MP_Static,
"::llvm::StringRef", "name");
auto &body = method->body();
for (auto namedAttr : derivedAttrs)
body << " if (name == \"" << namedAttr.name << "\") return true;\n";
body << " return false;";
}
// Generate method to materialize derived attributes as a DictionaryAttr.
{
auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
"materializeDerivedAttributes");
auto &body = method->body();
auto nonMaterializable =
make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
return namedAttr.attr.getConvertFromStorageCall().empty();
});
if (!nonMaterializable.empty()) {
std::string attrs;
llvm::raw_string_ostream os(attrs);
interleaveComma(nonMaterializable, os,
[&](const NamedAttribute &attr) { os << attr.name; });
PrintWarning(
op.getLoc(),
formatv(
"op has non-materializable derived attributes '{0}', skipping",
os.str()));
body << formatv(" emitOpError(\"op has non-materializable derived "
"attributes '{0}'\");\n",
attrs);
body << " return nullptr;";
return;
}
body << " ::mlir::MLIRContext* ctx = getContext();\n";
body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
body << " return ::mlir::DictionaryAttr::get(";
body << " ctx, {\n";
interleave(
derivedAttrs, body,
[&](const NamedAttribute &namedAttr) {
auto tmpl = namedAttr.attr.getConvertFromStorageCall();
body << " {::mlir::Identifier::get(\"" << namedAttr.name
<< "\", ctx),\n"
<< tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
.withBuilder("odsBuilder")
.addSubst("_ctx", "ctx"))
<< "}";
},
",\n");
body << "});";
}
}
}
void OpEmitter::genAttrSetters() {
// Generate raw named setter type. This is a wrapper class that allows setting
// to the attributes via setters instead of having to use the string interface
// for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
2021-08-04 20:24:07 +08:00
getStorageType(attr), "attr");
2021-07-23 11:38:34 +08:00
if (!method)
return;
auto &body = method->body();
body << " (*this)->setAttr(\"" << name << "\", attr);";
};
for (auto &namedAttr : op.getAttributes()) {
const auto &name = namedAttr.name;
const auto &attr = namedAttr.attr;
if (!attr.isDerivedAttr())
emitAttrWithStorageType(name, attr);
}
}
void OpEmitter::genOptionalAttrRemovers() {
// Generate methods for removing optional attributes, instead of having to
// use the string interface. Enables better compile time verification.
auto emitRemoveAttr = [&](StringRef name) {
auto upperInitial = name.take_front().upper();
auto suffix = name.drop_front();
auto *method = opClass.addMethodAndPrune(
"::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
if (!method)
return;
auto &body = method->body();
body << " return (*this)->removeAttr(\"" << name << "\");";
};
for (const auto &namedAttr : op.getAttributes()) {
const auto &name = namedAttr.name;
const auto &attr = namedAttr.attr;
if (attr.isOptional())
emitRemoveAttr(name);
}
}
// Generates the code to compute the start and end index of an operand or result
// range.
template <typename RangeT>
static void
generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
int numVariadic, int numNonVariadic,
StringRef rangeSizeCall, bool hasAttrSegmentSize,
StringRef sizeAttrInit, RangeT &&odsValues) {
auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
methodName, "unsigned", "index");
if (!method)
return;
auto &body = method->body();
if (numVariadic == 0) {
body << " return {index, 1};\n";
} else if (hasAttrSegmentSize) {
body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
} else {
// Because the op can have arbitrarily interleaved variadic and non-variadic
// operands, we need to embed a list in the "sink" getter method for
// calculation at run-time.
llvm::SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(llvm::size(odsValues));
for (auto &it : odsValues)
isVariadic.push_back(it.isVariableLength() ? "true" : "false");
std::string isVariadicList = llvm::join(isVariadic, ", ");
body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
numNonVariadic, numVariadic, rangeSizeCall, "operand");
}
}
// Generates the named operand getter methods for the given Operator `op` and
// puts them in `opClass`. Uses `rangeType` as the return type of getters that
// return a range of operands (individual operands are `Value ` and each
// element in the range must also be `Value `); use `rangeBeginCall` to get
// an iterator to the beginning of the operand range; use `rangeSizeCall` to
// obtain the number of operands. `getOperandCallPattern` contains the code
// necessary to obtain a single operand whose position will be substituted
// instead of
// "{0}" marker in the pattern. Note that the pattern should work for any kind
// of ops, in particular for one-operand ops that may not have the
// `getOperand(unsigned)` method.
static void generateNamedOperandGetters(const Operator &op, Class &opClass,
StringRef sizeAttrInit,
StringRef rangeType,
StringRef rangeBeginCall,
StringRef rangeSizeCall,
StringRef getOperandCallPattern) {
const int numOperands = op.getNumOperands();
const int numVariadicOperands = op.getNumVariableLengthOperands();
const int numNormalOperands = numOperands - numVariadicOperands;
const auto *sameVariadicSize =
op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
const auto *attrSizedOperands =
op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
"specification over their sizes");
}
if (numVariadicOperands < 2 && attrSizedOperands) {
PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
"to use 'AttrSizedOperandSegments' trait");
}
if (attrSizedOperands && sameVariadicSize) {
PrintFatalError(op.getLoc(),
"op cannot have both 'AttrSizedOperandSegments' and "
"'SameVariadicOperandSize' traits");
}
// First emit a few "sink" getter methods upon which we layer all nicer named
// getter methods.
generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
numVariadicOperands, numNormalOperands,
rangeSizeCall, attrSizedOperands, sizeAttrInit,
const_cast<Operator &>(op).getOperands());
auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
"index");
auto &body = m->body();
body << formatv(valueRangeReturnCode, rangeBeginCall,
"getODSOperandIndexAndLength(index)");
// Then we emit nicer named getter methods by redirecting to the "sink" getter
// method.
for (int i = 0; i != numOperands; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
if (operand.isOptional()) {
m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
m->body()
<< " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? ::mlir::Value() : *operands.begin();";
} else if (operand.isVariadic()) {
m = opClass.addMethodAndPrune(rangeType, operand.name);
m->body() << " return getODSOperands(" << i << ");";
} else {
m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
m->body() << " return *getODSOperands(" << i << ").begin();";
}
}
}
void OpEmitter::genNamedOperandGetters() {
generateNamedOperandGetters(
op, opClass,
/*sizeAttrInit=*/
formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
/*rangeType=*/"::mlir::Operation::operand_range",
/*rangeBeginCall=*/"getOperation()->operand_begin()",
/*rangeSizeCall=*/"getOperation()->getNumOperands()",
/*getOperandCallPattern=*/"getOperation()->getOperand({0})");
}
void OpEmitter::genNamedOperandSetters() {
auto *attrSizedOperands =
op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
(operand.name + "Mutable").str());
auto &body = m->body();
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
<< " return ::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
if (attrSizedOperands)
body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
<< "u, *getOperation()->getAttrDictionary().getNamed("
"\"operand_segment_sizes\"))";
body << ");\n";
}
}
void OpEmitter::genNamedResultGetters() {
const int numResults = op.getNumResults();
const int numVariadicResults = op.getNumVariableLengthResults();
const int numNormalResults = numResults - numVariadicResults;
// If we have more than one variadic results, we need more complicated logic
// to calculate the value range for each result.
const auto *sameVariadicSize =
op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
const auto *attrSizedResults =
op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
"specification over their sizes");
}
if (numVariadicResults < 2 && attrSizedResults) {
PrintFatalError(op.getLoc(), "op must have at least two variadic results "
"to use 'AttrSizedResultSegments' trait");
}
if (attrSizedResults && sameVariadicSize) {
PrintFatalError(op.getLoc(),
"op cannot have both 'AttrSizedResultSegments' and "
"'SameVariadicResultSize' traits");
}
generateValueRangeStartAndEnd(
opClass, "getODSResultIndexAndLength", numVariadicResults,
numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
op.getResults());
auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
"getODSResults", "unsigned", "index");
m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
"getODSResultIndexAndLength(index)");
for (int i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i);
if (result.name.empty())
continue;
if (result.isOptional()) {
m = opClass.addMethodAndPrune("::mlir::Value", result.name);
m->body()
<< " auto results = getODSResults(" << i << ");\n"
<< " return results.empty() ? ::mlir::Value() : *results.begin();";
} else if (result.isVariadic()) {
m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
result.name);
m->body() << " return getODSResults(" << i << ");";
} else {
m = opClass.addMethodAndPrune("::mlir::Value", result.name);
m->body() << " return *getODSResults(" << i << ").begin();";
}
}
}
void OpEmitter::genNamedRegionGetters() {
unsigned numRegions = op.getNumRegions();
for (unsigned i = 0; i < numRegions; ++i) {
const auto &region = op.getRegion(i);
if (region.name.empty())
continue;
// Generate the accessors for a variadic region.
if (region.isVariadic()) {
auto *m = opClass.addMethodAndPrune(
"::mlir::MutableArrayRef<::mlir::Region>", region.name);
m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
i);
continue;
}
auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name);
m->body() << formatv(" return (*this)->getRegion({0});", i);
}
}
void OpEmitter::genNamedSuccessorGetters() {
unsigned numSuccessors = op.getNumSuccessors();
for (unsigned i = 0; i < numSuccessors; ++i) {
const NamedSuccessor &successor = op.getSuccessor(i);
if (successor.name.empty())
continue;
// Generate the accessors for a variadic successor list.
if (successor.isVariadic()) {
auto *m =
opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name);
m->body() << formatv(
" return {std::next((*this)->successor_begin(), {0}), "
"(*this)->successor_end()};",
i);
continue;
}
auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name);
m->body() << formatv(" return (*this)->getSuccessor({0});", i);
}
}
static bool canInferType(Operator &op) {
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
op.getNumRegions() == 0;
}
void OpEmitter::genSeparateArgParamBuilder() {
SmallVector<AttrParamKind, 2> attrBuilderType;
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
// Emit with separate builders with or without unwrapped attributes and/or
// inferring result type.
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
bool inferType) {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, paramKind, attrType);
2021-08-04 20:24:07 +08:00
auto *m = opClass.addMethodAndPrune(
2021-08-13 15:05:10 +08:00
"builder::Op", "build", OpMethod::MP_Static, std::move(paramList));
2021-07-23 11:38:34 +08:00
// If the builder is redundant, skip generating the method.
if (!m)
return;
auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder(
2021-08-11 10:46:07 +08:00
body, attrType == AttrParamKind::UnwrappedValue);
2021-07-23 11:38:34 +08:00
// Push all result types to the operation state
2021-08-11 10:46:07 +08:00
2021-08-04 20:24:07 +08:00
// if (inferType) {
// // Generate builder that infers type too.
// // TODO: Subsume this with general checking if type can be
// // inferred automatically.
// // TODO: Expand to handle regions.
// body << formatv(R"(
// ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
// if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
// {1}.location, {1}.operands,
// {1}.attributes.getDictionary({1}.getContext()),
// /*regions=*/{{}, inferredReturnTypes)))
// {1}.addTypes(inferredReturnTypes);
// else
// ::llvm::report_fatal_error("Failed to infer result type(s).");)",
// opClass.getClassName(), builderOpState);
// return;
// }
// switch (paramKind) {
// case TypeParamKind::None:
// return;
// case TypeParamKind::Separate:
// for (int i = 0, e = op.getNumResults(); i < e; ++i) {
// if (op.getResult(i).isOptional())
// body << " if (" << resultNames[i] << ")\n ";
// body << " " << builderOpState << ".addTypes(" << resultNames[i]
// << ");\n";
// }
// return;
// case TypeParamKind::Collective: {
// int numResults = op.getNumResults();
// int numVariadicResults = op.getNumVariableLengthResults();
// int numNonVariadicResults = numResults - numVariadicResults;
// bool hasVariadicResult = numVariadicResults != 0;
// // Avoid emitting "resultTypes.size() >= 0u" which is always true.
// if (!(hasVariadicResult && numNonVariadicResults == 0))
// body << " "
// << "assert(resultTypes.size() "
// << (hasVariadicResult ? ">=" : "==") << " "
// << numNonVariadicResults
// << "u && \"mismatched number of results\");\n";
// body << " " << builderOpState << ".addTypes(resultTypes);\n";
// }
// return;
// }
// llvm_unreachable("unhandled TypeParamKind");
2021-07-23 11:38:34 +08:00
};
// Some of the build methods generated here may be ambiguous, but TableGen's
// ambiguous function detection will elide those ones.
for (auto attrType : attrBuilderType) {
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
2021-08-04 20:24:07 +08:00
// if (canInferType(op))
// emit(attrType, TypeParamKind::None, /*inferType=*/true);
// emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
2021-07-23 11:38:34 +08:00
}
}
/// Returns a signature of the builder. Updates the context `fctx` to enable
/// replacement of $_builder and $_state in the body.
static std::string getBuilderSignature(const Builder &builder) {
ArrayRef<Builder::Parameter> params(builder.getParameters());
// Inject builder and state arguments.
llvm::SmallVector<std::string, 8> arguments;
arguments.reserve(params.size() + 2);
arguments.push_back(
llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
arguments.push_back(
llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
for (unsigned i = 0, e = params.size(); i < e; ++i) {
// If no name is provided, generate one.
Optional<StringRef> paramName = params[i].getName();
std::string name =
paramName ? paramName->str() : "odsArg" + std::to_string(i);
std::string defaultValue;
if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
arguments.push_back(
llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
.str());
}
return llvm::join(arguments, ", ");
}
void OpEmitter::genBuilder() {
// Handle custom builders if provided.
2021-08-04 20:24:07 +08:00
// for (const Builder &builder : op.getBuilders()) {
// std::string paramStr = getBuilderSignature(builder);
// Optional<StringRef> body = builder.getBody();
// OpMethod::Property properties =
// body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
// auto *method =
// opClass.addMethodAndPrune("void", "build", properties, paramStr);
// FmtContext fctx;
// fctx.withBuilder(odsBuilder);
// fctx.addSubst("_state", builderOpState);
// if (body)
// method->body() << tgfmt(*body, &fctx);
// }
2021-07-23 11:38:34 +08:00
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
if (op.skipDefaultBuilders())
return;
// We generate three classes of builders here:
// 1. one having a stand-alone parameter for each operand / attribute, and
genSeparateArgParamBuilder();
// 2. one having an aggregated parameter for all result types / operands /
// attributes, and
2021-08-04 20:24:07 +08:00
// genCollectiveParamBuilder();
// // 3. one having a stand-alone parameter for each operand and attribute,
// // use the first operand or attribute's type as all result types
// // to facilitate different call patterns.
// if (op.getNumVariableLengthResults() == 0) {
// if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// genUseOperandAsResultTypeSeparateParamBuilder();
// genUseOperandAsResultTypeCollectiveParamBuilder();
// }
// if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
// genUseAttrAsResultTypeBuilder();
// }
2021-07-23 11:38:34 +08:00
}
void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
AttrParamKind attrParamKind) {
resultTypeNames.clear();
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
2021-08-04 20:24:07 +08:00
// paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2021-08-13 15:05:10 +08:00
paramList.emplace_back("builder::Builder &", "builder");
2021-08-04 20:24:07 +08:00
// paramList.emplace_back("::mlir::OperationState &", builderOpState);
2021-07-23 11:38:34 +08:00
switch (typeParamKind) {
case TypeParamKind::None:
break;
case TypeParamKind::Separate: {
// Add parameters for all return types
for (int i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
std::string resultName = std::string(result.name);
if (resultName.empty())
resultName = std::string(formatv("resultType{0}", i));
StringRef type =
2021-08-13 15:05:10 +08:00
result.isVariadic() ? "std::vector<builder::Type>" : "builder::Type";
2021-07-23 11:38:34 +08:00
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (result.isOptional())
properties = OpMethodParameter::PP_Optional;
paramList.emplace_back(type, resultName, properties);
resultTypeNames.emplace_back(std::move(resultName));
}
} break;
case TypeParamKind::Collective: {
paramList.emplace_back("::mlir::TypeRange", "resultTypes");
resultTypeNames.push_back("resultTypes");
} break;
}
// Add parameters for all arguments (operands and attributes).
int numOperands = 0;
int numAttrs = 0;
int defaultValuedAttrStartIndex = op.getNumArgs();
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
StringRef type =
2021-08-13 15:05:10 +08:00
operand.isVariadic() ? "std::vector<builder::Op>" : "builder::Op";
2021-07-23 11:38:34 +08:00
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (operand.isOptional())
properties = OpMethodParameter::PP_Optional;
paramList.emplace_back(type, getArgumentName(op, numOperands),
properties);
++numOperands;
} else {
const auto &namedAttr = op.getAttribute(numAttrs);
const auto &attr = namedAttr.attr;
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (attr.isOptional())
properties = OpMethodParameter::PP_Optional;
StringRef type;
switch (attrParamKind) {
case AttrParamKind::WrappedAttr:
2021-08-04 20:24:07 +08:00
type = getStorageType(attr);
2021-07-23 11:38:34 +08:00
break;
case AttrParamKind::UnwrappedValue:
2021-08-11 10:46:07 +08:00
// if (canUseUnwrappedRawValue(attr))
// type = getReturnType(attr);
// else
// type = getStorageType(attr);
2021-07-23 11:38:34 +08:00
break;
}
std::string defaultValue;
// Attach default value if requested and possible.
if (attrParamKind == AttrParamKind::UnwrappedValue &&
i >= defaultValuedAttrStartIndex) {
2021-08-04 20:24:07 +08:00
bool isString = getReturnType(attr) == "::llvm::StringRef";
2021-07-23 11:38:34 +08:00
if (isString)
defaultValue.append("\"");
defaultValue += attr.getDefaultValue();
if (isString)
defaultValue.append("\"");
}
paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
++numAttrs;
}
}
/// Insert parameters for each successor.
for (const NamedSuccessor &succ : op.getSuccessors()) {
StringRef type =
succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
paramList.emplace_back(type, succ.name);
}
/// Insert parameters for variadic regions.
for (const NamedRegion &region : op.getRegions())
if (region.isVariadic())
paramList.emplace_back("unsigned",
llvm::formatv("{0}Count", region.name).str());
}
2021-08-11 10:46:07 +08:00
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
bool isRawValueAttr) {
auto op = GetOp();
auto operands = op.getOperands();
auto attrs = op.getAttributes();
2021-08-16 15:35:37 +08:00
auto numResults = op.getNumResults();
auto numOperands = op.getNumOperands();
2021-08-11 10:46:07 +08:00
SmallVector<std::string, 4> newAttrs;
2021-08-04 20:24:07 +08:00
2021-08-13 15:05:10 +08:00
// if (attrType == "::mlir::DenseIntElementsAttr") {
2021-08-16 15:35:37 +08:00
// body << " // BBBBBBBB getStorageType:" <<
// a.attr.getStorageType().str()
2021-08-13 15:05:10 +08:00
// << "\n";
// body << " // BBBBBBBB getReturnType:" << a.attr.getReturnType().str()
// << "\n";
2021-08-11 10:46:07 +08:00
// }
2021-08-16 15:35:37 +08:00
// for(const auto& a : op.getArgs()){
// body << " // BBBBBBBB argument.is:"
// << a.is<tblgen::NamedTypeConstraint *>() << "\n";
// }
// // if (argument.is<tblgen::NamedTypeConstraint *>())
body << " auto bPtr = builder.GetImpl();\n";
body << " auto loc = bPtr->GetLoc();\n";
body << " auto opBuilder = bPtr->GetBuilder();\n";
body << " auto ctx = bPtr->GetContext();\n";
2021-08-11 10:46:07 +08:00
for (auto a : attrs) {
std::string attrType = a.attr.getStorageType().str();
auto typePair = typeMapMLIR.find(attrType);
if (typePair != typeMapMLIR.end()) {
std::string attrName = a.name.str();
auto mlirName = typePair->second.ConvertToMlir(attrName, body);
newAttrs.emplace_back(mlirName);
}
}
2021-08-16 15:35:37 +08:00
for (int i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
std::string resultName = std::string(result.name);
if (resultName.empty())
resultName = std::string(formatv("resultType{0}", i));
bool isVec = result.isVariadic();
if (isVec) {
body << " std::vector<mlir::Type> " << resultName << "_v;\n";
body << " for(auto r : " << resultName << "){\n " << resultName
<< "_v.push_back(r.GetImpl()->GetMlirType(ctx));\n }"
<< "\n";
}
}
for (int i = 0; i < numOperands; i++) {
const auto &v = op.getOperand(i);
2021-08-13 15:05:10 +08:00
std::string name =
2021-08-16 15:35:37 +08:00
v.name.empty() ? "odsArg" + std::to_string(i) : v.name.str();
2021-08-13 15:05:10 +08:00
if (v.isVariadic()) {
body << " std::vector<mlir::Value> " << name << "_v;\n";
body << " for(auto v : " << name << "){\n " << name
<< "_v.push_back(v.GetImpl()->GetResult());\n }"
<< "\n";
}
}
2021-08-04 20:24:07 +08:00
body << " mlir::" << op.getDialectName() << "::" << op.getCppClassName()
<< " currentOp =\n";
body << " opBuilder.create<mlir::" << op.getDialectName()
2021-08-13 15:05:10 +08:00
<< "::" << op.getCppClassName() << ">(\n";
2021-08-11 10:46:07 +08:00
body << " loc";
2021-08-13 15:05:10 +08:00
2021-08-16 15:35:37 +08:00
for (int i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
std::string resultName = std::string(result.name);
if (resultName.empty())
resultName = std::string(formatv("resultType{0}", i));
bool isVec = result.isVariadic();
if (isVec) {
body << ",\n mlir::TypeRange(" << resultName << "_v)";
2021-08-13 15:05:10 +08:00
} else {
2021-08-16 15:35:37 +08:00
body << ",\n " << resultName << ".GetImpl()->GetMlirType(ctx)";
2021-08-13 15:05:10 +08:00
}
2021-08-16 15:35:37 +08:00
}
int operandIndex = 0;
int attributeIndex = 0;
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
// if true Operands else Attribute
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &v = op.getOperand(operandIndex);
std::string name = v.name.empty()
? "odsArg_" + std::to_string(operandIndex)
: v.name.str();
if (v.isVariadic()) {
body << ",\n mlir::ValueRange(" << name << "_v)";
} else {
body << ",\n " << name << ".GetImpl()->GetResult()";
}
operandIndex++;
} else {
body << ",\n " << newAttrs[attributeIndex];
attributeIndex++;
}
}
for (const NamedRegion &region : op.getRegions())
if (region.isVariadic())
body << ",\n " << llvm::formatv("{0}Count", region.name).str();
2021-08-11 10:46:07 +08:00
body << "\n );\n";
2021-08-13 15:05:10 +08:00
body << " builder::mhlo::" << op.getCppClassName() << " builderOp;\n";
2021-08-04 20:24:07 +08:00
body << " auto opImpl = builderOp.GetImpl();\n";
2021-08-13 15:05:10 +08:00
body << " opImpl->SetOperation(currentOp.getOperation());\n";
2021-08-04 20:24:07 +08:00
body << " return builderOp;\n";
2021-07-23 11:38:34 +08:00
2021-08-04 20:24:07 +08:00
// // Push all operands to the result.
// for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
// std::string argName = getArgumentName(op, i);
// if (op.getOperand(i).isOptional())
// body << " if (" << argName << ")\n ";
// body << " " << builderOpState << ".addOperands(" << argName << ");\n";
// }
// // If the operation has the operand segment size attribute, add it here.
// if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
// body << " " << builderOpState
// << ".addAttribute(\"operand_segment_sizes\", "
// "odsBuilder.getI32VectorAttr({";
2021-08-16 15:35:37 +08:00
// interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i)
// {
2021-08-04 20:24:07 +08:00
// if (op.getOperand(i).isOptional())
// body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
// else if (op.getOperand(i).isVariadic())
2021-08-16 15:35:37 +08:00
// body << "static_cast<int32_t>(" << getArgumentName(op, i) <<
// ".size())";
2021-08-04 20:24:07 +08:00
// else
// body << "1";
// });
// body << "}));\n";
// }
// // Push all attributes to the result.
// for (const auto &namedAttr : op.getAttributes()) {
// auto &attr = namedAttr.attr;
// if (!attr.isDerivedAttr()) {
// bool emitNotNullCheck = attr.isOptional();
// if (emitNotNullCheck) {
// body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
// }
// if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
// // If this is a raw value, then we need to wrap it in an Attribute
// // instance.
// FmtContext fctx;
// fctx.withBuilder("odsBuilder");
// std::string builderTemplate =
// std::string(attr.getConstBuilderTemplate());
// // For StringAttr, its constant builder call will wrap the input in
// // quotes, which is correct for normal string literals, but incorrect
// // here given we use function arguments. So we need to strip the
// // wrapping quotes.
// if (StringRef(builderTemplate).contains("\"$0\""))
2021-08-16 15:35:37 +08:00
// builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"",
// "$0");
2021-08-04 20:24:07 +08:00
// std::string value =
// std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
2021-08-16 15:35:37 +08:00
// body << formatv(" {0}.addAttribute(\"{1}\", {2});\n",
// builderOpState,
2021-08-04 20:24:07 +08:00
// namedAttr.name, value);
// } else {
2021-08-16 15:35:37 +08:00
// body << formatv(" {0}.addAttribute(\"{1}\", {1});\n",
// builderOpState,
2021-08-04 20:24:07 +08:00
// namedAttr.name);
// }
// if (emitNotNullCheck) {
// body << " }\n";
// }
// }
// }
// // Create the correct number of regions.
// for (const NamedRegion &region : op.getRegions()) {
// if (region.isVariadic())
// body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
// region.name);
// body << " (void)" << builderOpState << ".addRegion();\n";
// }
// // Push all successors to the result.
// for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
// body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
// namedSuccessor.name);
// }
2021-07-23 11:38:34 +08:00
}
void OpEmitter::genCanonicalizerDecls() {
bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
if (hasCanonicalizeMethod) {
// static LogicResult FooOp::
// canonicalize(FooOp op, PatternRewriter &rewriter);
SmallVector<OpMethodParameter, 2> paramList;
paramList.emplace_back(op.getCppClassName(), "op");
paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
OpMethod::MP_StaticDeclaration,
std::move(paramList));
}
// We get a prototype for 'getCanonicalizationPatterns' if requested directly
// or if using a 'canonicalize' method.
bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
if (!hasCanonicalizeMethod && !hasCanonicalizer)
return;
// We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
// method, but not implementing 'getCanonicalizationPatterns' manually.
bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
// Add a signature for getCanonicalizationPatterns if implemented by the
// dialect or if synthesized to call 'canonicalize'.
SmallVector<OpMethodParameter, 2> paramList;
paramList.emplace_back("::mlir::RewritePatternSet &", "results");
paramList.emplace_back("::mlir::MLIRContext *", "context");
auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
auto *method = opClass.addMethodAndPrune(
"void", "getCanonicalizationPatterns", kind, std::move(paramList));
// If synthesizing the method, fill it it.
if (hasBody)
method->body() << " results.add(canonicalize);\n";
}
void OpEmitter::genFolderDecls() {
bool hasSingleResult =
op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
opClass.addMethodAndPrune(
"::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
"::llvm::ArrayRef<::mlir::Attribute>", "operands");
} else {
SmallVector<OpMethodParameter, 2> paramList;
paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
"results");
opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
OpMethod::MP_Declaration, std::move(paramList));
}
}
}
void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
Interface interface = opTrait->getInterface();
// Get the set of methods that should always be declared.
auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
llvm::StringSet<> alwaysDeclaredMethods;
alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
alwaysDeclaredMethodsVec.end());
for (const InterfaceMethod &method : interface.getMethods()) {
// Don't declare if the method has a body.
if (method.getBody())
continue;
// Don't declare if the method has a default implementation and the op
// didn't request that it always be declared.
if (method.getDefaultImplementation() &&
!alwaysDeclaredMethods.count(method.getName()))
continue;
genOpInterfaceMethod(method);
}
}
OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
bool declaration) {
SmallVector<OpMethodParameter, 4> paramList;
for (const InterfaceMethod::Argument &arg : method.getArguments())
paramList.emplace_back(arg.type, arg.name);
auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
if (declaration)
properties =
static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
properties, std::move(paramList));
}
void OpEmitter::genOpInterfaceMethods() {
for (const auto &trait : op.getTraits()) {
if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
if (opTrait->shouldDeclareMethods())
genOpInterfaceMethods(opTrait);
}
}
void OpEmitter::genSideEffectInterfaceMethods() {
enum EffectKind { Operand, Result, Symbol, Static };
struct EffectLocation {
/// The effect applied.
SideEffect effect;
/// The index if the kind is not static.
unsigned index : 30;
/// The kind of the location.
unsigned kind : 2;
};
StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
auto resolveDecorators = [&](Operator::var_decorator_range decorators,
unsigned index, unsigned kind) {
for (auto decorator : decorators)
if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
opClass.addTrait(effect->getInterfaceTrait());
interfaceEffects[effect->getBaseEffectName()].push_back(
EffectLocation{*effect, index, kind});
}
};
// Collect effects that were specified via:
/// Traits.
for (const auto &trait : op.getTraits()) {
const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
if (!opTrait)
continue;
auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
for (auto decorator : opTrait->getEffects())
effects.push_back(EffectLocation{cast<SideEffect>(decorator),
/*index=*/0, EffectKind::Static});
}
/// Attributes and Operands.
for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
Argument arg = op.getArg(i);
if (arg.is<NamedTypeConstraint *>()) {
resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
++operandIt;
continue;
}
const NamedAttribute *attr = arg.get<NamedAttribute *>();
if (attr->attr.getBaseAttr().isSymbolRefAttr())
resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
}
/// Results.
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
// The code used to add an effect instance.
// {0}: The effect class.
// {1}: Optional value or symbol reference.
// {1}: The resource class.
const char *addEffectCode =
" effects.emplace_back({0}::get(), {1}{2}::get());\n";
for (auto &it : interfaceEffects) {
// Generate the 'getEffects' method.
std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
"SideEffects::EffectInstance<{0}>> &",
it.first())
.str();
auto *getEffects =
opClass.addMethodAndPrune("void", "getEffects", type, "effects");
auto &body = getEffects->body();
// Add effect instances for each of the locations marked on the operation.
for (auto &location : it.second) {
StringRef effect = location.effect.getName();
StringRef resource = location.effect.getResource();
if (location.kind == EffectKind::Static) {
// A static instance has no attached value.
body << llvm::formatv(addEffectCode, effect, "", resource).str();
} else if (location.kind == EffectKind::Symbol) {
// A symbol reference requires adding the proper attribute.
const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
if (attr->attr.isOptional()) {
body << " if (auto symbolRef = " << attr->name << "Attr())\n "
<< llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
.str();
} else {
body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
resource)
.str();
}
} else {
// Otherwise this is an operand/result, so we need to attach the Value.
body << " for (::mlir::Value value : getODS"
<< (location.kind == EffectKind::Operand ? "Operands" : "Results")
<< "(" << location.index << "))\n "
<< llvm::formatv(addEffectCode, effect, "value, ", resource).str();
}
}
}
}
void OpEmitter::genTypeInterfaceMethods() {
if (!op.allResultTypesKnown())
return;
// Generate 'inferReturnTypes' method declaration using the interface method
// declared in 'InferTypeOpInterface' op interface.
const auto *trait = dyn_cast<InterfaceTrait>(
op.getTrait("::mlir::InferTypeOpInterface::Trait"));
Interface interface = trait->getInterface();
OpMethod *method = [&]() -> OpMethod * {
for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
if (interfaceMethod.getName() == "inferReturnTypes") {
return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
}
}
assert(0 && "unable to find inferReturnTypes interface method");
return nullptr;
}();
auto &body = method->body();
body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
FmtContext fctx;
fctx.withBuilder("odsBuilder");
body << " ::mlir::Builder odsBuilder(context);\n";
auto emitType =
[&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
if (type.isArg()) {
auto argIndex = type.getArg();
assert(!op.getArg(argIndex).is<NamedAttribute *>());
auto arg = op.getArgToOperandOrAttribute(argIndex);
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
return body << "operands[" << arg.operandOrAttributeIndex()
<< "].getType()";
return body << "attributes[" << arg.operandOrAttributeIndex()
<< "].getType()";
} else {
return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
}
};
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
body << " inferredReturnTypes[" << i << "] = ";
auto types = op.getSameTypeAsResult(i);
emitType(types[0]) << ";\n";
if (types.size() == 1)
continue;
// TODO: We could verify equality here, but skipping that for verification.
}
body << " return ::mlir::success();";
}
void OpEmitter::genParser() {
if (!hasStringAttribute(def, "parser") ||
hasStringAttribute(def, "assemblyFormat"))
return;
SmallVector<OpMethodParameter, 2> paramList;
paramList.emplace_back("::mlir::OpAsmParser &", "parser");
paramList.emplace_back("::mlir::OperationState &", "result");
auto *method =
opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
OpMethod::MP_Static, std::move(paramList));
FmtContext fctx;
fctx.addSubst("cppClass", opClass.getClassName());
auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
method->body() << " " << tgfmt(parser, &fctx);
}
void OpEmitter::genPrinter() {
if (hasStringAttribute(def, "assemblyFormat"))
return;
auto valueInit = def.getValueInit("printer");
StringInit *stringInit = dyn_cast<StringInit>(valueInit);
if (!stringInit)
return;
auto *method =
opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
FmtContext fctx;
fctx.addSubst("cppClass", opClass.getClassName());
auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
method->body() << " " << tgfmt(printer, &fctx);
}
void OpEmitter::genVerifier() {
auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
auto &body = method->body();
// body << " if (failed(" << op.getAdaptorName()
// << "(*this).verify((*this)->getLoc()))) "
// << "return ::mlir::failure();\n";
auto *valueInit = def.getValueInit("verifier");
StringInit *stringInit = dyn_cast<StringInit>(valueInit);
bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
populateSubstitutions(op, "(*this)->getAttr", "this->getODSOperands",
"this->getODSResults", verifyCtx);
genAttributeVerifier(op, "(*this)->getAttr", "emitOpError(",
/*emitVerificationRequiringOp=*/true, verifyCtx, body);
genOperandResultVerifier(body, op.getOperands(), "operand");
genOperandResultVerifier(body, op.getResults(), "result");
for (auto &trait : op.getTraits()) {
if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
body << tgfmt(" if (!($0))\n "
"return emitOpError(\"failed to verify that $1\");\n",
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
t->getSummary());
}
}
genRegionVerifier(body);
genSuccessorVerifier(body);
if (hasCustomVerify) {
FmtContext fctx;
fctx.addSubst("cppClass", opClass.getClassName());
auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
body << " " << tgfmt(printer, &fctx);
} else {
body << " return ::mlir::success();\n";
}
}
void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
Operator::value_range values,
StringRef valueKind) {
FmtContext fctx;
body << " {\n";
body << " unsigned index = 0; (void)index;\n";
for (auto staticValue : llvm::enumerate(values)) {
bool hasPredicate = staticValue.value().hasPredicate();
bool isOptional = staticValue.value().isOptional();
if (!hasPredicate && !isOptional)
continue;
body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
staticValue.index());
// If the constraint is optional check that the value group has at most 1
// value.
if (isOptional) {
body << formatv(" if (valueGroup{0}.size() > 1)\n"
" return emitOpError(\"{1} group starting at #\") "
"<< index << \" requires 0 or 1 element, but found \" << "
"valueGroup{0}.size();\n",
staticValue.index(), valueKind);
}
// Otherwise, if there is no predicate there is nothing left to do.
if (!hasPredicate)
continue;
// Emit a loop to check all the dynamic values in the pack.
StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
staticValue.value().constraint);
body << " for (::mlir::Value v : valueGroup" << staticValue.index()
<< ") {\n"
<< " if (::mlir::failed(" << constraintFn
<< "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
<< " return ::mlir::failure();\n"
<< " ++index;\n"
<< " }\n";
}
body << " }\n";
}
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
// If we have no regions, there is nothing more to do.
unsigned numRegions = op.getNumRegions();
if (numRegions == 0)
return;
body << "{\n";
body << " unsigned index = 0; (void)index;\n";
for (unsigned i = 0; i < numRegions; ++i) {
const auto &region = op.getRegion(i);
if (region.constraint.getPredicate().isNull())
continue;
body << " for (::mlir::Region &region : ";
body << formatv(region.isVariadic()
? "{0}()"
: "::mlir::MutableArrayRef<::mlir::Region>((*this)"
"->getRegion({1}))",
region.name, i);
body << ") {\n";
auto constraint = tgfmt(region.constraint.getConditionTemplate(),
&verifyCtx.withSelf("region"))
.str();
body << formatv(" (void)region;\n"
" if (!({0})) {\n "
"return emitOpError(\"region #\") << index << \" {1}"
"failed to "
"verify constraint: {2}\";\n }\n",
constraint,
region.name.empty() ? "" : "('" + region.name + "') ",
region.constraint.getSummary())
<< " ++index;\n"
<< " }\n";
}
body << " }\n";
}
void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
// If we have no successors, there is nothing more to do.
unsigned numSuccessors = op.getNumSuccessors();
if (numSuccessors == 0)
return;
body << "{\n";
body << " unsigned index = 0; (void)index;\n";
for (unsigned i = 0; i < numSuccessors; ++i) {
const auto &successor = op.getSuccessor(i);
if (successor.constraint.getPredicate().isNull())
continue;
if (successor.isVariadic()) {
body << formatv(" for (::mlir::Block *successor : {0}()) {\n",
successor.name);
} else {
body << " {\n";
body << formatv(" ::mlir::Block *successor = {0}();\n",
successor.name);
}
auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
&verifyCtx.withSelf("successor"))
.str();
body << formatv(" (void)successor;\n"
" if (!({0})) {\n "
"return emitOpError(\"successor #\") << index << \"('{1}') "
"failed to "
"verify constraint: {2}\";\n }\n",
constraint, successor.name,
successor.constraint.getSummary())
<< " ++index;\n"
<< " }\n";
}
body << " }\n";
}
/// Add a size count trait to the given operation class.
static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
int numTotal, int numVariadic) {
if (numVariadic != 0) {
if (numTotal == numVariadic)
opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
else
opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
Twine(numTotal - numVariadic) + ">::Impl");
return;
}
switch (numTotal) {
case 0:
opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
break;
case 1:
opClass.addTrait("::mlir::OpTrait::One" + traitKind);
break;
default:
opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
">::Impl");
break;
}
}
void OpEmitter::genTraits() {
// Add region size trait.
unsigned numRegions = op.getNumRegions();
unsigned numVariadicRegions = op.getNumVariadicRegions();
addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
// Add result size traits.
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariableLengthResults();
addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
// For single result ops with a known specific type, generate a OneTypedResult
// trait.
if (numResults == 1 && numVariadicResults == 0) {
auto cppName = op.getResults().begin()->constraint.getCPPClassName();
opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
}
// Add successor size trait.
unsigned numSuccessors = op.getNumSuccessors();
unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
// Add variadic size trait and normal op traits.
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariableLengthOperands();
// Add operand size trait.
if (numVariadicOperands != 0) {
if (numOperands == numVariadicOperands)
opClass.addTrait("::mlir::OpTrait::VariadicOperands");
else
opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
Twine(numOperands - numVariadicOperands) + ">::Impl");
} else {
switch (numOperands) {
case 0:
opClass.addTrait("::mlir::OpTrait::ZeroOperands");
break;
case 1:
opClass.addTrait("::mlir::OpTrait::OneOperand");
break;
default:
opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
">::Impl");
break;
}
}
// Add the native and interface traits.
for (const auto &trait : op.getTraits()) {
if (auto opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
opClass.addTrait(opTrait->getFullyQualifiedTraitName());
else if (auto opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
opClass.addTrait(opTrait->getFullyQualifiedTraitName());
}
}
void OpEmitter::genOpNameGetter() {
auto *method = opClass.addMethodAndPrune(
"std::string", "getOperationName",
2021-08-13 15:05:10 +08:00
OpMethod::Property(OpMethod::MP_Static));
// OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
2021-08-04 20:24:07 +08:00
method->body() << " return std::string(\"" << op.getOperationName()
2021-07-23 11:38:34 +08:00
<< "\");";
}
void OpEmitter::genOpAsmInterface() {
// If the user only has one results or specifically added the Asm trait,
// then don't generate it for them. We specifically only handle multi result
// operations, because the name of a single result in the common case is not
// interesting(generally 'result'/'output'/etc.).
// TODO: We could also add a flag to allow operations to opt in to this
// generation, even if they only have a single operation.
int numResults = op.getNumResults();
if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
return;
SmallVector<StringRef, 4> resultNames(numResults);
for (int i = 0; i != numResults; ++i)
resultNames[i] = op.getResultName(i);
// Don't add the trait if none of the results have a valid name.
if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
return;
opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
// Generate the right accessor for the number of results.
auto *method = opClass.addMethodAndPrune(
"void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
auto &body = method->body();
for (int i = 0; i != numResults; ++i) {
body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
<< " if (!llvm::empty(resultGroup" << i << "))\n"
<< " setNameFn(*resultGroup" << i << ".begin(), \""
<< resultNames[i] << "\");\n";
}
}
// Emits the opcode enum and op classes.
static void emitOpClasses(const RecordKeeper &recordKeeper,
const std::vector<Record *> &defs, raw_ostream &os,
bool emitDecl) {
// First emit forward declaration for each class, this allows them to refer
// to each others in traits for example.
if (emitDecl) {
os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
os << "#undef GET_OP_FWD_DEFINES\n";
for (auto *def : defs) {
Operator op(*def);
NamespaceEmitter emitter(os, op.getCppNamespace());
os << "class " << op.getCppClassName() << ";\n";
}
os << "#endif\n\n";
}
IfDefScope scope("GET_OP_CLASSES", os);
if (defs.empty())
return;
// Generate all of the locally instantiated methods first.
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
emitDecl);
for (auto *def : defs) {
Operator op(*def);
NamespaceEmitter emitter(os, op.getCppNamespace());
if (emitDecl) {
2021-08-04 20:24:07 +08:00
// os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
2021-07-23 11:38:34 +08:00
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
} else {
2021-08-04 20:24:07 +08:00
// os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2021-07-23 11:38:34 +08:00
OpEmitter::emitDef(op, os, staticVerifierEmitter);
}
}
}
// Emits a comma-separated list of the ops.
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
IfDefScope scope("GET_OP_LIST", os);
interleave(
// TODO: We are constructing the Operator wrapper instance just for
// getting it's qualified class name here. Reduce the overhead by having a
// lightweight version of Operator class just for that purpose.
2021-08-04 20:24:07 +08:00
defs,
[&os](Record *def) {
SmallVector<StringRef, 4> namespaces;
std::string className = Operator(def).getQualCppClassName();
llvm::SplitString(StringRef(className), namespaces, StringRef("::"));
if (namespaces.begin() != namespaces.end())
2021-08-13 15:05:10 +08:00
os << "builder::mhlo::" << namespaces.back().str();
2021-08-04 20:24:07 +08:00
},
2021-07-23 11:38:34 +08:00
[&os]() { os << ",\n"; });
}
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Declarations", os);
std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
return false;
}
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Definitions", os);
std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
emitOpList(defs, os);
emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
return false;
}
static mlir::GenRegistration
genOpDecls("gen-builder-decls", "Generate op declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitOpDecls(records, os);
});
static mlir::GenRegistration genOpDefs("gen-builder-defs", "Generate op definitions",
[](const RecordKeeper &records,
raw_ostream &os) {
return emitOpDefs(records, os);
});