Support that op's all inputs are constant (#264)
Signed-off-by: yuenan.li <yuenan.li@verisilicon.com> Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
parent
36e6afa567
commit
e2180a6341
|
|
@ -76,11 +76,13 @@ class Graph {
|
|||
virtual const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
|
||||
std::shared_ptr<Tensor> tensor) const = 0;
|
||||
|
||||
virtual std::vector<std::shared_ptr<Operation>> GetProducerOp(
|
||||
virtual std::shared_ptr<Operation> GetProducerOp(
|
||||
std::shared_ptr<Tensor> tensor) = 0;
|
||||
|
||||
virtual void PrintGraph() const = 0;
|
||||
|
||||
const std::vector<std::shared_ptr<Tensor>> GetConstantInputs() const;
|
||||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<tim::vx::Operation>> op_vector_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -48,8 +48,9 @@ class Operation {
|
|||
uint32_t accumulator_bits = 0);
|
||||
std::unique_ptr<OpImpl>& impl();
|
||||
const std::unique_ptr<OpImpl>& impl() const;
|
||||
|
||||
virtual const std::vector<std::shared_ptr<Tensor>> ConstantInputsTensor() const;
|
||||
protected:
|
||||
bool IsAllInputsConst() const;
|
||||
std::unique_ptr<OpImpl> impl_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ class Conv2d : public DirectMapOp {
|
|||
|
||||
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
|
||||
|
||||
const std::vector<std::shared_ptr<Tensor>> ConstantInputsTensor() const override;
|
||||
protected:
|
||||
const uint32_t weights_;
|
||||
const PadType padding_;
|
||||
|
|
|
|||
|
|
@ -298,6 +298,16 @@ std::pair<std::shared_ptr<vx::Graph>,
|
|||
MakeShared(t_src->GetShape().size()));
|
||||
}
|
||||
|
||||
auto const_inputs = src_graph->GetConstantInputs();
|
||||
for (auto const_in : const_inputs) {
|
||||
auto input =
|
||||
infer_graph->CreateTensor(const_in->GetSpec(), const_in->GetDataRef());
|
||||
layout_infer_ctx->UpdateTensorMap(const_in, input);
|
||||
tensor_queue.push_back(const_in);
|
||||
layout_infer_ctx->SetPermuteVector(const_in,
|
||||
MakeShared(const_in->GetShape().size()));
|
||||
}
|
||||
|
||||
while (!tensor_queue.empty()) {
|
||||
const auto& tensor = tensor_queue.front();
|
||||
tensor_queue.pop_front();
|
||||
|
|
|
|||
|
|
@ -35,6 +35,15 @@
|
|||
namespace tim {
|
||||
namespace vx {
|
||||
|
||||
const std::vector<std::shared_ptr<Tensor>> Graph::GetConstantInputs() const {
|
||||
std::vector<std::shared_ptr<Tensor>> const_inputs;
|
||||
for (auto op : op_vector_) {
|
||||
auto const_i = op->ConstantInputsTensor();
|
||||
const_inputs.insert(const_inputs.end(), const_i.begin(), const_i.end());
|
||||
}
|
||||
return const_inputs;
|
||||
}
|
||||
|
||||
GraphImpl::GraphImpl(ContextImpl* context)
|
||||
: context_(context),
|
||||
graph_(vsi_nn_CreateGraph(context_->context(), 0, 0)),
|
||||
|
|
@ -91,7 +100,7 @@ void GraphImpl::UpdateTensorProducerMap(const std::shared_ptr<Tensor>& tensor,
|
|||
const Operation* op) {
|
||||
for (const auto& added_op : op_vector_) {
|
||||
if (added_op.get() == op) {
|
||||
tensor_producer_[tensor].push_back(added_op);
|
||||
tensor_producer_[tensor] = added_op;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -107,7 +116,7 @@ const std::vector<std::shared_ptr<Operation>> GraphImpl::GetConsumersOp(
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Operation>> GraphImpl::GetProducerOp(
|
||||
std::shared_ptr<Operation> GraphImpl::GetProducerOp(
|
||||
std::shared_ptr<Tensor> tensor) {
|
||||
auto producer = tensor_producer_.find(tensor);
|
||||
if (tensor_producer_.end() != producer) {
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class GraphImpl : public Graph {
|
|||
const Operation* op) override;
|
||||
const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
|
||||
std::shared_ptr<Tensor> tensor) const override;
|
||||
std::vector<std::shared_ptr<Operation>> GetProducerOp(
|
||||
std::shared_ptr<Operation> GetProducerOp(
|
||||
std::shared_ptr<Tensor> tensor) override;
|
||||
|
||||
void PrintGraph() const override;
|
||||
|
|
@ -87,7 +87,7 @@ class GraphImpl : public Graph {
|
|||
std::vector<std::shared_ptr<Tensor>> inputs_tensor_;
|
||||
std::vector<std::shared_ptr<Tensor>> outputs_tensor_;
|
||||
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_consumers_;
|
||||
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_producer_;
|
||||
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>> tensor_producer_;
|
||||
};
|
||||
|
||||
} // namespace vx
|
||||
|
|
|
|||
|
|
@ -76,5 +76,20 @@ Operation& Operation::BindOutputs(
|
|||
return *this;
|
||||
}
|
||||
|
||||
bool Operation::IsAllInputsConst() const{
|
||||
for (auto tensor : impl_->inputs_tensor_) {
|
||||
if (!tensor->IsConstTensor()) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Tensor>> Operation::ConstantInputsTensor() const{
|
||||
if (this->IsAllInputsConst()) {
|
||||
return impl_->inputs_tensor_;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
|
@ -88,6 +88,14 @@ std::shared_ptr<Operation> Conv2d::Clone(std::shared_ptr<Graph>& graph) const {
|
|||
this->kernel_layout_);
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Tensor>> Conv2d::ConstantInputsTensor() const {
|
||||
if (this->IsAllInputsConst()) {
|
||||
return {this->impl_->inputs_tensor_[0]};
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace vx
|
||||
} // namespace tim
|
||||
Loading…
Reference in New Issue