aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/test_util.h
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-06-19 12:35:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 12:38:27 -0700
commit5fab6df2788937bee1cce3a4e8f5b9d1db7497ec (patch)
treeba18594841593a0b2a3eda55c076ca78c7bf0d4e /tensorflow/contrib/lite/kernels/test_util.h
parent8f19772410ec20010e9930f9765dbd3aaeb06111 (diff)
Support Variable Tensor API in LSTM Full kernel.
TFLite LSTM now supports 5 inputs, 18 inputs and 20 inputs. PiperOrigin-RevId: 201222516
Diffstat (limited to 'tensorflow/contrib/lite/kernels/test_util.h')
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index db80c0082c..6dcece4af6 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -126,8 +126,10 @@ class SingleOpModel {
SingleOpModel& operator=(const SingleOpModel&) = delete;
// Add a TensorType input tensor and return its index.
- int AddInput(TensorType type) { return AddInput(TensorData{type}); }
- int AddInput(const TensorData& t);
+ int AddInput(TensorType type, bool is_variable = false) {
+ return AddInput(TensorData{type}, is_variable);
+ }
+ int AddInput(const TensorData& t, bool is_variable = false);
// Templated version of AddConstInput().
template <typename T>
@@ -260,7 +262,8 @@ class SingleOpModel {
}
template <typename T>
- int AddTensor(TensorData t, std::initializer_list<T> data) {
+ int AddTensor(TensorData t, std::initializer_list<T> data,
+ bool is_variable = false) {
int id = tensors_.size();
// This is slightly different depending on whether we are adding a
@@ -309,7 +312,7 @@ class SingleOpModel {
tensors_.push_back(CreateTensor(builder_,
builder_.CreateVector<int>(t.shape), t.type,
/*buffer=*/buffer_id,
- /*name=*/0, q_params));
+ /*name=*/0, q_params, is_variable));
tensor_data_[id] = t;