fix const tensor align bug in AlignPermuteVectorForElementWise (#666)
* fix const tensor align bug in AlignPermuteVectorForElementWise Signed-off-by: Chen <jack.chen@verisilicon.com> * fix build issue use android ndk Type: Bug fix Signed-off-by: Chen <jack.chen@verisilicon.com> * Fix inappropriate comments for reduce layoutinfer Type: Code refine Signed-off-by: Chen <jack.chen@verisilicon.com> --------- Signed-off-by: Chen <jack.chen@verisilicon.com> Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
parent
720f0a485a
commit
0dc7a3465e
|
|
@ -59,7 +59,11 @@ std::shared_ptr<IExecutable> CreateExecutableSet(
|
|||
class IDevice {
|
||||
public:
|
||||
using device_id_t = uint32_t;
|
||||
#ifdef __ANDROID_NDK__
|
||||
typedef bool (*async_callback)(const void*);
|
||||
#else
|
||||
using async_callback = std::function<bool(const void*)>;
|
||||
#endif
|
||||
using data_t = const void*;
|
||||
virtual ~IDevice(){};
|
||||
virtual bool Submit(const std::shared_ptr<Graph>& graph) = 0;
|
||||
|
|
|
|||
|
|
@ -231,11 +231,13 @@ OpLayoutInfer::AlignPermuteVectorForElementWise() {
|
|||
auto src_inputs = op_->impl()->InputsTensor();
|
||||
std::shared_ptr<IPermuteVector> required_pv = nullptr;
|
||||
std::shared_ptr<vx::Tensor> ref_input;
|
||||
int32_t ref_rank = 0;
|
||||
for (const auto& in : src_inputs) {
|
||||
if (!in->IsConstTensor()) {
|
||||
int32_t rank = in->GetShape().size();
|
||||
if (!in->IsConstTensor() && rank > ref_rank) {
|
||||
required_pv = context_->GetPermuteVector(in);
|
||||
ref_input = in;
|
||||
break;
|
||||
ref_rank = rank;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -297,14 +299,11 @@ void OpLayoutInfer::ReverseInputsPermuteVector() {
|
|||
std::vector<uint32_t> OpLayoutInfer::GetExpandedShape(
|
||||
const std::vector<uint32_t>& ref_shape,
|
||||
const std::vector<uint32_t>& origin_shape) {
|
||||
std::vector<uint32_t> expanded_shape;
|
||||
for (uint32_t i = 0, j = 0; i < ref_shape.size(); ++i) {
|
||||
if (ref_shape[i] == origin_shape[j] && j < origin_shape.size()) {
|
||||
expanded_shape.push_back(origin_shape[j]);
|
||||
++j;
|
||||
} else {
|
||||
expanded_shape.push_back(1);
|
||||
}
|
||||
std::vector<uint32_t> expanded_shape(origin_shape);
|
||||
int32_t ref_rank = ref_shape.size();
|
||||
int32_t origin_rank = origin_shape.size();
|
||||
for (int32_t i = 0; i < ref_rank; ++i) {
|
||||
if (i >= origin_rank) expanded_shape.push_back(1);
|
||||
}
|
||||
return expanded_shape;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,21 +45,23 @@ class ReduceLayoutInfer : public OpLayoutInfer {
|
|||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensor) override {
|
||||
auto t_src = op_->impl()->InputsTensor()[0];
|
||||
auto pv = context_->GetPermuteVector(op_->impl()->InputsTensor()[0]);
|
||||
std::set<int32_t> axis_set; //Same value as new_axis, convenient for searching
|
||||
std::set<int32_t> axis_set; // Save unique axis values
|
||||
std::vector<int32_t> new_axis, pv_reduced;
|
||||
for (uint32_t i = 0; i < op_->impl()->node()->nn_param.reduce.axis_num;
|
||||
++i) {
|
||||
uint32_t axis_num = op_->impl()->node()->nn_param.reduce.axis_num;
|
||||
for (uint32_t i = 0; i < axis_num; ++i) {
|
||||
int32_t axis = op_->impl()->node()->nn_param.reduce.axis[i];
|
||||
if (axis < 0) {
|
||||
axis += pv->Rank();
|
||||
}
|
||||
axis = MapAxis(pv->AsStdVec(), axis);
|
||||
// Save unique axis values for calculating pv length
|
||||
axis_set.insert(axis);
|
||||
new_axis.push_back(axis);
|
||||
}
|
||||
auto reduce = context_->infer_graph_->CreateOperation<OpType>(
|
||||
new_axis, op_->impl()->node()->nn_param.reduce.keep_dim);
|
||||
(*reduce).BindInput(context_->GetMapedTensor(t_src));
|
||||
|
||||
if (op_->impl()->node()->nn_param.reduce.keep_dim) {
|
||||
auto otensor_infer = CreateOutputsTensor(pv);
|
||||
(*reduce).BindOutput(otensor_infer[0]);
|
||||
|
|
|
|||
Loading…
Reference in New Issue