fix const tensor align bug in AlignPermuteVectorForElementWise (#666)

* fix const tensor align bug in AlignPermuteVectorForElementWise

Signed-off-by: Chen <jack.chen@verisilicon.com>

* fix build issue use android ndk

Type: Bug fix

Signed-off-by: Chen <jack.chen@verisilicon.com>

* Fix inappropriate comments for reduce layoutinfer

Type: Code refine

Signed-off-by: Chen <jack.chen@verisilicon.com>

---------

Signed-off-by: Chen <jack.chen@verisilicon.com>
Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-12-11 16:59:37 +08:00 committed by GitHub
parent 720f0a485a
commit 0dc7a3465e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 13 deletions

View File

@ -59,7 +59,11 @@ std::shared_ptr<IExecutable> CreateExecutableSet(
class IDevice {
public:
using device_id_t = uint32_t;
#ifdef __ANDROID_NDK__
typedef bool (*async_callback)(const void*);
#else
using async_callback = std::function<bool(const void*)>;
#endif
using data_t = const void*;
virtual ~IDevice(){};
virtual bool Submit(const std::shared_ptr<Graph>& graph) = 0;

View File

@ -231,11 +231,13 @@ OpLayoutInfer::AlignPermuteVectorForElementWise() {
auto src_inputs = op_->impl()->InputsTensor();
std::shared_ptr<IPermuteVector> required_pv = nullptr;
std::shared_ptr<vx::Tensor> ref_input;
int32_t ref_rank = 0;
for (const auto& in : src_inputs) {
if (!in->IsConstTensor()) {
int32_t rank = in->GetShape().size();
if (!in->IsConstTensor() && rank > ref_rank) {
required_pv = context_->GetPermuteVector(in);
ref_input = in;
break;
ref_rank = rank;
}
}
@ -297,14 +299,11 @@ void OpLayoutInfer::ReverseInputsPermuteVector() {
std::vector<uint32_t> OpLayoutInfer::GetExpandedShape(
const std::vector<uint32_t>& ref_shape,
const std::vector<uint32_t>& origin_shape) {
std::vector<uint32_t> expanded_shape;
for (uint32_t i = 0, j = 0; i < ref_shape.size(); ++i) {
if (ref_shape[i] == origin_shape[j] && j < origin_shape.size()) {
expanded_shape.push_back(origin_shape[j]);
++j;
} else {
expanded_shape.push_back(1);
}
std::vector<uint32_t> expanded_shape(origin_shape);
int32_t ref_rank = ref_shape.size();
int32_t origin_rank = origin_shape.size();
for (int32_t i = 0; i < ref_rank; ++i) {
if (i >= origin_rank) expanded_shape.push_back(1);
}
return expanded_shape;
}

View File

@ -45,21 +45,23 @@ class ReduceLayoutInfer : public OpLayoutInfer {
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
auto t_src = op_->impl()->InputsTensor()[0];
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
std::set<int32_t> axis_set; //Same value as new_axis, convenient for searching
std::set<int32_t> axis_set; // Save unique axis values
std::vector<int32_t> new_axis, pv_reduced;
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num;
++i) {
uint32_t axis_num = op_->impl()->node()->nn_param.reduce.axis_num;
for (uint32_t i = 0; i < axis_num; ++i) {
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
if (axis < 0) {
axis += pv->Rank();
}
axis = MapAxis(pv->AsStdVec(), axis);
// Save unique axis values for calculating pv length
axis_set.insert(axis);
new_axis.push_back(axis);
}
auto reduce = context_->infer_graph_->CreateOperation<OpType>(
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
(*reduce).BindInput(context_->GetMapedTensor(t_src));
if (op_->impl()->node()->nn_param.reduce.keep_dim) {
auto otensor_infer = CreateOutputsTensor(pv);
(*reduce).BindOutput(otensor_infer[0]);