aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-09-27 12:51:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 12:55:44 -0700
commit1084594657a5d139102ac794f84d1427a710e39a (patch)
tree108ac1628f966a213cf63b9f24e0f8b4add20d1e /tensorflow
parent750466c6e6624d279de7f9a43accd682d487509c (diff)
TFLite: Rename ResetVariableTensorsToZero -> ResetVariableTensors
PiperOrigin-RevId: 214820383
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.cc5
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.h2
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc9
-rw-r--r--tensorflow/contrib/lite/interpreter.h7
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc2
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py4
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc4
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h2
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc2
10 files changed, 20 insertions, 19 deletions
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
index 0f16595811..29f8701f53 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
@@ -21,9 +21,8 @@ limitations under the License.
extern "C" {
#endif // __cplusplus
-TFL_Status TFL_InterpreterResetVariableTensorsToZero(
- TFL_Interpreter* interpreter) {
- return interpreter->impl->ResetVariableTensorsToZero();
+TFL_Status TFL_InterpreterResetVariableTensors(TFL_Interpreter* interpreter) {
+ return interpreter->impl->ResetVariableTensors();
}
void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
index b8de7b9964..fca5d92f77 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
@@ -25,7 +25,7 @@ extern "C" {
typedef TfLiteBuiltinOperator TFL_BuiltinOperator;
// Resets all variable tensors to zero.
-TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero(
+TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensors(
TFL_Interpreter* interpreter);
// Adds an op registration for a builtin operator.
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
index d86ad00d6d..1b1bedb754 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
@@ -44,7 +44,7 @@ TEST(CApiExperimentalSimple, Smoke) {
TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
ASSERT_NE(interpreter, nullptr);
ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk);
- EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk);
+ EXPECT_EQ(TFL_InterpreterResetVariableTensors(interpreter), kTfLiteOk);
EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk);
TFL_DeleteInterpreter(interpreter);
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 2657bcd42b..88e41ffc55 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -451,16 +451,15 @@ TfLiteStatus Interpreter::AllocateTensors() {
// Reset the variable tensors to zero after (re)allocating the tensors.
// Developers shouldn't rely on the side effect of this function to reset
- // variable tesnsors. They should call `ResetVariableTensorsToZero` directly
+ // variable tesnsors. They should call `ResetVariableTensors` directly
// instead.
- ResetVariableTensorsToZero();
+ ResetVariableTensors();
return kTfLiteOk;
}
-// TODO(ycling): Consider to provide other functions to initialize variable
-// tensors to non-zero values.
-TfLiteStatus Interpreter::ResetVariableTensorsToZero() {
+// TODO(ycling): Support non-zero default values.
+TfLiteStatus Interpreter::ResetVariableTensors() {
for (auto& tensor : tensors_) {
if (!tensor.is_variable) {
continue;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index aa2bc4def6..7ef736d01b 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -421,9 +421,12 @@ class Interpreter {
allow_buffer_handle_output_ = allow_buffer_handle_output;
}
- // Reset all variable tensors to zero.
+ // Reset all variable tensors to the default value.
+ // If a variable tensor doesn't have a buffer, reset it to zero.
+ // TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
+ // to the value of the buffer.
// WARNING: This is an experimental API and subject to change.
- TfLiteStatus ResetVariableTensorsToZero();
+ TfLiteStatus ResetVariableTensors();
// Retrieve an operator's description of its work, for profiling purposes.
const char* OpProfilingString(const TfLiteRegistration& op_reg,
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 0fdb0a3935..05a7c23ba1 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -122,7 +122,7 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
- interpreter_->ResetVariableTensorsToZero();
+ interpreter_->ResetVariableTensors();
}
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index 1be61fe053..5700bf7892 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -253,5 +253,5 @@ class Interpreter(object):
self._ensure_safe()
self._interpreter.Invoke()
- def reset_all_variables_to_zero(self):
- return self._interpreter.ResetVariableTensorsToZero()
+ def reset_all_variables(self):
+ return self._interpreter.ResetVariableTensors()
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 9ab05f3068..418f19a179 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -466,9 +466,9 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
error_msg);
}
-PyObject* InterpreterWrapper::ResetVariableTensorsToZero() {
+PyObject* InterpreterWrapper::ResetVariableTensors() {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
- TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero());
+ TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
Py_RETURN_NONE;
}
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index 641dd93db5..f5ca81e62a 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -65,7 +65,7 @@ class InterpreterWrapper {
PyObject* TensorQuantization(int i) const;
PyObject* SetTensor(int i, PyObject* value);
PyObject* GetTensor(int i) const;
- PyObject* ResetVariableTensorsToZero();
+ PyObject* ResetVariableTensors();
// Returns a reference to tensor index i as a numpy array. The base_object
// should be the interpreter object providing the memory.
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 1836eb53b9..17aa8cb293 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -301,7 +301,7 @@ bool TfLiteDriver::CheckResults() {
}
void TfLiteDriver::ResetLSTMStateTensors() {
- interpreter_->ResetVariableTensorsToZero();
+ interpreter_->ResetVariableTensors();
}
} // namespace testing