Enable SetRoundingPolicy (#426)
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com> Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
e047fce59f
commit
24fa582a56
|
|
@ -43,13 +43,13 @@ class DirectMapOpImpl : public OpImpl {
|
|||
DirectMapOpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override;
|
||||
|
||||
vsi_nn_node_t* node() override { return this->node_; }
|
||||
void SetNode(vsi_nn_node_t* node) {this->node_ = node; }
|
||||
void SetNode(vsi_nn_node_t* node) { this->node_ = node; }
|
||||
|
||||
void SetRoundingPolicy(
|
||||
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
||||
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
||||
uint32_t accumulator_bits = 0);
|
||||
uint32_t accumulator_bits = 0) override;
|
||||
|
||||
std::vector<std::shared_ptr<Tensor>> InputsTensor() override {
|
||||
return inputs_tensor_;
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import platform
|
||||
import datetime
|
||||
import re
|
||||
import string
|
||||
|
||||
|
||||
def checkFile(path):
|
||||
return os.path.isfile(path)
|
||||
|
|
@ -77,9 +74,7 @@ def main():
|
|||
args = sys.argv
|
||||
argc = len(args)
|
||||
target_path = "."
|
||||
tool = ""
|
||||
source = ""
|
||||
tool_opt = ["-s", "-V"]
|
||||
if argc <= 1:
|
||||
usage()
|
||||
return
|
||||
|
|
|
|||
|
|
@ -38,5 +38,14 @@ OpImpl::OpImpl(Graph* graph, DataLayout layout)
|
|||
: graph_(reinterpret_cast<GraphImpl*>(graph)),
|
||||
layout_(layout) {}
|
||||
|
||||
void OpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy,
|
||||
RoundingPolicy rounding_policy,
|
||||
RoundType down_scale_size_roundin,
|
||||
uint32_t accumulator_bits) {
|
||||
(void)overflow_policy;
|
||||
(void)rounding_policy;
|
||||
(void)down_scale_size_roundin;
|
||||
(void)accumulator_bits;
|
||||
}
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
|
|||
|
|
@ -43,6 +43,11 @@ class OpImpl {
|
|||
virtual std::vector<std::shared_ptr<Tensor>> InputsTensor() = 0;
|
||||
virtual std::vector<std::shared_ptr<Tensor>> OutputsTensor() = 0;
|
||||
virtual vsi_nn_node_t* node() = 0;
|
||||
virtual void SetRoundingPolicy(
|
||||
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
||||
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
||||
uint32_t accumulator_bits = 0);
|
||||
|
||||
GraphImpl* graph_{nullptr};
|
||||
uint32_t kind_{0};
|
||||
|
|
|
|||
|
|
@ -54,9 +54,8 @@ Operation& Operation::BindOutput(const std::shared_ptr<Tensor>& tensor) {
|
|||
Operation& Operation::SetRoundingPolicy(
|
||||
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
|
||||
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
|
||||
// impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
|
||||
// down_scale_size_rounding, accumulator_bits);
|
||||
(void) overflow_policy;(void) rounding_policy;(void) down_scale_size_rounding;(void) accumulator_bits;
|
||||
impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
|
||||
down_scale_size_rounding, accumulator_bits);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue