Merge pull request #11 from clang-ykt/naive-transpose

Add default shape inference for the transposition operation.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-01-09 14:07:44 -05:00 committed by GitHub
commit 7607edefe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 3 deletions

View File

@ -267,7 +267,7 @@ def gen_schema(schema) :
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log'] 'Identity', 'Cos', 'Log', 'Transpose']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']
line_indent = ' ' line_indent = ' '

View File

@ -396,6 +396,24 @@ void ONNXReshapeOp::inferShapes() {
RankedTensorType::get(dims, inputTensorTy.getElementType())); RankedTensorType::get(dims, inputTensorTy.getElementType()));
} }
//===----------------------------------------------------------------------===//
// Transpose
void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand()->getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked.");
// Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose).
// TODO: Once attributes are supported we can handle the case where the
// transposition uses a permutation vector to interchange the axes.
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -3089,7 +3089,7 @@ def ONNXTopKOp:ONNX_Op<"TopK",
} }
def ONNXTransposeOp:ONNX_Op<"Transpose", def ONNXTransposeOp:ONNX_Op<"Transpose",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Transpose operation"; let summary = "ONNX Transpose operation";
let description = [{ let description = [{
"Transpose the input tensor similar to numpy.transpose. For example, when" "Transpose the input tensor similar to numpy.transpose. For example, when"

View File

@ -115,7 +115,8 @@ public:
op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.MatMul" &&
op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.FullGemm" && op->getName().getStringRef() != "onnx.FullGemm" &&
op->getName().getStringRef() != "onnx.Reshape") op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose")
return false; return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) { return llvm::any_of(op->getResultTypes(), [](Type result_type) {
return !result_type.isa<RankedTensorType>(); return !result_type.isa<RankedTensorType>();

View File

@ -0,0 +1,12 @@
// RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s
/// Test the default behavior of transpose when no information for the
/// permutation of the axes is provided.
func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_default_transpose
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>