Enable SetRoundingPolicy (#426)
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com> Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
e047fce59f
commit
24fa582a56
|
|
@ -43,13 +43,13 @@ class DirectMapOpImpl : public OpImpl {
|
||||||
DirectMapOpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override;
|
DirectMapOpImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override;
|
||||||
|
|
||||||
vsi_nn_node_t* node() override { return this->node_; }
|
vsi_nn_node_t* node() override { return this->node_; }
|
||||||
void SetNode(vsi_nn_node_t* node) {this->node_ = node; }
|
void SetNode(vsi_nn_node_t* node) { this->node_ = node; }
|
||||||
|
|
||||||
void SetRoundingPolicy(
|
void SetRoundingPolicy(
|
||||||
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||||
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
||||||
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
||||||
uint32_t accumulator_bits = 0);
|
uint32_t accumulator_bits = 0) override;
|
||||||
|
|
||||||
std::vector<std::shared_ptr<Tensor>> InputsTensor() override {
|
std::vector<std::shared_ptr<Tensor>> InputsTensor() override {
|
||||||
return inputs_tensor_;
|
return inputs_tensor_;
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,9 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import platform
|
|
||||||
import datetime
|
|
||||||
import re
|
import re
|
||||||
import string
|
|
||||||
|
|
||||||
def checkFile(path):
|
def checkFile(path):
|
||||||
return os.path.isfile(path)
|
return os.path.isfile(path)
|
||||||
|
|
@ -77,9 +74,7 @@ def main():
|
||||||
args = sys.argv
|
args = sys.argv
|
||||||
argc = len(args)
|
argc = len(args)
|
||||||
target_path = "."
|
target_path = "."
|
||||||
tool = ""
|
|
||||||
source = ""
|
source = ""
|
||||||
tool_opt = ["-s", "-V"]
|
|
||||||
if argc <= 1:
|
if argc <= 1:
|
||||||
usage()
|
usage()
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -38,5 +38,14 @@ OpImpl::OpImpl(Graph* graph, DataLayout layout)
|
||||||
: graph_(reinterpret_cast<GraphImpl*>(graph)),
|
: graph_(reinterpret_cast<GraphImpl*>(graph)),
|
||||||
layout_(layout) {}
|
layout_(layout) {}
|
||||||
|
|
||||||
|
void OpImpl::SetRoundingPolicy(OverflowPolicy overflow_policy,
|
||||||
|
RoundingPolicy rounding_policy,
|
||||||
|
RoundType down_scale_size_roundin,
|
||||||
|
uint32_t accumulator_bits) {
|
||||||
|
(void)overflow_policy;
|
||||||
|
(void)rounding_policy;
|
||||||
|
(void)down_scale_size_roundin;
|
||||||
|
(void)accumulator_bits;
|
||||||
|
}
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,11 @@ class OpImpl {
|
||||||
virtual std::vector<std::shared_ptr<Tensor>> InputsTensor() = 0;
|
virtual std::vector<std::shared_ptr<Tensor>> InputsTensor() = 0;
|
||||||
virtual std::vector<std::shared_ptr<Tensor>> OutputsTensor() = 0;
|
virtual std::vector<std::shared_ptr<Tensor>> OutputsTensor() = 0;
|
||||||
virtual vsi_nn_node_t* node() = 0;
|
virtual vsi_nn_node_t* node() = 0;
|
||||||
|
virtual void SetRoundingPolicy(
|
||||||
|
OverflowPolicy overflow_policy = OverflowPolicy::SATURATE,
|
||||||
|
RoundingPolicy rounding_policy = RoundingPolicy::RTNE,
|
||||||
|
RoundType down_scale_size_rounding = RoundType::FLOOR,
|
||||||
|
uint32_t accumulator_bits = 0);
|
||||||
|
|
||||||
GraphImpl* graph_{nullptr};
|
GraphImpl* graph_{nullptr};
|
||||||
uint32_t kind_{0};
|
uint32_t kind_{0};
|
||||||
|
|
|
||||||
|
|
@ -54,9 +54,8 @@ Operation& Operation::BindOutput(const std::shared_ptr<Tensor>& tensor) {
|
||||||
Operation& Operation::SetRoundingPolicy(
|
Operation& Operation::SetRoundingPolicy(
|
||||||
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
|
OverflowPolicy overflow_policy, RoundingPolicy rounding_policy,
|
||||||
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
|
RoundType down_scale_size_rounding, uint32_t accumulator_bits) {
|
||||||
// impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
|
impl_->SetRoundingPolicy(overflow_policy, rounding_policy,
|
||||||
// down_scale_size_rounding, accumulator_bits);
|
down_scale_size_rounding, accumulator_bits);
|
||||||
(void) overflow_policy;(void) rounding_policy;(void) down_scale_size_rounding;(void) accumulator_bits;
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue