modified cumsum header && resolve conflict in README.md

This commit is contained in:
Feiyue Chen 2022-09-22 12:45:52 +08:00 committed by Sven
parent 264e491d2a
commit 1802e558ad
5 changed files with 17 additions and 6 deletions

View File

@ -49,10 +49,16 @@ class Operation {
std::unique_ptr<OpImpl>& impl();
const std::unique_ptr<OpImpl>& impl() const;
virtual const std::vector<std::shared_ptr<Tensor>> ConstantInputsTensor() const;
virtual void HandleAfterBindInput(const std::shared_ptr<Tensor>& tensor, int32_t input_idx);
protected:
bool IsAllInputsConst() const;
std::unique_ptr<OpImpl> impl_;
private:
// Post processing at the final step on BindInput func
// - tensor : input tensor
// - input_idx: the index of input tensor
virtual void OnBindInputPostProc(const std::shared_ptr<Tensor>& tensor, int32_t input_idx);
};
} // namespace vx

View File

@ -47,11 +47,14 @@ namespace ops {
class CumSum : public BuiltinOp {
public:
CumSum(Graph* Graph, int32_t axis=0, int32_t exclusive=0, int32_t reverse=0);
std::shared_ptr<Operation> Clone(std::shared_ptr<Graph>& graph) const override;
void HandleAfterBindInput(const std::shared_ptr<Tensor>& tensor, int32_t input_idx) override;
protected:
int32_t axis_, exclusive_, reverse_;
private:
void OnBindInputPostProc(const std::shared_ptr<Tensor>& tensor, int32_t input_idx) override;
};
} // namespace ops

View File

@ -42,7 +42,7 @@ const std::unique_ptr<OpImpl>& Operation::impl() const { return impl_; }
Operation& Operation::BindInput(const std::shared_ptr<Tensor>& tensor) {
impl_->BindInput(tensor);
impl_->graph_->UpdateTensorConsumersMap(tensor, this);
HandleAfterBindInput(tensor, impl_->input_tensor_index - 1);
OnBindInputPostProc(tensor, impl_->input_tensor_index - 1);
return *this;
}
@ -90,7 +90,7 @@ const std::vector<std::shared_ptr<Tensor>> Operation::ConstantInputsTensor() con
return {};
}
}
void Operation::HandleAfterBindInput(const std::shared_ptr<Tensor>& tensor, int32_t input_idx){
void Operation::OnBindInputPostProc(const std::shared_ptr<Tensor>& tensor, int32_t input_idx){
(void) tensor;
(void) input_idx;
}

View File

@ -116,10 +116,12 @@ GroupedConv1d|GROUPED_CONV1D|Mapped|[tf.keras.layers.Conv1D](https://tensorflow.
Mod|MOD|Mapped|[Onnx.Mod](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Mod)
Selu|SELU|Mapped|[tf.keras.activations.selu](https://www.tensorflow.org/api_docs/python/tf/keras/activations/selu)
Celu|CELU|Mapped|[Onnx.celu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Celu)
Sign|SIGN|Mapped|[tf.math.sign](https://www.tensorflow.org/api_docs/python/tf/math/sign)
SoftSign|SOFTSIGN|Mapped|[tf.keras.activations.softsign](https://www.tensorflow.org/api_docs/python/tf/keras/activations/softsign)
CumSum|CUMSUM|Mapped|[tf.math.cumsum](https://www.tensorflow.org/api_docs/python/tf/math/cumsum)
Rcp|RCP|Mapped|[tf.math.reciprocal](https://www.tensorflow.org/api_docs/python/tf/math/reciprocal)
MaxPool3d|MAX_POOL3D|Mapped|[Onnx.MaxPool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool)
|UnidirectionalSequenceRNN|UNIDIRECTIONAL_SEQUENCE_RNN|Planned 22Q3|[ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0ae11aa1d461d2abaa117f6ee2cb503dd8)
CumSum|CUMSUM|Mapped|[tf.math.cumsum](https://www.tensorflow.org/api_docs/python/tf/math/cumsum)
|BidirectionalSequenceRNN|BIDIRECTIONAL_SEQUENCE_RNN|Planned 22Q3|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_RNN](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a487fc5ae247de828f13e62b99f259f3c)
|BidirectionalSequenceLSTM|BIDIRECTIONAL_SEQUENCE_LSTM|Mapped|[ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0a492a71cb7aa50b9a1a834a3cb269d778)
|UnidirectionalSequenceLSTM|LSTM_OVXLIB|Mapped|[ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM](https://developer.android.com/ndk/reference/group/neural-networks#group___neural_networks_1ggaabbe492c60331b13038e39d4207940e0aaf30e491ad0b1fc7602cbde695b2c859)

View File

@ -37,7 +37,7 @@ CumSum::CumSum(Graph* graph, int32_t axis, int32_t exclusive, int32_t reverse)
this->impl()->node()->nn_param.cumsum.reverse = reverse_;
}
void CumSum::HandleAfterBindInput(const std::shared_ptr<Tensor>& tensor, int32_t input_idx){
void CumSum::OnBindInputPostProc(const std::shared_ptr<Tensor>& tensor, int32_t input_idx){
if (axis_ < 0){
axis_ += tensor->GetShape().size();
(void) input_idx;