diff options
author | 2018-06-19 12:35:44 -0700 | |
---|---|---|
committer | 2018-06-19 12:38:27 -0700 | |
commit | 5fab6df2788937bee1cce3a4e8f5b9d1db7497ec (patch) | |
tree | ba18594841593a0b2a3eda55c076ca78c7bf0d4e /tensorflow/contrib/lite/kernels/test_util.h | |
parent | 8f19772410ec20010e9930f9765dbd3aaeb06111 (diff) |
Support Variable Tensor API in LSTM Full kernel.
TFLite LSTM now supports 5 inputs, 18 inputs and 20 inputs.
PiperOrigin-RevId: 201222516
Diffstat (limited to 'tensorflow/contrib/lite/kernels/test_util.h')
-rw-r--r-- | tensorflow/contrib/lite/kernels/test_util.h | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index db80c0082c..6dcece4af6 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -126,8 +126,10 @@ class SingleOpModel { SingleOpModel& operator=(const SingleOpModel&) = delete; // Add a TensorType input tensor and return its index. - int AddInput(TensorType type) { return AddInput(TensorData{type}); } - int AddInput(const TensorData& t); + int AddInput(TensorType type, bool is_variable = false) { + return AddInput(TensorData{type}, is_variable); + } + int AddInput(const TensorData& t, bool is_variable = false); // Templated version of AddConstInput(). template <typename T> @@ -260,7 +262,8 @@ class SingleOpModel { } template <typename T> - int AddTensor(TensorData t, std::initializer_list<T> data) { + int AddTensor(TensorData t, std::initializer_list<T> data, + bool is_variable = false) { int id = tensors_.size(); // This is slightly different depending on whether we are adding a @@ -309,7 +312,7 @@ class SingleOpModel { tensors_.push_back(CreateTensor(builder_, builder_.CreateVector<int>(t.shape), t.type, /*buffer=*/buffer_id, - /*name=*/0, q_params)); + /*name=*/0, q_params, is_variable)); tensor_data_[id] = t; |