Transpose using perm attribute.
This commit is contained in:
parent
8665ecd998
commit
9d1078540d
|
@ -407,11 +407,25 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
|
|
||||||
// Naive transposition which handles the default case of
|
// Naive transposition which handles the default case of
|
||||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
// TODO: Once attributes are supported we can handle the case where the
|
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||||
// transposition uses a permutation vector to interchange the axes.
|
SmallVector<int64_t, 2> dims;
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
|
||||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
if (auto permutation = getAttrOfType<ArrayAttr>(
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
ONNXTransposeOp::getPermAttrName())) {
|
||||||
|
// Perform transposition according to perm attribute.
|
||||||
|
for (auto perm : permutation.getValue()) {
|
||||||
|
int32_t index = perm.cast<IntegerAttr>().getInt();
|
||||||
|
if (index < 0)
|
||||||
|
emitError("Cannot tranpose when permutation contains negative index.");
|
||||||
|
dims.emplace_back(arrayTy.getShape()[index]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Default
|
||||||
|
for (auto shape : llvm::reverse(arrayTy.getShape()))
|
||||||
|
dims.emplace_back(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -3098,6 +3098,10 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static StringRef getPermAttrName() { return "perm"; }
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXUniqueOp:ONNX_Op<"Unique",
|
def ONNXUniqueOp:ONNX_Op<"Unique",
|
||||||
|
|
Loading…
Reference in New Issue