aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-11-29 14:01:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-29 14:06:17 -0800
commitcb5a63d8d2b6e049a0a128ba47560f842497db8b (patch)
treeced4a0647f9bab632a2d9a80895e7cc9bca0c78c /tensorflow/c/c_api.cc
parent1d0b07351d901334b33565595d4c23607f11cc27 (diff)
Check when session cannot run because its graph was modified
With current tensorflow code, if user modifies some operation after session.run() was called, this modification will never make it to the C++ runtime and no errors will be raised leading to silent wrong results. This change adds checks for such cases when C API is enabled. We don't change the code path for C API being disabled because C API should be enabled by default soon. PiperOrigin-RevId: 177359630
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc38
1 files changed, 31 insertions, 7 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 4fb8ec8e4b..c8b4bfffd4 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -624,6 +624,23 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
return Status::OK();
}
+void RecordMutation(TF_Graph* graph, const TF_Operation& op,
+ const char* mutation_type)
+ EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
+ // If any session has already run this node_id, mark this session as
+ // unrunnable.
+ for (auto it : graph->sessions) {
+ if (it.first->last_num_graph_nodes > op.node.id()) {
+ it.second = FailedPrecondition(
+ "Operation '", op.node.DebugString(), "' was changed by ",
+ mutation_type,
+ " after it was run by a session. Nodes can be mutated "
+ "only before they are executed by a session. Either don't modify "
+ "nodes after running them or create a new session.");
+ }
+ }
+}
+
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
@@ -1744,7 +1761,6 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
TF_Graph::TF_Graph()
: graph(tensorflow::OpRegistry::Global()),
refiner(graph.versions().producer(), graph.op_registry()),
- num_sessions(0),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {}
@@ -1754,7 +1770,7 @@ TF_Graph* TF_NewGraph() { return new TF_Graph; }
void TF_DeleteGraph(TF_Graph* g) {
g->mu.lock();
g->delete_requested = true;
- const bool del = g->num_sessions == 0;
+ const bool del = g->sessions.empty();
g->mu.unlock();
if (del) delete g;
}
@@ -2324,11 +2340,12 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
Session* session;
status->status = NewSession(opt->options, &session);
if (status->status.ok()) {
+ TF_Session* new_session = new TF_Session(session, graph);
if (graph != nullptr) {
mutex_lock l(graph->mu);
- graph->num_sessions += 1;
+ graph->sessions[new_session] = Status::OK();
}
- return new TF_Session(session, graph);
+ return new_session;
} else {
DCHECK_EQ(nullptr, session);
return nullptr;
@@ -2392,7 +2409,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
TF_Session* session = new TF_Session(bundle.session.release(), graph);
- graph->num_sessions += 1;
+ graph->sessions[session] = Status::OK();
session->last_num_graph_nodes = graph->graph.num_node_ids();
return session;
#endif // __ANDROID__
@@ -2407,8 +2424,8 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) {
TF_Graph* const graph = s->graph;
if (graph != nullptr) {
graph->mu.lock();
- graph->num_sessions -= 1;
- const bool del = graph->delete_requested && graph->num_sessions == 0;
+ graph->sessions.erase(s);
+ const bool del = graph->delete_requested && graph->sessions.empty();
graph->mu.unlock();
if (del) delete graph;
}
@@ -2424,6 +2441,13 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
mutex_lock session_lock(session->mu);
session->graph->mu.lock();
const Graph& graph = session->graph->graph;
+
+ status->status = session->graph->sessions[session];
+ if (!status->status.ok()) {
+ session->graph->mu.unlock();
+ return false;
+ }
+
const auto num_nodes = graph.num_node_ids();
if (session->last_num_graph_nodes < num_nodes) {
status->status = tensorflow::ValidateNoCycles(session->graph->graph);