constant folding for transpose of constant tensors (#171)
* added constant folding for transpose of constant tensors * format * responding to reviews
This commit is contained in:
		
							parent
							
								
									742e817722
								
							
						
					
					
						commit
						82d2caa542
					
				|  | @ -68,17 +68,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern { | |||
|     // Read perm attribute.
 | ||||
|     SmallVector<int, 4> perm; | ||||
|     auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr(); | ||||
|     if (permAttribute) { | ||||
|       for (auto permVal : permAttribute.getValue()) | ||||
|         perm.emplace_back(permVal.cast<IntegerAttr>().getInt()); | ||||
|     } else { | ||||
|       // TODO: Remove when perm is guaranteed to be present (even for
 | ||||
|       // the default case). This means that perm was added by shape
 | ||||
|       // inference or another pass to contain the values corresponding
 | ||||
|       // to the default behavior of Transpose.
 | ||||
|       for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--) | ||||
|         perm.emplace_back(i); | ||||
|     } | ||||
|     assert(permAttribute && "permute attribute expected to be defined here"); | ||||
|     for (auto permVal : permAttribute.getValue()) | ||||
|       perm.emplace_back(permVal.cast<IntegerAttr>().getInt()); | ||||
| 
 | ||||
|     SmallVector<Value, 4> inLoopIVs; | ||||
|     for (auto arg : iterationBlock.getArguments()) | ||||
|  |  | |||
|  | @ -1217,16 +1217,21 @@ LogicalResult ONNXTransposeOp::inferShapes() { | |||
|   auto arrayTy = data().getType().cast<RankedTensorType>(); | ||||
|   SmallVector<int64_t, 2> dims; | ||||
|   auto permutation = ONNXTransposeOp::permAttr(); | ||||
|   if (permutation) { | ||||
|     // Perform transposition according to perm attribute.
 | ||||
|     for (auto perm : permutation.getValue()) | ||||
|       dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]); | ||||
|   } else { | ||||
|     // Default
 | ||||
|     for (auto dim : llvm::reverse(arrayTy.getShape())) | ||||
|       dims.emplace_back(dim); | ||||
|   if (!permutation) { | ||||
|     // Generate revese order for default transpose operation.
 | ||||
|     SmallVector<int64_t, 4> defaultVals; | ||||
|     auto builder = mlir::Builder(getContext()); | ||||
|     auto rank = arrayTy.getShape().size(); | ||||
|     for (int i = rank - 1; i >= 0; --i) | ||||
|       defaultVals.emplace_back(i); | ||||
|     // Set default attribute.
 | ||||
|     ArrayRef<int64_t> defaultRefs(defaultVals); | ||||
|     permAttr(builder.getI64ArrayAttr(defaultRefs)); | ||||
|     permutation = permAttr(); | ||||
|   } | ||||
| 
 | ||||
|   // Perform transposition according to perm attribute.
 | ||||
|   for (auto perm : permutation.getValue()) | ||||
|     dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]); | ||||
|   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); | ||||
|   return success(); | ||||
| } | ||||
|  |  | |||
|  | @ -39,7 +39,7 @@ namespace { | |||
| 
 | ||||
| // The methods are:
 | ||||
| //
 | ||||
| // ComputeConstProppElementwiseBinary and ComputeConstProppElementwiseUnary
 | ||||
| // ComputeConstPropElementwiseBinary and ComputeConstPropElementwiseUnary
 | ||||
| // and they need to be tempalted wtih an ONNX Operation (presuably).
 | ||||
| //
 | ||||
| // Then you need to add rules on how to transform the patterns; look into
 | ||||
|  | @ -56,20 +56,19 @@ namespace { | |||
| // attribute.
 | ||||
| 
 | ||||
| template <typename OP> | ||||
| Attribute ComputeConstProppElementwiseBinary(PatternRewriter &rewriter, | ||||
| Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter, | ||||
|     Type elementType, Attribute &lhsAttr, Attribute &secondAttr) { | ||||
|   llvm_unreachable("unkonwn operation"); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>( | ||||
| Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>( | ||||
|     PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, | ||||
|     Attribute &secondAttr) { | ||||
|   if (elementType.isa<FloatType>()) { | ||||
|     double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble(); | ||||
|     double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble(); | ||||
|     double res = lhsVal + rhsVal; | ||||
|     // printf("  %f + %f -> %f\n", lhsVal, rhsVal, res);
 | ||||
|     // Could use the APFloat interface to emulate the results, are ok to simply
 | ||||
|     // perform them in the highest possible precision.
 | ||||
|     return rewriter.getFloatAttr(elementType, res); | ||||
|  | @ -78,14 +77,13 @@ Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>( | |||
|     uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt(); | ||||
|     uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt(); | ||||
|     uint64_t res = lhsVal + rhsVal; | ||||
|     // printf("  %llu + %llu -> %llu\n", lhsVal, rhsVal, res);
 | ||||
|     return rewriter.getIntegerAttr(elementType, res); | ||||
|   } | ||||
|   llvm_unreachable("constant propagation for AddOp: unkonwn data type"); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>( | ||||
| Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>( | ||||
|     PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, | ||||
|     Attribute &secondAttr) { | ||||
|   if (elementType.isa<FloatType>()) { | ||||
|  | @ -104,7 +102,7 @@ Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>( | |||
| } | ||||
| 
 | ||||
| template <> | ||||
| Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>( | ||||
| Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>( | ||||
|     PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, | ||||
|     Attribute &secondAttr) { | ||||
|   if (elementType.isa<FloatType>()) { | ||||
|  | @ -133,11 +131,10 @@ Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>( | |||
| // dimension size is equal to 1.
 | ||||
| 
 | ||||
| template <typename ElementwiseBinaryOp> | ||||
| void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, | ||||
| void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter, | ||||
|     std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr, | ||||
|     DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices, | ||||
|     SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) { | ||||
|   // printf("recurse with free %d/%d\n", lhsFreeRank, rhsFreeRank);
 | ||||
|   if (lhsFreeRank == 0) { | ||||
|     // Fully defined ranks.
 | ||||
|     assert( | ||||
|  | @ -145,7 +142,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, | |||
|     auto lhsElementAttr = lhsAttr.getValue(ArrayRef<uint64_t>(lhsIndices)); | ||||
|     auto rhsElementAttr = rhsAttr.getValue(ArrayRef<uint64_t>(rhsIndices)); | ||||
|     auto elementaryType = lhsAttr.getType().getElementType(); | ||||
|     auto res = ComputeConstProppElementwiseBinary<ElementwiseBinaryOp>( | ||||
|     auto res = ComputeConstPropElementwiseBinary<ElementwiseBinaryOp>( | ||||
|         rewriter, elementaryType, lhsElementAttr, rhsElementAttr); | ||||
|     resVector.emplace_back(res); | ||||
|   } else if (lhsFreeRank > rhsFreeRank) { | ||||
|  | @ -156,7 +153,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, | |||
|     int lhsSize = lhsAttr.getType().getShape()[lhsIndex]; | ||||
|     for (int i = 0; i < lhsSize; ++i) { | ||||
|       lhsIndices[lhsIndex] = i; | ||||
|       RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, | ||||
|       RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter, | ||||
|           resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1, | ||||
|           rhsFreeRank); | ||||
|     } | ||||
|  | @ -168,7 +165,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, | |||
|     int rhsSize = rhsAttr.getType().getShape()[rhsIndex]; | ||||
|     for (int i = 0; i < rhsSize; ++i) { | ||||
|       rhsIndices[rhsIndex] = i; | ||||
|       RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, | ||||
|       RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter, | ||||
|           resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank, | ||||
|           rhsFreeRank - 1); | ||||
|     } | ||||
|  | @ -192,7 +189,7 @@ void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, | |||
|         lhsIndices[lhsIndex] = i; | ||||
|       if (rhsSize > 1) | ||||
|         rhsIndices[rhsIndex] = i; | ||||
|       RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, | ||||
|       RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter, | ||||
|           resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1, | ||||
|           rhsFreeRank - 1); | ||||
|     } | ||||
|  | @ -217,7 +214,7 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter, | |||
|   SmallVector<uint64_t, 4> lhsIndices(lhsRank, 0); | ||||
|   SmallVector<uint64_t, 4> rhsIndices(rhsRank, 0); | ||||
|   std::vector<Attribute> resVector; | ||||
|   RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector, | ||||
|   RecurseConstPropElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector, | ||||
|       lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank); | ||||
|   ArrayRef<Attribute> resRef(resVector); | ||||
|   return DenseElementsAttr::get(resType, resRef); | ||||
|  | @ -228,13 +225,13 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter, | |||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| template <typename OP> | ||||
| Attribute ComputeConstProppElementwiseUnary( | ||||
| Attribute ComputeConstPropElementwiseUnary( | ||||
|     PatternRewriter &rewriter, Type elementType, Attribute &attr) { | ||||
|   llvm_unreachable("unkonwn operation"); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>( | ||||
| Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>( | ||||
|     PatternRewriter &rewriter, Type elementType, Attribute &attr) { | ||||
|   if (elementType.isa<FloatType>()) { | ||||
|     double val = attr.cast<FloatAttr>().getValueAsDouble(); | ||||
|  | @ -250,15 +247,14 @@ Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>( | |||
| } | ||||
| 
 | ||||
| template <typename ElementwiseUnaryOp> | ||||
| void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter, | ||||
| void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter, | ||||
|     std::vector<Attribute> &resVector, DenseElementsAttr &attr, | ||||
|     SmallVector<uint64_t, 4> &indices, int freeRank) { | ||||
|   // printf("recurse with free %d\n", freeRank);
 | ||||
|   if (freeRank == 0) { | ||||
|     // Fully defined ranks.
 | ||||
|     auto elementAttr = attr.getValue(ArrayRef<uint64_t>(indices)); | ||||
|     auto elementaryType = attr.getType().getElementType(); | ||||
|     auto res = ComputeConstProppElementwiseUnary<ElementwiseUnaryOp>( | ||||
|     auto res = ComputeConstPropElementwiseUnary<ElementwiseUnaryOp>( | ||||
|         rewriter, elementaryType, elementAttr); | ||||
|     resVector.emplace_back(res); | ||||
|   } else { | ||||
|  | @ -269,7 +265,7 @@ void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter, | |||
|     int size = attr.getType().getShape()[index]; | ||||
|     for (int i = 0; i < size; ++i) { | ||||
|       indices[index] = i; | ||||
|       RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>( | ||||
|       RecurseConstPropElementwiseUnary<ElementwiseUnaryOp>( | ||||
|           rewriter, resVector, attr, indices, freeRank - 1); | ||||
|     } | ||||
|   } | ||||
|  | @ -289,12 +285,61 @@ DenseElementsAttr ConstPropElementwiseUnary( | |||
|   auto rank = denseAttr.getType().getShape().size(); | ||||
|   SmallVector<uint64_t, 4> indices(rank, 0); | ||||
|   std::vector<Attribute> resVector; | ||||
|   RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>( | ||||
|   RecurseConstPropElementwiseUnary<ElementwiseUnaryOp>( | ||||
|       rewriter, resVector, denseAttr, indices, rank); | ||||
|   ArrayRef<Attribute> resRef(resVector); | ||||
|   return DenseElementsAttr::get(resType, resRef); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Code to perform constant propagation for transpose.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| void RecurseConstPropTranspose(PatternRewriter &rewriter, | ||||
|     std::vector<Attribute> &resVector, DenseElementsAttr &attr, | ||||
|     SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm, | ||||
|     int freeRank) { | ||||
|   if (freeRank == 0) { | ||||
|     // Fully defined ranks.
 | ||||
|     auto res = attr.getValue(ArrayRef<uint64_t>(indices)); | ||||
|     resVector.emplace_back(res); | ||||
|   } else { | ||||
|     // Recurse.
 | ||||
|     auto shape = attr.getType().getShape(); | ||||
|     int rank = shape.size(); | ||||
|     int index = perm[rank - freeRank]; | ||||
|     int size = attr.getType().getShape()[index]; | ||||
|     for (int i = 0; i < size; ++i) { | ||||
|       indices[index] = i; | ||||
|       RecurseConstPropTranspose( | ||||
|           rewriter, resVector, attr, indices, perm, freeRank - 1); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter, | ||||
|     Value resOperand, Attribute &attr, ArrayAttr &permAttr) { | ||||
|   // Read dense attribute, the constant tensor we are transforming.
 | ||||
|   DenseElementsAttr denseAttr = | ||||
|       attr.dyn_cast_or_null<mlir::DenseElementsAttr>(); | ||||
|   assert(denseAttr && "expected dense attribute"); | ||||
|   ShapedType resType = resOperand.getType().cast<RankedTensorType>(); | ||||
|   auto rank = denseAttr.getType().getShape().size(); | ||||
|   // Read permute vector.
 | ||||
|   SmallVector<uint64_t, 4> perm; | ||||
|   assert(permAttr && "permute attribute expected to be defined here"); | ||||
|   for (auto permVal : permAttr.getValue()) | ||||
|     perm.emplace_back(permVal.cast<IntegerAttr>().getInt()); | ||||
|   // Init indice vector.
 | ||||
|   SmallVector<uint64_t, 4> indices(rank, 0); | ||||
|   std::vector<Attribute> resVector; | ||||
|   // Copy using permute order.
 | ||||
|   RecurseConstPropTranspose( | ||||
|       rewriter, resVector, denseAttr, indices, perm, rank); | ||||
|   ArrayRef<Attribute> resRef(resVector); | ||||
|   return DenseElementsAttr::get(resType, resRef); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Pattern definition.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -63,6 +63,9 @@ def CreateNegOfConst : | |||
|  def CreateMulOfTwoConst : | ||||
|    NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">; | ||||
| 
 | ||||
| def CreateTransposeOfConst : | ||||
|    NativeCodeCall<"ConstPropTranspose($_builder, $0, $1, $2)">; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // Patterns to enable opportunities with elementwise ADD operations. | ||||
| //===----------------------------------------------------------------------===// | ||||
|  | @ -229,4 +232,17 @@ def MulConstProp : Pat< | |||
|     // Mulitional constraints (no sparse) | ||||
|     [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| // Patterns to enable opportunities with Transpose operations. | ||||
| //===----------------------------------------------------------------------===// | ||||
| 
 | ||||
| // Neg of constant is simly -const | ||||
| def TransposeofConst :  Pat< | ||||
|     // From TransposeOp(c, p) | ||||
|     (ONNXTransposeOp:$resOp (ONNXConstantOp $s, $v), $p), | ||||
|     // To c' where c' is transposed attribute | ||||
|     (ONNXConstantOp (GetNullAttr), (CreateTransposeOfConst $resOp, $v, $p)), | ||||
|     [(AttributeIsNull:$s)]>; | ||||
| 
 | ||||
|    | ||||
| #endif // ONNX_CONSTPROP | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| // RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s | ||||
| // RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file | FileCheck %s | ||||
| 
 | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
|  | @ -195,3 +195,35 @@ func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { | |||
|   // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===// | ||||
| /// Transpose tests. | ||||
| 
 | ||||
| // -----   | ||||
| // CHECK-LABEL: test_default_transpose_const_1 | ||||
|   func @test_default_transpose_const_1() -> tensor<*xi32> { | ||||
|   %0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32> | ||||
|   %1 = "onnx.Transpose"(%0) : (tensor<2x3x4xi32>) -> tensor<*xi32> | ||||
|   "std.return"(%1) : (tensor<*xi32>) -> () | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 211], [121, 221], [131, 231]{{.}}, [{{.}}112, 212], [122, 222], [132, 232]{{.}}, [{{.}}113, 213], [123, 223], [133, 233]{{.}}, [{{.}}114, 214], [124, 224], [134, 234]{{.}}]> : tensor<4x3x2xi32>} : () -> tensor<4x3x2xi32> | ||||
|   // CHECK: return [[RES]] : tensor<4x3x2xi32> | ||||
| } | ||||
| 
 | ||||
| // -----   | ||||
| // CHECK-LABEL: test_default_transpose_const_2 | ||||
| func @test_default_transpose_const_2() -> tensor<*xi32> { | ||||
|   %0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32> | ||||
|   %1 = "onnx.Transpose"(%0) {perm = [0, 2, 1]} : (tensor<2x3x4xi32>) -> tensor<*xi32> | ||||
|   "std.return"(%1) : (tensor<*xi32>) -> () | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Constant"() {value = dense<[{{.}}[111, 121, 131], [112, 122, 132], [113, 123, 133], [114, 124, 134]{{.}}, [{{.}}211, 221, 231], [212, 222, 232], [213, 223, 233], [214, 224, 234]{{.}}]> : tensor<2x4x3xi32>} : () -> tensor<2x4x3xi32> | ||||
|   // CHECK: return [[RES]] : tensor<2x4x3xi32> | ||||
| } | ||||
| 
 | ||||
| // -----   | ||||
| // CHECK-LABEL: test_default_transpose_const_3 | ||||
| func @test_default_transpose_const_3() -> tensor<*xi32> { | ||||
|   %0 = "onnx.Constant"() {value = dense<[[[111, 112, 113, 114], [121, 122, 123, 124], [131, 132, 133, 134]], [[211, 212, 213, 214], [221, 222, 223, 224], [231, 232, 233, 234]]]> : tensor<2x3x4xi32>} : () -> tensor<2x3x4xi32> | ||||
|   %1 = "onnx.Transpose"(%0) {perm = [1, 0, 2]} : (tensor<2x3x4xi32>) -> tensor<*xi32> | ||||
|   "std.return"(%1) : (tensor<*xi32>) -> () | ||||
|   // CHECK: [[RES:%.+]] =  "onnx.Constant"() {value = dense<[{{.}}[111, 112, 113, 114], [211, 212, 213, 214]{{.}}, [{{.}}121, 122, 123, 124], [221, 222, 223, 224]{{.}}, [{{.}}131, 132, 133, 134], [231, 232, 233, 234]{{.}}]> : tensor<3x2x4xi32>} : () -> tensor<3x2x4xi32> | ||||
|   // CHECK: return [[RES]] : tensor<3x2x4xi32> | ||||
| } | ||||
|  |  | |||
|  | @ -12,7 +12,7 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { | |||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_default_transpose | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> | ||||
|   // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> | ||||
|   // CHECK: return [[RES]] : tensor<32x1x5x5xf32> | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue