aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/master_session.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-05 22:37:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-05 22:40:16 -0700
commitc2d6faafc48b251faa24a342dc063d9fa624421e (patch)
tree38cca5a204916b78772a8f6ef6aaf66c41d82cb2 /tensorflow/core/distributed_runtime/master_session.cc
parent1b4f2c51b668dbc1952cabdaf61773b7cff2a0c3 (diff)
Fix StringPiece use-after-free in MasterSession::ReffedClientGraph.
Use the owned ClientGraph as the source for the node_to_name_ map, rather than the borrowed GraphExecutionState (which can be deleted while the ReffedClientGraph is in use). PiperOrigin-RevId: 191847023
Diffstat (limited to 'tensorflow/core/distributed_runtime/master_session.cc')
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc24
1 files changed, 11 insertions, 13 deletions
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 01da54fcb3..64adf35c5e 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -66,8 +66,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
std::unique_ptr<ClientGraph> cg,
const SessionOptions& session_opts,
const StatsPublisherFactory& stats_publisher_factory,
- GraphExecutionState* execution_state, bool is_partial,
- WorkerCacheInterface* worker_cache, bool should_deregister)
+ bool is_partial, WorkerCacheInterface* worker_cache,
+ bool should_deregister)
: session_handle_(handle),
client_graph_(std::move(cg)),
session_opts_(session_opts),
@@ -80,8 +80,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
- // Initialize a name to node map for testing that fetches are reachable.
- for (Node* n : execution_state->full_graph()->nodes()) {
+ // Initialize a name to node map for processing device stats.
+ for (Node* n : client_graph_->graph.nodes()) {
name_to_node_.insert({n->name(), n});
}
}
@@ -829,8 +829,6 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
// TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
// on once at setup time to prevent us from computing the dependencies
// everytime.
-// TODO(suharshs,mrry): Consider removing the need for execution_state to reduce
-// contention.
Status MasterSession::ReffedClientGraph::CheckFetches(
const RunStepRequestWrapper& req, const RunState* run_state,
GraphExecutionState* execution_state) {
@@ -840,8 +838,8 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
// Skip if already fed.
if (input.second) continue;
TensorId id(ParseTensorName(input.first));
- const auto it = name_to_node_.find(id.first);
- if (it == name_to_node_.end()) {
+ const Node* n = execution_state->get_node_by_name(id.first.ToString());
+ if (n == nullptr) {
return errors::NotFound("Feed ", input.first, ": not found");
}
pending_feeds.insert(id);
@@ -856,11 +854,11 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
for (size_t i = 0; i < req.num_fetches(); ++i) {
const string& fetch = req.fetch_name(i);
const TensorId id(ParseTensorName(fetch));
- auto it = name_to_node_.find(id.first);
- if (it == name_to_node_.end()) {
+ const Node* n = execution_state->get_node_by_name(id.first.ToString());
+ if (n == nullptr) {
return errors::NotFound("Fetch ", fetch, ": not found");
}
- stack.push_back(it->second);
+ stack.push_back(n);
}
// Any tensor needed for fetches can't be in pending_feeds.
@@ -1293,8 +1291,8 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
- stats_publisher_factory_, execution_state_.get(), is_partial,
- worker_cache, !should_delete_worker_sessions_);
+ stats_publisher_factory_, is_partial, worker_cache,
+ !should_delete_worker_sessions_);
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}