Added case for hardswish
Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
parent
e71d537042
commit
f0a0f1728a
|
|
@ -457,4 +457,48 @@ TEST(SoftSign, shape_5_1_fp32) {
|
|||
std::vector<float> output(5, 0);
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
|
||||
}
|
||||
|
||||
TEST(HardSwish, 40_f32) {
|
||||
auto ctx = tim::vx::Context::Create();
|
||||
auto graph = ctx->CreateGraph();
|
||||
|
||||
tim::vx::ShapeType io_shape({40});
|
||||
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, io_shape,
|
||||
tim::vx::TensorAttribute::INPUT);
|
||||
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32, io_shape,
|
||||
tim::vx::TensorAttribute::OUTPUT);
|
||||
|
||||
auto input_tensor = graph->CreateTensor(input_spec);
|
||||
auto output_tensor = graph->CreateTensor(output_spec);
|
||||
|
||||
std::vector<float> in_data = {
|
||||
4.53125f, 3.90625f, 3.046875f, -8.59375f, -1.328125f, 1.328125f,
|
||||
0.0f, -8.515625f, -8.984375f, -0.234375f, 0.859375f, 9.84375f,
|
||||
-0.15625f, -8.515625f, 8.671875f, 4.609375f, 9.21875f, -1.796875f,
|
||||
1.171875f, 9.375f, -8.75f, 2.421875f, -8.125f, -1.09375f,
|
||||
-9.609375f, -1.015625f, -9.84375f, 2.578125f, 4.921875f, -5.078125f,
|
||||
5.0f, -0.859375f, 1.953125f, -6.640625f, -7.8125f, 4.453125f,
|
||||
-4.453125f, -6.875f, 0.78125f, 0.859375f};
|
||||
std::vector<float> golden = {
|
||||
4.53125f, 3.90625f, 3.046875f, 0.0f, -0.3700765f,
|
||||
0.9580485f, 0.0f, 0.0f, 0.0f, -0.1080322f,
|
||||
0.5527751f, 9.84375f, -0.074056f, 0.0f, 8.671875f,
|
||||
4.609375f, 9.21875f, -0.3603109f, 0.8148193f, 9.375f,
|
||||
0.0f, 2.1885173f, 0.0f, -0.3474935f, 0.0f,
|
||||
-0.3358968f, 0.0f, 2.3968506f, 4.921875f, 0.0f,
|
||||
5.0f, -0.3065999f, 1.6123454f, 0.0f, 0.0f,
|
||||
4.453125f, 0.0f, 0.0f, 0.4923503f, 0.5527751f};
|
||||
|
||||
EXPECT_TRUE(
|
||||
input_tensor->CopyDataToTensor(in_data.data(), in_data.size() * 4));
|
||||
|
||||
auto op = graph->CreateOperation<tim::vx::ops::HardSwish>();
|
||||
(*op).BindInputs({input_tensor}).BindOutputs({output_tensor});
|
||||
|
||||
EXPECT_TRUE(graph->Compile());
|
||||
EXPECT_TRUE(graph->Run());
|
||||
std::vector<float> output(40);
|
||||
EXPECT_TRUE(output_tensor->CopyDataFromTensor(output.data()));
|
||||
EXPECT_TRUE(ArraysMatch(golden, output, 1e-5f));
|
||||
}
|
||||
Loading…
Reference in New Issue