aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-04-28 17:44:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 19:12:47 -0700
commitdae9329b0adb1628cd9665543da0c96f7a1fcbce (patch)
treedd49ecf0a6f84cf40b374aa718dbc50ab5440e69 /tensorflow
parentad3c84b58bb42c87ae8f38b81f75447afcc86d5f (diff)
Distinguish between duplicate feed/fetch and unspecified feed/fetch errors.
Change: 154606429
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc44
-rw-r--r--tensorflow/core/common_runtime/direct_session.h7
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc46
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h6
-rw-r--r--tensorflow/python/client/session_test.py77
5 files changed, 149 insertions, 31 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 002e246b80..f208e4b78e 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -720,16 +720,21 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
if (it == run_state->pending_inputs.end()) {
return errors::InvalidArgument(
"The feed ", input.first,
- " has already been fed or was not specified in partial_run_setup.");
+ " was not specified in partial_run_setup.");
+ } else if (it->second) {
+ return errors::InvalidArgument("The feed ", input.first,
+ " has already been fed.");
}
}
// Check that this is a new set of fetches that are still pending.
for (const auto& output : output_names) {
auto it = run_state->pending_outputs.find(output);
if (it == run_state->pending_outputs.end()) {
+ return errors::InvalidArgument(
+ "The fetch ", output, " was not specified in partial_run_setup.");
+ } else if (it->second) {
return errors::InvalidArgument("The fetch ", output,
- " has already been fetched or was not "
- "specified in partial_run_setup.");
+ " has already been fetched.");
}
}
}
@@ -764,14 +769,15 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
<< run_state->status;
}
}
- for (const auto& it : inputs) {
- run_state->pending_inputs.erase(it.first);
+ for (const auto& input : inputs) {
+ auto it = run_state->pending_inputs.find(input.first);
+ it->second = true;
}
for (const auto& name : output_names) {
- run_state->pending_outputs.erase(name);
+ auto it = run_state->pending_outputs.find(name);
+ it->second = true;
}
- done = (run_state->pending_inputs.size() == 0 &&
- run_state->pending_outputs.size() == 0);
+ done = run_state->PendingDone();
}
if (done) {
WaitForNotification(run_state, cancellation_manager_,
@@ -900,11 +906,13 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
{
mutex_lock l(executor_lock_);
- for (const string& feed : run_state->pending_inputs) {
- TensorId id(ParseTensorName(feed));
+ for (const auto& input : run_state->pending_inputs) {
+ // Skip if the feed has already been fed.
+ if (input.second) continue;
+ TensorId id(ParseTensorName(input.first));
auto it = name_to_node->find(id.first);
if (it == name_to_node->end()) {
- return errors::NotFound("Feed ", feed, ": not found");
+ return errors::NotFound("Feed ", input.first, ": not found");
}
pending_feeds.insert(id);
}
@@ -1351,10 +1359,10 @@ DirectSession::RunState::RunState(
}) {
// Initially all the feeds and fetches are pending.
for (auto& name : pending_input_names) {
- pending_inputs.emplace(name);
+ pending_inputs[name] = false;
}
for (auto& name : pending_output_names) {
- pending_outputs.emplace(name);
+ pending_outputs[name] = false;
}
}
@@ -1372,6 +1380,16 @@ DirectSession::RunState::~RunState() {
}
}
+bool DirectSession::RunState::PendingDone() const {
+ for (const auto& it : pending_inputs) {
+ if (!it.second) return false;
+ }
+ for (const auto& it : pending_outputs) {
+ if (!it.second) return false;
+ }
+ return true;
+}
+
void DirectSession::WaitForNotification(RunState* run_state,
CancellationManager* cm,
int64 timeout_in_ms) {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 848ef3bc62..061a7fa787 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -151,8 +151,8 @@ class DirectSession : public Session {
IntraProcessRendezvous* rendez = nullptr;
std::unique_ptr<StepStatsCollector> collector;
Notification executors_done;
- std::unordered_set<string> pending_inputs;
- std::unordered_set<string> pending_outputs;
+ std::unordered_map<string, bool> pending_inputs; // true if fed
+ std::unordered_map<string, bool> pending_outputs; // true if fetched
TensorStore tensor_store;
ScopedStepContainer step_container;
@@ -162,6 +162,9 @@ class DirectSession : public Session {
const std::vector<string>& pending_output_names, int64 step_id,
const std::vector<Device*>* devices);
+ // Returns true if all pending inputs and outputs have been completed.
+ bool PendingDone() const;
+
~RunState();
};
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index cec956ba49..73d4e6ab00 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -803,11 +803,13 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
SimpleGraphExecutionState* execution_state) {
// Build the set of pending feeds that we haven't seen.
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
- for (const string& feed : run_state->pending_inputs) {
- TensorId id(ParseTensorName(feed));
+ for (const auto& input : run_state->pending_inputs) {
+ // Skip if already fed.
+ if (input.second) continue;
+ TensorId id(ParseTensorName(input.first));
auto it = name_to_node_.find(id.first);
if (it == name_to_node_.end()) {
- return errors::NotFound("Feed ", feed, ": not found");
+ return errors::NotFound("Feed ", input.first, ": not found");
}
pending_feeds.insert(id);
}
@@ -1247,11 +1249,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
// Make sure that this is a new set of feeds that are still pending.
for (size_t i = 0; i < req.num_feeds(); ++i) {
- auto it = run_state->pending_inputs.find(req.feed_name(i));
+ const string& feed = req.feed_name(i);
+ auto it = run_state->pending_inputs.find(feed);
if (it == run_state->pending_inputs.end()) {
return errors::InvalidArgument(
- "The feed ", req.feed_name(i),
- " has already been fed or was not specified in partial_run_setup.");
+ "The feed ", feed, " was not specified in partial_run_setup.");
+ } else if (it->second) {
+ return errors::InvalidArgument("The feed ", feed,
+ " has already been fed.");
}
}
// Check that this is a new set of fetches that are still pending.
@@ -1259,9 +1264,11 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
const string& fetch = req.fetch_name(i);
auto it = run_state->pending_outputs.find(fetch);
if (it == run_state->pending_outputs.end()) {
+ return errors::InvalidArgument(
+ "The fetch ", fetch, " was not specified in partial_run_setup.");
+ } else if (it->second) {
return errors::InvalidArgument("The fetch ", fetch,
- " had already been fetched or was not "
- "specified in partial_run_setup.");
+ " has already been fetched.");
}
}
@@ -1274,13 +1281,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
// Determine if this partial run satisfies all the pending inputs and ouputs.
for (size_t i = 0; i < req.num_feeds(); ++i) {
- run_state->pending_inputs.erase(req.feed_name(i));
+ auto it = run_state->pending_inputs.find(req.feed_name(i));
+ it->second = true;
}
for (size_t i = 0; i < req.num_fetches(); ++i) {
- run_state->pending_outputs.erase(req.fetch_name(i));
+ auto it = run_state->pending_outputs.find(req.fetch_name(i));
+ it->second = true;
}
- bool is_last_partial_run =
- (run_state->pending_inputs.empty() && run_state->pending_outputs.empty());
+ bool is_last_partial_run = run_state->PendingDone();
Status s = run_state->rcg->RunPartitions(
env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
@@ -1418,10 +1426,10 @@ MasterSession::RunState::RunState(const std::vector<string>& input_names,
: rcg(rcg), step_id(step_id), count(count) {
// Initially all the feeds and fetches are pending.
for (auto& name : input_names) {
- pending_inputs.emplace(name);
+ pending_inputs[name] = false;
}
for (auto& name : output_names) {
- pending_outputs.emplace(name);
+ pending_outputs[name] = false;
}
}
@@ -1429,4 +1437,14 @@ MasterSession::RunState::~RunState() {
if (rcg) rcg->Unref();
}
+bool MasterSession::RunState::PendingDone() const {
+ for (const auto& it : pending_inputs) {
+ if (!it.second) return false;
+ }
+ for (const auto& it : pending_outputs) {
+ if (!it.second) return false;
+ }
+ return true;
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index c593f84f03..ee1a340c8e 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -141,8 +141,8 @@ class MasterSession : public core::RefCounted {
};
struct RunState {
- std::unordered_set<string> pending_inputs;
- std::unordered_set<string> pending_outputs;
+ std::unordered_map<string, bool> pending_inputs; // true if fed
+ std::unordered_map<string, bool> pending_outputs; // true if fetched
ReffedClientGraph* rcg = nullptr;
uint64 step_id;
int64 count = 0;
@@ -154,6 +154,8 @@ class MasterSession : public core::RefCounted {
const std::vector<string>& output_names, ReffedClientGraph* rcg,
const uint64 step_id, const int64 count);
+ bool PendingDone() const;
+
~RunState();
};
std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index e6f1c57c7d..9add5bd3cd 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1431,6 +1431,55 @@ class SessionTest(test_util.TensorFlowTestCase):
'You must feed a value for placeholder'):
sess.partial_run(handle, fetches[0])
+ def runTestPartialRunUnspecifiedFeed(self, sess):
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+
+ h = sess.partial_run_setup([r1], [a, b])
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ 'was not specified in partial_run_setup.$'):
+ sess.partial_run(h, r1, feed_dict={a: 1, b: 2, c: 3})
+
+ def runTestPartialRunUnspecifiedFetch(self, sess):
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.multiply(a, c)
+
+ h = sess.partial_run_setup([r1], [a, b, c])
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ 'was not specified in partial_run_setup.$'):
+ sess.partial_run(h, r2, feed_dict={a: 1, c: 3})
+
+ def runTestPartialRunAlreadyFed(self, sess):
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.multiply(a, c)
+
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ 'has already been fed.$'):
+ sess.partial_run(h, r2, feed_dict={a: 1, c: 3})
+
+ def runTestPartialRunAlreadyFetched(self, sess):
+ a = array_ops.placeholder(dtypes.float32, shape=[])
+ b = array_ops.placeholder(dtypes.float32, shape=[])
+ c = array_ops.placeholder(dtypes.float32, shape=[])
+ r1 = math_ops.add(a, b)
+ r2 = math_ops.multiply(a, c)
+
+ h = sess.partial_run_setup([r1, r2], [a, b, c])
+ sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ 'has already been fetched.$'):
+ sess.partial_run(h, r1, feed_dict={c: 3})
+
def testInvalidPartialRunSetup(self):
sess = session.Session()
x = array_ops.placeholder(dtypes.float32, shape=[])
@@ -1457,6 +1506,18 @@ class SessionTest(test_util.TensorFlowTestCase):
def testPartialRunMissingPlaceholderFeedExceptionDirect(self):
self.runTestPartialRunMissingPlaceholderFeedException(session.Session())
+ def testPartialRunUnspecifiedFeedDirect(self):
+ self.runTestPartialRunUnspecifiedFeed(session.Session())
+
+ def testPartialRunUnspecifiedFetchDirect(self):
+ self.runTestPartialRunUnspecifiedFetch(session.Session())
+
+ def testPartialRunAlreadyFedDirect(self):
+ self.runTestPartialRunAlreadyFed(session.Session())
+
+ def testPartialRunAlreadyFetchedDirect(self):
+ self.runTestPartialRunAlreadyFetched(session.Session())
+
def testPartialRunDist(self):
server = server_lib.Server.create_local_server()
self.runTestPartialRun(session.Session(server.target))
@@ -1482,6 +1543,22 @@ class SessionTest(test_util.TensorFlowTestCase):
self.runTestPartialRunMissingPlaceholderFeedException(
session.Session(server.target))
+ def testPartialRunUnspecifiedFeedDist(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestPartialRunUnspecifiedFeed(session.Session(server.target))
+
+ def testPartialRunUnspecifiedFetchDist(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestPartialRunUnspecifiedFetch(session.Session(server.target))
+
+ def testPartialRunAlreadyFedDist(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestPartialRunAlreadyFed(session.Session(server.target))
+
+ def testPartialRunAlreadyFetchedDist(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestPartialRunAlreadyFetched(session.Session(server.target))
+
def testFeedDictKeyException(self):
with session.Session() as sess:
a = constant_op.constant(1.0, dtypes.float32, name='a')