diff options
author | Christopher Olston <olston@google.com> | 2018-07-26 21:04:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-26 21:08:24 -0700 |
commit | 19defc68c99483049e14ec18718cb037f1ca050b (patch) | |
tree | aee27a93b6fb86e0c2d359d503bdeb8519f0b80f /tensorflow/cc | |
parent | 82293a9cb0606b4dfbe45fc4e2f4aa8778ff7e9a (diff) |
Have the SavedModel loader use Session's Make/Run/ReleaseCallable() API instead of Run(), to avoid leaving behind non-GC'ed state after model initialization.
PiperOrigin-RevId: 206266841
Diffstat (limited to 'tensorflow/cc')
-rw-r--r-- | tensorflow/cc/saved_model/loader.cc | 56 |
1 files changed, 52 insertions, 4 deletions
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index d47b025743..98be66a6ad 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -74,6 +74,54 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir, } } +// Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid +// leaving behind non-GC'ed state. +// +// Detailed motivation behind this approach, from ashankar@: +// +// Each call to Session::Run() that identifies a new subgraph (based on feeds +// and fetches) creates some datastructures that live as long as the session +// (the partitioned graph, associated executors etc.). +// +// A pathological case of this would be if say the initialization op +// (main_op/legacy_init_op) involves the use of a large constant. Then we +// allocate memory for that large constant that will just stick around till the +// session dies. With this Callable mechanism, that memory will be released +// right after ReleaseCallable returns. +// +// However, the resource manager state remains. +Status RunOnce(const RunOptions& run_options, + const std::vector<std::pair<string, Tensor>>& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, RunMetadata* run_metadata, + Session* session) { + CallableOptions callable_options; + std::vector<Tensor> feed_tensors; + *callable_options.mutable_run_options() = run_options; + for (const auto& input : inputs) { + const string& name = input.first; + const Tensor& tensor = input.second; + callable_options.add_feed(name); + feed_tensors.push_back(tensor); + } + for (const string& output_tensor_name : output_tensor_names) { + callable_options.add_fetch(output_tensor_name); + } + for (const string& target_node_name : target_node_names) { + callable_options.add_target(target_node_name); + } + + Session::CallableHandle callable_handle; + TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle)); + const Status run_status = session->RunCallable(callable_handle, feed_tensors, + outputs, run_metadata); + // Be sure to call ReleaseCallable() regardless of the outcome of + // RunCallable(). + session->ReleaseCallable(callable_handle).IgnoreError(); + return run_status; +} + bool HasMainOp(const MetaGraphDef& meta_graph_def) { const auto& collection_def_map = meta_graph_def.collection_def(); if (collection_def_map.find(kSavedModelMainOpKey) != @@ -100,8 +148,8 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return session->Run(run_options, inputs, {}, {main_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); + return RunOnce(run_options, inputs, {}, {main_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } return Status::OK(); } @@ -138,8 +186,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - return session->Run(run_options, inputs, {}, {restore_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); + return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, |