Transpose using perm attribute.

This commit is contained in:
Doru Bercea 2020-01-13 18:08:19 -05:00
parent 8665ecd998
commit 9d1078540d
2 changed files with 23 additions and 5 deletions

View File

@ -407,11 +407,25 @@ void ONNXTransposeOp::inferShapes() {
// Naive transposition which handles the default case of // Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
// TODO: Once attributes are supported we can handle the case where the auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
// transposition uses a permutation vector to interchange the axes. SmallVector<int64_t, 2> dims;
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape())); if (auto permutation = getAttrOfType<ArrayAttr>(
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); ONNXTransposeOp::getPermAttrName())) {
// Perform transposition according to perm attribute.
for (auto perm : permutation.getValue()) {
int32_t index = perm.cast<IntegerAttr>().getInt();
if (index < 0)
emitError("Cannot tranpose when permutation contains negative index.");
dims.emplace_back(arrayTy.getShape()[index]);
}
} else {
// Default
for (auto shape : llvm::reverse(arrayTy.getShape()))
dims.emplace_back(shape);
}
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -3098,6 +3098,10 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
let extraClassDeclaration = [{
static StringRef getPermAttrName() { return "perm"; }
}];
} }
def ONNXUniqueOp:ONNX_Op<"Unique", def ONNXUniqueOp:ONNX_Op<"Unique",