From 9d1078540d37210644681bb3cd47acd2eb080e5f Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 13 Jan 2020 18:08:19 -0500 Subject: [PATCH] Transpose using perm attribute. --- src/dialect/onnx/onnx_ops.cpp | 24 +++++++++++++++++++----- src/dialect/onnx/onnxop.inc | 4 ++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 985f63d..73d76b5 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -407,11 +407,25 @@ void ONNXTransposeOp::inferShapes() { // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). - // TODO: Once attributes are supported we can handle the case where the - // transposition uses a permutation vector to interchange the axes. - auto arrayTy = getOperand().getType().cast(); - SmallVector dims(llvm::reverse(arrayTy.getShape())); - getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); + auto arrayTy = getOperand()->getType().cast(); + SmallVector dims; + + if (auto permutation = getAttrOfType( + ONNXTransposeOp::getPermAttrName())) { + // Perform transposition according to perm attribute. + for (auto perm : permutation.getValue()) { + int32_t index = perm.cast().getInt(); + if (index < 0) + emitError("Cannot tranpose when permutation contains negative index."); + dims.emplace_back(arrayTy.getShape()[index]); + } + } else { + // Default + for (auto shape : llvm::reverse(arrayTy.getShape())) + dims.emplace_back(shape); + } + + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); } //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 16ad979..5d22346 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3098,6 +3098,10 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); + + let extraClassDeclaration = [{ + static StringRef getPermAttrName() { return "perm"; } + }]; } def ONNXUniqueOp:ONNX_Op<"Unique",