Fix maxpoolgrad, hide unused pool value output

Type: Bug Fix
This commit is contained in:
Qin.Chen 2022-12-01 05:43:27 +00:00 committed by Sven
parent b7478f7872
commit 13da73bbe3
2 changed files with 67 additions and 34 deletions

View File

@ -37,7 +37,8 @@ namespace ops {
* ## MaxpooGrad
*
* Acquire the gradient of 2-D Max pooling operation's input tensor. \
* Like the tensorflow_XLA op SelectAndScatter, see https://tensorflow.google.cn/xla/operation_semantics?hl=en#selectandscatter.
* Like the tensorflow_XLA op SelectAndScatter, see \
* https://tensorflow.google.cn/xla/operation_semantics?hl=en#selectandscatter.
*
* - padding : AUTO, VALID or SAME.
* - ksize : filter size.
@ -48,6 +49,10 @@ namespace ops {
*
* - 0 : input tensor of 2-D Max pooling.
* - 1 : gradient of 2-D Max pooling output tensor.
*
* * Outputs:
*
* - 0 : updated tensor of 2-D Max pooling input.
*/
class MaxpoolGrad: public Operation {

View File

@ -36,10 +36,11 @@ namespace ops {
class MaxpoolGradImpl : public OpImpl {
public:
enum {
TENSOR_BEFORE_POOL = 0,
UPDATES_TENSOR,
INPUT_CNT,
OUT_CNT = 1,
POOL_INPUT_TENSOR = 0,
GRADIENT_TENSOR = 1,
INPUT_CNT = 2,
UPDATED_TENSOR = 0,
OUTPUT_CNT = 1,
};
MaxpoolGradImpl(Graph* graph, PadType padding,
const std::array<uint32_t, 2>& ksize,
@ -52,52 +53,73 @@ class MaxpoolGradImpl : public OpImpl {
ksize_(ksize),
stride_(stride),
round_type_(round_type) {
maxpoolwithargmax2_ = graph->CreateOperation<tim::vx::ops::MaxpoolWithArgmax2>(
padding_, ksize_, stride_, round_type_, layout_);
maxpoolwithargmax2_ =
graph->CreateOperation<tim::vx::ops::MaxpoolWithArgmax2>(
padding_, ksize_, stride_, round_type_, layout_);
}
~MaxpoolGradImpl() {}
MaxpoolGradImpl& BindInput(const std::shared_ptr<Tensor>& tensor) override {
in_tensors_[input_tensor_index] = tensor;
if (this->input_tensor_index == INPUT_CNT - 1) {
tim::vx::ShapeType in_shape = in_tensors_[TENSOR_BEFORE_POOL]->GetShape();
tim::vx::ShapeType updates_shape = in_tensors_[UPDATES_TENSOR]->GetShape();
tim::vx::ShapeType idx_flattened_shape({CalFlattenedShape(updates_shape)});
tim::vx::ShapeType in_shape = in_tensors_[POOL_INPUT_TENSOR]->GetShape();
tim::vx::ShapeType grad_shape = in_tensors_[GRADIENT_TENSOR]->GetShape();
tim::vx::ShapeType idx_flattened_shape({CalFlattenedShape(grad_shape)});
tim::vx::ShapeType out_flattened_shape({CalFlattenedShape(in_shape)});
tim::vx::TensorSpec pool_out_spec_indices(tim::vx::DataType::INT32,
updates_shape, tim::vx::TensorAttribute::TRANSIENT);
tim::vx::TensorSpec pool_out_spec_values(tim::vx::DataType::FLOAT32,
updates_shape, tim::vx::TensorAttribute::OUTPUT);
grad_shape, tim::vx::TensorAttribute::TRANSIENT);
tim::vx::TensorSpec pool_out_spec_indices(tim::vx::DataType::INT32,
grad_shape, tim::vx::TensorAttribute::TRANSIENT);
tim::vx::TensorSpec idx_flattened_spec(tim::vx::DataType::INT32,
idx_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
idx_flattened_shape,tim::vx::TensorAttribute::TRANSIENT);
tim::vx::TensorSpec upd_flattened_spec(tim::vx::DataType::FLOAT32,
idx_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
idx_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
tim::vx::TensorSpec out_flattened_spec(tim::vx::DataType::FLOAT32,
out_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
auto pool_out_indices_tensor = graph_->CreateTensor(pool_out_spec_indices);
out_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
auto pool_out_values_tensor = graph_->CreateTensor(pool_out_spec_values);
auto pool_out_indices_tensor = graph_->CreateTensor(pool_out_spec_indices);
auto idx_flattened_tensor = graph_->CreateTensor(idx_flattened_spec);
auto upd_flattened_tensor = graph_->CreateTensor(upd_flattened_spec);
auto out_flattened_tensor = graph_->CreateTensor(out_flattened_spec);
(*maxpoolwithargmax2_).BindInput(in_tensors_[TENSOR_BEFORE_POOL])
.BindOutputs({pool_out_values_tensor, pool_out_indices_tensor});
flatten_idx = graph_->CreateOperation<tim::vx::ops::Reshape>(idx_flattened_shape);
(*flatten_idx).BindInput(pool_out_indices_tensor).BindOutput(idx_flattened_tensor);
(*maxpoolwithargmax2_).BindInput(in_tensors_[POOL_INPUT_TENSOR])
.BindOutputs({pool_out_values_tensor, pool_out_indices_tensor});
flatten_upd = graph_->CreateOperation<tim::vx::ops::Reshape>(idx_flattened_shape);
(*flatten_upd).BindInput(in_tensors_[UPDATES_TENSOR]).BindOutput(upd_flattened_tensor);
// eliminate pool out of maxpoolwithargmax begin
tim::vx::TensorSpec sliced_spec(tim::vx::DataType::FLOAT32,
{1, 1, 1, 1}, tim::vx::TensorAttribute::TRANSIENT);
auto sliced_tensor = graph_->CreateTensor(sliced_spec);
auto one_zero_tensor = graph_->CreateTensor(sliced_spec);
auto grad_tensor = graph_->CreateTensor(pool_out_spec_values);
std::vector<int32_t> start = {0, 0, 0, 0};
std::vector<int32_t> length = {1, 1, 1, 1};
slice_one_ = graph_->CreateOperation<tim::vx::ops::Slice>(0, start, length);
(*slice_one_).BindInput(pool_out_values_tensor).BindOutput(sliced_tensor);
self_sub_ = graph_->CreateOperation<tim::vx::ops::Sub>();
(*self_sub_).BindInputs({sliced_tensor, sliced_tensor}).BindOutput(one_zero_tensor);
add_zeros_ = graph_->CreateOperation<tim::vx::ops::Add>();
(*add_zeros_).BindInputs({one_zero_tensor, in_tensors_[GRADIENT_TENSOR]})
.BindOutput(grad_tensor);
// eliminate pool out of maxpoolwithargmax end
flatten_idx_ = graph_->CreateOperation<tim::vx::ops::Reshape>(idx_flattened_shape);
(*flatten_idx_).BindInput(pool_out_indices_tensor).BindOutput(idx_flattened_tensor);
flatten_upd_ = graph_->CreateOperation<tim::vx::ops::Reshape>(idx_flattened_shape);
(*flatten_upd_).BindInput(grad_tensor).BindOutput(upd_flattened_tensor);
scatternd_ = graph_->CreateOperation<tim::vx::ops::ScatterND>(out_flattened_shape);
(*scatternd_).BindInputs({idx_flattened_tensor, upd_flattened_tensor}).BindOutput(out_flattened_tensor);
(*scatternd_).BindInputs({idx_flattened_tensor, upd_flattened_tensor})
.BindOutput(out_flattened_tensor);
reshape_like_input_ = graph_->CreateOperation<tim::vx::ops::Reshape>(in_shape);
(*reshape_like_input_).BindInput(out_flattened_tensor);
}
this->input_tensor_index++;
return *this;
@ -105,7 +127,9 @@ class MaxpoolGradImpl : public OpImpl {
MaxpoolGradImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override {
out_tensors_[output_tensor_index] = tensor;
(*reshape_like_input_).BindOutput(tensor);
if (this->output_tensor_index == OUTPUT_CNT - 1) {
(*reshape_like_input_).BindOutput(out_tensors_[UPDATED_TENSOR]);
}
this->output_tensor_index++;
return *this;
}
@ -126,12 +150,16 @@ class MaxpoolGradImpl : public OpImpl {
const RoundType round_type_;
std::shared_ptr<tim::vx::Operation> maxpoolwithargmax2_;
std::shared_ptr<tim::vx::Operation> flatten_idx;
std::shared_ptr<tim::vx::Operation> flatten_upd;
std::shared_ptr<tim::vx::Operation> slice_one_;
std::shared_ptr<tim::vx::Operation> self_sub_;
std::shared_ptr<tim::vx::Operation> add_zeros_;
std::shared_ptr<tim::vx::Operation> flatten_idx_;
std::shared_ptr<tim::vx::Operation> flatten_upd_;
std::shared_ptr<tim::vx::Operation> scatternd_;
std::shared_ptr<tim::vx::Operation> reshape_like_input_;
std::shared_ptr<tim::vx::Operation> reshape_pool_output_;
std::array<std::shared_ptr<tim::vx::Tensor>, INPUT_CNT> in_tensors_;
std::array<std::shared_ptr<tim::vx::Tensor>, OUT_CNT> out_tensors_;
std::array<std::shared_ptr<tim::vx::Tensor>, OUTPUT_CNT> out_tensors_;
uint32_t CalFlattenedShape(const tim::vx::ShapeType& shape) {
uint32_t out = 1;
for(auto& x: shape) {
@ -150,7 +178,7 @@ MaxpoolGrad::MaxpoolGrad(Graph* graph, PadType padding,
ksize_(ksize),
stride_(stride),
round_type_(round_type) {
impl_ = std::make_unique<MaxpoolGradImpl>(graph, padding, ksize, stride, 0, 0, round_type, layout);
impl_ = std::make_unique<MaxpoolGradImpl>(graph, padding, ksize, stride, 2, 1, round_type, layout);
}
std::shared_ptr<Operation> MaxpoolGrad::Clone(
@ -164,4 +192,4 @@ std::shared_ptr<Operation> MaxpoolGrad::Clone(
} // namespace vx
} // namespace tim
#endif //(VSI_FEAT_OP_MAXPOOLWITHARGMAX)
#endif //(VSI_FEAT_OP_MAXPOOLWITHARGMAX)