diff options
author | 2017-07-09 13:09:51 -0700 | |
---|---|---|
committer | 2017-07-09 13:13:43 -0700 | |
commit | 23caae7408cf6bb961d1bab0c8b819d88ccf5f3c (patch) | |
tree | 4ea5472cc514bd976f4de7f251f3467d0982c637 /tensorflow | |
parent | d22c433216254d2e6ab33a849f922b4545d1f7dd (diff) |
Allow adding new functions to an already-run session
PiperOrigin-RevId: 161337922
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/simple_graph_execution_state.cc | 19 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 23 |
3 files changed, 33 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index f6eac2d698..ada2e32f38 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -390,6 +390,7 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) { TF_RETURN_IF_ERROR( MaybeInitializeExecutionState(graph, &already_initialized)); if (already_initialized) { + TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library())); std::unique_ptr<SimpleGraphExecutionState> state; TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); execution_state_.swap(state); diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index 5a69d2440a..41e685bdc7 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -117,16 +117,21 @@ SimpleGraphExecutionState::~SimpleGraphExecutionState() { Status SimpleGraphExecutionState::Extend( const GraphDef& extension_def, std::unique_ptr<SimpleGraphExecutionState>* out) const { + GraphDef gdef; + + // 1. Copy the function library. + TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library())); + *gdef.mutable_library() = flib_def_->ToProto(); + + // 2. Build an index of the new node names. std::unordered_set<string> new_names; - // 1. Build an index of the new node names. for (const NodeDef& node : extension_def.node()) { new_names.insert(node.name()); } - // 2. Add the non-duplicates from the old graph to the new graph. + // 3. Add the non-duplicates from the old graph to the new graph. // Return an error if the same node name appears in both the // old graph and the extension. - GraphDef gdef; for (const NodeDef& node : original_graph_def_.node()) { if (new_names.count(node.name()) == 0) { *gdef.add_node() = node; @@ -138,7 +143,7 @@ Status SimpleGraphExecutionState::Extend( } } - // 3. Merge the versions field. + // 4. Merge the versions field. int old_node_size = gdef.node_size(); gdef.mutable_node()->MergeFrom(extension_def.node()); TF_RETURN_IF_ERROR( @@ -174,12 +179,6 @@ Status SimpleGraphExecutionState::Extend( gdef.mutable_versions()->CopyFrom(extension_def.versions()); } - // 4. Copy the function library from this execution state. - // NOTE(mrry): To match the previous behavior, the first GraphDef - // passed to a session will contain the function library that is - // used for all subsequent execution states. - *gdef.mutable_library() = flib_def_->ToProto(); - // 5. Validate that the final graphdef is valid. if (gdef.versions().producer() >= 5) { // Validate the graph: we assume that merging two valid graphs diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 571b9a4668..83c52b7cf7 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util @@ -1752,6 +1753,28 @@ class SessionTest(test_util.TensorFlowTestCase): str_repr = '%s' % attrs self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) + def runTestAddFunctionToSession(self, target=''): + """Add a function to a session after the graph has already been run.""" + @function.Defun(dtypes.float32) + def foo(x): + return x + 1 + + x = constant_op.constant(1.0) + with session.Session(target=target) as sess: + sess.run(x) + f = foo(x) + result = sess.run(f) + self.assertEqual(result, 2.0) + + @test_util.disable_c_api # functions don't work with C API + def testAddFunctionToSession(self): + self.runTestAddFunctionToSession() + + @test_util.disable_c_api # functions don't work with C API + def testAddFunctionToGrpcSession(self): + server = server_lib.Server.create_local_server() + self.runTestAddFunctionToSession(server.target) + if __name__ == '__main__': googletest.main() |