From 9d1078540d37210644681bb3cd47acd2eb080e5f Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 13 Jan 2020 18:08:19 -0500 Subject: [PATCH 1/4] Transpose using perm attribute. --- src/dialect/onnx/onnx_ops.cpp | 24 +++++++++++++++++++----- src/dialect/onnx/onnxop.inc | 4 ++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 985f63d..73d76b5 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -407,11 +407,25 @@ void ONNXTransposeOp::inferShapes() { // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). - // TODO: Once attributes are supported we can handle the case where the - // transposition uses a permutation vector to interchange the axes. - auto arrayTy = getOperand().getType().cast(); - SmallVector dims(llvm::reverse(arrayTy.getShape())); - getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); + auto arrayTy = getOperand()->getType().cast(); + SmallVector dims; + + if (auto permutation = getAttrOfType( + ONNXTransposeOp::getPermAttrName())) { + // Perform transposition according to perm attribute. + for (auto perm : permutation.getValue()) { + int32_t index = perm.cast().getInt(); + if (index < 0) + emitError("Cannot tranpose when permutation contains negative index."); + dims.emplace_back(arrayTy.getShape()[index]); + } + } else { + // Default + for (auto shape : llvm::reverse(arrayTy.getShape())) + dims.emplace_back(shape); + } + + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); } //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 16ad979..5d22346 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3098,6 +3098,10 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); + + let extraClassDeclaration = [{ + static StringRef getPermAttrName() { return "perm"; } + }]; } def ONNXUniqueOp:ONNX_Op<"Unique", From f0b484c0bc59f0f39a93e111cee428fb6f596ccf Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Tue, 14 Jan 2020 10:37:05 -0500 Subject: [PATCH 2/4] Add test for transpose with permutation. --- test/mlir/onnx/onnx_shape_inference.mlir | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 4cb9bec..aaa08a7 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -9,4 +9,14 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_default_transpose // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> -// CHECK: return [[RES]] : tensor<32x1x5x5xf32> \ No newline at end of file +// CHECK: return [[RES]] : tensor<32x1x5x5xf32> + +/// Test shape inference for transposition when perm attribute is specified. +func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} + +// CHECK-LABEL: test_transpose +// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32> +// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32> \ No newline at end of file From bd44d8402e09c8b1c6068550abc640fdf5c41ef9 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 20 Jan 2020 14:46:54 -0500 Subject: [PATCH 3/4] Add verifier function for checking negative perms. --- src/dialect/onnx/onnx_ops.cpp | 28 ++++++++++++++++++++-------- src/dialect/onnx/onnxop.inc | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 73d76b5..cef90cb 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SetVector.h" @@ -413,21 +414,32 @@ void ONNXTransposeOp::inferShapes() { if (auto permutation = getAttrOfType( ONNXTransposeOp::getPermAttrName())) { // Perform transposition according to perm attribute. - for (auto perm : permutation.getValue()) { - int32_t index = perm.cast().getInt(); - if (index < 0) - emitError("Cannot tranpose when permutation contains negative index."); - dims.emplace_back(arrayTy.getShape()[index]); - } + for (auto perm : permutation.getValue()) + dims.emplace_back(arrayTy.getShape()[perm.cast().getInt()]); } else { // Default - for (auto shape : llvm::reverse(arrayTy.getShape())) - dims.emplace_back(shape); + for (auto dim : llvm::reverse(arrayTy.getShape())) + dims.emplace_back(dim); } getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); } +LogicalResult verify(ONNXTransposeOp op) { + auto module = op.getParentOfType(); + if (!module) + op.emitError("Expected to belong to a module."); + + if (auto permutation = op.getAttrOfType( + ONNXTransposeOp::getPermAttrName())) { + for (auto perm : permutation.getValue()) + if (perm.cast().getInt() < 0) + op.emitError("Cannot tranpose, permuation contains negative index."); + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 5d22346..fc2714e 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3102,6 +3102,8 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", let extraClassDeclaration = [{ static StringRef getPermAttrName() { return "perm"; } }]; + + let verifier = [{ return ::verify(*this); }]; } def ONNXUniqueOp:ONNX_Op<"Unique", From 6b55bb43c7fb1c7eab0104b29139caf4dedd2aa3 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 20 Jan 2020 15:48:16 -0500 Subject: [PATCH 4/4] Fix operand type access. --- src/dialect/onnx/onnx_ops.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index cef90cb..7e1675d 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -408,7 +408,7 @@ void ONNXTransposeOp::inferShapes() { // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). - auto arrayTy = getOperand()->getType().cast(); + auto arrayTy = getOperand().getType().cast(); SmallVector dims; if (auto permutation = getAttrOfType( @@ -422,7 +422,7 @@ void ONNXTransposeOp::inferShapes() { dims.emplace_back(dim); } - getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); + getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } LogicalResult verify(ONNXTransposeOp op) {