aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc197
-rw-r--r--tensorflow/core/common_runtime/direct_session.h17
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc50
-rw-r--r--tensorflow/core/common_runtime/simple_graph_execution_state.cc93
-rw-r--r--tensorflow/core/common_runtime/simple_graph_execution_state.h36
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc2
-rw-r--r--tensorflow/core/protobuf/config.proto9
-rw-r--r--tensorflow/python/client/session.py6
-rw-r--r--tensorflow/python/client/session_test.py42
10 files changed, 323 insertions, 130 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 044d732f30..bc97c96416 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1365,6 +1365,7 @@ tf_cc_test(
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops",
"//tensorflow/core/kernels:fifo_queue_op",
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0c3ddaf2f7..5f81e8c38c 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -41,7 +41,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -178,6 +177,23 @@ DirectSession::~DirectSession() {
if (options_.config.use_per_session_threads()) {
delete thread_pool_;
}
+
+ execution_state_.reset(nullptr);
+ flib_def_.reset(nullptr);
+}
+
+void DirectSession::MaybeInitializeExecutionState(const GraphDef& graph) {
+ // If already initialied, do nothing.
+ if (flib_def_ && execution_state_) {
+ return;
+ }
+ // Set up the per-session execution state.
+ flib_def_.reset(new FunctionLibraryDefinition(graph.library()));
+ SimpleGraphExecutionStateOptions options;
+ options.device_set = &device_set_;
+ options.session_options = &options_;
+ execution_state_.reset(
+ new SimpleGraphExecutionState(flib_def_.get(), options));
}
Status DirectSession::Create(const GraphDef& graph) {
@@ -195,57 +211,14 @@ Status DirectSession::Extend(const GraphDef& graph) {
}
Status DirectSession::ExtendLocked(const GraphDef& graph) {
- // Merge versions
- if (graph_def_.has_versions()) {
- if (graph_def_.versions().producer() != graph.versions().producer()) {
- return errors::InvalidArgument(
- "Can't extend GraphDef at version ", graph_def_.versions().producer(),
- " with graph at version ", graph.versions().producer());
- }
- VersionDef* versions = graph_def_.mutable_versions();
- versions->set_min_consumer(
- std::max(versions->min_consumer(), graph.versions().min_consumer()));
- if (graph.versions().bad_consumers_size()) {
- // Add new bad_consumers that aren't already marked bad.
- //
- // Note: This implementation is quadratic time if there are many calls to
- // ExtendLocked with many bad consumers. Since this is unlikely, and
- // fixing it would require data structures outside of this routine,
- // quadratic time it is.
- auto* bad_consumers = versions->mutable_bad_consumers();
- const std::unordered_set<int> existing(bad_consumers->begin(),
- bad_consumers->end());
- for (const int v : graph.versions().bad_consumers()) {
- if (existing.find(v) == existing.end()) {
- bad_consumers->Add(v);
- }
- }
- }
- } else {
- graph_def_.mutable_versions()->CopyFrom(graph.versions());
- }
-
- const int node_size_before_merge = graph_def_.node_size();
- graph_def_.MergeFrom(graph);
+ MaybeInitializeExecutionState(graph);
+ std::unique_ptr<SimpleGraphExecutionState> old_state;
+ SimpleGraphExecutionState* new_state = nullptr;
+ TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &new_state));
- FunctionLibraryDefinition fdefs(graph_def_.library());
- // Add default attributes to all new nodes in the graph.
- Status s =
- AddDefaultAttrsToGraphDef(&graph_def_, fdefs, node_size_before_merge);
- if (!s.ok()) {
- // One of the nodes was invalid, return the state of graph_def_
- // to what it was before this function.
- const int nodes_added = graph_def_.node_size() - node_size_before_merge;
- graph_def_.mutable_node()->DeleteSubrange(node_size_before_merge,
- nodes_added);
- return s;
- }
-
- if (graph_def_.versions().producer() >= 5) {
- // Validate the graph: we assume that merging two valid graphs
- // should maintain graph validity.
- TF_RETURN_IF_ERROR(graph::ValidateGraphDef(graph_def_, fdefs));
- }
+ // Swap out the old state.
+ old_state = std::move(execution_state_);
+ execution_state_.reset(new_state);
graph_created_ = true; // In case this is first call
return Status::OK();
@@ -680,16 +653,19 @@ Status DirectSession::GetOrCreateExecutors(
}
}
+ BuildGraphOptions options;
+ options.feed_endpoints = inputs_sorted;
+ options.fetch_endpoints = outputs_sorted;
+ options.target_nodes = tn_sorted;
+
// The executor_lock_ is intentionally released while executor is
// being created.
- FunctionLibraryDefinition* fdefs;
std::unordered_map<string, Graph*> graphs;
- Status s = CreateGraphs(inputs, outputs, target_nodes, &fdefs, &graphs,
- run_state_args);
+ Status s = CreateGraphs(options, &graphs, run_state_args);
TF_RETURN_IF_ERROR(s);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
- ek->func_defs = fdefs;
+ ek->func_defs = flib_def_.get();
if (run_state_args->is_partial_run) {
ek->graph = run_state_args->graph;
ek->name_to_node = new NameNodeMap;
@@ -724,9 +700,9 @@ Status DirectSession::GetOrCreateExecutors(
ek->items.resize(ek->items.size() + 1);
auto* item = &(ek->items.back());
- item->flib =
- NewFunctionLibraryRuntime(device_mgr_.get(), device, runner,
- graph_def_version, fdefs, optimizer_opts);
+ item->flib = NewFunctionLibraryRuntime(device_mgr_.get(), device, runner,
+ graph_def_version, flib_def_.get(),
+ optimizer_opts);
LocalExecutorParams params;
params.device = device;
@@ -802,67 +778,67 @@ Status DirectSession::GetOrCreateExecutors(
return Status::OK();
}
-void DirectSession::SaveStatefulNodes(Graph* graph) {
- for (Node* n : graph->nodes()) {
- if (n->op_def().is_stateful()) {
- VLOG(2) << "Saving " << n->DebugString();
- stateful_placements_[n->name()] = n->assigned_device_name();
+Status DirectSession::CreateGraphs(const BuildGraphOptions& options,
+ std::unordered_map<string, Graph*>* outputs,
+ RunStateArgs* run_state_args) {
+ std::unique_ptr<SimpleClientGraph> client_graph;
+ SimpleClientGraph* cgraph = nullptr;
+
+ std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder;
+ SimpleGraphExecutionState* execution_state = nullptr;
+ if (options_.config.graph_options().place_pruned_graph()) {
+ // Because we are placing pruned graphs, we need to create a
+ // new SimpleGraphExecutorState for every new unseen graph,
+ // and then place it.
+ SimpleGraphExecutionStateOptions prune_options;
+ prune_options.device_set = &device_set_;
+ prune_options.session_options = &options_;
+ temp_exec_state_holder.reset(
+ new SimpleGraphExecutionState(flib_def_.get(), prune_options));
+ {
+ mutex_lock l(mu_);
+ temp_exec_state_holder->SetStatefulPlacements(stateful_placements_);
}
+
+ TF_RETURN_IF_ERROR(temp_exec_state_holder->Extend(
+ execution_state_->original_graph_def(), &execution_state));
+ temp_exec_state_holder.reset(execution_state);
+ } else {
+ execution_state = execution_state_.get();
}
-}
-void DirectSession::RestoreStatefulNodes(Graph* graph) {
- for (Node* n : graph->nodes()) {
- if (n->op_def().is_stateful()) {
- auto iter = stateful_placements_.find(n->name());
- if (iter != stateful_placements_.end()) {
- n->set_assigned_device_name(iter->second);
- VLOG(2) << "Restored " << n->DebugString();
+ TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &cgraph));
+ {
+ auto current_stateful_placements = execution_state->GetStatefulPlacements();
+ mutex_lock l(mu_);
+ // Update our current state based on the execution_state's
+ // placements. If there are any mismatches for a node,
+ // we should fail, as this should never happen.
+ for (auto placement_pair : current_stateful_placements) {
+ const string& node_name = placement_pair.first;
+ const string& placement = placement_pair.second;
+ auto iter = stateful_placements_.find(node_name);
+ if (iter == stateful_placements_.end()) {
+ stateful_placements_.insert(std::make_pair(node_name, placement));
+ } else if (iter->second != placement) {
+ return errors::Internal(
+ "Stateful placement mismatch. "
+ "Current assignment of ",
+ node_name, " to ", iter->second, " does not match ", placement);
}
}
- }
-}
-Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
- gtl::ArraySlice<string> fetches,
- gtl::ArraySlice<string> target_nodes,
- FunctionLibraryDefinition** func_defs,
- std::unordered_map<string, Graph*>* outputs,
- RunStateArgs* run_state_args) {
- std::unique_ptr<FunctionLibraryDefinition> fdefs;
- std::unique_ptr<Graph> graph;
- {
- mutex_lock l(graph_def_lock_);
- fdefs.reset(new FunctionLibraryDefinition(graph_def_.library()));
- graph.reset(new Graph(fdefs.get()));
- GraphConstructorOptions opts;
- TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def_, graph.get()));
+ stateful_placements_ = execution_state->GetStatefulPlacements();
}
+ client_graph.reset(cgraph);
// Remember the graph in run state if this is a partial run.
if (run_state_args->is_partial_run) {
- run_state_args->graph = new Graph(fdefs.get());
- CopyGraph(*graph.get(), run_state_args->graph);
- }
-
- TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
- graph.get(), feeds, fetches, target_nodes,
- device_set_.client_device()->attributes()));
-
- // Run the simple placer after rewriting the graph.
- SimplePlacer placer(graph.get(), &device_set_, &options_);
-
- {
- mutex_lock l(mu_);
- // Restore stateful nodes.
- RestoreStatefulNodes(graph.get());
- TF_RETURN_IF_ERROR(placer.Run());
- // Save stateful nodes.
- SaveStatefulNodes(graph.get());
+ run_state_args->graph = new Graph(flib_def_.get());
+ CopyGraph(*execution_state->full_graph(), run_state_args->graph);
}
// Partition the graph across devices.
- std::unordered_map<string, GraphDef> partitions;
PartitionOptions popts;
popts.node_to_loc = [](const Node* node) {
return node->assigned_device_name();
@@ -877,7 +853,9 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
return 1;
};
popts.control_flow_added = false;
- TF_RETURN_IF_ERROR(Partition(popts, graph.get(), &partitions));
+
+ std::unordered_map<string, GraphDef> partitions;
+ TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
std::vector<string> device_names;
for (auto device : devices_) {
@@ -917,12 +895,12 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
// may be possible use cases where a device may want to modify
// function definitions - in which case the library would need to be
// replicated per device.
- s = d->MaybeRewriteGraph(graph_def_.library(), graph_def);
+ s = d->MaybeRewriteGraph(flib_def_->ToProto(), graph_def);
if (!s.ok()) {
break;
}
}
- Graph* device_graph = new Graph(fdefs.get());
+ Graph* device_graph = new Graph(flib_def_.get());
GraphConstructorOptions device_opts;
// There are internal operations (e.g., send/recv) that we now
// allow.
@@ -940,7 +918,6 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
gtl::STLDeleteValues(outputs);
return s;
}
- *func_defs = fdefs.release();
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index d64460bd1f..6a531c4ec8 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/session_state.h"
@@ -119,7 +120,6 @@ class DirectSession : public Session {
delete item.executor;
delete item.flib;
}
- delete func_defs;
delete graph;
delete name_to_node;
}
@@ -159,6 +159,10 @@ class DirectSession : public Session {
Graph* graph = nullptr;
};
+ // Initializes the base execution state given the 'graph',
+ // if not already initialized.
+ void MaybeInitializeExecutionState(const GraphDef& graph);
+
// Retrieves an already existing set of executors to run 'inputs' and
// 'outputs', or creates and caches them for future use.
::tensorflow::Status GetOrCreateExecutors(
@@ -168,10 +172,7 @@ class DirectSession : public Session {
// Creates several graphs given the existing graph_def_ and the
// input feeds and fetches, given 'devices'.
- ::tensorflow::Status CreateGraphs(gtl::ArraySlice<string> feeds,
- gtl::ArraySlice<string> fetches,
- gtl::ArraySlice<string> target_nodes,
- FunctionLibraryDefinition** func_defs,
+ ::tensorflow::Status CreateGraphs(const BuildGraphOptions& options,
std::unordered_map<string, Graph*>* outputs,
RunStateArgs* run_state_args);
@@ -239,14 +240,16 @@ class DirectSession : public Session {
// Saves and restores device placements for stateful nodes.
mutex mu_;
- void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
- void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
// is true, such as "params" and "queue" nodes. Once placed these
// nodes can not be moved to a different device. Maps node names to
// device names.
std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_);
+ // Execution_state; used when placing the entire graph.
+ std::unique_ptr<SimpleGraphExecutionState> execution_state_;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+
// For generating unique names.
int64 name_counter_ GUARDED_BY(mu_) = 0;
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 75a1235f0b..7347f255d8 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -240,7 +240,9 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
test::graph::ToGraphDef(&graph, &def);
- std::unique_ptr<Session> session(CreateSession());
+ SessionOptions options;
+ (*options.config.mutable_device_count())["CPU"] = 2;
+ std::unique_ptr<Session> session(NewSession(options));
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
std::vector<std::pair<string, Tensor>> inputs;
@@ -254,7 +256,7 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
def.Clear();
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
test::graph::ToGraphDef(&graph, &def);
- session.reset(CreateSession());
+ session.reset(NewSession(options));
TF_ASSERT_OK(session->Create(def));
TF_ASSERT_OK(session->Run(inputs, output_names, {}, &outputs));
}
@@ -431,6 +433,50 @@ TEST(DirectSessionTest, DarthKernel) {
delete sess;
}
+// Have the Darth op in the graph placed on GPU, but don't run it.
+TEST(DirectSessionTest, PlacePrunedGraph) {
+ {
+ Graph g(OpRegistry::Global());
+ Tensor vx(DT_FLOAT, TensorShape({}));
+ vx.scalar<float>()() = 1.0;
+ Node* x = test::graph::Constant(&g, vx);
+ Node* y = test::graph::Unary(&g, "Darth", x);
+ y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
+ GraphDef def;
+ test::graph::ToGraphDef(&g, &def);
+
+ // By default, we place the entire graph, so we should fail the
+ // call to Run, even if we don't run the bad op.
+ SessionOptions options;
+ std::unique_ptr<Session> sess(NewSession(options));
+ TF_ASSERT_OK(sess->Create(def));
+ std::vector<Tensor> outputs;
+ auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs);
+ EXPECT_TRUE(errors::IsInvalidArgument(s));
+ }
+
+ {
+ Graph g(OpRegistry::Global());
+ Tensor vx(DT_FLOAT, TensorShape({}));
+ vx.scalar<float>()() = 1.0;
+ Node* x = test::graph::Constant(&g, vx);
+ Node* y = test::graph::Unary(&g, "Darth", x);
+ y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0");
+ GraphDef def;
+ test::graph::ToGraphDef(&g, &def);
+
+ SessionOptions options;
+ // Set the option to place pruned graphs, we should expect this
+ // to run.
+ options.config.mutable_graph_options()->set_place_pruned_graph(true);
+ std::unique_ptr<Session> sess(NewSession(options));
+ TF_ASSERT_OK(sess->Create(def));
+ std::vector<Tensor> outputs;
+ auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs);
+ TF_EXPECT_OK(s);
+ }
+}
+
TEST(DirectSessionTest, PartialRunTest) {
GraphDef def;
Graph g(OpRegistry::Global());
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
index 8fae2634de..6ba32814ac 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
@@ -87,13 +88,52 @@ Status SimpleGraphExecutionState::Extend(
}
}
+ // 3. Merge the versions field.
int old_node_size = gdef.node_size();
gdef.mutable_node()->MergeFrom(extension_def.node());
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *ops_, old_node_size));
+ // Merge versions
+ if (gdef.has_versions()) {
+ if (gdef.versions().producer() != extension_def.versions().producer()) {
+ return errors::InvalidArgument(
+ "Can't extend GraphDef at version ", gdef.versions().producer(),
+ " with graph at version ", extension_def.versions().producer());
+ }
+ VersionDef* versions = gdef.mutable_versions();
+ versions->set_min_consumer(std::max(
+ versions->min_consumer(), extension_def.versions().min_consumer()));
+ if (extension_def.versions().bad_consumers_size()) {
+ // Add new bad_consumers that aren't already marked bad.
+ //
+ // Note: This implementation is quadratic time if there are many calls to
+ // ExtendLocked with many bad consumers. Since this is unlikely, and
+ // fixing it would require data structures outside of this routine,
+ // quadratic time it is.
+ auto* bad_consumers = versions->mutable_bad_consumers();
+ const std::unordered_set<int> existing(bad_consumers->begin(),
+ bad_consumers->end());
+ for (const int v : extension_def.versions().bad_consumers()) {
+ if (existing.find(v) == existing.end()) {
+ bad_consumers->Add(v);
+ }
+ }
+ }
- // 3. Add the extension.
+ } else {
+ gdef.mutable_versions()->CopyFrom(extension_def.versions());
+ }
+
+ // 5. Validate that the final graphdef is valid.
+ if (gdef.versions().producer() >= 5) {
+ // Validate the graph: we assume that merging two valid graphs
+ // should maintain graph validity.
+ TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *ops_));
+ }
+
+ // 6. Add the extension.
SimpleGraphExecutionStateOptions combined_options;
combined_options.device_set = device_set_;
+ combined_options.session_options = session_options_;
SimpleGraphExecutionState* new_execution_state =
new SimpleGraphExecutionState(ops_, combined_options);
@@ -102,6 +142,7 @@ Status SimpleGraphExecutionState::Extend(
delete new_execution_state;
return new_execution_state_status;
}
+ new_execution_state->SetStatefulPlacements(GetStatefulPlacements());
*out = new_execution_state;
// TODO(mrry): This is likely to be used for non-throughput-sensitive
@@ -110,14 +151,47 @@ Status SimpleGraphExecutionState::Extend(
return Status::OK();
}
-Status SimpleGraphExecutionState::InitBaseGraph() {
+void SimpleGraphExecutionState::SaveStatefulNodes(Graph* graph) {
+ for (Node* n : graph->nodes()) {
+ if (n->op_def().is_stateful()) {
+ VLOG(2) << "Saving " << n->DebugString();
+ stateful_placements_[n->name()] = n->assigned_device_name();
+ }
+ }
+}
+
+void SimpleGraphExecutionState::RestoreStatefulNodes(Graph* graph) {
+ for (Node* n : graph->nodes()) {
+ if (n->op_def().is_stateful()) {
+ auto iter = stateful_placements_.find(n->name());
+ if (iter != stateful_placements_.end()) {
+ n->set_assigned_device_name(iter->second);
+ VLOG(2) << "Restored " << n->DebugString();
+ }
+ }
+ }
+}
+
+Status SimpleGraphExecutionState::InitBaseGraph(
+ const BuildGraphOptions& options) {
std::unique_ptr<Graph> new_graph(new Graph(ops_));
GraphConstructorOptions opts;
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(opts, original_graph_def_, new_graph.get()));
+ if (session_options_ &&
+ session_options_->config.graph_options().place_pruned_graph()) {
+ // Rewrite the graph before placement.
+ TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
+ new_graph.get(), options.feed_endpoints, options.fetch_endpoints,
+ options.target_nodes, device_set_->client_device()->attributes()));
+ }
+
+ // Save stateful placements before placing.
+ RestoreStatefulNodes(new_graph.get());
SimplePlacer placer(new_graph.get(), device_set_, session_options_);
// TODO(mrry): Consider making the SimplePlacer cancelable.
TF_RETURN_IF_ERROR(placer.Run());
+ SaveStatefulNodes(new_graph.get());
graph_ = new_graph.release();
return Status::OK();
}
@@ -128,17 +202,20 @@ Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options,
mutex_lock l(mu_);
// Lazily initialize the base graph.
if (!graph_) {
- TF_RETURN_IF_ERROR(InitBaseGraph());
+ TF_RETURN_IF_ERROR(InitBaseGraph(options));
}
std::unique_ptr<SimpleClientGraph> cgraph(new SimpleClientGraph(ops_));
CopyGraph(*graph_, &cgraph->graph);
- // Extract the subset of the graph that needs to be run, adding feed/fetch
- // ops as needed.
- TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
- &cgraph->graph, options.feed_endpoints, options.fetch_endpoints,
- options.target_nodes, device_set_->client_device()->attributes()));
+ if (session_options_ == nullptr ||
+ !session_options_->config.graph_options().place_pruned_graph()) {
+ // Extract the subset of the graph that needs to be run, adding feed/fetch
+ // ops as needed.
+ TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
+ &cgraph->graph, options.feed_endpoints, options.fetch_endpoints,
+ options.target_nodes, device_set_->client_device()->attributes()));
+ }
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h
index bf3a85b5a2..02357c8037 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.h
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h
@@ -114,10 +114,44 @@ class SimpleGraphExecutionState {
// execution, e.g. a send, recv or feed node.
Status GlobalNodeDefByName(const string& name, NodeDef* out);
+ // The graph returned by BuildGraph may contain only the pruned
+ // graph, whereas some clients may want access to the full graph.
+ const Graph* full_graph() {
+ mutex_lock l(mu_);
+ return graph_;
+ }
+
+ // Returns a reference to the current graph_def. Use must
+ // not extend beyond lifetime of SimpleGrahExecutionState object.
+ const GraphDef& original_graph_def() { return original_graph_def_; }
+
+ // Returns the map of stateful placements as a map of
+ // node name to placement string.
+ std::unordered_map<string, string> GetStatefulPlacements() const {
+ mutex_lock l(mu_);
+ return stateful_placements_;
+ }
+
+ // Restores the map of stateful placements as a map of
+ // node name to placement string.
+ void SetStatefulPlacements(const std::unordered_map<string, string>& sp) {
+ mutex_lock l(mu_);
+ stateful_placements_ = sp;
+ }
+
private:
mutable mutex mu_;
- Status InitBaseGraph() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ Status InitBaseGraph(const BuildGraphOptions& options)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Map of placed stateful nodes, i.e. nodes for which is_stateful()
+ // is true, such as "params" and "queue" nodes. Once placed these
+ // nodes can not be moved to a different device. Maps node names to
+ // device names.
+ std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_);
+ void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
const OpRegistryInterface* const ops_; // Not owned
GraphDef original_graph_def_; // Immutable after ctor.
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
index 19a44c6a98..c14c37a4c3 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -469,7 +469,6 @@ class ColocationGraph {
// If the NodeDef contains a device, then we interpret it as a
// (partial) device specification.
- string colocated_node_name;
if (!node.def().device().empty()) {
// The user has specified a device in the NodeDef, try to find a
// valid device matching their specification in the set of
@@ -659,7 +658,6 @@ Status SimplePlacer::Run() {
// and we can experiment with other algorithms when given a choice of
// devices.
node->set_assigned_device_name(devices[0]->name());
-
// Log placement if log_device_placement is set.
if (options_ && options_->config.log_device_placement()) {
printf("%s: %s\n", node->name().c_str(),
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 3bec167206..fe926a184d 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -80,6 +80,15 @@ message GraphOptions {
// Annotate each Node with Op output shape data, to the extent it can
// be statically inferred.
bool infer_shapes = 5;
+
+ // Only place the subgraphs that are run, rather than the entire graph.
+ //
+ // This is useful for interactive graph building, where one might
+ // produce graphs that cannot be placed during the debugging
+ // process. In particular, it allows the client to continue work in
+ // a session after adding a node to a graph whose placement
+ // constraints are unsatisfiable.
+ bool place_pruned_graph = 6;
};
// Session configuration parameters.
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index b0c82b8874..fbc48cdbba 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -24,6 +24,7 @@ import threading
import numpy as np
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -897,6 +898,11 @@ class InteractiveSession(BaseSession):
graph: (Optional.) The `Graph` to be launched (described above).
config: (Optional) `ConfigProto` proto used to configure the session.
"""
+ if not config:
+ config = config_pb2.ConfigProto()
+ # Interactive sessions always place pruned graphs.
+ config.graph_options.place_pruned_graph = True
+
super(InteractiveSession, self).__init__(target, graph, config)
self._default_session = self.as_default()
self._default_session.__enter__()
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index d60883464d..09698d2c78 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -707,6 +707,48 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[24.0]], e.eval())
sess.close()
+ def testInteractivePlacePrunedGraph(self):
+ sess = session.InteractiveSession()
+
+ # Build a graph that has a bad op in it (no kernel).
+ #
+ # This test currently does not link in any GPU kernels,
+ # which is why placing this is invalid. If at some point
+ # GPU kernels are added to this test, some other different
+ # op / device combo should be chosen.
+ with ops.device('/gpu:0'):
+ a = constant_op.constant(1.0, shape=[1, 2])
+
+ b = constant_op.constant(1.0, shape=[1, 2])
+
+ # Only run the valid op, this should work.
+ b.eval()
+
+ with self.assertRaises(errors.InvalidArgumentError):
+ a.eval()
+ sess.close()
+
+ def testDefaultSessionPlacePrunedGraph(self):
+ sess = session.Session()
+
+ # Build a graph that has a bad op in it (no kernel).
+ #
+ # This test currently does not link in any GPU kernels,
+ # which is why placing this is invalid. If at some point
+ # GPU kernels are added to this test, some other different
+ # op / device combo should be chosen.
+ with ops.device('/gpu:0'):
+ _ = constant_op.constant(1.0, shape=[1, 2])
+
+ b = constant_op.constant(1.0, shape=[1, 2])
+
+ with self.assertRaises(errors.InvalidArgumentError):
+ # Even though we don't run the bad op, we place the entire
+ # graph, which should fail with a non-interactive session.
+ sess.run(b)
+
+ sess.close()
+
def testSharedGraph(self):
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
a = constant_op.constant(1.0, shape=[1, 2])