diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index da8ec59..0c1f4bc 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -267,7 +267,7 @@ def gen_schema(schema) : 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', - 'Identity', 'Cos', 'Log'] + 'Identity', 'Cos', 'Log', 'Transpose'] CanonicalList=['Add', 'Identity'] line_indent = ' ' diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 902cbf6..6ff0fad 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -396,6 +396,24 @@ void ONNXReshapeOp::inferShapes() { RankedTensorType::get(dims, inputTensorTy.getElementType())); } +//===----------------------------------------------------------------------===// + +// Transpose + +void ONNXTransposeOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!getOperand()->getType().isa()) + emitError("Shape tensor not ranked."); + + // Naive transposition which handles the default case of + // reversing the shape of the tensor (similar to numpy.transpose). + // TODO: Once attributes are supported we can handle the case where the + // transposition uses a permutation vector to interchange the axes. + auto arrayTy = getOperand()->getType().cast(); + SmallVector dims(llvm::reverse(arrayTy.getShape())); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index e3e6a94..2129b60 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3089,7 +3089,7 @@ def ONNXTopKOp:ONNX_Op<"TopK", } def ONNXTransposeOp:ONNX_Op<"Transpose", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Transpose operation"; let description = [{ "Transpose the input tensor similar to numpy.transpose. For example, when" diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index f54feb4..cbdf04b 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -115,7 +115,8 @@ public: op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.FullGemm" && - op->getName().getStringRef() != "onnx.Reshape") + op->getName().getStringRef() != "onnx.Reshape" && + op->getName().getStringRef() != "onnx.Transpose") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir new file mode 100644 index 0000000..4cb9bec --- /dev/null +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -0,0 +1,12 @@ +// RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s + +/// Test the default behavior of transpose when no information for the +/// permutation of the axes is provided. +func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} + +// CHECK-LABEL: test_default_transpose +// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> +// CHECK: return [[RES]] : tensor<32x1x5x5xf32> \ No newline at end of file