diff --git a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp index 53da219..a51a935 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp @@ -68,17 +68,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern { // Read perm attribute. SmallVector perm; auto permAttribute = llvm::dyn_cast(op).permAttr(); - if (permAttribute) { - for (auto permVal : permAttribute.getValue()) - perm.emplace_back(permVal.cast().getInt()); - } else { - // TODO: Remove when perm is guaranteed to be present (even for - // the default case). This means that perm was added by shape - // inference or another pass to contain the values corresponding - // to the default behavior of Transpose. - for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--) - perm.emplace_back(i); - } + assert(permAttribute && "permute attribute expected to be defined here"); + for (auto permVal : permAttribute.getValue()) + perm.emplace_back(permVal.cast().getInt()); SmallVector inLoopIVs; for (auto arg : iterationBlock.getArguments()) diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 4e22262..b1126d6 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -1217,16 +1217,21 @@ LogicalResult ONNXTransposeOp::inferShapes() { auto arrayTy = data().getType().cast(); SmallVector dims; auto permutation = ONNXTransposeOp::permAttr(); - if (permutation) { - // Perform transposition according to perm attribute. - for (auto perm : permutation.getValue()) - dims.emplace_back(arrayTy.getShape()[perm.cast().getInt()]); - } else { - // Default - for (auto dim : llvm::reverse(arrayTy.getShape())) - dims.emplace_back(dim); + if (!permutation) { + // Generate revese order for default transpose operation. + SmallVector defaultVals; + auto builder = mlir::Builder(getContext()); + auto rank = arrayTy.getShape().size(); + for (int i = rank - 1; i >= 0; --i) + defaultVals.emplace_back(i); + // Set default attribute. + ArrayRef defaultRefs(defaultVals); + permAttr(builder.getI64ArrayAttr(defaultRefs)); + permutation = permAttr(); } - + // Perform transposition according to perm attribute. + for (auto perm : permutation.getValue()) + dims.emplace_back(arrayTy.getShape()[perm.cast().getInt()]); getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); return success(); } diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp index 262cbcb..599c9db 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -39,7 +39,7 @@ namespace { // The methods are: // -// ComputeConstProppElementwiseBinary and ComputeConstProppElementwiseUnary +// ComputeConstPropElementwiseBinary and ComputeConstPropElementwiseUnary // and they need to be tempalted wtih an ONNX Operation (presuably). // // Then you need to add rules on how to transform the patterns; look into @@ -56,20 +56,19 @@ namespace { // attribute. template -Attribute ComputeConstProppElementwiseBinary(PatternRewriter &rewriter, +Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, Attribute &secondAttr) { llvm_unreachable("unkonwn operation"); } template <> -Attribute ComputeConstProppElementwiseBinary( +Attribute ComputeConstPropElementwiseBinary( PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, Attribute &secondAttr) { if (elementType.isa()) { double lhsVal = lhsAttr.cast().getValueAsDouble(); double rhsVal = secondAttr.cast().getValueAsDouble(); double res = lhsVal + rhsVal; - // printf(" %f + %f -> %f\n", lhsVal, rhsVal, res); // Could use the APFloat interface to emulate the results, are ok to simply // perform them in the highest possible precision. return rewriter.getFloatAttr(elementType, res); @@ -78,14 +77,13 @@ Attribute ComputeConstProppElementwiseBinary( uint64_t lhsVal = lhsAttr.cast().getInt(); uint64_t rhsVal = secondAttr.cast().getInt(); uint64_t res = lhsVal + rhsVal; - // printf(" %llu + %llu -> %llu\n", lhsVal, rhsVal, res); return rewriter.getIntegerAttr(elementType, res); } llvm_unreachable("constant propagation for AddOp: unkonwn data type"); } template <> -Attribute ComputeConstProppElementwiseBinary( +Attribute ComputeConstPropElementwiseBinary( PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, Attribute &secondAttr) { if (elementType.isa()) { @@ -104,7 +102,7 @@ Attribute ComputeConstProppElementwiseBinary( } template <> -Attribute ComputeConstProppElementwiseBinary( +Attribute ComputeConstPropElementwiseBinary( PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, Attribute &secondAttr) { if (elementType.isa()) { @@ -133,11 +131,10 @@ Attribute ComputeConstProppElementwiseBinary( // dimension size is equal to 1. template -void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, +void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter, std::vector &resVector, DenseElementsAttr &lhsAttr, DenseElementsAttr &rhsAttr, SmallVector &lhsIndices, SmallVector &rhsIndices, int lhsFreeRank, int rhsFreeRank) { - // printf("recurse with free %d/%d\n", lhsFreeRank, rhsFreeRank); if (lhsFreeRank == 0) { // Fully defined ranks. assert( @@ -145,7 +142,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, auto lhsElementAttr = lhsAttr.getValue(ArrayRef(lhsIndices)); auto rhsElementAttr = rhsAttr.getValue(ArrayRef(rhsIndices)); auto elementaryType = lhsAttr.getType().getElementType(); - auto res = ComputeConstProppElementwiseBinary( + auto res = ComputeConstPropElementwiseBinary( rewriter, elementaryType, lhsElementAttr, rhsElementAttr); resVector.emplace_back(res); } else if (lhsFreeRank > rhsFreeRank) { @@ -156,7 +153,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, int lhsSize = lhsAttr.getType().getShape()[lhsIndex]; for (int i = 0; i < lhsSize; ++i) { lhsIndices[lhsIndex] = i; - RecurseConstProppElementwiseBinary(rewriter, + RecurseConstPropElementwiseBinary(rewriter, resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1, rhsFreeRank); } @@ -168,7 +165,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, int rhsSize = rhsAttr.getType().getShape()[rhsIndex]; for (int i = 0; i < rhsSize; ++i) { rhsIndices[rhsIndex] = i; - RecurseConstProppElementwiseBinary(rewriter, + RecurseConstPropElementwiseBinary(rewriter, resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank, rhsFreeRank - 1); } @@ -192,7 +189,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, lhsIndices[lhsIndex] = i; if (rhsSize > 1) rhsIndices[rhsIndex] = i; - RecurseConstProppElementwiseBinary(rewriter, + RecurseConstPropElementwiseBinary(rewriter, resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1, rhsFreeRank - 1); } @@ -217,7 +214,7 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter, SmallVector lhsIndices(lhsRank, 0); SmallVector rhsIndices(rhsRank, 0); std::vector resVector; - RecurseConstProppElementwiseBinary(rewriter, resVector, + RecurseConstPropElementwiseBinary(rewriter, resVector, lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank); ArrayRef resRef(resVector); return DenseElementsAttr::get(resType, resRef); @@ -228,13 +225,13 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter, //===----------------------------------------------------------------------===// template -Attribute ComputeConstProppElementwiseUnary( +Attribute ComputeConstPropElementwiseUnary( PatternRewriter &rewriter, Type elementType, Attribute &attr) { llvm_unreachable("unkonwn operation"); } template <> -Attribute ComputeConstProppElementwiseUnary( +Attribute ComputeConstPropElementwiseUnary( PatternRewriter &rewriter, Type elementType, Attribute &attr) { if (elementType.isa()) { double val = attr.cast().getValueAsDouble(); @@ -250,15 +247,14 @@ Attribute ComputeConstProppElementwiseUnary( } template -void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter, +void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter, std::vector &resVector, DenseElementsAttr &attr, SmallVector &indices, int freeRank) { - // printf("recurse with free %d\n", freeRank); if (freeRank == 0) { // Fully defined ranks. auto elementAttr = attr.getValue(ArrayRef(indices)); auto elementaryType = attr.getType().getElementType(); - auto res = ComputeConstProppElementwiseUnary( + auto res = ComputeConstPropElementwiseUnary( rewriter, elementaryType, elementAttr); resVector.emplace_back(res); } else { @@ -269,7 +265,7 @@ void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter, int size = attr.getType().getShape()[index]; for (int i = 0; i < size; ++i) { indices[index] = i; - RecurseConstProppElementwiseUnary( + RecurseConstPropElementwiseUnary( rewriter, resVector, attr, indices, freeRank - 1); } } @@ -289,12 +285,61 @@ DenseElementsAttr ConstPropElementwiseUnary( auto rank = denseAttr.getType().getShape().size(); SmallVector indices(rank, 0); std::vector resVector; - RecurseConstProppElementwiseUnary( + RecurseConstPropElementwiseUnary( rewriter, resVector, denseAttr, indices, rank); ArrayRef resRef(resVector); return DenseElementsAttr::get(resType, resRef); } +//===----------------------------------------------------------------------===// +// Code to perform constant propagation for transpose. +//===----------------------------------------------------------------------===// + +void RecurseConstPropTranspose(PatternRewriter &rewriter, + std::vector &resVector, DenseElementsAttr &attr, + SmallVector &indices, SmallVector &perm, + int freeRank) { + if (freeRank == 0) { + // Fully defined ranks. + auto res = attr.getValue(ArrayRef(indices)); + resVector.emplace_back(res); + } else { + // Recurse. + auto shape = attr.getType().getShape(); + int rank = shape.size(); + int index = perm[rank - freeRank]; + int size = attr.getType().getShape()[index]; + for (int i = 0; i < size; ++i) { + indices[index] = i; + RecurseConstPropTranspose( + rewriter, resVector, attr, indices, perm, freeRank - 1); + } + } +} + +DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter, + Value resOperand, Attribute &attr, ArrayAttr &permAttr) { + // Read dense attribute, the constant tensor we are transforming. + DenseElementsAttr denseAttr = + attr.dyn_cast_or_null(); + assert(denseAttr && "expected dense attribute"); + ShapedType resType = resOperand.getType().cast(); + auto rank = denseAttr.getType().getShape().size(); + // Read permute vector. + SmallVector perm; + assert(permAttr && "permute attribute expected to be defined here"); + for (auto permVal : permAttr.getValue()) + perm.emplace_back(permVal.cast().getInt()); + // Init indice vector. + SmallVector indices(rank, 0); + std::vector resVector; + // Copy using permute order. + RecurseConstPropTranspose( + rewriter, resVector, denseAttr, indices, perm, rank); + ArrayRef resRef(resVector); + return DenseElementsAttr::get(resType, resRef); +} + //===----------------------------------------------------------------------===// // Pattern definition. //===----------------------------------------------------------------------===// diff --git a/src/Transform/ONNX/ConstProp.td b/src/Transform/ONNX/ConstProp.td index 122b713..231a94a 100644 --- a/src/Transform/ONNX/ConstProp.td +++ b/src/Transform/ONNX/ConstProp.td @@ -63,6 +63,9 @@ def CreateNegOfConst : def CreateMulOfTwoConst : NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; +def CreateTransposeOfConst : + NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">; + //===----------------------------------------------------------------------===// // Patterns to enable opportunities with elementwise ADD operations. //===----------------------------------------------------------------------===// @@ -229,4 +232,17 @@ def MulConstProp : Pat< // Mulitional constraints (no sparse) [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; +//===----------------------------------------------------------------------===// +// Patterns to enable opportunities with Transpose operations. +//===----------------------------------------------------------------------===// + +// Neg of constant is simly -const +def TransposeofConst : Pat< + // From TransposeOp(c, p) + (ONNXTransposeOp:$resOp (ONNXConstantOp $s, $v), $p), + // To c' where c' is transposed attribute + (ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)), + [(AttributeIsNull:$s)]>; + + #endif // ONNX_CONSTPROP diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index aba651f..7934aaf 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file | FileCheck %s //===----------------------------------------------------------------------===// @@ -195,3 +195,35 @@ func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> } +//===----------------------------------------------------------------------===// +/// Transpose tests. + +// ----- +// CHECK-LABEL: test_default_transpose_const_1 + func @test_default_transpose_const_1() -> tensor<*xi32> { + %0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32> + %1 = "onnx.Transpose"(%0) : (tensor<2x3x4xi32>) -> tensor<*xi32> + "std.return"(%1) : (tensor<*xi32>) -> () + // CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 211], [121, 221], [131, 231]{{.}}, [{{.}}112, 212], [122, 222], [132, 232]{{.}}, [{{.}}113, 213], [123, 223], [133, 233]{{.}}, [{{.}}114, 214], [124, 224], [134, 234]{{.}}]> : tensor<4x3x2xi32>} : () -> tensor<4x3x2xi32> + // CHECK: return [[RES]] : tensor<4x3x2xi32> +} + +// ----- +// CHECK-LABEL: test_default_transpose_const_2 +func @test_default_transpose_const_2() -> tensor<*xi32> { + %0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32> + %1 = "onnx.Transpose"(%0) {perm = [0, 2, 1]} : (tensor<2x3x4xi32>) -> tensor<*xi32> + "std.return"(%1) : (tensor<*xi32>) -> () + // CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 121, 131], [112, 122, 132], [113, 123, 133], [114, 124, 134]{{.}}, [{{.}}211, 221, 231], [212, 222, 232], [213, 223, 233], [214, 224, 234]{{.}}]> : tensor<2x4x3xi32>} : () -> tensor<2x4x3xi32> + // CHECK: return [[RES]] : tensor<2x4x3xi32> +} + +// ----- +// CHECK-LABEL: test_default_transpose_const_3 +func @test_default_transpose_const_3() -> tensor<*xi32> { + %0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32> + %1 = "onnx.Transpose"(%0) {perm = [1, 0, 2]} : (tensor<2x3x4xi32>) -> tensor<*xi32> + "std.return"(%1) : (tensor<*xi32>) -> () + // CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32> + // CHECK: return [[RES]] : tensor<3x2x4xi32> +} diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 67a695a..d302418 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -12,7 +12,7 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_default_transpose - // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> + // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> // CHECK: return [[RES]] : tensor<32x1x5x5xf32> }