Transpose using perm attribute.
This commit is contained in:
parent
8665ecd998
commit
9d1078540d
|
@ -407,11 +407,25 @@ void ONNXTransposeOp::inferShapes() {
|
|||
|
||||
// Naive transposition which handles the default case of
|
||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||
// TODO: Once attributes are supported we can handle the case where the
|
||||
// transposition uses a permutation vector to interchange the axes.
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims;
|
||||
|
||||
if (auto permutation = getAttrOfType<ArrayAttr>(
|
||||
ONNXTransposeOp::getPermAttrName())) {
|
||||
// Perform transposition according to perm attribute.
|
||||
for (auto perm : permutation.getValue()) {
|
||||
int32_t index = perm.cast<IntegerAttr>().getInt();
|
||||
if (index < 0)
|
||||
emitError("Cannot tranpose when permutation contains negative index.");
|
||||
dims.emplace_back(arrayTy.getShape()[index]);
|
||||
}
|
||||
} else {
|
||||
// Default
|
||||
for (auto shape : llvm::reverse(arrayTy.getShape()))
|
||||
dims.emplace_back(shape);
|
||||
}
|
||||
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -3098,6 +3098,10 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
|
|||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getPermAttrName() { return "perm"; }
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXUniqueOp:ONNX_Op<"Unique",
|
||||
|
|
Loading…
Reference in New Issue