aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc71
-rw-r--r--tensorflow/core/common_runtime/direct_session.h3
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc16
-rw-r--r--tensorflow/core/common_runtime/simple_graph_execution_state.cc111
-rw-r--r--tensorflow/core/common_runtime/simple_graph_execution_state.h75
-rw-r--r--tensorflow/core/distributed_runtime/master.cc2
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc20
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc34
9 files changed, 200 insertions, 134 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 52a810d76b..848a71c474 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -285,28 +285,47 @@ DirectSession::~DirectSession() {
flib_def_.reset(nullptr);
}
-void DirectSession::MaybeInitializeExecutionState(const GraphDef& graph) {
+Status DirectSession::MaybeInitializeExecutionState(
+ const GraphDef& graph, bool* out_already_initialized) {
// If already initialized, do nothing.
if (flib_def_ && execution_state_) {
- return;
+ *out_already_initialized = true;
+ return Status::OK();
}
// Set up the per-session execution state.
+ // NOTE(mrry): The function library created here will be used for
+ // all subsequent extensions of the graph.
flib_def_.reset(
new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
SimpleGraphExecutionStateOptions options;
options.device_set = &device_set_;
options.session_options = &options_;
- execution_state_.reset(
- new SimpleGraphExecutionState(graph.library(), options));
+ // TODO(mrry,suharshs): We explicitly copy `graph` so that
+ // `MakeForBaseGraph()` can take ownership of its
+ // contents. Previously this happened implicitly in calls to the
+ // `SimpleGraphExecutionState`. Other sessions call
+ // `MakeForBaseGraph` in such a way that we can destructively read
+ // the passed-in `GraphDef`. In principle we could do the same here,
+ // with a wider refactoring; we might revise the direct session so
+ // that it copies the graph fewer times.
+ GraphDef temp(graph);
+ TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
+ &temp, options, &execution_state_));
+ graph_created_ = true;
+ *out_already_initialized = false;
+ return Status::OK();
}
Status DirectSession::Create(const GraphDef& graph) {
- mutex_lock l(graph_def_lock_);
- if (graph_created_) {
- return errors::AlreadyExists(
- "A Graph has already been created for this session.");
+ if (graph.node_size() > 0) {
+ mutex_lock l(graph_def_lock_);
+ if (graph_created_) {
+ return errors::AlreadyExists(
+ "A Graph has already been created for this session.");
+ }
+ return ExtendLocked(graph);
}
- return ExtendLocked(graph);
+ return Status::OK();
}
Status DirectSession::Extend(const GraphDef& graph) {
@@ -316,12 +335,16 @@ Status DirectSession::Extend(const GraphDef& graph) {
}
Status DirectSession::ExtendLocked(const GraphDef& graph) {
- MaybeInitializeExecutionState(graph);
- std::unique_ptr<SimpleGraphExecutionState> state;
- TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
- execution_state_.swap(state);
-
- graph_created_ = true; // In case this is first call
+ bool already_initialized;
+ // If this is the first call, we can initialize the execution state
+ // with `graph` and do not need to call `Extend()`.
+ TF_RETURN_IF_ERROR(
+ MaybeInitializeExecutionState(graph, &already_initialized));
+ if (already_initialized) {
+ std::unique_ptr<SimpleGraphExecutionState> state;
+ TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
+ execution_state_.swap(state);
+ }
return Status::OK();
}
@@ -949,7 +972,7 @@ Status DirectSession::GetOrCreateExecutors(
}
Status DirectSession::CreateGraphs(
- const BuildGraphOptions& options,
+ const BuildGraphOptions& subgraph_options,
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args) {
@@ -960,23 +983,23 @@ Status DirectSession::CreateGraphs(
SimpleGraphExecutionState* execution_state = nullptr;
if (options_.config.graph_options().place_pruned_graph()) {
// Because we are placing pruned graphs, we need to create a
- // new SimpleGraphExecutorState for every new unseen graph,
+ // new SimpleGraphExecutionState for every new unseen graph,
// and then place it.
SimpleGraphExecutionStateOptions prune_options;
prune_options.device_set = &device_set_;
prune_options.session_options = &options_;
- temp_exec_state_holder.reset(new SimpleGraphExecutionState(
- execution_state_->original_graph_def().library(), prune_options));
- temp_exec_state_holder->SetStatefulPlacements(stateful_placements_);
-
- TF_RETURN_IF_ERROR(temp_exec_state_holder->Extend(
- execution_state_->original_graph_def(), &temp_exec_state_holder));
+ prune_options.stateful_placements = stateful_placements_;
+ TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph(
+ execution_state_->original_graph_def().library(), prune_options,
+ execution_state_->original_graph_def(), subgraph_options,
+ &temp_exec_state_holder, &client_graph));
execution_state = temp_exec_state_holder.get();
} else {
execution_state = execution_state_.get();
+ TF_RETURN_IF_ERROR(
+ execution_state->BuildGraph(subgraph_options, &client_graph));
}
- TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &client_graph));
auto current_stateful_placements = execution_state->GetStatefulPlacements();
// Update our current state based on the execution_state's
// placements. If there are any mismatches for a node,
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 8e757e9273..a428911253 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -164,7 +164,8 @@ class DirectSession : public Session {
// Initializes the base execution state given the 'graph',
// if not already initialized.
- void MaybeInitializeExecutionState(const GraphDef& graph)
+ Status MaybeInitializeExecutionState(const GraphDef& graph,
+ bool* out_already_initialized)
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
// Retrieves an already existing set of executors to run 'inputs' and
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 4eb48c7bcf..38dd627da0 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -244,13 +244,8 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
(*options.config.mutable_device_count())["CPU"] = 2;
std::unique_ptr<Session> session(NewSession(options));
ASSERT_TRUE(session != nullptr);
- TF_ASSERT_OK(session->Create(def));
- std::vector<std::pair<string, Tensor>> inputs;
- std::vector<string> output_names = {y->name() + ":0"};
- std::vector<Tensor> outputs;
-
// Should return an error.
- ASSERT_FALSE(session->Run(inputs, output_names, {}, &outputs).ok());
+ ASSERT_FALSE(session->Create(def).ok());
// Fix placement and run again
def.Clear();
@@ -258,7 +253,8 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
test::graph::ToGraphDef(&graph, &def);
session.reset(NewSession(options));
TF_ASSERT_OK(session->Create(def));
- TF_ASSERT_OK(session->Run(inputs, output_names, {}, &outputs));
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session->Run({}, {y->name() + ":0"}, {}, &outputs));
}
TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
@@ -454,12 +450,10 @@ TEST(DirectSessionTest, PlacePrunedGraph) {
test::graph::ToGraphDef(&g, &def);
// By default, we place the entire graph, so we should fail the
- // call to Run, even if we don't run the bad op.
+ // call to Create.
SessionOptions options;
std::unique_ptr<Session> sess(NewSession(options));
- TF_ASSERT_OK(sess->Create(def));
- std::vector<Tensor> outputs;
- auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs);
+ auto s = sess->Create(def);
EXPECT_TRUE(errors::IsInvalidArgument(s));
}
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
index 365120dd0e..2be12e7ad8 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
@@ -40,33 +40,68 @@ limitations under the License.
namespace tensorflow {
SimpleGraphExecutionState::SimpleGraphExecutionState(
- const FunctionDefLibrary& func_def_lib,
- const SimpleGraphExecutionStateOptions& options)
- : device_set_(options.device_set),
+ GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options)
+ : stateful_placements_(options.stateful_placements),
+ device_set_(options.device_set),
session_options_(options.session_options),
costs_(true /*is_global*/),
- flib_def_(
- new FunctionLibraryDefinition(OpRegistry::Global(), func_def_lib)),
+ flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(),
+ graph_def->library())),
graph_(nullptr) {
+ // NOTE(mrry): GraphDef does not have a move constructor, so we pass
+ // a non-const pointer and use `Swap()` to transfer the contents
+ // without copying.
+ original_graph_def_.Swap(graph_def);
// TODO(mrry): Publish placement visualizations or handle the log
// placement option.
}
SimpleGraphExecutionState::~SimpleGraphExecutionState() {
- mutex_lock l(mu_);
node_name_to_cost_id_map_.clear();
delete graph_;
}
-Status SimpleGraphExecutionState::Create(GraphDef* graph_def) {
- if (original_graph_def_.node_size() > 0) {
- return errors::InvalidArgument(
- "Cannot call Create on SimpleGraphExecutionState twice");
+/* static */ Status SimpleGraphExecutionState::MakeForBaseGraph(
+ GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options,
+ std::unique_ptr<SimpleGraphExecutionState>* out_state) {
+ std::unique_ptr<SimpleGraphExecutionState> ret(
+ new SimpleGraphExecutionState(graph_def, options));
+
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&ret->original_graph_def_,
+ *ret->flib_def_.get(), 0));
+ // TODO(mrry): Refactor InitBaseGraph() so that we don't have to
+ // pass an empty BuildGraphOptions (that isn't going to be used when
+ // place_pruned_graph is false).
+ if (!ret->session_options_->config.graph_options().place_pruned_graph()) {
+ TF_RETURN_IF_ERROR(ret->InitBaseGraph(BuildGraphOptions()));
}
+ *out_state = std::move(ret);
+ return Status::OK();
+}
- original_graph_def_.Swap(graph_def);
- VLOG(2) << "Incoming def: " << ProtoDebugString(original_graph_def_);
- return AddDefaultAttrsToGraphDef(&original_graph_def_, *flib_def_.get(), 0);
+/* static */ Status SimpleGraphExecutionState::MakeForPrunedGraph(
+ const FunctionDefLibrary& func_def_lib,
+ const SimpleGraphExecutionStateOptions& options, const GraphDef& graph_def,
+ const BuildGraphOptions& subgraph_options,
+ std::unique_ptr<SimpleGraphExecutionState>* out_state,
+ std::unique_ptr<SimpleClientGraph>* out_client_graph) {
+ DCHECK(options.session_options->config.graph_options().place_pruned_graph());
+ // NOTE(mrry): This makes a copy of `graph_def`, which is
+ // regrettable. We could make `GraphDef` objects sharable between
+ // execution states to optimize pruned graph execution, but since
+ // this case is primarily used for interactive sessions, we make the
+ // bet that graph construction is not performance-critical. (Note
+ // also that the previous version used `Extend()`, which is strictly
+ // more expensive than copying a `GraphDef`.)
+ GraphDef temp(graph_def);
+ std::unique_ptr<SimpleGraphExecutionState> ret(
+ new SimpleGraphExecutionState(&temp, options));
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&ret->original_graph_def_,
+ *ret->flib_def_.get(), 0));
+ TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options));
+ TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph));
+ *out_state = std::move(ret);
+ return Status::OK();
}
Status SimpleGraphExecutionState::Extend(
@@ -129,6 +164,12 @@ Status SimpleGraphExecutionState::Extend(
gdef.mutable_versions()->CopyFrom(extension_def.versions());
}
+ // 4. Copy the function library from this execution state.
+ // NOTE(mrry): To match the previous behavior, the first GraphDef
+ // passed to a session will contain the function library that is
+ // used for all subsequent execution states.
+ *gdef.mutable_library() = flib_def_->ToProto();
+
// 5. Validate that the final graphdef is valid.
if (gdef.versions().producer() >= 5) {
// Validate the graph: we assume that merging two valid graphs
@@ -140,11 +181,21 @@ Status SimpleGraphExecutionState::Extend(
SimpleGraphExecutionStateOptions combined_options;
combined_options.device_set = device_set_;
combined_options.session_options = session_options_;
+ combined_options.stateful_placements = stateful_placements_;
+ // NOTE(mrry): `gdef` is no longer valid after the constructor
+ // executes.
std::unique_ptr<SimpleGraphExecutionState> new_execution_state(
- new SimpleGraphExecutionState(flib_def_->ToProto(), combined_options));
- TF_RETURN_IF_ERROR(new_execution_state->Create(&gdef));
- new_execution_state->SetStatefulPlacements(GetStatefulPlacements());
+ new SimpleGraphExecutionState(&gdef, combined_options));
+
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
+ &new_execution_state->original_graph_def_, *flib_def_.get(), 0));
+ if (!session_options_->config.graph_options().place_pruned_graph()) {
+ // TODO(mrry): Refactor InitBaseGraph() so that we don't have to
+ // pass an empty BuildGraphOptions (that isn't going to be used
+ // when place_pruned_graph is false).
+ TF_RETURN_IF_ERROR(new_execution_state->InitBaseGraph(BuildGraphOptions()));
+ }
*out = std::move(new_execution_state);
// TODO(mrry): This is likely to be used for non-throughput-sensitive
@@ -196,8 +247,11 @@ Status SimpleGraphExecutionState::InitBaseGraph(
RestoreStatefulNodes(new_graph.get());
CostModel costs(true /*is_global*/);
- costs_.InitFromGraph(*new_graph.get());
- costs.MergeFromGlobal(costs_);
+ {
+ mutex_lock l(mu_);
+ costs_.InitFromGraph(*new_graph.get());
+ costs.MergeFromGlobal(costs_);
+ }
GraphOptimizationPassOptions optimization_options;
optimization_options.session_options = session_options_;
@@ -230,30 +284,9 @@ void SimpleGraphExecutionState::MergeCostsFromGlobal(CostModel* costs) {
costs->MergeFromGlobal(costs_);
}
-Status SimpleGraphExecutionState::GlobalNodeDefByName(const string& name,
- NodeDef* out) {
- NodeNameToCostIdMap::const_iterator iter =
- node_name_to_cost_id_map_.find(name);
- if (iter != node_name_to_cost_id_map_.end()) {
- mutex_lock l(mu_); // could use reader lock
- const Node* node = graph_->FindNodeId(iter->second);
- if (node) {
- *out = node->def();
- return Status::OK();
- }
- }
- return errors::NotFound("Node name: ", name);
-}
-
Status SimpleGraphExecutionState::BuildGraph(
const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
VLOG(1) << "BuildGraph";
- mutex_lock l(mu_);
- // Lazily initialize the base graph.
- if (!graph_) {
- TF_RETURN_IF_ERROR(InitBaseGraph(options));
- }
-
std::unique_ptr<Graph> ng(new Graph(flib_def_.get()));
CopyGraph(*graph_, ng.get());
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h
index 5fe16f0f42..2a33d9e298 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.h
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h
@@ -42,6 +42,9 @@ class Timeline;
struct SimpleGraphExecutionStateOptions {
const DeviceSet* device_set = nullptr;
const SessionOptions* session_options = nullptr;
+ // A map from node name to device name, representing the unchangeable
+ // placement of stateful nodes.
+ std::unordered_map<string, string> stateful_placements;
};
// A SimpleClientGraph is simply a sub-graph of the full graph as induced by
@@ -82,15 +85,29 @@ struct SimpleClientGraph {
class SimpleGraphExecutionState {
public:
- SimpleGraphExecutionState(const FunctionDefLibrary& func_def_lib,
- const SimpleGraphExecutionStateOptions& options);
-
virtual ~SimpleGraphExecutionState();
- // Initializes the SimpleGraphExecutionState with 'graph_def'. Can only be
- // called once on an original SimpleGraphExecutionState. Callee may modify
- // 'graph_def'.
- Status Create(GraphDef* graph_def);
+ // Creates a new `SimpleGraphExecutionState` for the given
+ // `graph_def`, which represents the entire graph for a session.
+ //
+ // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def`
+ // in an undefined state. If it is necessary to use `*graph_def`
+ // after this call, make an explicit copy of the graph before
+ // calling this method.
+ static Status MakeForBaseGraph(
+ GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options,
+ std::unique_ptr<SimpleGraphExecutionState>* out_state);
+
+ // Creates a new `SimpleGraphExecutionState` and `SimpleClientGraph`
+ // for the subgraph of `original_graph_def` defined by
+ // `subgraph_options`.
+ static Status MakeForPrunedGraph(
+ const FunctionDefLibrary& func_def_lib,
+ const SimpleGraphExecutionStateOptions& options,
+ const GraphDef& original_graph_def,
+ const BuildGraphOptions& subgraph_options,
+ std::unique_ptr<SimpleGraphExecutionState>* out_state,
+ std::unique_ptr<SimpleClientGraph>* out_client_graph);
// Creates a new SimpleGraphExecutionState representing the
// concatenation of this graph, and the graph defined by
@@ -100,6 +117,9 @@ class SimpleGraphExecutionState {
// If successful, returns OK and the caller takes ownership of "*out".
// Otherwise returns an error and does not modify "*out".
//
+ // After calling `old_state->Extend()`, `old_state` may no longer be
+ // used.
+ //
// NOTE(mrry): This method respects the placement of stateful nodes in
// in *this, but currently does not transfer any other placement
// or cost model information to the new graph.
@@ -113,12 +133,6 @@ class SimpleGraphExecutionState {
Status BuildGraph(const BuildGraphOptions& options,
std::unique_ptr<SimpleClientGraph>* out);
- // Returns OK if the named node is found in the placed full graph owned
- // by this execution_state, and sets *out to the NodeDef for that node.
- // It may not exist if name is of a Node added for a particular subgraph
- // execution, e.g. a send, recv or feed node.
- Status GlobalNodeDefByName(const string& name, NodeDef* out);
-
// Sums execution statistics in "ss" into the CostModel.
void UpdateCostsFromStats(const StepStats& ss);
@@ -138,10 +152,20 @@ class SimpleGraphExecutionState {
// The graph returned by BuildGraph may contain only the pruned
// graph, whereas some clients may want access to the full graph.
const Graph* full_graph() {
- mutex_lock l(mu_);
return graph_;
}
+ // Returns the node with the given name, or null if it does not exist.
+ const Node* get_node_by_name(const string& name) const {
+ NodeNameToCostIdMap::const_iterator iter =
+ node_name_to_cost_id_map_.find(name);
+ if (iter != node_name_to_cost_id_map_.end()) {
+ return graph_->FindNodeId(iter->second);
+ } else {
+ return nullptr;
+ }
+ }
+
// Returns a reference to the current graph_def. Use must
// not extend beyond lifetime of SimpleGrahExecutionState object.
const GraphDef& original_graph_def() { return original_graph_def_; }
@@ -153,31 +177,26 @@ class SimpleGraphExecutionState {
return stateful_placements_;
}
- // Restores the map of stateful placements as a map of
- // node name to placement string.
- void SetStatefulPlacements(const std::unordered_map<string, string>& sp) {
- mutex_lock l(mu_);
- stateful_placements_ = sp;
- }
-
private:
- mutable mutex mu_;
+ SimpleGraphExecutionState(GraphDef* graph_def,
+ const SimpleGraphExecutionStateOptions& options);
- Status InitBaseGraph(const BuildGraphOptions& options)
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ Status InitBaseGraph(const BuildGraphOptions& options);
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
// is true, such as "params" and "queue" nodes. Once placed these
// nodes can not be moved to a different device. Maps node names to
// device names.
- std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_);
- void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
- void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ std::unordered_map<string, string> stateful_placements_; // Immutable after
+ // ctor.
+ void SaveStatefulNodes(Graph* graph);
+ void RestoreStatefulNodes(Graph* graph);
GraphDef original_graph_def_; // Immutable after ctor.
const DeviceSet* device_set_; // Not owned
const SessionOptions* session_options_; // Not owned
+ mutable mutex mu_;
CostModel costs_ GUARDED_BY(mu_);
// Map from name to Node for the full graph in placed_.
@@ -188,7 +207,7 @@ class SimpleGraphExecutionState {
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
// The dataflow graph owned by this object.
- Graph* graph_ GUARDED_BY(mu_);
+ Graph* graph_;
TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState);
};
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index 8a6e1322ac..cf9deaabd8 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -256,6 +256,8 @@ void Master::CreateSession(const CreateSessionRequest* req,
const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
Status create_status = session->Create(gdef);
if (!create_status.ok()) {
+ // Takes ownership of `session` (and destroys it).
+ session->Close();
done(create_status);
return;
}
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 4a7b7c072a..29d2959426 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -180,7 +180,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
SimpleGraphExecutionState* execution_state,
ProfileHandler* ph, RunStepResponse* resp);
void ProcessDeviceStats(ProfileHandler* ph,
- SimpleGraphExecutionState* execution_state,
+ const SimpleGraphExecutionState* execution_state,
const DeviceStepStats& ds, bool is_rpc);
string DetailText(const NodeDef& def, const NodeExecStats& ns) {
@@ -724,7 +724,7 @@ void MasterSession::ReffedClientGraph::ProcessStats(
}
void MasterSession::ReffedClientGraph::ProcessDeviceStats(
- ProfileHandler* ph, SimpleGraphExecutionState* execution_state,
+ ProfileHandler* ph, const SimpleGraphExecutionState* execution_state,
const DeviceStepStats& ds, bool is_rpc) {
const string& dev_name = ds.device();
VLOG(1) << "Device " << dev_name << " reports stats for "
@@ -736,9 +736,8 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
ns.timeline_label());
} else {
- NodeDef ndef;
- Status s = execution_state->GlobalNodeDefByName(ns.node_name(), &ndef);
- const bool found_node_in_graph = s.ok();
+ const Node* node = execution_state->get_node_by_name(ns.node_name());
+ const bool found_node_in_graph = node != nullptr;
if (!found_node_in_graph && ns.timeline_label().empty()) {
// The counter incrementing is not thread-safe. But we don't really
// care.
@@ -752,12 +751,13 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
}
continue;
}
- string optype = found_node_in_graph ? ndef.op() : ns.node_name();
+ string optype =
+ found_node_in_graph ? node->type_string() : ns.node_name();
string details;
if (!ns.timeline_label().empty()) {
details = ns.timeline_label();
} else if (found_node_in_graph) {
- details = DetailText(ndef, ns);
+ details = DetailText(node->def(), ns);
} else {
// Leave details string empty
}
@@ -892,10 +892,8 @@ Status MasterSession::Create(GraphDef* graph_def) {
SimpleGraphExecutionStateOptions options;
options.device_set = &devices_;
options.session_options = &session_opts_;
- execution_state_.reset(
- new SimpleGraphExecutionState(graph_def->library(), options));
- TF_RETURN_IF_ERROR(execution_state_->Create(graph_def));
-
+ TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
+ graph_def, options, &execution_state_));
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index 7f99fd2981..e17614c819 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -50,7 +50,7 @@ class MasterSession {
// Initialize the MasterSession for "def". Must be called before Extend(),
// Run(), or Close().
//
- // The callee may clear "def".
+ // After this method returns, `def` will no longer be valid.
Status Create(GraphDef* def);
// Returns the session handle.
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index 86a09551fd..e5d28f8450 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -179,10 +179,8 @@ TEST(GrpcSessionTest, NonLocalWithFilters) {
{
GraphDef graph_copy(graph);
graph::SetDefaultDevice(cluster->devices()[1].name(), &graph_copy);
- TF_CHECK_OK(session->Create(graph_copy));
- auto status = session->Run({}, {}, {node_names[2]}, nullptr);
+ auto status = session->Create(graph_copy);
EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
- TF_CHECK_OK(session->Close());
}
}
@@ -415,25 +413,23 @@ TEST(GrpcSessionTest, MultiDevices_String) {
SetDevice(&def, a->name(), a_dev.name());
SetDevice(&def, b->name(), b_dev.name());
- TF_CHECK_OK(session->Create(def));
- {
+ Status s = session->Create(def);
+ if (s.ok()) {
std::vector<Tensor> outputs;
- Status s = session->Run({}, {b->name()}, {}, &outputs);
- if (s.ok()) {
- ASSERT_EQ(1, outputs.size());
- ASSERT_EQ(outputs[0].dtype(), DT_STRING);
- ASSERT_EQ(outputs[0].NumElements(), 4);
- for (int i = 0; i < outputs[0].NumElements(); ++i) {
- EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world");
- }
- } else {
- LOG(ERROR) << "Error: " << s;
- ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
- (b_dev.device_type() == DEVICE_GPU));
- ASSERT_FALSE(s.ok());
+ TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs));
+ ASSERT_EQ(1, outputs.size());
+ ASSERT_EQ(outputs[0].dtype(), DT_STRING);
+ ASSERT_EQ(outputs[0].NumElements(), 4);
+ for (int i = 0; i < outputs[0].NumElements(); ++i) {
+ EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world");
}
+ TF_CHECK_OK(session->Close());
+ } else {
+ LOG(ERROR) << "Error: " << s;
+ ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
+ (b_dev.device_type() == DEVICE_GPU));
+ ASSERT_FALSE(s.ok());
}
- TF_CHECK_OK(session->Close());
}
}
}