Support that op's all inputs are constant (#264)

Signed-off-by: yuenan.li <yuenan.li@verisilicon.com>

Co-authored-by: yuenan.li <yuenan.li@verisilicon.com>
This commit is contained in:
liyuenan 2022-01-14 12:34:38 +08:00 committed by GitHub
parent 36e6afa567
commit e2180a6341
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 52 additions and 6 deletions

View File

@ -76,11 +76,13 @@ class Graph {
virtual const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
std::shared_ptr<Tensor> tensor) const = 0;
virtual std::vector<std::shared_ptr<Operation>> GetProducerOp(
virtual std::shared_ptr<Operation> GetProducerOp(
std::shared_ptr<Tensor> tensor) = 0;
virtual void PrintGraph() const = 0;
const std::vector<std::shared_ptr<Tensor>> GetConstantInputs() const;
protected:
std::vector<std::shared_ptr<tim::vx::Operation>> op_vector_;
};

View File

@ -48,8 +48,9 @@ class Operation {
uint32_t accumulator_bits = 0);
std::unique_ptr<OpImpl>& impl();
const std::unique_ptr<OpImpl>& impl() const;
virtual const std::vector<std::shared_ptr<Tensor>> ConstantInputsTensor() const;
protected:
bool IsAllInputsConst() const;
std::unique_ptr<OpImpl> impl_;
};

View File

@ -85,6 +85,7 @@ class Conv2d : public DirectMapOp {
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
const std::vector<std::shared_ptr<Tensor>> ConstantInputsTensor() const override;
protected:
const uint32_t weights_;
const PadType padding_;

View File

@ -298,6 +298,16 @@ std::pair<std::shared_ptr<vx::Graph>,
MakeShared(t_src->GetShape().size()));
}
auto const_inputs = src_graph->GetConstantInputs();
for (auto const_in : const_inputs) {
auto input =
infer_graph->CreateTensor(const_in->GetSpec(), const_in->GetDataRef());
layout_infer_ctx->UpdateTensorMap(const_in, input);
tensor_queue.push_back(const_in);
layout_infer_ctx->SetPermuteVector(const_in,
MakeShared(const_in->GetShape().size()));
}
while (!tensor_queue.empty()) {
const auto& tensor = tensor_queue.front();
tensor_queue.pop_front();

View File

@ -35,6 +35,15 @@
namespace tim {
namespace vx {
const std::vector<std::shared_ptr<Tensor>> Graph::GetConstantInputs() const {
std::vector<std::shared_ptr<Tensor>> const_inputs;
for (auto op : op_vector_) {
auto const_i = op->ConstantInputsTensor();
const_inputs.insert(const_inputs.end(), const_i.begin(), const_i.end());
}
return const_inputs;
}
GraphImpl::GraphImpl(ContextImpl* context)
: context_(context),
graph_(vsi_nn_CreateGraph(context_->context(), 0, 0)),
@ -91,7 +100,7 @@ void GraphImpl::UpdateTensorProducerMap(const std::shared_ptr<Tensor>& tensor,
const Operation* op) {
for (const auto& added_op : op_vector_) {
if (added_op.get() == op) {
tensor_producer_[tensor].push_back(added_op);
tensor_producer_[tensor] = added_op;
}
}
}
@ -107,7 +116,7 @@ const std::vector<std::shared_ptr<Operation>> GraphImpl::GetConsumersOp(
}
}
std::vector<std::shared_ptr<Operation>> GraphImpl::GetProducerOp(
std::shared_ptr<Operation> GraphImpl::GetProducerOp(
std::shared_ptr<Tensor> tensor) {
auto producer = tensor_producer_.find(tensor);
if (tensor_producer_.end() != producer) {

View File

@ -60,7 +60,7 @@ class GraphImpl : public Graph {
const Operation* op) override;
const std::vector<std::shared_ptr<Operation>> GetConsumersOp(
std::shared_ptr<Tensor> tensor) const override;
std::vector<std::shared_ptr<Operation>> GetProducerOp(
std::shared_ptr<Operation> GetProducerOp(
std::shared_ptr<Tensor> tensor) override;
void PrintGraph() const override;
@ -87,7 +87,7 @@ class GraphImpl : public Graph {
std::vector<std::shared_ptr<Tensor>> inputs_tensor_;
std::vector<std::shared_ptr<Tensor>> outputs_tensor_;
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_consumers_;
std::map<std::shared_ptr<Tensor>, std::vector<std::shared_ptr<Operation>>> tensor_producer_;
std::map<std::shared_ptr<Tensor>, std::shared_ptr<Operation>> tensor_producer_;
};
} // namespace vx

View File

@ -76,5 +76,20 @@ Operation& Operation::BindOutputs(
return *this;
}
bool Operation::IsAllInputsConst() const{
for (auto tensor : impl_->inputs_tensor_) {
if (!tensor->IsConstTensor()) return false;
}
return true;
}
const std::vector<std::shared_ptr<Tensor>> Operation::ConstantInputsTensor() const{
if (this->IsAllInputsConst()) {
return impl_->inputs_tensor_;
} else {
return {};
}
}
} // namespace vx
} // namespace tim

View File

@ -88,6 +88,14 @@ std::shared_ptr<Operation> Conv2d::Clone(std::shared_ptr<Graph>& graph) const {
this->kernel_layout_);
}
const std::vector<std::shared_ptr<Tensor>> Conv2d::ConstantInputsTensor() const {
if (this->IsAllInputsConst()) {
return {this->impl_->inputs_tensor_[0]};
} else {
return {};
}
}
} // namespace ops
} // namespace vx
} // namespace tim