fix const tensor align bug in AlignPermuteVectorForElementWise (#666)

* fix const tensor align bug in AlignPermuteVectorForElementWise

Signed-off-by: Chen <jack.chen@verisilicon.com>

* fix build issue use android ndk

Type: Bug fix

Signed-off-by: Chen <jack.chen@verisilicon.com>

* Fix inappropriate comments for reduce layoutinfer

Type: Code refine

Signed-off-by: Chen <jack.chen@verisilicon.com>

---------

Signed-off-by: Chen <jack.chen@verisilicon.com>
Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-12-11 16:59:37 +08:00 committed by GitHub
parent 720f0a485a
commit 0dc7a3465e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 13 deletions

View File

@ -59,7 +59,11 @@ std::shared_ptr<IExecutable> CreateExecutableSet(
class IDevice { class IDevice {
public: public:
using device_id_t = uint32_t; using device_id_t = uint32_t;
#ifdef __ANDROID_NDK__
typedef bool (*async_callback)(const void*);
#else
using async_callback = std::function<bool(const void*)>; using async_callback = std::function<bool(const void*)>;
#endif
using data_t = const void*; using data_t = const void*;
virtual ~IDevice(){}; virtual ~IDevice(){};
virtual bool Submit(const std::shared_ptr<Graph>& graph) = 0; virtual bool Submit(const std::shared_ptr<Graph>& graph) = 0;

View File

@ -231,11 +231,13 @@ OpLayoutInfer::AlignPermuteVectorForElementWise() {
auto src_inputs = op_->impl()->InputsTensor(); auto src_inputs = op_->impl()->InputsTensor();
std::shared_ptr<IPermuteVector> required_pv = nullptr; std::shared_ptr<IPermuteVector> required_pv = nullptr;
std::shared_ptr<vx::Tensor> ref_input; std::shared_ptr<vx::Tensor> ref_input;
int32_t ref_rank = 0;
for (const auto& in : src_inputs) { for (const auto& in : src_inputs) {
if (!in->IsConstTensor()) { int32_t rank = in->GetShape().size();
if (!in->IsConstTensor() && rank > ref_rank) {
required_pv = context_->GetPermuteVector(in); required_pv = context_->GetPermuteVector(in);
ref_input = in; ref_input = in;
break; ref_rank = rank;
} }
} }
@ -297,14 +299,11 @@ void OpLayoutInfer::ReverseInputsPermuteVector() {
std::vector<uint32_t> OpLayoutInfer::GetExpandedShape( std::vector<uint32_t> OpLayoutInfer::GetExpandedShape(
const std::vector<uint32_t>& ref_shape, const std::vector<uint32_t>& ref_shape,
const std::vector<uint32_t>& origin_shape) { const std::vector<uint32_t>& origin_shape) {
std::vector<uint32_t> expanded_shape; std::vector<uint32_t> expanded_shape(origin_shape);
for (uint32_t i = 0, j = 0; i < ref_shape.size(); ++i) { int32_t ref_rank = ref_shape.size();
if (ref_shape[i] == origin_shape[j] && j < origin_shape.size()) { int32_t origin_rank = origin_shape.size();
expanded_shape.push_back(origin_shape[j]); for (int32_t i = 0; i < ref_rank; ++i) {
++j; if (i >= origin_rank) expanded_shape.push_back(1);
} else {
expanded_shape.push_back(1);
}
} }
return expanded_shape; return expanded_shape;
} }

View File

@ -45,21 +45,23 @@ class ReduceLayoutInfer : public OpLayoutInfer {
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override { std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
auto t_src = op_->impl()->InputsTensor()[0]; auto t_src = op_->impl()->InputsTensor()[0];
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]); auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
std::set<int32_t> axis_set; //Same value as new_axis, convenient for searching std::set<int32_t> axis_set; // Save unique axis values
std::vector<int32_t> new_axis, pv_reduced; std::vector<int32_t> new_axis, pv_reduced;
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num; uint32_t axis_num = op_->impl()->node()->nn_param.reduce.axis_num;
++i) { for (uint32_t i = 0; i < axis_num; ++i) {
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i]; int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
if (axis < 0) { if (axis < 0) {
axis += pv->Rank(); axis += pv->Rank();
} }
axis = MapAxis(pv->AsStdVec(), axis); axis = MapAxis(pv->AsStdVec(), axis);
// Save unique axis values for calculating pv length
axis_set.insert(axis); axis_set.insert(axis);
new_axis.push_back(axis); new_axis.push_back(axis);
} }
auto reduce = context_->infer_graph_->CreateOperation<OpType>( auto reduce = context_->infer_graph_->CreateOperation<OpType>(
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim); new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
(*reduce).BindInput(context_->GetMapedTensor(t_src)); (*reduce).BindInput(context_->GetMapedTensor(t_src));
if (op_->impl()->node()->nn_param.reduce.keep_dim) { if (op_->impl()->node()->nn_param.reduce.keep_dim) {
auto otensor_infer = CreateOutputsTensor(pv); auto otensor_infer = CreateOutputsTensor(pv);
(*reduce).BindOutput(otensor_infer[0]); (*reduce).BindOutput(otensor_infer[0]);