From fe31a47bf9c6f9359fb27947ac4a8a6a6354851c Mon Sep 17 00:00:00 2001 From: liyuenan <37231553+liyuenan2333@users.noreply.github.com> Date: Mon, 21 Feb 2022 19:09:38 +0800 Subject: [PATCH] enable no bias in FC layout inference (#294) Signed-off-by: yuenan.li Co-authored-by: yuenan.li --- .../ops/fullyconnected_layout_inference.h | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/tim/transform/ops/fullyconnected_layout_inference.h b/src/tim/transform/ops/fullyconnected_layout_inference.h index b86c53d..ac5cd96 100644 --- a/src/tim/transform/ops/fullyconnected_layout_inference.h +++ b/src/tim/transform/ops/fullyconnected_layout_inference.h @@ -53,19 +53,15 @@ class FullyConnectedLayoutInfer : public OpLayoutInfer { context_->SetPermuteVector(in, trans_pv); } } - uint32_t axis = op_->impl()->node()->nn_param.fcl.axis; - uint32_t weight = op_->impl()->node()->nn_param.fcl.weights; - auto fcl = context_->infer_graph_->CreateOperation( - axis, weight); + auto fcl = op_->Clone(context_->infer_graph_); auto required_pv = MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size()); auto out_infer = CreateOutputsTensor(required_pv); - (*fcl) - .BindInputs({context_->GetMapedTensor(op_->impl()->InputsTensor()[0]), - context_->GetMapedTensor(op_->impl()->InputsTensor()[1]), - context_->GetMapedTensor(op_->impl()->InputsTensor()[2])}) - .BindOutput(out_infer[0]); + for (auto in : op_->impl()->InputsTensor()) { + (*fcl).BindInput(context_->GetMapedTensor(in)); + } + (*fcl).BindOutput(out_infer[0]); context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], required_pv); next_tensors.push_back(op_->impl()->OutputsTensor()[0]); }