aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-03-06 17:19:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 17:23:35 -0800
commit9c3cf322a3051339899ffb74c33533f60c0c2d8e (patch)
treebe3c72e843a9178be50c0e5a791686d1253cea82 /tensorflow/c/c_api.cc
parent6e99d56489b4e6c3176fa1199d4270b6439a22fe (diff)
Make graph construction work while graph is being concurrently run.
The overall approach is to use Graph._lock to synchronize Session.run calls and construction methods that rely on graph mutation. We don't want to synchronize the actual running of the graph, only the Extend call, so this change exposes an ExtendSession method to the Python API and disables extending automatically in TF_SessionRun. PiperOrigin-RevId: 188106818
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc134
1 files changed, 75 insertions, 59 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 3d0e886476..e3a95a0577 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -710,6 +710,58 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
+// TODO(josh11b,mrry): Change Session to be able to use a Graph*
+// directly, instead of requiring us to serialize to a GraphDef and
+// call Session::Extend().
+bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
+ EXCLUSIVE_LOCKS_REQUIRED(session->mu) {
+ if (session->graph != nullptr) {
+ session->graph->mu.lock();
+ const Graph& graph = session->graph->graph;
+
+ status->status = session->graph->sessions[session];
+ if (!status->status.ok()) {
+ session->graph->mu.unlock();
+ return false;
+ }
+
+ const auto num_nodes = graph.num_node_ids();
+ if (session->last_num_graph_nodes < num_nodes) {
+ status->status = tensorflow::ValidateNoCycles(session->graph->graph);
+ if (!status->status.ok()) {
+ session->graph->mu.unlock();
+ return false;
+ }
+
+ GraphDef graph_def;
+ *graph_def.mutable_versions() = graph.versions();
+ // Fill graph_def with nodes with ids in the range
+ // [session->last_num_graph_nodes, num_nodes), that is the nodes
+ // added since the last TF_SessionRun() call.
+ for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
+ Node* const node = graph.FindNodeId(id);
+ if (node != nullptr && node->IsOp()) {
+ NodeDef* const node_def = graph_def.add_node();
+ *node_def = node->def();
+ }
+ }
+ *graph_def.mutable_library() = graph.flib_def().ToProto();
+ session->graph->mu.unlock();
+ status->status = session->session->Extend(graph_def);
+ if (!status->status.ok()) {
+ // Contract is we always delete input_values[i].
+ return false;
+ }
+ // Note: session->session is not modified if Extend() fails, so
+ // we only set last_num_graph_nodes if it succeeds.
+ session->last_num_graph_nodes = num_nodes;
+ } else {
+ session->graph->mu.unlock();
+ }
+ }
+ return true;
+}
+
} // namespace tensorflow
static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
@@ -2410,7 +2462,11 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
// TF_Session functions ----------------------------------------------
TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
- : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) {
+ : session(s),
+ graph(g),
+ last_num_graph_nodes(0),
+ device_mgr(nullptr),
+ extend_before_run(true) {
if (s->LocalDeviceManager(&device_mgr).ok()) {
devices = device_mgr->ListDevices();
}
@@ -2514,58 +2570,6 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) {
delete s;
}
-// TODO(josh11b,mrry): Change Session to be able to use a Graph*
-// directly, instead of requiring us to serialize to a GraphDef and
-// call Session::Extend().
-static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
- if (session->graph != nullptr) {
- mutex_lock session_lock(session->mu);
- session->graph->mu.lock();
- const Graph& graph = session->graph->graph;
-
- status->status = session->graph->sessions[session];
- if (!status->status.ok()) {
- session->graph->mu.unlock();
- return false;
- }
-
- const auto num_nodes = graph.num_node_ids();
- if (session->last_num_graph_nodes < num_nodes) {
- status->status = tensorflow::ValidateNoCycles(session->graph->graph);
- if (!status->status.ok()) {
- session->graph->mu.unlock();
- return false;
- }
-
- GraphDef graph_def;
- *graph_def.mutable_versions() = graph.versions();
- // Fill graph_def with nodes with ids in the range
- // [session->last_num_graph_nodes, num_nodes), that is the nodes
- // added since the last TF_SessionRun() call.
- for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
- Node* const node = graph.FindNodeId(id);
- if (node != nullptr && node->IsOp()) {
- NodeDef* const node_def = graph_def.add_node();
- *node_def = node->def();
- }
- }
- *graph_def.mutable_library() = graph.flib_def().ToProto();
- session->graph->mu.unlock();
- status->status = session->session->Extend(graph_def);
- if (!status->status.ok()) {
- // Contract is we always delete input_values[i].
- return false;
- }
- // Note: session->session is not modified if Extend() fails, so
- // we only set last_num_graph_nodes if it succeeds.
- session->last_num_graph_nodes = num_nodes;
- } else {
- session->graph->mu.unlock();
- }
- }
- return true;
-}
-
void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
const TF_Output* inputs, TF_Tensor* const* input_values,
int ninputs, const TF_Output* outputs,
@@ -2575,8 +2579,12 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend().
- if (!ExtendSessionGraphHelper(session, status)) {
- return;
+ {
+ mutex_lock l(session->mu);
+ if (session->extend_before_run &&
+ !tensorflow::ExtendSessionGraphHelper(session, status)) {
+ return;
+ }
}
TF_Run_Setup(noutputs, output_values, status);
@@ -2612,8 +2620,12 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
const char** handle, TF_Status* status) {
*handle = nullptr;
- if (!ExtendSessionGraphHelper(session, status)) {
- return;
+ {
+ mutex_lock l(session->mu);
+ if (session->extend_before_run &&
+ !tensorflow::ExtendSessionGraphHelper(session, status)) {
+ return;
+ }
}
std::vector<string> input_names(ninputs);
@@ -2655,8 +2667,12 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend().
- if (!ExtendSessionGraphHelper(session, status)) {
- return;
+ {
+ mutex_lock l(session->mu);
+ if (session->extend_before_run &&
+ !tensorflow::ExtendSessionGraphHelper(session, status)) {
+ return;
+ }
}
TF_Run_Setup(noutputs, output_values, status);