NFC: Attribute cleanup (remove references of attributes) (#286)
* Define krnl.permute op. * Support krnl.permute operation. * Properly remove loop references. * Re-push, Github was down. * Need to debug interpretOp error. * Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code. * Introduce permute, unroll operations. * More debug. * Remove std::set. * krnl.terminate fails to be converted. * Pass all tests, need to add legal ops as well as part of the conversion target. * Change test format to new permute spec. * Bug fix for nested iterate op lowering. * Simplify error reporting. * Fix compilation error. * Increase comments coverage. * Remove unnecessary imports. * Re-trigger Jenkins * Add permute/unroll tests. * Retrigger Jenkins * remove & (ref) for Attributes Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
8bfde7de4b
commit
c1262c184e
|
@ -111,7 +111,7 @@ private:
|
|||
}
|
||||
|
||||
mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
|
||||
onnx::AttributeProto &attr) {
|
||||
onnx::AttributeProto attr) {
|
||||
mlir::Attribute mlirAttr;
|
||||
switch (attr.type()) {
|
||||
case onnx::AttributeProto::FLOAT:
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace {
|
|||
|
||||
/// Compute the combined permute pattern from a pair of permute patterns.
|
||||
ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
|
||||
ArrayAttr &firstPermAttr, ArrayAttr &secondPermAttr) {
|
||||
ArrayAttr firstPermAttr, ArrayAttr secondPermAttr) {
|
||||
// Read first permute vectors.
|
||||
SmallVector<int64_t, 4> initialPerm;
|
||||
for (auto firstPermVal : firstPermAttr.getValue())
|
||||
|
@ -44,7 +44,7 @@ ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
|
|||
|
||||
/// Test if the permute pattern correspond to an identity pattern.
|
||||
/// Identity patterns are {0, 1, 2, ... , rank -1}.
|
||||
bool IsIdentityPermuteVector(ArrayAttr &permAttr) {
|
||||
bool IsIdentityPermuteVector(ArrayAttr permAttr) {
|
||||
int64_t currentIndex = 0;
|
||||
for (auto permVal : permAttr.getValue())
|
||||
if (permVal.cast<IntegerAttr>().getInt() != currentIndex++)
|
||||
|
|
|
@ -59,14 +59,14 @@ namespace {
|
|||
|
||||
template <typename OP>
|
||||
Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
|
||||
Type elementType, Attribute lhsAttr, Attribute secondAttr) {
|
||||
llvm_unreachable("unkonwn operation");
|
||||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
|
||||
Attribute secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
|
@ -86,8 +86,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
|
|||
|
||||
template <>
|
||||
Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
|
||||
Attribute secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
|
@ -105,8 +105,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
|
|||
|
||||
template <>
|
||||
Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
|
||||
Attribute secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
|
@ -124,8 +124,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
|
|||
|
||||
template <>
|
||||
Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
|
||||
Attribute secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
|
@ -154,8 +154,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
|
|||
|
||||
template <typename ElementwiseBinaryOp>
|
||||
void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
|
||||
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr lhsAttr,
|
||||
DenseElementsAttr rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
|
||||
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
|
||||
if (lhsFreeRank == 0) {
|
||||
// Fully defined ranks.
|
||||
|
@ -222,7 +222,7 @@ void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
|
|||
// generate the new constant operation.
|
||||
template <typename ElementwiseBinaryOp>
|
||||
DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||
Value resOperand, Attribute &lhsAttr, Attribute &rhsAttr) {
|
||||
Value resOperand, Attribute lhsAttr, Attribute rhsAttr) {
|
||||
DenseElementsAttr lhsDenseAttr =
|
||||
lhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||
DenseElementsAttr rhsDenseAttr =
|
||||
|
@ -248,13 +248,13 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
|||
|
||||
template <typename OP>
|
||||
Attribute ComputeConstPropElementwiseUnary(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||
PatternRewriter &rewriter, Type elementType, Attribute attr) {
|
||||
llvm_unreachable("unkonwn operation");
|
||||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||
PatternRewriter &rewriter, Type elementType, Attribute attr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
||||
double res = -val;
|
||||
|
@ -270,7 +270,7 @@ Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
|
|||
|
||||
template <>
|
||||
Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||
PatternRewriter &rewriter, Type elementType, Attribute attr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
||||
double res = sqrt(val);
|
||||
|
@ -281,7 +281,7 @@ Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
|
|||
|
||||
template <typename ElementwiseUnaryOp>
|
||||
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr attr,
|
||||
SmallVector<uint64_t, 4> &indices, int freeRank) {
|
||||
if (freeRank == 0) {
|
||||
// Fully defined ranks.
|
||||
|
@ -308,7 +308,7 @@ void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
|
|||
// generate the new constant operation.
|
||||
template <typename ElementwiseUnaryOp>
|
||||
DenseElementsAttr ConstPropElementwiseUnary(
|
||||
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
|
||||
PatternRewriter &rewriter, Value resOperand, Attribute attr) {
|
||||
DenseElementsAttr denseAttr =
|
||||
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||
assert(denseAttr && "expected dense attribute");
|
||||
|
@ -329,7 +329,7 @@ DenseElementsAttr ConstPropElementwiseUnary(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RecurseConstPropTranspose(PatternRewriter &rewriter,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr attr,
|
||||
SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm,
|
||||
int freeRank) {
|
||||
if (freeRank == 0) {
|
||||
|
@ -351,7 +351,7 @@ void RecurseConstPropTranspose(PatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
|
||||
Value resOperand, Attribute &attr, ArrayAttr &permAttr) {
|
||||
Value resOperand, Attribute attr, ArrayAttr permAttr) {
|
||||
// Read dense attribute, the constant tensor we are transforming.
|
||||
DenseElementsAttr denseAttr =
|
||||
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||
|
@ -378,7 +378,7 @@ DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DenseElementsAttr ConstPropUnsqueeze(
|
||||
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
|
||||
PatternRewriter &rewriter, Value resOperand, Attribute attr) {
|
||||
// Read dense attribute, the constant tensor we are transforming.
|
||||
DenseElementsAttr denseAttr =
|
||||
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||
|
|
Loading…
Reference in New Issue