Merge pull request #35 from clang-ykt/perm-transpose
Infer shape of transposition operations using the perm attribute
This commit is contained in:
commit
0aaab0d2d2
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/IntegerSet.h"
|
#include "mlir/IR/IntegerSet.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/Module.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
@ -407,13 +408,38 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
|
|
||||||
// Naive transposition which handles the default case of
|
// Naive transposition which handles the default case of
|
||||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
// TODO: Once attributes are supported we can handle the case where the
|
|
||||||
// transposition uses a permutation vector to interchange the axes.
|
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
SmallVector<int64_t, 2> dims;
|
||||||
|
|
||||||
|
if (auto permutation = getAttrOfType<ArrayAttr>(
|
||||||
|
ONNXTransposeOp::getPermAttrName())) {
|
||||||
|
// Perform transposition according to perm attribute.
|
||||||
|
for (auto perm : permutation.getValue())
|
||||||
|
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
||||||
|
} else {
|
||||||
|
// Default
|
||||||
|
for (auto dim : llvm::reverse(arrayTy.getShape()))
|
||||||
|
dims.emplace_back(dim);
|
||||||
|
}
|
||||||
|
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult verify(ONNXTransposeOp op) {
|
||||||
|
auto module = op.getParentOfType<ModuleOp>();
|
||||||
|
if (!module)
|
||||||
|
op.emitError("Expected to belong to a module.");
|
||||||
|
|
||||||
|
if (auto permutation = op.getAttrOfType<ArrayAttr>(
|
||||||
|
ONNXTransposeOp::getPermAttrName())) {
|
||||||
|
for (auto perm : permutation.getValue())
|
||||||
|
if (perm.cast<IntegerAttr>().getInt() < 0)
|
||||||
|
op.emitError("Cannot tranpose, permuation contains negative index.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -3098,6 +3098,12 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static StringRef getPermAttrName() { return "perm"; }
|
||||||
|
}];
|
||||||
|
|
||||||
|
let verifier = [{ return ::verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXUniqueOp:ONNX_Op<"Unique",
|
def ONNXUniqueOp:ONNX_Op<"Unique",
|
||||||
|
|
|
@ -9,4 +9,14 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
|
|
||||||
// CHECK-LABEL: test_default_transpose
|
// CHECK-LABEL: test_default_transpose
|
||||||
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
||||||
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
||||||
|
|
||||||
|
/// Test shape inference for transposition when perm attribute is specified.
|
||||||
|
func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_transpose
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32>
|
Loading…
Reference in New Issue