added cumsum op & added handle api after BindInput
This commit is contained in:
parent
9cb37b920f
commit
264e491d2a
|
|
@ -49,6 +49,7 @@ class Operation {
|
|||
std::unique_ptr<OpImpl>& impl();
|
||||
const std::unique_ptr<OpImpl>& impl() const;
|
||||
virtual const std::vector<std::shared_ptr<Tensor>> ConstantInputsTensor() const;
|
||||
virtual void HandleAfterBindInput(const std::shared_ptr<Tensor>& tensor, int32_t input_idx);
|
||||
protected:
|
||||
bool IsAllInputsConst() const;
|
||||
std::unique_ptr<OpImpl> impl_;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2022 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#ifdef VSI_FEAT_OP_CUMSUM
|
||||
#ifndef TIM_VX_OPS_CUMSUM_H_
|
||||
#define TIM_VX_OPS_CUMSUM_H_
|
||||
|
||||
#include "tim/vx/builtin_op.h"
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
/**
|
||||
* ## Cumsum
|
||||
*
|
||||
* Compute the cumulative sum of the tensor along the giveb axis. By default, it
|
||||
* will do the sum inclusively meaning the first element is copied as is. Through
|
||||
* an exclusive attribute, this behavior can change to exclude the first element.
|
||||
* It can also perform summation in the opposite direction of the axis by setting
|
||||
* reverse atrribution to 1.
|
||||
* All the attributes can be combined.
|
||||
* - axis : Specify the cumsum eperforming along which axis.Default = 0.
|
||||
* - exclusive : If exclusive = 1, perform exclusive cumsum.
|
||||
* - reverse : If reverse = 1, the cumsum is performed in the opposite direction.
|
||||
*/
|
||||
|
||||
class CumSum : public BuiltinOp {
|
||||
public:
|
||||
CumSum(Graph* Graph, int32_t axis=0, int32_t exclusive=0, int32_t reverse=0);
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
void HandleAfterBindInput(const std::shared_ptr<Tensor>& tensor, int32_t input_idx) override;
|
||||
|
||||
protected:
|
||||
int32_t axis_, exclusive_, reverse_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
#endif /* TIM_VX_OPS_CUMSUM_H_ */
|
||||
#endif //(VSI_FEAT_OP_CUMSUM)
|
||||
|
|
@ -42,6 +42,7 @@ const std::unique_ptr<OpImpl>& Operation::impl() const { return impl_; }
|
|||
Operation& Operation::BindInput(const std::shared_ptr<Tensor>& tensor) {
|
||||
impl_->BindInput(tensor);
|
||||
impl_->graph_->UpdateTensorConsumersMap(tensor, this);
|
||||
HandleAfterBindInput(tensor, impl_->input_tensor_index - 1);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
@ -89,6 +90,10 @@ const std::vector<std::shared_ptr<Tensor>> Operation::ConstantInputsTensor() con
|
|||
return {};
|
||||
}
|
||||
}
|
||||
void Operation::HandleAfterBindInput(const std::shared_ptr<Tensor>& tensor, int32_t input_idx){
|
||||
(void) tensor;
|
||||
(void) input_idx;
|
||||
}
|
||||
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
@ -119,6 +119,7 @@ Celu|CELU|Mapped|[Onnx.celu](https://github.com/onnx/onnx/blob/main/docs/Operato
|
|||
Rcp|RCP|Mapped|[tf.math.reciprocal](https://www.tensorflow.org/api_docs/python/tf/math/reciprocal)
|
||||
MaxPool3d|MAX_POOL3D|Mapped|[Onnx.MaxPool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool)
|
||||
|UnidirectionalSequenceRNN|UNIDIRECTIONAL_SEQUENCE_RNN|Planned 22Q3|[ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0ae11aa1d461d2abaa117f6ee2cb503dd8)
|
||||
CumSum|CUMSUM|Mapped|[tf.math.cumsum](https://www.tensorflow.org/api_docs/python/tf/math/cumsum)
|
||||
|BidirectionalSequenceRNN|BIDIRECTIONAL_SEQUENCE_RNN|Planned 22Q3|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a487fc5ae247de828f13e62b99f259f3c)
|
||||
|BidirectionalSequenceLSTM|BIDIRECTIONAL_SEQUENCE_LSTM|Mapped|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a492a71cb7aa50b9a1a834a3cb269d778)
|
||||
|UnidirectionalSequenceLSTM|LSTM_OVXLIB|Mapped|[ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0aaf30e491ad0b1fc7602cbde695b2c859)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,56 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2022 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#ifdef VSI_FEAT_OP_CUMSUM
|
||||
#include "tim/vx/ops/cumsum.h"
|
||||
|
||||
#include "builtin_op_impl.h"
|
||||
#include "vsi_nn_pub.h"
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
CumSum::CumSum(Graph* graph, int32_t axis, int32_t exclusive, int32_t reverse)
|
||||
: BuiltinOp(graph, VSI_NN_OP_CUMSUM), axis_(axis), exclusive_(exclusive), reverse_(reverse){
|
||||
this->impl()->node()->nn_param.cumsum.axis = axis_;
|
||||
this->impl()->node()->nn_param.cumsum.exclusive = exclusive_;
|
||||
this->impl()->node()->nn_param.cumsum.reverse = reverse_;
|
||||
}
|
||||
|
||||
void CumSum::HandleAfterBindInput(const std::shared_ptr<Tensor>& tensor, int32_t input_idx){
|
||||
if (axis_ < 0){
|
||||
axis_ += tensor->GetShape().size();
|
||||
(void) input_idx;
|
||||
this->impl()->node()->nn_param.cumsum.axis = axis_;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> CumSum::Clone(std::shared_ptr<Graph>& graph) const {
|
||||
return graph->CreateOperation<CumSum>(this->axis_, this->exclusive_, this->reverse_);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
#endif //(VSI_FEAT_OP_CUMSUM)
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2022 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#ifdef VSI_FEAT_OP_CUMSUM
|
||||
#include "tim/vx/context.h"
|
||||
#include "tim/vx/graph.h"
|
||||
#include "tim/vx/ops/cumsum.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
TEST(CumSum, shape_4_2_fp32_axis_0_exclusive_0_reverse_0) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType io_shape({4, 2});
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<float> in_data = {
|
||||
2, 4, 6, 8,
|
||||
1, 3, 5, 7,
|
||||
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
2, 6, 12, 20,
|
||||
1, 4, 9, 16,
|
||||
};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float)));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::CumSum>(0,0,0);
|
||||
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
|
||||
std::vector<float> output(golden.size());
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(CumSum, shape_4_2_fp32_axis_1_exclusive_0_reverse_0) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType io_shape({4, 2});
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<float> in_data = {
|
||||
2, 4, 6, 8,
|
||||
1, 3, 5, 7,
|
||||
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
2, 4, 6, 8,
|
||||
3, 7, 11,15,
|
||||
};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float)));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::CumSum>(1,0,0);
|
||||
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
|
||||
std::vector<float> output(golden.size());
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(CumSum, shape_4_1_fp32_axis_0_exclusive_1_reverse_0) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType io_shape({4, 1});
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<float> in_data = {
|
||||
2, 4, 6, 8,
|
||||
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
0, 2, 6, 12,
|
||||
};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float)));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::CumSum>(0,1,0);
|
||||
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
|
||||
std::vector<float> output(golden.size());
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(CumSum, shape_4_1_fp32_axis_0_exclusive_1_reverse_1) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType io_shape({4, 1});
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<float> in_data = {
|
||||
2, 4, 6, 8,
|
||||
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
18, 14, 8, 0,
|
||||
};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float)));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::CumSum>(0,1,1);
|
||||
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
|
||||
std::vector<float> output(golden.size());
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(CumSum, shape_4_2_fp32_axis_minus1_exclusive_1_reverse_1) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType io_shape({4, 2});
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
|
||||
io_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<float> in_data = {
|
||||
2, 4, 6, 8,
|
||||
1, 3, 5, 7,
|
||||
|
||||
};
|
||||
std::vector<float> golden = {
|
||||
1, 3, 5, 7,
|
||||
0, 0, 0, 0,
|
||||
};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * sizeof(float)));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::CumSum>(-1,1,1);
|
||||
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
|
||||
std::vector<float> output(golden.size());
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
#endif //(VSI_FEAT_OP_CUMSUM)
|
||||
Loading…
Reference in New Issue