diff options
author | 2018-09-27 12:51:52 -0700 | |
---|---|---|
committer | 2018-09-27 12:55:44 -0700 | |
commit | 1084594657a5d139102ac794f84d1427a710e39a (patch) | |
tree | 108ac1628f966a213cf63b9f24e0f8b4add20d1e /tensorflow | |
parent | 750466c6e6624d279de7f9a43accd682d487509c (diff) |
TFLite: Rename ResetVariableTensorsToZero -> ResetVariableTensors
PiperOrigin-RevId: 214820383
Diffstat (limited to 'tensorflow')
10 files changed, 20 insertions, 19 deletions
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc index 0f16595811..29f8701f53 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc @@ -21,9 +21,8 @@ limitations under the License. extern "C" { #endif // __cplusplus -TFL_Status TFL_InterpreterResetVariableTensorsToZero( - TFL_Interpreter* interpreter) { - return interpreter->impl->ResetVariableTensorsToZero(); +TFL_Status TFL_InterpreterResetVariableTensors(TFL_Interpreter* interpreter) { + return interpreter->impl->ResetVariableTensors(); } void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options, diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h index b8de7b9964..fca5d92f77 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h @@ -25,7 +25,7 @@ extern "C" { typedef TfLiteBuiltinOperator TFL_BuiltinOperator; // Resets all variable tensors to zero. -TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero( +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensors( TFL_Interpreter* interpreter); // Adds an op registration for a builtin operator. diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc index d86ad00d6d..1b1bedb754 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc @@ -44,7 +44,7 @@ TEST(CApiExperimentalSimple, Smoke) { TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options); ASSERT_NE(interpreter, nullptr); ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); - EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk); + EXPECT_EQ(TFL_InterpreterResetVariableTensors(interpreter), kTfLiteOk); EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk); TFL_DeleteInterpreter(interpreter); diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 2657bcd42b..88e41ffc55 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -451,16 +451,15 @@ TfLiteStatus Interpreter::AllocateTensors() { // Reset the variable tensors to zero after (re)allocating the tensors. // Developers shouldn't rely on the side effect of this function to reset - // variable tesnsors. They should call `ResetVariableTensorsToZero` directly + // variable tesnsors. They should call `ResetVariableTensors` directly // instead. - ResetVariableTensorsToZero(); + ResetVariableTensors(); return kTfLiteOk; } -// TODO(ycling): Consider to provide other functions to initialize variable -// tensors to non-zero values. -TfLiteStatus Interpreter::ResetVariableTensorsToZero() { +// TODO(ycling): Support non-zero default values. +TfLiteStatus Interpreter::ResetVariableTensors() { for (auto& tensor : tensors_) { if (!tensor.is_variable) { continue; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index aa2bc4def6..7ef736d01b 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -421,9 +421,12 @@ class Interpreter { allow_buffer_handle_output_ = allow_buffer_handle_output; } - // Reset all variable tensors to zero. + // Reset all variable tensors to the default value. + // If a variable tensor doesn't have a buffer, reset it to zero. + // TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it + // to the value of the buffer. // WARNING: This is an experimental API and subject to change. - TfLiteStatus ResetVariableTensorsToZero(); + TfLiteStatus ResetVariableTensors(); // Retrieve an operator's description of its work, for profiling purposes. const char* OpProfilingString(const TfLiteRegistration& op_reg, diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 0fdb0a3935..05a7c23ba1 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -122,7 +122,7 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes, CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; - interpreter_->ResetVariableTensorsToZero(); + interpreter_->ResetVariableTensors(); } void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index 1be61fe053..5700bf7892 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -253,5 +253,5 @@ class Interpreter(object): self._ensure_safe() self._interpreter.Invoke() - def reset_all_variables_to_zero(self): - return self._interpreter.ResetVariableTensorsToZero() + def reset_all_variables(self): + return self._interpreter.ResetVariableTensors() diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 9ab05f3068..418f19a179 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -466,9 +466,9 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( error_msg); } -PyObject* InterpreterWrapper::ResetVariableTensorsToZero() { +PyObject* InterpreterWrapper::ResetVariableTensors() { TFLITE_PY_ENSURE_VALID_INTERPRETER(); - TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero()); + TFLITE_PY_CHECK(interpreter_->ResetVariableTensors()); Py_RETURN_NONE; } diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index 641dd93db5..f5ca81e62a 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -65,7 +65,7 @@ class InterpreterWrapper { PyObject* TensorQuantization(int i) const; PyObject* SetTensor(int i, PyObject* value); PyObject* GetTensor(int i) const; - PyObject* ResetVariableTensorsToZero(); + PyObject* ResetVariableTensors(); // Returns a reference to tensor index i as a numpy array. The base_object // should be the interpreter object providing the memory. diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 1836eb53b9..17aa8cb293 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -301,7 +301,7 @@ bool TfLiteDriver::CheckResults() { } void TfLiteDriver::ResetLSTMStateTensors() { - interpreter_->ResetVariableTensorsToZero(); + interpreter_->ResetVariableTensors(); } } // namespace testing |