diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2017-04-28 17:44:46 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-28 19:12:47 -0700 |
commit | dae9329b0adb1628cd9665543da0c96f7a1fcbce (patch) | |
tree | dd49ecf0a6f84cf40b374aa718dbc50ab5440e69 /tensorflow | |
parent | ad3c84b58bb42c87ae8f38b81f75447afcc86d5f (diff) |
Distinguish between duplicate feed/fetch and unspecified feed/fetch errors.
Change: 154606429
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 44 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.h | 7 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/master_session.cc | 46 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/master_session.h | 6 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 77 |
5 files changed, 149 insertions, 31 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 002e246b80..f208e4b78e 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -720,16 +720,21 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, if (it == run_state->pending_inputs.end()) { return errors::InvalidArgument( "The feed ", input.first, - " has already been fed or was not specified in partial_run_setup."); + " was not specified in partial_run_setup."); + } else if (it->second) { + return errors::InvalidArgument("The feed ", input.first, + " has already been fed."); } } // Check that this is a new set of fetches that are still pending. for (const auto& output : output_names) { auto it = run_state->pending_outputs.find(output); if (it == run_state->pending_outputs.end()) { + return errors::InvalidArgument( + "The fetch ", output, " was not specified in partial_run_setup."); + } else if (it->second) { return errors::InvalidArgument("The fetch ", output, - " has already been fetched or was not " - "specified in partial_run_setup."); + " has already been fetched."); } } } @@ -764,14 +769,15 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, << run_state->status; } } - for (const auto& it : inputs) { - run_state->pending_inputs.erase(it.first); + for (const auto& input : inputs) { + auto it = run_state->pending_inputs.find(input.first); + it->second = true; } for (const auto& name : output_names) { - run_state->pending_outputs.erase(name); + auto it = run_state->pending_outputs.find(name); + it->second = true; } - done = (run_state->pending_inputs.size() == 0 && - run_state->pending_outputs.size() == 0); + done = run_state->PendingDone(); } if (done) { WaitForNotification(run_state, cancellation_manager_, @@ -900,11 +906,13 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds, std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; { mutex_lock l(executor_lock_); - for (const string& feed : run_state->pending_inputs) { - TensorId id(ParseTensorName(feed)); + for (const auto& input : run_state->pending_inputs) { + // Skip if the feed has already been fed. + if (input.second) continue; + TensorId id(ParseTensorName(input.first)); auto it = name_to_node->find(id.first); if (it == name_to_node->end()) { - return errors::NotFound("Feed ", feed, ": not found"); + return errors::NotFound("Feed ", input.first, ": not found"); } pending_feeds.insert(id); } @@ -1351,10 +1359,10 @@ DirectSession::RunState::RunState( }) { // Initially all the feeds and fetches are pending. for (auto& name : pending_input_names) { - pending_inputs.emplace(name); + pending_inputs[name] = false; } for (auto& name : pending_output_names) { - pending_outputs.emplace(name); + pending_outputs[name] = false; } } @@ -1372,6 +1380,16 @@ DirectSession::RunState::~RunState() { } } +bool DirectSession::RunState::PendingDone() const { + for (const auto& it : pending_inputs) { + if (!it.second) return false; + } + for (const auto& it : pending_outputs) { + if (!it.second) return false; + } + return true; +} + void DirectSession::WaitForNotification(RunState* run_state, CancellationManager* cm, int64 timeout_in_ms) { diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 848ef3bc62..061a7fa787 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -151,8 +151,8 @@ class DirectSession : public Session { IntraProcessRendezvous* rendez = nullptr; std::unique_ptr<StepStatsCollector> collector; Notification executors_done; - std::unordered_set<string> pending_inputs; - std::unordered_set<string> pending_outputs; + std::unordered_map<string, bool> pending_inputs; // true if fed + std::unordered_map<string, bool> pending_outputs; // true if fetched TensorStore tensor_store; ScopedStepContainer step_container; @@ -162,6 +162,9 @@ class DirectSession : public Session { const std::vector<string>& pending_output_names, int64 step_id, const std::vector<Device*>* devices); + // Returns true if all pending inputs and outputs have been completed. + bool PendingDone() const; + ~RunState(); }; diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index cec956ba49..73d4e6ab00 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -803,11 +803,13 @@ Status MasterSession::ReffedClientGraph::CheckFetches( SimpleGraphExecutionState* execution_state) { // Build the set of pending feeds that we haven't seen. std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; - for (const string& feed : run_state->pending_inputs) { - TensorId id(ParseTensorName(feed)); + for (const auto& input : run_state->pending_inputs) { + // Skip if already fed. + if (input.second) continue; + TensorId id(ParseTensorName(input.first)); auto it = name_to_node_.find(id.first); if (it == name_to_node_.end()) { - return errors::NotFound("Feed ", feed, ": not found"); + return errors::NotFound("Feed ", input.first, ": not found"); } pending_feeds.insert(id); } @@ -1247,11 +1249,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts, // Make sure that this is a new set of feeds that are still pending. for (size_t i = 0; i < req.num_feeds(); ++i) { - auto it = run_state->pending_inputs.find(req.feed_name(i)); + const string& feed = req.feed_name(i); + auto it = run_state->pending_inputs.find(feed); if (it == run_state->pending_inputs.end()) { return errors::InvalidArgument( - "The feed ", req.feed_name(i), - " has already been fed or was not specified in partial_run_setup."); + "The feed ", feed, " was not specified in partial_run_setup."); + } else if (it->second) { + return errors::InvalidArgument("The feed ", feed, + " has already been fed."); } } // Check that this is a new set of fetches that are still pending. @@ -1259,9 +1264,11 @@ Status MasterSession::DoPartialRun(CallOptions* opts, const string& fetch = req.fetch_name(i); auto it = run_state->pending_outputs.find(fetch); if (it == run_state->pending_outputs.end()) { + return errors::InvalidArgument( + "The fetch ", fetch, " was not specified in partial_run_setup."); + } else if (it->second) { return errors::InvalidArgument("The fetch ", fetch, - " had already been fetched or was not " - "specified in partial_run_setup."); + " has already been fetched."); } } @@ -1274,13 +1281,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts, // Determine if this partial run satisfies all the pending inputs and ouputs. for (size_t i = 0; i < req.num_feeds(); ++i) { - run_state->pending_inputs.erase(req.feed_name(i)); + auto it = run_state->pending_inputs.find(req.feed_name(i)); + it->second = true; } for (size_t i = 0; i < req.num_fetches(); ++i) { - run_state->pending_outputs.erase(req.fetch_name(i)); + auto it = run_state->pending_outputs.find(req.fetch_name(i)); + it->second = true; } - bool is_last_partial_run = - (run_state->pending_inputs.empty() && run_state->pending_outputs.empty()); + bool is_last_partial_run = run_state->PendingDone(); Status s = run_state->rcg->RunPartitions( env_, run_state->step_id, run_state->count, &run_state->pss, opts, req, @@ -1418,10 +1426,10 @@ MasterSession::RunState::RunState(const std::vector<string>& input_names, : rcg(rcg), step_id(step_id), count(count) { // Initially all the feeds and fetches are pending. for (auto& name : input_names) { - pending_inputs.emplace(name); + pending_inputs[name] = false; } for (auto& name : output_names) { - pending_outputs.emplace(name); + pending_outputs[name] = false; } } @@ -1429,4 +1437,14 @@ MasterSession::RunState::~RunState() { if (rcg) rcg->Unref(); } +bool MasterSession::RunState::PendingDone() const { + for (const auto& it : pending_inputs) { + if (!it.second) return false; + } + for (const auto& it : pending_outputs) { + if (!it.second) return false; + } + return true; +} + } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index c593f84f03..ee1a340c8e 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -141,8 +141,8 @@ class MasterSession : public core::RefCounted { }; struct RunState { - std::unordered_set<string> pending_inputs; - std::unordered_set<string> pending_outputs; + std::unordered_map<string, bool> pending_inputs; // true if fed + std::unordered_map<string, bool> pending_outputs; // true if fetched ReffedClientGraph* rcg = nullptr; uint64 step_id; int64 count = 0; @@ -154,6 +154,8 @@ class MasterSession : public core::RefCounted { const std::vector<string>& output_names, ReffedClientGraph* rcg, const uint64 step_id, const int64 count); + bool PendingDone() const; + ~RunState(); }; std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e6f1c57c7d..9add5bd3cd 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1431,6 +1431,55 @@ class SessionTest(test_util.TensorFlowTestCase): 'You must feed a value for placeholder'): sess.partial_run(handle, fetches[0]) + def runTestPartialRunUnspecifiedFeed(self, sess): + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + + h = sess.partial_run_setup([r1], [a, b]) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'was not specified in partial_run_setup.$'): + sess.partial_run(h, r1, feed_dict={a: 1, b: 2, c: 3}) + + def runTestPartialRunUnspecifiedFetch(self, sess): + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.multiply(a, c) + + h = sess.partial_run_setup([r1], [a, b, c]) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'was not specified in partial_run_setup.$'): + sess.partial_run(h, r2, feed_dict={a: 1, c: 3}) + + def runTestPartialRunAlreadyFed(self, sess): + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.multiply(a, c) + + h = sess.partial_run_setup([r1, r2], [a, b, c]) + sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'has already been fed.$'): + sess.partial_run(h, r2, feed_dict={a: 1, c: 3}) + + def runTestPartialRunAlreadyFetched(self, sess): + a = array_ops.placeholder(dtypes.float32, shape=[]) + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + r2 = math_ops.multiply(a, c) + + h = sess.partial_run_setup([r1, r2], [a, b, c]) + sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + 'has already been fetched.$'): + sess.partial_run(h, r1, feed_dict={c: 3}) + def testInvalidPartialRunSetup(self): sess = session.Session() x = array_ops.placeholder(dtypes.float32, shape=[]) @@ -1457,6 +1506,18 @@ class SessionTest(test_util.TensorFlowTestCase): def testPartialRunMissingPlaceholderFeedExceptionDirect(self): self.runTestPartialRunMissingPlaceholderFeedException(session.Session()) + def testPartialRunUnspecifiedFeedDirect(self): + self.runTestPartialRunUnspecifiedFeed(session.Session()) + + def testPartialRunUnspecifiedFetchDirect(self): + self.runTestPartialRunUnspecifiedFetch(session.Session()) + + def testPartialRunAlreadyFedDirect(self): + self.runTestPartialRunAlreadyFed(session.Session()) + + def testPartialRunAlreadyFetchedDirect(self): + self.runTestPartialRunAlreadyFetched(session.Session()) + def testPartialRunDist(self): server = server_lib.Server.create_local_server() self.runTestPartialRun(session.Session(server.target)) @@ -1482,6 +1543,22 @@ class SessionTest(test_util.TensorFlowTestCase): self.runTestPartialRunMissingPlaceholderFeedException( session.Session(server.target)) + def testPartialRunUnspecifiedFeedDist(self): + server = server_lib.Server.create_local_server() + self.runTestPartialRunUnspecifiedFeed(session.Session(server.target)) + + def testPartialRunUnspecifiedFetchDist(self): + server = server_lib.Server.create_local_server() + self.runTestPartialRunUnspecifiedFetch(session.Session(server.target)) + + def testPartialRunAlreadyFedDist(self): + server = server_lib.Server.create_local_server() + self.runTestPartialRunAlreadyFed(session.Session(server.target)) + + def testPartialRunAlreadyFetchedDist(self): + server = server_lib.Server.create_local_server() + self.runTestPartialRunAlreadyFetched(session.Session(server.target)) + def testFeedDictKeyException(self): with session.Session() as sess: a = constant_op.constant(1.0, dtypes.float32, name='a') |