Fix maxpoolgrad, hide unused pool value output
Type: Bug Fix
This commit is contained in:
parent
b7478f7872
commit
13da73bbe3
|
|
@ -37,7 +37,8 @@ namespace ops {
|
|||
* ## MaxpooGrad
|
||||
*
|
||||
* Acquire the gradient of 2-D Max pooling operation's input tensor. \
|
||||
* Like the tensorflow_XLA op SelectAndScatter, see https://tensorflow.google.cn/xla/operation_semantics?hl=en#selectandscatter.
|
||||
* Like the tensorflow_XLA op SelectAndScatter, see \
|
||||
* https://tensorflow.google.cn/xla/operation_semantics?hl=en#selectandscatter.
|
||||
*
|
||||
* - padding : AUTO, VALID or SAME.
|
||||
* - ksize : filter size.
|
||||
|
|
@ -48,6 +49,10 @@ namespace ops {
|
|||
*
|
||||
* - 0 : input tensor of 2-D Max pooling.
|
||||
* - 1 : gradient of 2-D Max pooling output tensor.
|
||||
*
|
||||
* * Outputs:
|
||||
*
|
||||
* - 0 : updated tensor of 2-D Max pooling input.
|
||||
*/
|
||||
|
||||
class MaxpoolGrad: public Operation {
|
||||
|
|
|
|||
|
|
@ -36,10 +36,11 @@ namespace ops {
|
|||
class MaxpoolGradImpl : public OpImpl {
|
||||
public:
|
||||
enum {
|
||||
TENSOR_BEFORE_POOL = 0,
|
||||
UPDATES_TENSOR,
|
||||
INPUT_CNT,
|
||||
OUT_CNT = 1,
|
||||
POOL_INPUT_TENSOR = 0,
|
||||
GRADIENT_TENSOR = 1,
|
||||
INPUT_CNT = 2,
|
||||
UPDATED_TENSOR = 0,
|
||||
OUTPUT_CNT = 1,
|
||||
};
|
||||
MaxpoolGradImpl(Graph* graph, PadType padding,
|
||||
const std::array<uint32_t, 2>& ksize,
|
||||
|
|
@ -52,52 +53,73 @@ class MaxpoolGradImpl : public OpImpl {
|
|||
ksize_(ksize),
|
||||
stride_(stride),
|
||||
round_type_(round_type) {
|
||||
maxpoolwithargmax2_ = graph->CreateOperation<tim::vx::ops::MaxpoolWithArgmax2>(
|
||||
padding_, ksize_, stride_, round_type_, layout_);
|
||||
maxpoolwithargmax2_ =
|
||||
graph->CreateOperation<tim::vx::ops::MaxpoolWithArgmax2>(
|
||||
padding_, ksize_, stride_, round_type_, layout_);
|
||||
}
|
||||
~MaxpoolGradImpl() {}
|
||||
|
||||
|
||||
|
||||
MaxpoolGradImpl& BindInput(const std::shared_ptr<Tensor>& tensor) override {
|
||||
in_tensors_[input_tensor_index] = tensor;
|
||||
if (this->input_tensor_index == INPUT_CNT - 1) {
|
||||
tim::vx::ShapeType in_shape = in_tensors_[TENSOR_BEFORE_POOL]->GetShape();
|
||||
tim::vx::ShapeType updates_shape = in_tensors_[UPDATES_TENSOR]->GetShape();
|
||||
tim::vx::ShapeType idx_flattened_shape({CalFlattenedShape(updates_shape)});
|
||||
tim::vx::ShapeType in_shape = in_tensors_[POOL_INPUT_TENSOR]->GetShape();
|
||||
tim::vx::ShapeType grad_shape = in_tensors_[GRADIENT_TENSOR]->GetShape();
|
||||
tim::vx::ShapeType idx_flattened_shape({CalFlattenedShape(grad_shape)});
|
||||
tim::vx::ShapeType out_flattened_shape({CalFlattenedShape(in_shape)});
|
||||
|
||||
tim::vx::TensorSpec pool_out_spec_indices(tim::vx::DataType::INT32,
|
||||
updates_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
tim::vx::TensorSpec pool_out_spec_values(tim::vx::DataType::FLOAT32,
|
||||
updates_shape, tim::vx::TensorAttribute::OUTPUT);
|
||||
grad_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
tim::vx::TensorSpec pool_out_spec_indices(tim::vx::DataType::INT32,
|
||||
grad_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
tim::vx::TensorSpec idx_flattened_spec(tim::vx::DataType::INT32,
|
||||
idx_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
idx_flattened_shape,tim::vx::TensorAttribute::TRANSIENT);
|
||||
tim::vx::TensorSpec upd_flattened_spec(tim::vx::DataType::FLOAT32,
|
||||
idx_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
idx_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
tim::vx::TensorSpec out_flattened_spec(tim::vx::DataType::FLOAT32,
|
||||
out_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
|
||||
auto pool_out_indices_tensor = graph_->CreateTensor(pool_out_spec_indices);
|
||||
out_flattened_shape, tim::vx::TensorAttribute::TRANSIENT);
|
||||
|
||||
auto pool_out_values_tensor = graph_->CreateTensor(pool_out_spec_values);
|
||||
auto pool_out_indices_tensor = graph_->CreateTensor(pool_out_spec_indices);
|
||||
auto idx_flattened_tensor = graph_->CreateTensor(idx_flattened_spec);
|
||||
auto upd_flattened_tensor = graph_->CreateTensor(upd_flattened_spec);
|
||||
auto out_flattened_tensor = graph_->CreateTensor(out_flattened_spec);
|
||||
|
||||
(*maxpoolwithargmax2_).BindInput(in_tensors_[TENSOR_BEFORE_POOL])
|
||||
.BindOutputs({pool_out_values_tensor, pool_out_indices_tensor});
|
||||
|
||||
flatten_idx = graph_->CreateOperation<tim::vx::ops::Reshape>(idx_flattened_shape);
|
||||
(*flatten_idx).BindInput(pool_out_indices_tensor).BindOutput(idx_flattened_tensor);
|
||||
(*maxpoolwithargmax2_).BindInput(in_tensors_[POOL_INPUT_TENSOR])
|
||||
.BindOutputs({pool_out_values_tensor, pool_out_indices_tensor});
|
||||
|
||||
flatten_upd = graph_->CreateOperation<tim::vx::ops::Reshape>(idx_flattened_shape);
|
||||
(*flatten_upd).BindInput(in_tensors_[UPDATES_TENSOR]).BindOutput(upd_flattened_tensor);
|
||||
// eliminate pool out of maxpoolwithargmax begin
|
||||
tim::vx::TensorSpec sliced_spec(tim::vx::DataType::FLOAT32,
|
||||
{1, 1, 1, 1}, tim::vx::TensorAttribute::TRANSIENT);
|
||||
auto sliced_tensor = graph_->CreateTensor(sliced_spec);
|
||||
auto one_zero_tensor = graph_->CreateTensor(sliced_spec);
|
||||
auto grad_tensor = graph_->CreateTensor(pool_out_spec_values);
|
||||
|
||||
std::vector<int32_t> start = {0, 0, 0, 0};
|
||||
std::vector<int32_t> length = {1, 1, 1, 1};
|
||||
slice_one_ = graph_->CreateOperation<tim::vx::ops::Slice>(0, start, length);
|
||||
(*slice_one_).BindInput(pool_out_values_tensor).BindOutput(sliced_tensor);
|
||||
|
||||
self_sub_ = graph_->CreateOperation<tim::vx::ops::Sub>();
|
||||
(*self_sub_).BindInputs({sliced_tensor, sliced_tensor}).BindOutput(one_zero_tensor);
|
||||
|
||||
add_zeros_ = graph_->CreateOperation<tim::vx::ops::Add>();
|
||||
(*add_zeros_).BindInputs({one_zero_tensor, in_tensors_[GRADIENT_TENSOR]})
|
||||
.BindOutput(grad_tensor);
|
||||
// eliminate pool out of maxpoolwithargmax end
|
||||
|
||||
flatten_idx_ = graph_->CreateOperation<tim::vx::ops::Reshape>(idx_flattened_shape);
|
||||
(*flatten_idx_).BindInput(pool_out_indices_tensor).BindOutput(idx_flattened_tensor);
|
||||
|
||||
flatten_upd_ = graph_->CreateOperation<tim::vx::ops::Reshape>(idx_flattened_shape);
|
||||
(*flatten_upd_).BindInput(grad_tensor).BindOutput(upd_flattened_tensor);
|
||||
|
||||
scatternd_ = graph_->CreateOperation<tim::vx::ops::ScatterND>(out_flattened_shape);
|
||||
(*scatternd_).BindInputs({idx_flattened_tensor, upd_flattened_tensor}).BindOutput(out_flattened_tensor);
|
||||
(*scatternd_).BindInputs({idx_flattened_tensor, upd_flattened_tensor})
|
||||
.BindOutput(out_flattened_tensor);
|
||||
|
||||
reshape_like_input_ = graph_->CreateOperation<tim::vx::ops::Reshape>(in_shape);
|
||||
(*reshape_like_input_).BindInput(out_flattened_tensor);
|
||||
|
||||
}
|
||||
this->input_tensor_index++;
|
||||
return *this;
|
||||
|
|
@ -105,7 +127,9 @@ class MaxpoolGradImpl : public OpImpl {
|
|||
|
||||
MaxpoolGradImpl& BindOutput(const std::shared_ptr<Tensor>& tensor) override {
|
||||
out_tensors_[output_tensor_index] = tensor;
|
||||
(*reshape_like_input_).BindOutput(tensor);
|
||||
if (this->output_tensor_index == OUTPUT_CNT - 1) {
|
||||
(*reshape_like_input_).BindOutput(out_tensors_[UPDATED_TENSOR]);
|
||||
}
|
||||
this->output_tensor_index++;
|
||||
return *this;
|
||||
}
|
||||
|
|
@ -126,12 +150,16 @@ class MaxpoolGradImpl : public OpImpl {
|
|||
const RoundType round_type_;
|
||||
|
||||
std::shared_ptr<tim::vx::Operation> maxpoolwithargmax2_;
|
||||
std::shared_ptr<tim::vx::Operation> flatten_idx;
|
||||
std::shared_ptr<tim::vx::Operation> flatten_upd;
|
||||
std::shared_ptr<tim::vx::Operation> slice_one_;
|
||||
std::shared_ptr<tim::vx::Operation> self_sub_;
|
||||
std::shared_ptr<tim::vx::Operation> add_zeros_;
|
||||
std::shared_ptr<tim::vx::Operation> flatten_idx_;
|
||||
std::shared_ptr<tim::vx::Operation> flatten_upd_;
|
||||
std::shared_ptr<tim::vx::Operation> scatternd_;
|
||||
std::shared_ptr<tim::vx::Operation> reshape_like_input_;
|
||||
std::shared_ptr<tim::vx::Operation> reshape_pool_output_;
|
||||
std::array<std::shared_ptr<tim::vx::Tensor>, INPUT_CNT> in_tensors_;
|
||||
std::array<std::shared_ptr<tim::vx::Tensor>, OUT_CNT> out_tensors_;
|
||||
std::array<std::shared_ptr<tim::vx::Tensor>, OUTPUT_CNT> out_tensors_;
|
||||
uint32_t CalFlattenedShape(const tim::vx::ShapeType& shape) {
|
||||
uint32_t out = 1;
|
||||
for(auto& x: shape) {
|
||||
|
|
@ -150,7 +178,7 @@ MaxpoolGrad::MaxpoolGrad(Graph* graph, PadType padding,
|
|||
ksize_(ksize),
|
||||
stride_(stride),
|
||||
round_type_(round_type) {
|
||||
impl_ = std::make_unique<MaxpoolGradImpl>(graph, padding, ksize, stride, 0, 0, round_type, layout);
|
||||
impl_ = std::make_unique<MaxpoolGradImpl>(graph, padding, ksize, stride, 2, 1, round_type, layout);
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> MaxpoolGrad::Clone(
|
||||
|
|
@ -164,4 +192,4 @@ std::shared_ptr<Operation> MaxpoolGrad::Clone(
|
|||
} // namespace vx
|
||||
} // namespace tim
|
||||
|
||||
#endif //(VSI_FEAT_OP_MAXPOOLWITHARGMAX)
|
||||
#endif //(VSI_FEAT_OP_MAXPOOLWITHARGMAX)
|
||||
|
|
|
|||
Loading…
Reference in New Issue