Support that op's all inputs are constant (#264)
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com> Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
36e6afa567
commit
e2180a6341
|
|
@ -76,11 +76,13 @@ class Graph {
|
||||||
virtual const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
|
virtual const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
|
||||||
std::shared_ptr<Tensor> tensor) const = 0;
|
std::shared_ptr<Tensor> tensor) const = 0;
|
||||||
|
|
||||||
virtual std::vector<std::shared_ptr<Operation>> GetProducerOp(
|
virtual std::shared_ptr<Operation> GetProducerOp(
|
||||||
std::shared_ptr<Tensor> tensor) = 0;
|
std::shared_ptr<Tensor> tensor) = 0;
|
||||||
|
|
||||||
virtual void PrintGraph() const = 0;
|
virtual void PrintGraph() const = 0;
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Tensor>> GetConstantInputs() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<std::shared_ptr<tim::vx::Operation>> op_vector_;
|
std::vector<std::shared_ptr<tim::vx::Operation>> op_vector_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -48,8 +48,9 @@ class Operation {
|
||||||
uint32_t accumulator_bits = 0);
|
uint32_t accumulator_bits = 0);
|
||||||
std::unique_ptr<OpImpl>& impl();
|
std::unique_ptr<OpImpl>& impl();
|
||||||
const std::unique_ptr<OpImpl>& impl() const;
|
const std::unique_ptr<OpImpl>& impl() const;
|
||||||
|
virtual const std::vector<std::shared_ptr<Tensor>> ConstantInputsTensor() const;
|
||||||
protected:
|
protected:
|
||||||
|
bool IsAllInputsConst() const;
|
||||||
std::unique_ptr<OpImpl> impl_;
|
std::unique_ptr<OpImpl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,7 @@ class Conv2d : public DirectMapOp {
|
||||||
|
|
||||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Tensor>> ConstantInputsTensor() const override;
|
||||||
protected:
|
protected:
|
||||||
const uint32_t weights_;
|
const uint32_t weights_;
|
||||||
const PadType padding_;
|
const PadType padding_;
|
||||||
|
|
|
||||||
|
|
@ -298,6 +298,16 @@ std::pair<std::shared_ptr<vx::Graph>,
|
||||||
MakeShared(t_src->GetShape().size()));
|
MakeShared(t_src->GetShape().size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto const_inputs = src_graph->GetConstantInputs();
|
||||||
|
for (auto const_in : const_inputs) {
|
||||||
|
auto input =
|
||||||
|
infer_graph->CreateTensor(const_in->GetSpec(), const_in->GetDataRef());
|
||||||
|
layout_infer_ctx->UpdateTensorMap(const_in, input);
|
||||||
|
tensor_queue.push_back(const_in);
|
||||||
|
layout_infer_ctx->SetPermuteVector(const_in,
|
||||||
|
MakeShared(const_in->GetShape().size()));
|
||||||
|
}
|
||||||
|
|
||||||
while (!tensor_queue.empty()) {
|
while (!tensor_queue.empty()) {
|
||||||
const auto& tensor = tensor_queue.front();
|
const auto& tensor = tensor_queue.front();
|
||||||
tensor_queue.pop_front();
|
tensor_queue.pop_front();
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,15 @@
|
||||||
namespace tim {
|
namespace tim {
|
||||||
namespace vx {
|
namespace vx {
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Tensor>> Graph::GetConstantInputs() const {
|
||||||
|
std::vector<std::shared_ptr<Tensor>> const_inputs;
|
||||||
|
for (auto op : op_vector_) {
|
||||||
|
auto const_i = op->ConstantInputsTensor();
|
||||||
|
const_inputs.insert(const_inputs.end(), const_i.begin(), const_i.end());
|
||||||
|
}
|
||||||
|
return const_inputs;
|
||||||
|
}
|
||||||
|
|
||||||
GraphImpl::GraphImpl(ContextImpl* context)
|
GraphImpl::GraphImpl(ContextImpl* context)
|
||||||
: context_(context),
|
: context_(context),
|
||||||
graph_(vsi_nn_CreateGraph(context_->context(), 0, 0)),
|
graph_(vsi_nn_CreateGraph(context_->context(), 0, 0)),
|
||||||
|
|
@ -91,7 +100,7 @@ void GraphImpl::UpdateTensorProducerMap(const std::shared_ptr<Tensor>& tensor,
|
||||||
const Operation* op) {
|
const Operation* op) {
|
||||||
for (const auto& added_op : op_vector_) {
|
for (const auto& added_op : op_vector_) {
|
||||||
if (added_op.get() == op) {
|
if (added_op.get() == op) {
|
||||||
tensor_producer_[tensor].push_back(added_op);
|
tensor_producer_[tensor] = added_op;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -107,7 +116,7 @@ const std::vector<std::shared_ptr<Operation>> GraphImpl::GetConsumersOp(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<Operation>> GraphImpl::GetProducerOp(
|
std::shared_ptr<Operation> GraphImpl::GetProducerOp(
|
||||||
std::shared_ptr<Tensor> tensor) {
|
std::shared_ptr<Tensor> tensor) {
|
||||||
auto producer = tensor_producer_.find(tensor);
|
auto producer = tensor_producer_.find(tensor);
|
||||||
if (tensor_producer_.end() != producer) {
|
if (tensor_producer_.end() != producer) {
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ class GraphImpl : public Graph {
|
||||||
const Operation* op) override;
|
const Operation* op) override;
|
||||||
const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
|
const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
|
||||||
std::shared_ptr<Tensor> tensor) const override;
|
std::shared_ptr<Tensor> tensor) const override;
|
||||||
std::vector<std::shared_ptr<Operation>> GetProducerOp(
|
std::shared_ptr<Operation> GetProducerOp(
|
||||||
std::shared_ptr<Tensor> tensor) override;
|
std::shared_ptr<Tensor> tensor) override;
|
||||||
|
|
||||||
void PrintGraph() const override;
|
void PrintGraph() const override;
|
||||||
|
|
@ -87,7 +87,7 @@ class GraphImpl : public Graph {
|
||||||
std::vector<std::shared_ptr<Tensor>> inputs_tensor_;
|
std::vector<std::shared_ptr<Tensor>> inputs_tensor_;
|
||||||
std::vector<std::shared_ptr<Tensor>> outputs_tensor_;
|
std::vector<std::shared_ptr<Tensor>> outputs_tensor_;
|
||||||
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_consumers_;
|
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_consumers_;
|
||||||
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_producer_;
|
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>> tensor_producer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
|
|
|
||||||
|
|
@ -76,5 +76,20 @@ Operation& Operation::BindOutputs(
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Operation::IsAllInputsConst() const{
|
||||||
|
for (auto tensor : impl_->inputs_tensor_) {
|
||||||
|
if (!tensor->IsConstTensor()) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Tensor>> Operation::ConstantInputsTensor() const{
|
||||||
|
if (this->IsAllInputsConst()) {
|
||||||
|
return impl_->inputs_tensor_;
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
|
|
@ -88,6 +88,14 @@ std::shared_ptr<Operation> Conv2d::Clone(std::shared_ptr<Graph>& graph) const {
|
||||||
this->kernel_layout_);
|
this->kernel_layout_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Tensor>> Conv2d::ConstantInputsTensor() const {
|
||||||
|
if (this->IsAllInputsConst()) {
|
||||||
|
return {this->impl_->inputs_tensor_[0]};
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace vx
|
} // namespace vx
|
||||||
} // namespace tim
|
} // namespace tim
|
||||||
Loading…
Reference in New Issue