NFC: Attribute cleanup (remove references of attributes) (#286)

* Define krnl.permute op.

* Support krnl.permute operation.

* Properly remove loop references.

* Re-push, Github was down.

* Need to debug interpretOp error.

* Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code.

* Introduce permute, unroll operations.

* More debug.

* Remove std::set.

* krnl.terminate fails to be converted.

* Pass all tests, need to add legal ops as well as part of the conversion target.

* Change test format to new permute spec.

* Bug fix for nested iterate op lowering.

* Simplify error reporting.

* Fix compilation error.

* Increase comments coverage.

* Remove unnecessary imports.

* Re-trigger Jenkins

* Add permute/unroll tests.

* Retrigger Jenkins

* remove & (ref) for Attributes

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Alexandre Eichenberger 2020-08-31 14:28:16 -04:00 committed by GitHub
parent 8bfde7de4b
commit c1262c184e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 23 deletions

View File

@ -111,7 +111,7 @@ private:
} }
mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute( mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
onnx::AttributeProto &attr) { onnx::AttributeProto attr) {
mlir::Attribute mlirAttr; mlir::Attribute mlirAttr;
switch (attr.type()) { switch (attr.type()) {
case onnx::AttributeProto::FLOAT: case onnx::AttributeProto::FLOAT:

View File

@ -25,7 +25,7 @@ namespace {
/// Compute the combined permute pattern from a pair of permute patterns. /// Compute the combined permute pattern from a pair of permute patterns.
ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter, ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
ArrayAttr &firstPermAttr, ArrayAttr &secondPermAttr) { ArrayAttr firstPermAttr, ArrayAttr secondPermAttr) {
// Read first permute vectors. // Read first permute vectors.
SmallVector<int64_t, 4> initialPerm; SmallVector<int64_t, 4> initialPerm;
for (auto firstPermVal : firstPermAttr.getValue()) for (auto firstPermVal : firstPermAttr.getValue())
@ -44,7 +44,7 @@ ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
/// Test if the permute pattern correspond to an identity pattern. /// Test if the permute pattern correspond to an identity pattern.
/// Identity patterns are {0, 1, 2, ... , rank -1}. /// Identity patterns are {0, 1, 2, ... , rank -1}.
bool IsIdentityPermuteVector(ArrayAttr &permAttr) { bool IsIdentityPermuteVector(ArrayAttr permAttr) {
int64_t currentIndex = 0; int64_t currentIndex = 0;
for (auto permVal : permAttr.getValue()) for (auto permVal : permAttr.getValue())
if (permVal.cast<IntegerAttr>().getInt() != currentIndex++) if (permVal.cast<IntegerAttr>().getInt() != currentIndex++)

View File

@ -59,14 +59,14 @@ namespace {
template <typename OP> template <typename OP>
Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter, Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter,
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) { Type elementType, Attribute lhsAttr, Attribute secondAttr) {
llvm_unreachable("unkonwn operation"); llvm_unreachable("unkonwn operation");
} }
template <> template <>
Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>( Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
Attribute &secondAttr) { Attribute secondAttr) {
if (elementType.isa<FloatType>()) { if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble(); double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble(); double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
@ -86,8 +86,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
template <> template <>
Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>( Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
Attribute &secondAttr) { Attribute secondAttr) {
if (elementType.isa<FloatType>()) { if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble(); double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble(); double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
@ -105,8 +105,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
template <> template <>
Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>( Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
Attribute &secondAttr) { Attribute secondAttr) {
if (elementType.isa<FloatType>()) { if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble(); double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble(); double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
@ -124,8 +124,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
template <> template <>
Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>( Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
Attribute &secondAttr) { Attribute secondAttr) {
if (elementType.isa<FloatType>()) { if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble(); double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble(); double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
@ -154,8 +154,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
template <typename ElementwiseBinaryOp> template <typename ElementwiseBinaryOp>
void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter, void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr, std::vector<Attribute> &resVector, DenseElementsAttr lhsAttr,
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices, DenseElementsAttr rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) { SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
if (lhsFreeRank == 0) { if (lhsFreeRank == 0) {
// Fully defined ranks. // Fully defined ranks.
@ -222,7 +222,7 @@ void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
// generate the new constant operation. // generate the new constant operation.
template <typename ElementwiseBinaryOp> template <typename ElementwiseBinaryOp>
DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter, DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
Value resOperand, Attribute &lhsAttr, Attribute &rhsAttr) { Value resOperand, Attribute lhsAttr, Attribute rhsAttr) {
DenseElementsAttr lhsDenseAttr = DenseElementsAttr lhsDenseAttr =
lhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>(); lhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
DenseElementsAttr rhsDenseAttr = DenseElementsAttr rhsDenseAttr =
@ -248,13 +248,13 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
template <typename OP> template <typename OP>
Attribute ComputeConstPropElementwiseUnary( Attribute ComputeConstPropElementwiseUnary(
PatternRewriter &rewriter, Type elementType, Attribute &attr) { PatternRewriter &rewriter, Type elementType, Attribute attr) {
llvm_unreachable("unkonwn operation"); llvm_unreachable("unkonwn operation");
} }
template <> template <>
Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>( Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
PatternRewriter &rewriter, Type elementType, Attribute &attr) { PatternRewriter &rewriter, Type elementType, Attribute attr) {
if (elementType.isa<FloatType>()) { if (elementType.isa<FloatType>()) {
double val = attr.cast<FloatAttr>().getValueAsDouble(); double val = attr.cast<FloatAttr>().getValueAsDouble();
double res = -val; double res = -val;
@ -270,7 +270,7 @@ Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
template <> template <>
Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>( Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
PatternRewriter &rewriter, Type elementType, Attribute &attr) { PatternRewriter &rewriter, Type elementType, Attribute attr) {
if (elementType.isa<FloatType>()) { if (elementType.isa<FloatType>()) {
double val = attr.cast<FloatAttr>().getValueAsDouble(); double val = attr.cast<FloatAttr>().getValueAsDouble();
double res = sqrt(val); double res = sqrt(val);
@ -281,7 +281,7 @@ Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
template <typename ElementwiseUnaryOp> template <typename ElementwiseUnaryOp>
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter, void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &attr, std::vector<Attribute> &resVector, DenseElementsAttr attr,
SmallVector<uint64_t, 4> &indices, int freeRank) { SmallVector<uint64_t, 4> &indices, int freeRank) {
if (freeRank == 0) { if (freeRank == 0) {
// Fully defined ranks. // Fully defined ranks.
@ -308,7 +308,7 @@ void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
// generate the new constant operation. // generate the new constant operation.
template <typename ElementwiseUnaryOp> template <typename ElementwiseUnaryOp>
DenseElementsAttr ConstPropElementwiseUnary( DenseElementsAttr ConstPropElementwiseUnary(
PatternRewriter &rewriter, Value resOperand, Attribute &attr) { PatternRewriter &rewriter, Value resOperand, Attribute attr) {
DenseElementsAttr denseAttr = DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>(); attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
assert(denseAttr && "expected dense attribute"); assert(denseAttr && "expected dense attribute");
@ -329,7 +329,7 @@ DenseElementsAttr ConstPropElementwiseUnary(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void RecurseConstPropTranspose(PatternRewriter &rewriter, void RecurseConstPropTranspose(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &attr, std::vector<Attribute> &resVector, DenseElementsAttr attr,
SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm, SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm,
int freeRank) { int freeRank) {
if (freeRank == 0) { if (freeRank == 0) {
@ -351,7 +351,7 @@ void RecurseConstPropTranspose(PatternRewriter &rewriter,
} }
DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter, DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
Value resOperand, Attribute &attr, ArrayAttr &permAttr) { Value resOperand, Attribute attr, ArrayAttr permAttr) {
// Read dense attribute, the constant tensor we are transforming. // Read dense attribute, the constant tensor we are transforming.
DenseElementsAttr denseAttr = DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>(); attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
@ -378,7 +378,7 @@ DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
DenseElementsAttr ConstPropUnsqueeze( DenseElementsAttr ConstPropUnsqueeze(
PatternRewriter &rewriter, Value resOperand, Attribute &attr) { PatternRewriter &rewriter, Value resOperand, Attribute attr) {
// Read dense attribute, the constant tensor we are transforming. // Read dense attribute, the constant tensor we are transforming.
DenseElementsAttr denseAttr = DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>(); attr.dyn_cast_or_null<mlir::DenseElementsAttr>();