aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-07-09 13:09:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-09 13:13:43 -0700
commit23caae7408cf6bb961d1bab0c8b819d88ccf5f3c (patch)
tree4ea5472cc514bd976f4de7f251f3467d0982c637 /tensorflow
parentd22c433216254d2e6ab33a849f922b4545d1f7dd (diff)
Allow adding new functions to an already-run session
PiperOrigin-RevId: 161337922
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc1
-rw-r--r--tensorflow/core/common_runtime/simple_graph_execution_state.cc19
-rw-r--r--tensorflow/python/client/session_test.py23
3 files changed, 33 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index f6eac2d698..ada2e32f38 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -390,6 +390,7 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) {
TF_RETURN_IF_ERROR(
MaybeInitializeExecutionState(graph, &already_initialized));
if (already_initialized) {
+ TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
std::unique_ptr<SimpleGraphExecutionState> state;
TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
execution_state_.swap(state);
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
index 5a69d2440a..41e685bdc7 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
@@ -117,16 +117,21 @@ SimpleGraphExecutionState::~SimpleGraphExecutionState() {
Status SimpleGraphExecutionState::Extend(
const GraphDef& extension_def,
std::unique_ptr<SimpleGraphExecutionState>* out) const {
+ GraphDef gdef;
+
+ // 1. Copy the function library.
+ TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library()));
+ *gdef.mutable_library() = flib_def_->ToProto();
+
+ // 2. Build an index of the new node names.
std::unordered_set<string> new_names;
- // 1. Build an index of the new node names.
for (const NodeDef& node : extension_def.node()) {
new_names.insert(node.name());
}
- // 2. Add the non-duplicates from the old graph to the new graph.
+ // 3. Add the non-duplicates from the old graph to the new graph.
// Return an error if the same node name appears in both the
// old graph and the extension.
- GraphDef gdef;
for (const NodeDef& node : original_graph_def_.node()) {
if (new_names.count(node.name()) == 0) {
*gdef.add_node() = node;
@@ -138,7 +143,7 @@ Status SimpleGraphExecutionState::Extend(
}
}
- // 3. Merge the versions field.
+ // 4. Merge the versions field.
int old_node_size = gdef.node_size();
gdef.mutable_node()->MergeFrom(extension_def.node());
TF_RETURN_IF_ERROR(
@@ -174,12 +179,6 @@ Status SimpleGraphExecutionState::Extend(
gdef.mutable_versions()->CopyFrom(extension_def.versions());
}
- // 4. Copy the function library from this execution state.
- // NOTE(mrry): To match the previous behavior, the first GraphDef
- // passed to a session will contain the function library that is
- // used for all subsequent execution states.
- *gdef.mutable_library() = flib_def_->ToProto();
-
// 5. Validate that the final graphdef is valid.
if (gdef.versions().producer() >= 5) {
// Validate the graph: we assume that merging two valid graphs
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 571b9a4668..83c52b7cf7 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
@@ -1752,6 +1753,28 @@ class SessionTest(test_util.TensorFlowTestCase):
str_repr = '%s' % attrs
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
+ def runTestAddFunctionToSession(self, target=''):
+ """Add a function to a session after the graph has already been run."""
+ @function.Defun(dtypes.float32)
+ def foo(x):
+ return x + 1
+
+ x = constant_op.constant(1.0)
+ with session.Session(target=target) as sess:
+ sess.run(x)
+ f = foo(x)
+ result = sess.run(f)
+ self.assertEqual(result, 2.0)
+
+ @test_util.disable_c_api # functions don't work with C API
+ def testAddFunctionToSession(self):
+ self.runTestAddFunctionToSession()
+
+ @test_util.disable_c_api # functions don't work with C API
+ def testAddFunctionToGrpcSession(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestAddFunctionToSession(server.target)
+
if __name__ == '__main__':
googletest.main()