diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 985f63d..7e1675d 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SetVector.h" @@ -407,13 +408,38 @@ void ONNXTransposeOp::inferShapes() { // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). - // TODO: Once attributes are supported we can handle the case where the - // transposition uses a permutation vector to interchange the axes. auto arrayTy = getOperand().getType().cast(); - SmallVector dims(llvm::reverse(arrayTy.getShape())); + SmallVector dims; + + if (auto permutation = getAttrOfType( + ONNXTransposeOp::getPermAttrName())) { + // Perform transposition according to perm attribute. + for (auto perm : permutation.getValue()) + dims.emplace_back(arrayTy.getShape()[perm.cast().getInt()]); + } else { + // Default + for (auto dim : llvm::reverse(arrayTy.getShape())) + dims.emplace_back(dim); + } + getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } +LogicalResult verify(ONNXTransposeOp op) { + auto module = op.getParentOfType(); + if (!module) + op.emitError("Expected to belong to a module."); + + if (auto permutation = op.getAttrOfType( + ONNXTransposeOp::getPermAttrName())) { + for (auto perm : permutation.getValue()) + if (perm.cast().getInt() < 0) + op.emitError("Cannot tranpose, permuation contains negative index."); + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 16ad979..fc2714e 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3098,6 +3098,12 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); + + let extraClassDeclaration = [{ + static StringRef getPermAttrName() { return "perm"; } + }]; + + let verifier = [{ return ::verify(*this); }]; } def ONNXUniqueOp:ONNX_Op<"Unique", diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 4cb9bec..aaa08a7 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -9,4 +9,14 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_default_transpose // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> -// CHECK: return [[RES]] : tensor<32x1x5x5xf32> \ No newline at end of file +// CHECK: return [[RES]] : tensor<32x1x5x5xf32> + +/// Test shape inference for transposition when perm attribute is specified. +func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} + +// CHECK-LABEL: test_transpose +// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32> +// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32> \ No newline at end of file