Merge pull request #35 from clang-ykt/perm-transpose

Infer shape of transposition operations using the perm attribute
This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-20 15:53:14 -05:00 committed by GitHub
commit 0aaab0d2d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 4 deletions

View File

@ -13,6 +13,7 @@
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
@ -407,13 +408,38 @@ void ONNXTransposeOp::inferShapes() {
// Naive transposition which handles the default case of // Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
// TODO: Once attributes are supported we can handle the case where the
// transposition uses a permutation vector to interchange the axes.
auto arrayTy = getOperand().getType().cast<RankedTensorType>(); auto arrayTy = getOperand().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape())); SmallVector<int64_t, 2> dims;
if (auto permutation = getAttrOfType<ArrayAttr>(
ONNXTransposeOp::getPermAttrName())) {
// Perform transposition according to perm attribute.
for (auto perm : permutation.getValue())
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
} else {
// Default
for (auto dim : llvm::reverse(arrayTy.getShape()))
dims.emplace_back(dim);
}
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
} }
LogicalResult verify(ONNXTransposeOp op) {
auto module = op.getParentOfType<ModuleOp>();
if (!module)
op.emitError("Expected to belong to a module.");
if (auto permutation = op.getAttrOfType<ArrayAttr>(
ONNXTransposeOp::getPermAttrName())) {
for (auto perm : permutation.getValue())
if (perm.cast<IntegerAttr>().getInt() < 0)
op.emitError("Cannot tranpose, permuation contains negative index.");
}
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -3098,6 +3098,12 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
let extraClassDeclaration = [{
static StringRef getPermAttrName() { return "perm"; }
}];
let verifier = [{ return ::verify(*this); }];
} }
def ONNXUniqueOp:ONNX_Op<"Unique", def ONNXUniqueOp:ONNX_Op<"Unique",

View File

@ -9,4 +9,14 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
// CHECK-LABEL: test_default_transpose // CHECK-LABEL: test_default_transpose
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
// CHECK: return [[RES]] : tensor<32x1x5x5xf32> // CHECK: return [[RES]] : tensor<32x1x5x5xf32>
/// Test shape inference for transposition when perm attribute is specified.
func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_transpose
// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32>