Map OneHot & unit test (#258)
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com> Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
8e4ab68213
commit
7c63ba621e
|
|
@ -54,6 +54,7 @@
|
||||||
#include "tim/vx/ops/maxunpool2d.h"
|
#include "tim/vx/ops/maxunpool2d.h"
|
||||||
#include "tim/vx/ops/moments.h"
|
#include "tim/vx/ops/moments.h"
|
||||||
#include "tim/vx/ops/nbg.h"
|
#include "tim/vx/ops/nbg.h"
|
||||||
|
#include "tim/vx/ops/onehot.h"
|
||||||
#include "tim/vx/ops/pad.h"
|
#include "tim/vx/ops/pad.h"
|
||||||
#include "tim/vx/ops/pool2d.h"
|
#include "tim/vx/ops/pool2d.h"
|
||||||
#include "tim/vx/ops/reduce.h"
|
#include "tim/vx/ops/reduce.h"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2020 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#ifndef TIM_VX_OPERATION_ONE_HOT_H_
|
||||||
|
#define TIM_VX_OPERATION_ONE_HOT_H_
|
||||||
|
#include "tim/vx/direct_map_op.h"
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ## OneHot
|
||||||
|
*
|
||||||
|
* Create a one-hot tensor.
|
||||||
|
*
|
||||||
|
* - depth : A scalar defining the depth of the one hot dimension.
|
||||||
|
* - on_value : A scalar defining the value to fill in output.
|
||||||
|
* - off_value : A scalar defining the value to fill in output.
|
||||||
|
* - axis : The axis to fill.
|
||||||
|
*/
|
||||||
|
|
||||||
|
class OneHot : public DirectMapOp {
|
||||||
|
public:
|
||||||
|
OneHot(Graph* graph, int32_t depth, float on_value = 1, float off_value = 0,
|
||||||
|
int32_t axis = 0);
|
||||||
|
|
||||||
|
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int32_t depth_;
|
||||||
|
float on_value_;
|
||||||
|
float off_value_;
|
||||||
|
int32_t axis_;
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
#endif
|
||||||
|
|
@ -131,7 +131,7 @@ GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.
|
||||||
||CEIL|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/math/ceil)
|
||CEIL|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/math/ceil)
|
||||||
||SEQUENCE_MASK|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/sequence_mask)
|
||SEQUENCE_MASK|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/sequence_mask)
|
||||||
||REPEAT|Planned 21Q4|[tf.repeat](https://tensorflow.google.cn/api_docs/python/tf/repeat)
|
||REPEAT|Planned 21Q4|[tf.repeat](https://tensorflow.google.cn/api_docs/python/tf/repeat)
|
||||||
||ONE_HOT|Planned 21Q4|[tf.one_hot](https://tensorflow.google.cn/api_docs/python/tf/one_hot)
|
OneHot|ONE_HOT|Mapped|[tf.one_hot](https://tensorflow.google.cn/api_docs/python/tf/one_hot)
|
||||||
||NMS|Planned 21Q4|[tf.image.non_max_suppression](https://tensorflow.google.cn/api_docs/python/tf/image/non_max_suppression)
|
||NMS|Planned 21Q4|[tf.image.non_max_suppression](https://tensorflow.google.cn/api_docs/python/tf/image/non_max_suppression)
|
||||||
||SCATTER_ND_UPDATE|Planned 21Q4|[tf.compat.v1.scatter_nd_update](https://tensorflow.google.cn/api_docs/python/tf/compat/v1/scatter_nd_update)
|
||SCATTER_ND_UPDATE|Planned 21Q4|[tf.compat.v1.scatter_nd_update](https://tensorflow.google.cn/api_docs/python/tf/compat/v1/scatter_nd_update)
|
||||||
||GELU|Planned 21Q4|[tf.nn.gelu](https://tensorflow.google.cn/api_docs/python/tf/nn/gelu)
|
||GELU|Planned 21Q4|[tf.nn.gelu](https://tensorflow.google.cn/api_docs/python/tf/nn/gelu)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2020 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#include "tim/vx/ops/onehot.h"
|
||||||
|
|
||||||
|
#include "direct_map_op_impl.h"
|
||||||
|
#include "vsi_nn_pub.h"
|
||||||
|
|
||||||
|
namespace tim {
|
||||||
|
namespace vx {
|
||||||
|
namespace ops {
|
||||||
|
OneHot::OneHot(Graph* graph, int32_t depth, float on_value, float off_value,
|
||||||
|
int32_t axis)
|
||||||
|
: DirectMapOp(graph, VSI_NN_OP_ONE_HOT),
|
||||||
|
depth_(depth),
|
||||||
|
on_value_(on_value),
|
||||||
|
off_value_(off_value),
|
||||||
|
axis_(axis) {
|
||||||
|
this->impl()->node()->nn_param.one_hot.depth = depth_;
|
||||||
|
this->impl()->node()->nn_param.one_hot.on_value = on_value_;
|
||||||
|
this->impl()->node()->nn_param.one_hot.off_value = off_value_;
|
||||||
|
this->impl()->node()->nn_param.one_hot.axis = axis_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Operation> OneHot::Clone(std::shared_ptr<Graph>& graph) const {
|
||||||
|
return graph->CreateOperation<OneHot>(this->depth_, this->on_value_,
|
||||||
|
this->off_value_, this->axis_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace vx
|
||||||
|
} // namespace tim
|
||||||
|
|
@ -0,0 +1,326 @@
|
||||||
|
/****************************************************************************
|
||||||
|
*
|
||||||
|
* Copyright (c) 2021 Vivante Corporation
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
* copy of this software and associated documentation files (the "Software"),
|
||||||
|
* to deal in the Software without restriction, including without limitation
|
||||||
|
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
* and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
* Software is furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
* DEALINGS IN THE SOFTWARE.
|
||||||
|
*
|
||||||
|
*****************************************************************************/
|
||||||
|
#include "tim/vx/context.h"
|
||||||
|
#include "tim/vx/graph.h"
|
||||||
|
#include "tim/vx/ops/onehot.h"
|
||||||
|
#include "tim/vx/types.h"
|
||||||
|
#include "test_utils.h"
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
TEST(OneHot, shape_3_out_flaot_depth_3) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
int32_t depth = 3;
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||||
|
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
|
||||||
|
{3, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<int32_t> input_data = {0, 1, 2};
|
||||||
|
|
||||||
|
std::vector<float> golden = {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||||
|
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
std::vector<float> output(9);
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHot, shape_3_out_int32_depth_3) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
int32_t depth = 3;
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||||
|
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||||
|
{3, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<int32_t> input_data = {0, 1, 2};
|
||||||
|
|
||||||
|
std::vector<int32_t> golden = {1, 0, 0, 0, 1, 0, 0, 0, 1};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||||
|
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
std::vector<int32_t> output(9);
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHot, shape_3_out_int8_depth_3) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
int32_t depth = 3;
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||||
|
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT8,
|
||||||
|
{3, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<int32_t> input_data = {0, 1, 2};
|
||||||
|
|
||||||
|
std::vector<int8_t> golden = {1, 0, 0, 0, 1, 0, 0, 0, 1};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size()));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||||
|
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
std::vector<int8_t> output(9);
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHot, shape_3_out_uint8_depth_3) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
int32_t depth = 3;
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||||
|
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::UINT8,
|
||||||
|
{3, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<int32_t> input_data = {0, 1, 2};
|
||||||
|
|
||||||
|
std::vector<uint8_t> golden = {1, 0, 0, 0, 1, 0, 0, 0, 1};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size()));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||||
|
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
std::vector<uint8_t> output(9);
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHot, shape_3_out_int32_depth_1) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
int32_t depth = 1;
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||||
|
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||||
|
{3, 1}, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<int32_t> input_data = {0, 1, 2};
|
||||||
|
|
||||||
|
std::vector<int32_t> golden = {1, 0, 0};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||||
|
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
std::vector<int32_t> output(3);
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHot, shape_3_out_int32_depth_4) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
int32_t depth = 4;
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||||
|
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||||
|
{3, 4}, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<int32_t> input_data = {0, 1, 2};
|
||||||
|
|
||||||
|
std::vector<int32_t> golden = {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||||
|
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
std::vector<int32_t> output(12);
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHot, shape_3_out_int32_depth_3_on_6_off_N1) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
int32_t depth = 3;
|
||||||
|
float on = 6;
|
||||||
|
float off = -1;
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({4});//AKA: indices
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||||
|
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||||
|
{4, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<int32_t> input_data = {0, 2, -1, 1};
|
||||||
|
|
||||||
|
std::vector<int32_t> golden = {6, -1, -1, -1, -1, 6, -1, -1, -1, -1, 6, -1};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth, on, off);
|
||||||
|
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
std::vector<int32_t> output(12);
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHot, shape_3_out_int32_depth_3_on_5_off_0_axis_1) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
int32_t depth = 3;
|
||||||
|
float on = 5;
|
||||||
|
float off = 0;
|
||||||
|
int32_t axis = 1;
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({4});//AKA: indices
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||||
|
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||||
|
{4, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<int32_t> input_data = {0, 2, -1, 1};
|
||||||
|
|
||||||
|
std::vector<int32_t> golden = {5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth, on, off, axis);
|
||||||
|
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
std::vector<int32_t> output(12);
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHot, shape_2_2_out_int32_depth_3_on_2_off_0) {
|
||||||
|
auto ctx = tim::vx::Context::Create();
|
||||||
|
auto graph = ctx->CreateGraph();
|
||||||
|
|
||||||
|
int32_t depth = 3;
|
||||||
|
float on = 2;
|
||||||
|
float off = 0;
|
||||||
|
int32_t axis = 0;
|
||||||
|
|
||||||
|
tim::vx::ShapeType input_shape({2, 2});//AKA: indices
|
||||||
|
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||||
|
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||||
|
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||||
|
{2, 2, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||||
|
|
||||||
|
auto input_tensor = graph->CreateTensor(input_spec);
|
||||||
|
auto output_tensor = graph->CreateTensor(output_spec);
|
||||||
|
|
||||||
|
std::vector<int32_t> input_data = {0, 2, 1, -1};
|
||||||
|
|
||||||
|
std::vector<int32_t> golden = {2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0};
|
||||||
|
|
||||||
|
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||||
|
|
||||||
|
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth, on, off, axis);
|
||||||
|
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||||
|
|
||||||
|
EXPECT_TRUE(graph->Compile());
|
||||||
|
EXPECT_TRUE(graph->Run());
|
||||||
|
std::vector<int32_t> output(12);
|
||||||
|
|
||||||
|
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||||
|
EXPECT_EQ(golden, output);
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue