Enable SetRoundingPolicy (#426)

Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>

Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
liyuenan 2022-07-06 17:03:54 +08:00 committed by GitHub
parent e047fce59f
commit 24fa582a56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 19 additions and 11 deletions

View File

@ -43,13 +43,13 @@ class DirectMapOpImpl : public OpImpl {
DirectMapOpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override;
vsi_nn_node_t* node() override { return this->node_; }
void SetNode(vsi_nn_node_t* node) {this->node_ = node; }
void SetNode(vsi_nn_node_t* node) { this->node_ = node; }
void SetRoundingPolicy(
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
RoundType down_scale_size_rounding = RoundType::FLOOR,
uint32_t accumulator_bits = 0);
uint32_t accumulator_bits = 0) override;
std::vector<std::shared_ptr<Tensor>> InputsTensor() override {
return inputs_tensor_;

View File

@ -1,12 +1,9 @@
#!/usr/bin/env python
import subprocess
import sys
import os
import platform
import datetime
import re
import string
def checkFile(path):
return os.path.isfile(path)
@ -77,9 +74,7 @@ def main():
args = sys.argv
argc = len(args)
target_path = "."
tool = ""
source = ""
tool_opt = ["-s", "-V"]
if argc <= 1:
usage()
return

View File

@ -38,5 +38,14 @@ OpImpl::OpImpl(Graph* graph, DataLayout layout)
: graph_(reinterpret_cast<GraphImpl*>(graph)),
layout_(layout) {}
void OpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy,
RoundingPolicy rounding_policy,
RoundType down_scale_size_roundin,
uint32_t accumulator_bits) {
(void)overflow_policy;
(void)rounding_policy;
(void)down_scale_size_roundin;
(void)accumulator_bits;
}
} // namespace vx
} // namespace tim

View File

@ -43,6 +43,11 @@ class OpImpl {
virtual std::vector<std::shared_ptr<Tensor>> InputsTensor() = 0;
virtual std::vector<std::shared_ptr<Tensor>> OutputsTensor() = 0;
virtual vsi_nn_node_t* node() = 0;
virtual void SetRoundingPolicy(
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
RoundType down_scale_size_rounding = RoundType::FLOOR,
uint32_t accumulator_bits = 0);
GraphImpl* graph_{nullptr};
uint32_t kind_{0};

View File

@ -54,9 +54,8 @@ Operation& Operation::BindOutput(const std::shared_ptr<Tensor>& tensor) {
Operation& Operation::SetRoundingPolicy(
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
// impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
// down_scale_size_rounding, accumulator_bits);
(void) overflow_policy;(void) rounding_policy;(void) down_scale_size_rounding;(void) accumulator_bits;
impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
down_scale_size_rounding, accumulator_bits);
return *this;
}