Map OneHot & unit test (#258)
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com> Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
8e4ab68213
commit
7c63ba621e
|
|
@ -54,6 +54,7 @@
|
|||
#include "tim/vx/ops/maxunpool2d.h"
|
||||
#include "tim/vx/ops/moments.h"
|
||||
#include "tim/vx/ops/nbg.h"
|
||||
#include "tim/vx/ops/onehot.h"
|
||||
#include "tim/vx/ops/pad.h"
|
||||
#include "tim/vx/ops/pool2d.h"
|
||||
#include "tim/vx/ops/reduce.h"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2020 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#ifndef TIM_VX_OPERATION_ONE_HOT_H_
|
||||
#define TIM_VX_OPERATION_ONE_HOT_H_
|
||||
#include "tim/vx/direct_map_op.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
|
||||
/**
|
||||
* ## OneHot
|
||||
*
|
||||
* Create a one-hot tensor.
|
||||
*
|
||||
* - depth : A scalar defining the depth of the one hot dimension.
|
||||
* - on_value : A scalar defining the value to fill in output.
|
||||
* - off_value : A scalar defining the value to fill in output.
|
||||
* - axis : The axis to fill.
|
||||
*/
|
||||
|
||||
class OneHot : public DirectMapOp {
|
||||
public:
|
||||
OneHot(Graph* graph, int32_t depth, float on_value = 1, float off_value = 0,
|
||||
int32_t axis = 0);
|
||||
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
|
||||
protected:
|
||||
int32_t depth_;
|
||||
float on_value_;
|
||||
float off_value_;
|
||||
int32_t axis_;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
#endif
|
||||
|
|
@ -131,7 +131,7 @@ GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.
|
|||
||CEIL|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/math/ceil)
|
||||
||SEQUENCE_MASK|Planned 21Q4|[tf.math.ceil](https://tensorflow.google.cn/api_docs/python/tf/sequence_mask)
|
||||
||REPEAT|Planned 21Q4|[tf.repeat](https://tensorflow.google.cn/api_docs/python/tf/repeat)
|
||||
||ONE_HOT|Planned 21Q4|[tf.one_hot](https://tensorflow.google.cn/api_docs/python/tf/one_hot)
|
||||
OneHot|ONE_HOT|Mapped|[tf.one_hot](https://tensorflow.google.cn/api_docs/python/tf/one_hot)
|
||||
||NMS|Planned 21Q4|[tf.image.non_max_suppression](https://tensorflow.google.cn/api_docs/python/tf/image/non_max_suppression)
|
||||
||SCATTER_ND_UPDATE|Planned 21Q4|[tf.compat.v1.scatter_nd_update](https://tensorflow.google.cn/api_docs/python/tf/compat/v1/scatter_nd_update)
|
||||
||GELU|Planned 21Q4|[tf.nn.gelu](https://tensorflow.google.cn/api_docs/python/tf/nn/gelu)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2020 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#include "tim/vx/ops/onehot.h"
|
||||
|
||||
#include "direct_map_op_impl.h"
|
||||
#include "vsi_nn_pub.h"
|
||||
|
||||
namespace tim {
|
||||
namespace vx {
|
||||
namespace ops {
|
||||
OneHot::OneHot(Graph* graph, int32_t depth, float on_value, float off_value,
|
||||
int32_t axis)
|
||||
: DirectMapOp(graph, VSI_NN_OP_ONE_HOT),
|
||||
depth_(depth),
|
||||
on_value_(on_value),
|
||||
off_value_(off_value),
|
||||
axis_(axis) {
|
||||
this->impl()->node()->nn_param.one_hot.depth = depth_;
|
||||
this->impl()->node()->nn_param.one_hot.on_value = on_value_;
|
||||
this->impl()->node()->nn_param.one_hot.off_value = off_value_;
|
||||
this->impl()->node()->nn_param.one_hot.axis = axis_;
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> OneHot::Clone(std::shared_ptr<Graph>& graph) const {
|
||||
return graph->CreateOperation<OneHot>(this->depth_, this->on_value_,
|
||||
this->off_value_, this->axis_);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
@ -0,0 +1,326 @@
|
|||
/****************************************************************************
|
||||
*
|
||||
* Copyright (c) 2021 Vivante Corporation
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
* DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*****************************************************************************/
|
||||
#include "tim/vx/context.h"
|
||||
#include "tim/vx/graph.h"
|
||||
#include "tim/vx/ops/onehot.h"
|
||||
#include "tim/vx/types.h"
|
||||
#include "test_utils.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
TEST(OneHot, shape_3_out_flaot_depth_3) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
int32_t depth = 3;
|
||||
|
||||
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
|
||||
{3, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<int32_t> input_data = {0, 1, 2};
|
||||
|
||||
std::vector<float> golden = {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<float> output(9);
|
||||
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(OneHot, shape_3_out_int32_depth_3) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
int32_t depth = 3;
|
||||
|
||||
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||
{3, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<int32_t> input_data = {0, 1, 2};
|
||||
|
||||
std::vector<int32_t> golden = {1, 0, 0, 0, 1, 0, 0, 0, 1};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<int32_t> output(9);
|
||||
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(OneHot, shape_3_out_int8_depth_3) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
int32_t depth = 3;
|
||||
|
||||
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT8,
|
||||
{3, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<int32_t> input_data = {0, 1, 2};
|
||||
|
||||
std::vector<int8_t> golden = {1, 0, 0, 0, 1, 0, 0, 0, 1};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size()));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<int8_t> output(9);
|
||||
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(OneHot, shape_3_out_uint8_depth_3) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
int32_t depth = 3;
|
||||
|
||||
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::UINT8,
|
||||
{3, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<int32_t> input_data = {0, 1, 2};
|
||||
|
||||
std::vector<uint8_t> golden = {1, 0, 0, 0, 1, 0, 0, 0, 1};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size()));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<uint8_t> output(9);
|
||||
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(OneHot, shape_3_out_int32_depth_1) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
int32_t depth = 1;
|
||||
|
||||
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||
{3, 1}, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<int32_t> input_data = {0, 1, 2};
|
||||
|
||||
std::vector<int32_t> golden = {1, 0, 0};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<int32_t> output(3);
|
||||
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(OneHot, shape_3_out_int32_depth_4) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
int32_t depth = 4;
|
||||
|
||||
tim::vx::ShapeType input_shape({3});//AKA: indices
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||
{3, 4}, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<int32_t> input_data = {0, 1, 2};
|
||||
|
||||
std::vector<int32_t> golden = {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth);
|
||||
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<int32_t> output(12);
|
||||
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(OneHot, shape_3_out_int32_depth_3_on_6_off_N1) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
int32_t depth = 3;
|
||||
float on = 6;
|
||||
float off = -1;
|
||||
|
||||
tim::vx::ShapeType input_shape({4});//AKA: indices
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||
{4, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<int32_t> input_data = {0, 2, -1, 1};
|
||||
|
||||
std::vector<int32_t> golden = {6, -1, -1, -1, -1, 6, -1, -1, -1, -1, 6, -1};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth, on, off);
|
||||
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<int32_t> output(12);
|
||||
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(OneHot, shape_3_out_int32_depth_3_on_5_off_0_axis_1) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
int32_t depth = 3;
|
||||
float on = 5;
|
||||
float off = 0;
|
||||
int32_t axis = 1;
|
||||
|
||||
tim::vx::ShapeType input_shape({4});//AKA: indices
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||
{4, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<int32_t> input_data = {0, 2, -1, 1};
|
||||
|
||||
std::vector<int32_t> golden = {5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth, on, off, axis);
|
||||
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<int32_t> output(12);
|
||||
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
|
||||
TEST(OneHot, shape_2_2_out_int32_depth_3_on_2_off_0) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
int32_t depth = 3;
|
||||
float on = 2;
|
||||
float off = 0;
|
||||
int32_t axis = 0;
|
||||
|
||||
tim::vx::ShapeType input_shape({2, 2});//AKA: indices
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::INT32,
|
||||
input_shape, tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::INT32,
|
||||
{2, 2, 3}, tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<int32_t> input_data = {0, 2, 1, -1};
|
||||
|
||||
std::vector<int32_t> golden = {2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0};
|
||||
|
||||
EXPECT_TRUE(input_tensor->CopyDataToTensor(input_data.data(), input_data.size() * 4));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::OneHot>(depth, on, off, axis);
|
||||
(*op).BindInput(input_tensor).BindOutput(output_tensor);
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<int32_t> output(12);
|
||||
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_EQ(golden, output);
|
||||
}
|
||||
Loading…
Reference in New Issue