From 24fa582a560983cba107ce08b85df18ac373423d Mon Sep 17 00:00:00 2001 From: liyuenan <37231553+liyuenan2333@users.noreply.github.com> Date: Wed, 6 Jul 2022 17:03:54 +0800 Subject: [PATCH] Enable SetRoundingPolicy (#426) Signed-off-by: yuenan.li Co-authored-by: yuenan.li --- src/tim/vx/direct_map_op_impl.h | 4 ++-- src/tim/vx/internal/ConvertPGMToH.py | 7 +------ src/tim/vx/op_impl.cc | 9 +++++++++ src/tim/vx/op_impl.h | 5 +++++ src/tim/vx/operation.cc | 5 ++--- 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/tim/vx/direct_map_op_impl.h b/src/tim/vx/direct_map_op_impl.h index 06aa723..0d0d6d9 100644 --- a/src/tim/vx/direct_map_op_impl.h +++ b/src/tim/vx/direct_map_op_impl.h @@ -43,13 +43,13 @@ class DirectMapOpImpl : public OpImpl { DirectMapOpImpl& BindOutput(const std::shared_ptr& tensor) override; vsi_nn_node_t* node() override { return this->node_; } - void SetNode(vsi_nn_node_t* node) {this->node_ = node; } + void SetNode(vsi_nn_node_t* node) { this->node_ = node; } void SetRoundingPolicy( OverflowPolicy overflow_policy = OverflowPolicy::SATURATE, RoundingPolicy rounding_policy = RoundingPolicy::RTNE, RoundType down_scale_size_rounding = RoundType::FLOOR, - uint32_t accumulator_bits = 0); + uint32_t accumulator_bits = 0) override; std::vector> InputsTensor() override { return inputs_tensor_; diff --git a/src/tim/vx/internal/ConvertPGMToH.py b/src/tim/vx/internal/ConvertPGMToH.py index 046faed..1c1cce0 100644 --- a/src/tim/vx/internal/ConvertPGMToH.py +++ b/src/tim/vx/internal/ConvertPGMToH.py @@ -1,12 +1,9 @@ #!/usr/bin/env python -import subprocess import sys import os -import platform -import datetime import re -import string + def checkFile(path): return os.path.isfile(path) @@ -77,9 +74,7 @@ def main(): args = sys.argv argc = len(args) target_path = "." - tool = "" source = "" - tool_opt = ["-s", "-V"] if argc <= 1: usage() return diff --git a/src/tim/vx/op_impl.cc b/src/tim/vx/op_impl.cc index 68f280f..6430841 100644 --- a/src/tim/vx/op_impl.cc +++ b/src/tim/vx/op_impl.cc @@ -38,5 +38,14 @@ OpImpl::OpImpl(Graph* graph, DataLayout layout) : graph_(reinterpret_cast(graph)), layout_(layout) {} +void OpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy, + RoundingPolicy rounding_policy, + RoundType down_scale_size_roundin, + uint32_t accumulator_bits) { + (void)overflow_policy; + (void)rounding_policy; + (void)down_scale_size_roundin; + (void)accumulator_bits; +} } // namespace vx } // namespace tim diff --git a/src/tim/vx/op_impl.h b/src/tim/vx/op_impl.h index b27f320..2ac0825 100644 --- a/src/tim/vx/op_impl.h +++ b/src/tim/vx/op_impl.h @@ -43,6 +43,11 @@ class OpImpl { virtual std::vector> InputsTensor() = 0; virtual std::vector> OutputsTensor() = 0; virtual vsi_nn_node_t* node() = 0; + virtual void SetRoundingPolicy( + OverflowPolicy overflow_policy = OverflowPolicy::SATURATE, + RoundingPolicy rounding_policy = RoundingPolicy::RTNE, + RoundType down_scale_size_rounding = RoundType::FLOOR, + uint32_t accumulator_bits = 0); GraphImpl* graph_{nullptr}; uint32_t kind_{0}; diff --git a/src/tim/vx/operation.cc b/src/tim/vx/operation.cc index 3369d92..fd3d7e9 100644 --- a/src/tim/vx/operation.cc +++ b/src/tim/vx/operation.cc @@ -54,9 +54,8 @@ Operation& Operation::BindOutput(const std::shared_ptr& tensor) { Operation& Operation::SetRoundingPolicy( OverflowPolicy overflow_policy, RoundingPolicy rounding_policy, RoundType down_scale_size_rounding, uint32_t accumulator_bits) { - // impl_->SetRoundingPolicy(overflow_policy, rounding_policy, - // down_scale_size_rounding, accumulator_bits); - (void) overflow_policy;(void) rounding_policy;(void) down_scale_size_rounding;(void) accumulator_bits; + impl_->SetRoundingPolicy(overflow_policy, rounding_policy, + down_scale_size_rounding, accumulator_bits); return *this; }