NFC: Attribute cleanup (remove references of attributes) (#286)
* Define krnl.permute op. * Support krnl.permute operation. * Properly remove loop references. * Re-push, Github was down. * Need to debug interpretOp error. * Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code. * Introduce permute, unroll operations. * More debug. * Remove std::set. * krnl.terminate fails to be converted. * Pass all tests, need to add legal ops as well as part of the conversion target. * Change test format to new permute spec. * Bug fix for nested iterate op lowering. * Simplify error reporting. * Fix compilation error. * Increase comments coverage. * Remove unnecessary imports. * Re-trigger Jenkins * Add permute/unroll tests. * Retrigger Jenkins * remove & (ref) for Attributes Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
8bfde7de4b
commit
c1262c184e
|
@ -111,7 +111,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
|
mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
|
||||||
onnx::AttributeProto &attr) {
|
onnx::AttributeProto attr) {
|
||||||
mlir::Attribute mlirAttr;
|
mlir::Attribute mlirAttr;
|
||||||
switch (attr.type()) {
|
switch (attr.type()) {
|
||||||
case onnx::AttributeProto::FLOAT:
|
case onnx::AttributeProto::FLOAT:
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace {
|
||||||
|
|
||||||
/// Compute the combined permute pattern from a pair of permute patterns.
|
/// Compute the combined permute pattern from a pair of permute patterns.
|
||||||
ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
|
ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
|
||||||
ArrayAttr &firstPermAttr, ArrayAttr &secondPermAttr) {
|
ArrayAttr firstPermAttr, ArrayAttr secondPermAttr) {
|
||||||
// Read first permute vectors.
|
// Read first permute vectors.
|
||||||
SmallVector<int64_t, 4> initialPerm;
|
SmallVector<int64_t, 4> initialPerm;
|
||||||
for (auto firstPermVal : firstPermAttr.getValue())
|
for (auto firstPermVal : firstPermAttr.getValue())
|
||||||
|
@ -44,7 +44,7 @@ ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
|
||||||
|
|
||||||
/// Test if the permute pattern correspond to an identity pattern.
|
/// Test if the permute pattern correspond to an identity pattern.
|
||||||
/// Identity patterns are {0, 1, 2, ... , rank -1}.
|
/// Identity patterns are {0, 1, 2, ... , rank -1}.
|
||||||
bool IsIdentityPermuteVector(ArrayAttr &permAttr) {
|
bool IsIdentityPermuteVector(ArrayAttr permAttr) {
|
||||||
int64_t currentIndex = 0;
|
int64_t currentIndex = 0;
|
||||||
for (auto permVal : permAttr.getValue())
|
for (auto permVal : permAttr.getValue())
|
||||||
if (permVal.cast<IntegerAttr>().getInt() != currentIndex++)
|
if (permVal.cast<IntegerAttr>().getInt() != currentIndex++)
|
||||||
|
|
|
@ -59,14 +59,14 @@ namespace {
|
||||||
|
|
||||||
template <typename OP>
|
template <typename OP>
|
||||||
Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter,
|
Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
|
Type elementType, Attribute lhsAttr, Attribute secondAttr) {
|
||||||
llvm_unreachable("unkonwn operation");
|
llvm_unreachable("unkonwn operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
|
Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
|
||||||
Attribute &secondAttr) {
|
Attribute secondAttr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
@ -86,8 +86,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
|
Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
|
||||||
Attribute &secondAttr) {
|
Attribute secondAttr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
@ -105,8 +105,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
|
Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
|
||||||
Attribute &secondAttr) {
|
Attribute secondAttr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
@ -124,8 +124,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
|
Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
|
||||||
Attribute &secondAttr) {
|
Attribute secondAttr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
@ -154,8 +154,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
|
||||||
|
|
||||||
template <typename ElementwiseBinaryOp>
|
template <typename ElementwiseBinaryOp>
|
||||||
void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
|
void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
|
std::vector<Attribute> &resVector, DenseElementsAttr lhsAttr,
|
||||||
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
|
DenseElementsAttr rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
|
||||||
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
|
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
|
||||||
if (lhsFreeRank == 0) {
|
if (lhsFreeRank == 0) {
|
||||||
// Fully defined ranks.
|
// Fully defined ranks.
|
||||||
|
@ -222,7 +222,7 @@ void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
// generate the new constant operation.
|
// generate the new constant operation.
|
||||||
template <typename ElementwiseBinaryOp>
|
template <typename ElementwiseBinaryOp>
|
||||||
DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
Value resOperand, Attribute &lhsAttr, Attribute &rhsAttr) {
|
Value resOperand, Attribute lhsAttr, Attribute rhsAttr) {
|
||||||
DenseElementsAttr lhsDenseAttr =
|
DenseElementsAttr lhsDenseAttr =
|
||||||
lhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
lhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
DenseElementsAttr rhsDenseAttr =
|
DenseElementsAttr rhsDenseAttr =
|
||||||
|
@ -248,13 +248,13 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
|
|
||||||
template <typename OP>
|
template <typename OP>
|
||||||
Attribute ComputeConstPropElementwiseUnary(
|
Attribute ComputeConstPropElementwiseUnary(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
PatternRewriter &rewriter, Type elementType, Attribute attr) {
|
||||||
llvm_unreachable("unkonwn operation");
|
llvm_unreachable("unkonwn operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
|
Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
PatternRewriter &rewriter, Type elementType, Attribute attr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
||||||
double res = -val;
|
double res = -val;
|
||||||
|
@ -270,7 +270,7 @@ Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
|
Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
|
||||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
PatternRewriter &rewriter, Type elementType, Attribute attr) {
|
||||||
if (elementType.isa<FloatType>()) {
|
if (elementType.isa<FloatType>()) {
|
||||||
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
||||||
double res = sqrt(val);
|
double res = sqrt(val);
|
||||||
|
@ -281,7 +281,7 @@ Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
|
||||||
|
|
||||||
template <typename ElementwiseUnaryOp>
|
template <typename ElementwiseUnaryOp>
|
||||||
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
|
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
|
||||||
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
std::vector<Attribute> &resVector, DenseElementsAttr attr,
|
||||||
SmallVector<uint64_t, 4> &indices, int freeRank) {
|
SmallVector<uint64_t, 4> &indices, int freeRank) {
|
||||||
if (freeRank == 0) {
|
if (freeRank == 0) {
|
||||||
// Fully defined ranks.
|
// Fully defined ranks.
|
||||||
|
@ -308,7 +308,7 @@ void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
|
||||||
// generate the new constant operation.
|
// generate the new constant operation.
|
||||||
template <typename ElementwiseUnaryOp>
|
template <typename ElementwiseUnaryOp>
|
||||||
DenseElementsAttr ConstPropElementwiseUnary(
|
DenseElementsAttr ConstPropElementwiseUnary(
|
||||||
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
|
PatternRewriter &rewriter, Value resOperand, Attribute attr) {
|
||||||
DenseElementsAttr denseAttr =
|
DenseElementsAttr denseAttr =
|
||||||
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
assert(denseAttr && "expected dense attribute");
|
assert(denseAttr && "expected dense attribute");
|
||||||
|
@ -329,7 +329,7 @@ DenseElementsAttr ConstPropElementwiseUnary(
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void RecurseConstPropTranspose(PatternRewriter &rewriter,
|
void RecurseConstPropTranspose(PatternRewriter &rewriter,
|
||||||
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
std::vector<Attribute> &resVector, DenseElementsAttr attr,
|
||||||
SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm,
|
SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm,
|
||||||
int freeRank) {
|
int freeRank) {
|
||||||
if (freeRank == 0) {
|
if (freeRank == 0) {
|
||||||
|
@ -351,7 +351,7 @@ void RecurseConstPropTranspose(PatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
|
DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
|
||||||
Value resOperand, Attribute &attr, ArrayAttr &permAttr) {
|
Value resOperand, Attribute attr, ArrayAttr permAttr) {
|
||||||
// Read dense attribute, the constant tensor we are transforming.
|
// Read dense attribute, the constant tensor we are transforming.
|
||||||
DenseElementsAttr denseAttr =
|
DenseElementsAttr denseAttr =
|
||||||
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
@ -378,7 +378,7 @@ DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
DenseElementsAttr ConstPropUnsqueeze(
|
DenseElementsAttr ConstPropUnsqueeze(
|
||||||
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
|
PatternRewriter &rewriter, Value resOperand, Attribute attr) {
|
||||||
// Read dense attribute, the constant tensor we are transforming.
|
// Read dense attribute, the constant tensor we are transforming.
|
||||||
DenseElementsAttr denseAttr =
|
DenseElementsAttr denseAttr =
|
||||||
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
|
Loading…
Reference in New Issue