aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-03-06 17:19:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 17:23:35 -0800
commit9c3cf322a3051339899ffb74c33533f60c0c2d8e (patch)
treebe3c72e843a9178be50c0e5a791686d1253cea82 /tensorflow/c
parent6e99d56489b4e6c3176fa1199d4270b6439a22fe (diff)
Make graph construction work while graph is being concurrently run.
The overall approach is to use Graph._lock to synchronize Session.run calls and construction methods that rely on graph mutation. We don't want to synchronize the actual running of the graph, only the Extend call, so this change exposes an ExtendSession method to the Python API and disables extending automatically in TF_SessionRun. PiperOrigin-RevId: 188106818
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/c_api.cc134
-rw-r--r--tensorflow/c/c_api_internal.h8
-rw-r--r--tensorflow/c/python_api.cc6
-rw-r--r--tensorflow/c/python_api.h10
4 files changed, 99 insertions, 59 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 3d0e886476..e3a95a0577 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -710,6 +710,58 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
+// TODO(josh11b,mrry): Change Session to be able to use a Graph*
+// directly, instead of requiring us to serialize to a GraphDef and
+// call Session::Extend().
+bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
+ EXCLUSIVE_LOCKS_REQUIRED(session->mu) {
+ if (session->graph != nullptr) {
+ session->graph->mu.lock();
+ const Graph& graph = session->graph->graph;
+
+ status->status = session->graph->sessions[session];
+ if (!status->status.ok()) {
+ session->graph->mu.unlock();
+ return false;
+ }
+
+ const auto num_nodes = graph.num_node_ids();
+ if (session->last_num_graph_nodes < num_nodes) {
+ status->status = tensorflow::ValidateNoCycles(session->graph->graph);
+ if (!status->status.ok()) {
+ session->graph->mu.unlock();
+ return false;
+ }
+
+ GraphDef graph_def;
+ *graph_def.mutable_versions() = graph.versions();
+ // Fill graph_def with nodes with ids in the range
+ // [session->last_num_graph_nodes, num_nodes), that is the nodes
+ // added since the last TF_SessionRun() call.
+ for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
+ Node* const node = graph.FindNodeId(id);
+ if (node != nullptr && node->IsOp()) {
+ NodeDef* const node_def = graph_def.add_node();
+ *node_def = node->def();
+ }
+ }
+ *graph_def.mutable_library() = graph.flib_def().ToProto();
+ session->graph->mu.unlock();
+ status->status = session->session->Extend(graph_def);
+ if (!status->status.ok()) {
+ // Contract is we always delete input_values[i].
+ return false;
+ }
+ // Note: session->session is not modified if Extend() fails, so
+ // we only set last_num_graph_nodes if it succeeds.
+ session->last_num_graph_nodes = num_nodes;
+ } else {
+ session->graph->mu.unlock();
+ }
+ }
+ return true;
+}
+
} // namespace tensorflow
static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
@@ -2410,7 +2462,11 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
// TF_Session functions ----------------------------------------------
TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
- : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) {
+ : session(s),
+ graph(g),
+ last_num_graph_nodes(0),
+ device_mgr(nullptr),
+ extend_before_run(true) {
if (s->LocalDeviceManager(&device_mgr).ok()) {
devices = device_mgr->ListDevices();
}
@@ -2514,58 +2570,6 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) {
delete s;
}
-// TODO(josh11b,mrry): Change Session to be able to use a Graph*
-// directly, instead of requiring us to serialize to a GraphDef and
-// call Session::Extend().
-static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
- if (session->graph != nullptr) {
- mutex_lock session_lock(session->mu);
- session->graph->mu.lock();
- const Graph& graph = session->graph->graph;
-
- status->status = session->graph->sessions[session];
- if (!status->status.ok()) {
- session->graph->mu.unlock();
- return false;
- }
-
- const auto num_nodes = graph.num_node_ids();
- if (session->last_num_graph_nodes < num_nodes) {
- status->status = tensorflow::ValidateNoCycles(session->graph->graph);
- if (!status->status.ok()) {
- session->graph->mu.unlock();
- return false;
- }
-
- GraphDef graph_def;
- *graph_def.mutable_versions() = graph.versions();
- // Fill graph_def with nodes with ids in the range
- // [session->last_num_graph_nodes, num_nodes), that is the nodes
- // added since the last TF_SessionRun() call.
- for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
- Node* const node = graph.FindNodeId(id);
- if (node != nullptr && node->IsOp()) {
- NodeDef* const node_def = graph_def.add_node();
- *node_def = node->def();
- }
- }
- *graph_def.mutable_library() = graph.flib_def().ToProto();
- session->graph->mu.unlock();
- status->status = session->session->Extend(graph_def);
- if (!status->status.ok()) {
- // Contract is we always delete input_values[i].
- return false;
- }
- // Note: session->session is not modified if Extend() fails, so
- // we only set last_num_graph_nodes if it succeeds.
- session->last_num_graph_nodes = num_nodes;
- } else {
- session->graph->mu.unlock();
- }
- }
- return true;
-}
-
void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
const TF_Output* inputs, TF_Tensor* const* input_values,
int ninputs, const TF_Output* outputs,
@@ -2575,8 +2579,12 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend().
- if (!ExtendSessionGraphHelper(session, status)) {
- return;
+ {
+ mutex_lock l(session->mu);
+ if (session->extend_before_run &&
+ !tensorflow::ExtendSessionGraphHelper(session, status)) {
+ return;
+ }
}
TF_Run_Setup(noutputs, output_values, status);
@@ -2612,8 +2620,12 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
const char** handle, TF_Status* status) {
*handle = nullptr;
- if (!ExtendSessionGraphHelper(session, status)) {
- return;
+ {
+ mutex_lock l(session->mu);
+ if (session->extend_before_run &&
+ !tensorflow::ExtendSessionGraphHelper(session, status)) {
+ return;
+ }
}
std::vector<string> input_names(ninputs);
@@ -2655,8 +2667,12 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend().
- if (!ExtendSessionGraphHelper(session, status)) {
- return;
+ {
+ mutex_lock l(session->mu);
+ if (session->extend_before_run &&
+ !tensorflow::ExtendSessionGraphHelper(session, status)) {
+ return;
+ }
}
TF_Run_Setup(noutputs, output_values, status);
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index 91667056e0..027e2d2b15 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -133,6 +133,12 @@ struct TF_Session {
// buffers of a TF_Tensor pinned in device memory.
const tensorflow::DeviceMgr* device_mgr; // Owned by session.
std::vector<tensorflow::Device*> devices; // Owned by device_mgr.
+
+ // If true, TF_SessionRun and similar methods will call
+ // ExtendSessionGraphHelper before running the graph (this is the default
+ // public behavior). Can be set to false if the caller needs to call
+ // ExtendSessionGraphHelper manually.
+ bool extend_before_run GUARDED_BY(mu);
};
struct TF_ImportGraphDefOptions {
@@ -212,6 +218,8 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type);
+bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status);
+
} // end namespace tensorflow
#endif // TENSORFLOW_C_C_API_INTERNAL_H_
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index f553142d15..26683f50ec 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -104,4 +104,10 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
graph->refiner.set_require_shape_inference_fns(require);
}
+void ExtendSession(TF_Session* session, TF_Status* status) {
+ mutex_lock l(session->mu);
+ session->extend_before_run = false;
+ ExtendSessionGraphHelper(session, status);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 542d70f42c..13b680b3a2 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -41,6 +41,16 @@ void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);
// error. The default is true.
void SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
+// Extends `session` with any new operations added to its associated graph.
+// Usually this happens automatically in TF_SessionRun. After this is called,
+// TF_SessionRun will no longer extend the session on every call.
+//
+// We expose this here to allow fine-grained synchronization in multi-threaded
+// workloads, which is required since the Python implementation depends on the
+// above mutation methods. This allows us to prevent modifications to nodes in
+// the graph after the session has been made aware of them.
+void ExtendSession(TF_Session* session, TF_Status* status);
+
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_