diff --git a/src/tim/vx/tensor.cc b/src/tim/vx/tensor.cc index d4c9ee5..06ac7e1 100644 --- a/src/tim/vx/tensor.cc +++ b/src/tim/vx/tensor.cc @@ -154,6 +154,12 @@ bool TensorImpl::Init() { attr.is_const = static_cast(spec_.attr_ & TensorAttribute::CONSTANT); attr.vtl = static_cast(spec_.attr_ & TensorAttribute::TRANSIENT); + // Use auto shape for virtual tensors so that tim-vx can perform it's own + // shape inference + if (attr.vtl) { + attr.dim_num = VSI_NN_DIM_AUTO; + } + for (ShapeType::size_type i = 0; i < spec_.shape_.size(); i++) { attr.size[i] = spec_.shape_[i]; } @@ -195,4 +201,4 @@ bool TensorImpl::IsReadable() { } } // namespace vx -} // namespace tim \ No newline at end of file +} // namespace tim