diff options
author | 2018-03-12 11:02:29 -0700 | |
---|---|---|
committer | 2018-03-12 11:06:44 -0700 | |
commit | 1d6a57edc0be0dcc0c92eb2610b88420a7b7be51 (patch) | |
tree | beec27e175dc508ad0340d57b89676a805e5c79d /tensorflow/c/c_api_internal.h | |
parent | 89177f289e9467e04b205a1a3e705ad67d9854d2 (diff) |
Fix race in C API.
RecordMutation could race with ExtendSessionGraphHelper, which would
release the graph lock and only keep the session lock when extending
the session.
Also makes sure thread annotations are on declarations, not definitions
(otherwise they have no effect).
PiperOrigin-RevId: 188747158
Diffstat (limited to 'tensorflow/c/c_api_internal.h')
-rw-r--r-- | tensorflow/c/c_api_internal.h | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 25233931de..e885a69927 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -124,16 +124,16 @@ struct TF_Session { TF_Session(tensorflow::Session* s, TF_Graph* g); tensorflow::Session* session; - TF_Graph* graph; + TF_Graph* const graph; - tensorflow::mutex mu; + tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu); int last_num_graph_nodes; // If true, TF_SessionRun and similar methods will call // ExtendSessionGraphHelper before running the graph (this is the default // public behavior). Can be set to false if the caller needs to call // ExtendSessionGraphHelper manually. - bool extend_before_run GUARDED_BY(mu); + std::atomic<bool> extend_before_run; }; struct TF_ImportGraphDefOptions { @@ -211,9 +211,11 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, TF_Status* status); void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type); + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu); -bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status); +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) + LOCKS_EXCLUDED(session->graph->mu, session->mu); } // end namespace tensorflow |