aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-09-07 22:11:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-07 22:15:42 -0700
commitb5c1d0f8977e0f05c9aeeb9e5105500bf83972bb (patch)
treee98407a18e1f7e8f3719ba58557be6be3c4deb59
parentd27db78cd0168f10b308f7508c11dfaa3c6707e9 (diff)
SimpleGraphExecutionState -> GraphExecutionState
SimplePlacer -> Placer And clean up a couple unneeded headers. PiperOrigin-RevId: 167955883
-rw-r--r--tensorflow/core/BUILD10
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc23
-rw-r--r--tensorflow/core/common_runtime/direct_session.h4
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc (renamed from tensorflow/core/common_runtime/simple_graph_execution_state.cc)67
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.h (renamed from tensorflow/core/common_runtime/simple_graph_execution_state.h)65
-rw-r--r--tensorflow/core/common_runtime/placer.cc (renamed from tensorflow/core/common_runtime/simple_placer.cc)25
-rw-r--r--tensorflow/core/common_runtime/placer.h (renamed from tensorflow/core/common_runtime/simple_placer.h)23
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc (renamed from tensorflow/core/common_runtime/simple_placer_test.cc)106
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc22
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h4
10 files changed, 171 insertions, 178 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 9319928307..4b80d2c543 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1773,12 +1773,14 @@ tf_cuda_library(
"common_runtime/device_set.cc",
"common_runtime/executor.cc",
"common_runtime/function.cc",
+ "common_runtime/graph_execution_state.cc",
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
"common_runtime/local_device.cc",
"common_runtime/memory_types.cc",
"common_runtime/optimization_registry.cc",
"common_runtime/parallel_concat_optimizer.cc",
+ "common_runtime/placer.cc",
"common_runtime/process_function_library_runtime.cc",
"common_runtime/process_util.cc",
"common_runtime/renamed_device.cc",
@@ -1788,8 +1790,6 @@ tf_cuda_library(
"common_runtime/session_factory.cc",
"common_runtime/session_options.cc",
"common_runtime/session_state.cc",
- "common_runtime/simple_graph_execution_state.cc",
- "common_runtime/simple_placer.cc",
"common_runtime/stats_publisher_interface.cc",
"common_runtime/step_stats_collector.cc",
"common_runtime/threadpool_device.cc",
@@ -1829,8 +1829,8 @@ tf_cuda_library(
"common_runtime/renamed_device.h",
"common_runtime/rendezvous_mgr.h",
"common_runtime/session_factory.h",
- "common_runtime/simple_graph_execution_state.h",
- "common_runtime/simple_placer.h",
+ "common_runtime/graph_execution_state.h",
+ "common_runtime/placer.h",
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
@@ -2321,9 +2321,9 @@ tf_cc_tests(
"common_runtime/device_set_test.cc",
"common_runtime/optimization_registry_test.cc",
"common_runtime/pending_counts_test.cc",
+ "common_runtime/placer_test.cc",
"common_runtime/resource_variable_read_optimizer_test.cc",
"common_runtime/session_test.cc",
- "common_runtime/simple_placer_test.cc",
"example/feature_util_test.cc",
"framework/allocator_test.cc",
"framework/attr_value_util_test.cc",
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index a6630f38a5..d28857bb9f 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/common_runtime/simple_placer.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb_text.h"
@@ -345,20 +344,20 @@ Status DirectSession::MaybeInitializeExecutionState(
// all subsequent extensions of the graph.
flib_def_.reset(
new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
- SimpleGraphExecutionStateOptions options;
+ GraphExecutionStateOptions options;
options.device_set = &device_set_;
options.session_options = &options_;
// TODO(mrry,suharshs): We explicitly copy `graph` so that
// `MakeForBaseGraph()` can take ownership of its
// contents. Previously this happened implicitly in calls to the
- // `SimpleGraphExecutionState`. Other sessions call
+ // `GraphExecutionState`. Other sessions call
// `MakeForBaseGraph` in such a way that we can destructively read
// the passed-in `GraphDef`. In principle we could do the same here,
// with a wider refactoring; we might revise the direct session so
// that it copies the graph fewer times.
GraphDef temp(graph);
- TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
- &temp, options, &execution_state_));
+ TF_RETURN_IF_ERROR(
+ GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
graph_created_ = true;
*out_already_initialized = false;
return Status::OK();
@@ -391,7 +390,7 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) {
MaybeInitializeExecutionState(graph, &already_initialized));
if (already_initialized) {
TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
- std::unique_ptr<SimpleGraphExecutionState> state;
+ std::unique_ptr<GraphExecutionState> state;
TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
execution_state_.swap(state);
}
@@ -1280,19 +1279,19 @@ Status DirectSession::CreateGraphs(
RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types) {
mutex_lock l(graph_def_lock_);
- std::unique_ptr<SimpleClientGraph> client_graph;
+ std::unique_ptr<ClientGraph> client_graph;
- std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder;
- SimpleGraphExecutionState* execution_state = nullptr;
+ std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
+ GraphExecutionState* execution_state = nullptr;
if (options_.config.graph_options().place_pruned_graph()) {
// Because we are placing pruned graphs, we need to create a
- // new SimpleGraphExecutionState for every new unseen graph,
+ // new GraphExecutionState for every new unseen graph,
// and then place it.
- SimpleGraphExecutionStateOptions prune_options;
+ GraphExecutionStateOptions prune_options;
prune_options.device_set = &device_set_;
prune_options.session_options = &options_;
prune_options.stateful_placements = stateful_placements_;
- TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph(
+ TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
execution_state_->original_graph_def().library(), prune_options,
execution_state_->original_graph_def(), subgraph_options,
&temp_exec_state_holder, &client_graph));
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 020831d6cc..7fbabf6d81 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -28,10 +28,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/graph_execution_state.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/session_factory.h"
-#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/session_state.h"
@@ -310,7 +310,7 @@ class DirectSession : public Session {
GUARDED_BY(graph_def_lock_);
// Execution_state; used when placing the entire graph.
- std::unique_ptr<SimpleGraphExecutionState> execution_state_
+ std::unique_ptr<GraphExecutionState> execution_state_
GUARDED_BY(graph_def_lock_);
// The function library, before any rewrites or optimizations have been
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index c66dc568f6..4bd40c7978 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
+#include "tensorflow/core/common_runtime/graph_execution_state.h"
#include <memory>
#include <string>
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/common_runtime/simple_placer.h"
+#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -49,8 +49,8 @@ limitations under the License.
namespace tensorflow {
-SimpleGraphExecutionState::SimpleGraphExecutionState(
- GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options)
+GraphExecutionState::GraphExecutionState(
+ GraphDef* graph_def, const GraphExecutionStateOptions& options)
: stateful_placements_(options.stateful_placements),
device_set_(options.device_set),
session_options_(options.session_options),
@@ -65,16 +65,16 @@ SimpleGraphExecutionState::SimpleGraphExecutionState(
// placement option.
}
-SimpleGraphExecutionState::~SimpleGraphExecutionState() {
+GraphExecutionState::~GraphExecutionState() {
node_name_to_cost_id_map_.clear();
delete graph_;
}
-/* static */ Status SimpleGraphExecutionState::MakeForBaseGraph(
- GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options,
- std::unique_ptr<SimpleGraphExecutionState>* out_state) {
- std::unique_ptr<SimpleGraphExecutionState> ret(
- new SimpleGraphExecutionState(graph_def, options));
+/* static */ Status GraphExecutionState::MakeForBaseGraph(
+ GraphDef* graph_def, const GraphExecutionStateOptions& options,
+ std::unique_ptr<GraphExecutionState>* out_state) {
+ std::unique_ptr<GraphExecutionState> ret(
+ new GraphExecutionState(graph_def, options));
TF_RETURN_IF_ERROR(
AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0));
@@ -88,12 +88,12 @@ SimpleGraphExecutionState::~SimpleGraphExecutionState() {
return Status::OK();
}
-/* static */ Status SimpleGraphExecutionState::MakeForPrunedGraph(
+/* static */ Status GraphExecutionState::MakeForPrunedGraph(
const FunctionDefLibrary& func_def_lib,
- const SimpleGraphExecutionStateOptions& options, const GraphDef& graph_def,
+ const GraphExecutionStateOptions& options, const GraphDef& graph_def,
const BuildGraphOptions& subgraph_options,
- std::unique_ptr<SimpleGraphExecutionState>* out_state,
- std::unique_ptr<SimpleClientGraph>* out_client_graph) {
+ std::unique_ptr<GraphExecutionState>* out_state,
+ std::unique_ptr<ClientGraph>* out_client_graph) {
DCHECK(options.session_options->config.graph_options().place_pruned_graph());
// NOTE(mrry): This makes a copy of `graph_def`, which is
// regrettable. We could make `GraphDef` objects sharable between
@@ -103,8 +103,8 @@ SimpleGraphExecutionState::~SimpleGraphExecutionState() {
// also that the previous version used `Extend()`, which is strictly
// more expensive than copying a `GraphDef`.)
GraphDef temp(graph_def);
- std::unique_ptr<SimpleGraphExecutionState> ret(
- new SimpleGraphExecutionState(&temp, options));
+ std::unique_ptr<GraphExecutionState> ret(
+ new GraphExecutionState(&temp, options));
TF_RETURN_IF_ERROR(
AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0));
TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options));
@@ -113,9 +113,9 @@ SimpleGraphExecutionState::~SimpleGraphExecutionState() {
return Status::OK();
}
-Status SimpleGraphExecutionState::Extend(
+Status GraphExecutionState::Extend(
const GraphDef& extension_def,
- std::unique_ptr<SimpleGraphExecutionState>* out) const {
+ std::unique_ptr<GraphExecutionState>* out) const {
GraphDef gdef;
// 1. Copy the function library.
@@ -186,15 +186,15 @@ Status SimpleGraphExecutionState::Extend(
}
// 6. Add the extension.
- SimpleGraphExecutionStateOptions combined_options;
+ GraphExecutionStateOptions combined_options;
combined_options.device_set = device_set_;
combined_options.session_options = session_options_;
combined_options.stateful_placements = stateful_placements_;
// NOTE(mrry): `gdef` is no longer valid after the constructor
// executes.
- std::unique_ptr<SimpleGraphExecutionState> new_execution_state(
- new SimpleGraphExecutionState(&gdef, combined_options));
+ std::unique_ptr<GraphExecutionState> new_execution_state(
+ new GraphExecutionState(&gdef, combined_options));
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
&new_execution_state->original_graph_def_, *flib_def_, 0));
@@ -212,7 +212,7 @@ Status SimpleGraphExecutionState::Extend(
return Status::OK();
}
-void SimpleGraphExecutionState::SaveStatefulNodes(Graph* graph) {
+void GraphExecutionState::SaveStatefulNodes(Graph* graph) {
for (Node* n : graph->nodes()) {
if (n->op_def().is_stateful()) {
VLOG(2) << "Saving " << n->DebugString();
@@ -221,7 +221,7 @@ void SimpleGraphExecutionState::SaveStatefulNodes(Graph* graph) {
}
}
-void SimpleGraphExecutionState::RestoreStatefulNodes(Graph* graph) {
+void GraphExecutionState::RestoreStatefulNodes(Graph* graph) {
for (Node* n : graph->nodes()) {
if (n->op_def().is_stateful()) {
auto iter = stateful_placements_.find(n->name());
@@ -233,8 +233,7 @@ void SimpleGraphExecutionState::RestoreStatefulNodes(Graph* graph) {
}
}
-Status SimpleGraphExecutionState::InitBaseGraph(
- const BuildGraphOptions& options) {
+Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
const GraphDef* graph_def = &original_graph_def_;
std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global()));
@@ -266,8 +265,8 @@ Status SimpleGraphExecutionState::InitBaseGraph(
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
- SimplePlacer placer(new_graph.get(), device_set_, session_options_);
- // TODO(mrry): Consider making the SimplePlacer cancelable.
+ Placer placer(new_graph.get(), device_set_, session_options_);
+ // TODO(mrry): Consider making the Placer cancelable.
TF_RETURN_IF_ERROR(placer.Run());
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
@@ -278,7 +277,7 @@ Status SimpleGraphExecutionState::InitBaseGraph(
return Status::OK();
}
-Status SimpleGraphExecutionState::OptimizeGraph(
+Status GraphExecutionState::OptimizeGraph(
const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph) {
#ifndef IS_MOBILE_PLATFORM
if (session_options_->config.graph_options().place_pruned_graph()) {
@@ -378,8 +377,8 @@ Status SimpleGraphExecutionState::OptimizeGraph(
#endif // IS_MOBILE_PLATFORM
}
-Status SimpleGraphExecutionState::BuildGraph(
- const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
+Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
+ std::unique_ptr<ClientGraph>* out) {
VLOG(1) << "BuildGraph";
if (!graph_) {
// It is only valid to call this method directly when the original graph
@@ -406,7 +405,7 @@ Status SimpleGraphExecutionState::BuildGraph(
options.target_nodes, device_set_->client_device()->attributes(),
options.use_function_convention, &rewrite_metadata));
} else {
- // This SimpleGraphExecutionState represents a graph that was
+ // This GraphExecutionState represents a graph that was
// pruned when this was constructed, so we copy the metadata from
// a member variable.
CHECK(rewrite_metadata_);
@@ -433,9 +432,9 @@ Status SimpleGraphExecutionState::BuildGraph(
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
// the largest node id.
- std::unique_ptr<SimpleClientGraph> dense_copy(
- new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types,
- rewrite_metadata.fetch_types));
+ std::unique_ptr<ClientGraph> dense_copy(
+ new ClientGraph(std::move(flib), rewrite_metadata.feed_types,
+ rewrite_metadata.fetch_types));
CopyGraph(*ng, &dense_copy->graph);
// TODO(vrv): We should check invariants of the graph here.
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h
index 53eef8a07d..db2686ce2c 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.h
+++ b/tensorflow/core/common_runtime/graph_execution_state.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
#include <functional>
#include <memory>
@@ -38,7 +38,7 @@ namespace subgraph {
struct RewriteGraphMetadata;
}
-struct SimpleGraphExecutionStateOptions {
+struct GraphExecutionStateOptions {
const DeviceSet* device_set = nullptr;
const SessionOptions* session_options = nullptr;
// A map from node name to device name, representing the unchangeable
@@ -46,12 +46,11 @@ struct SimpleGraphExecutionStateOptions {
std::unordered_map<string, string> stateful_placements;
};
-// A SimpleClientGraph is simply a sub-graph of the full graph as induced by
+// A ClientGraph is simply a sub-graph of the full graph as induced by
// BuildGraphOptions.
-struct SimpleClientGraph {
- explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
- DataTypeVector feed_types,
- DataTypeVector fetch_types)
+struct ClientGraph {
+ explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
+ DataTypeVector feed_types, DataTypeVector fetch_types)
: flib_def(std::move(flib)),
graph(flib_def.get()),
feed_types(std::move(feed_types)),
@@ -64,8 +63,8 @@ struct SimpleClientGraph {
DataTypeVector fetch_types;
};
-// SimpleGraphExecutionState is responsible for generating an
-// executable SimpleClientGraph from the original GraphDef that specifies
+// GraphExecutionState is responsible for generating an
+// executable ClientGraph from the original GraphDef that specifies
// the complete graph and from BuildGraphOptions which specifies
// input/output nodes.
//
@@ -73,7 +72,7 @@ struct SimpleClientGraph {
// meaning that each Node is assigned to a single Device in the
// available set.
//
-// When SimpleGraphExecutionState is first constructed it instantiates
+// When GraphExecutionState is first constructed it instantiates
// a full Graph from the provided GraphDef, and places it, using only
// the static device assignments from the GraphDef. Nodes without are
// currently placed in a very naive way. Since stateful Nodes cannot
@@ -81,18 +80,18 @@ struct SimpleClientGraph {
// Nodes get sensible initial device assignments in the graph
// definition.
//
-// Subsequently, SimpleGraphExecutionState generates a SimpleClientGraph on
+// Subsequently, GraphExecutionState generates a SimpleClientGraph on
// demand, which is a sub-graph of the latest placement of the full
-// Graph. MasterSession uses such a SimpleClientGraph to execute one or
+// Graph. MasterSession uses such a ClientGraph to execute one or
// more similar client requests.
//
-// SimpleGraphExecutionState is thread-safe.
+// GraphExecutionState is thread-safe.
-class SimpleGraphExecutionState {
+class GraphExecutionState {
public:
- virtual ~SimpleGraphExecutionState();
+ virtual ~GraphExecutionState();
- // Creates a new `SimpleGraphExecutionState` for the given
+ // Creates a new `GraphExecutionState` for the given
// `graph_def`, which represents the entire graph for a session.
//
// N.B. This method uses `GraphDef::Swap()` and leaves `graph_def`
@@ -100,21 +99,21 @@ class SimpleGraphExecutionState {
// after this call, make an explicit copy of the graph before
// calling this method.
static Status MakeForBaseGraph(
- GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options,
- std::unique_ptr<SimpleGraphExecutionState>* out_state);
+ GraphDef* graph_def, const GraphExecutionStateOptions& options,
+ std::unique_ptr<GraphExecutionState>* out_state);
- // Creates a new `SimpleGraphExecutionState` and `SimpleClientGraph`
+ // Creates a new `GraphExecutionState` and `SimpleClientGraph`
// for the subgraph of `original_graph_def` defined by
// `subgraph_options`.
static Status MakeForPrunedGraph(
const FunctionDefLibrary& func_def_lib,
- const SimpleGraphExecutionStateOptions& options,
+ const GraphExecutionStateOptions& options,
const GraphDef& original_graph_def,
const BuildGraphOptions& subgraph_options,
- std::unique_ptr<SimpleGraphExecutionState>* out_state,
- std::unique_ptr<SimpleClientGraph>* out_client_graph);
+ std::unique_ptr<GraphExecutionState>* out_state,
+ std::unique_ptr<ClientGraph>* out_client_graph);
- // Creates a new SimpleGraphExecutionState representing the
+ // Creates a new GraphExecutionState representing the
// concatenation of this graph, and the graph defined by
// "extension_def". The same name may not be used to define a node
// in both this graph and "extension_def".
@@ -129,14 +128,14 @@ class SimpleGraphExecutionState {
// in *this, but currently does not transfer any other placement
// or cost model information to the new graph.
Status Extend(const GraphDef& extension_def,
- std::unique_ptr<SimpleGraphExecutionState>* out) const;
+ std::unique_ptr<GraphExecutionState>* out) const;
- // Builds a SimpleClientGraph (a sub-graph of the full graph as induced by
+ // Builds a ClientGraph (a sub-graph of the full graph as induced by
// the Node set specified in "options"). If successful, returns OK
// and the caller takes the ownership of "*out". Otherwise, returns
// an error.
Status BuildGraph(const BuildGraphOptions& options,
- std::unique_ptr<SimpleClientGraph>* out);
+ std::unique_ptr<ClientGraph>* out);
// The graph returned by BuildGraph may contain only the pruned
// graph, whereas some clients may want access to the full graph.
@@ -156,7 +155,7 @@ class SimpleGraphExecutionState {
}
// Returns a reference to the current graph_def. Use must
- // not extend beyond lifetime of SimpleGrahExecutionState object.
+ // not extend beyond lifetime of GrahExecutionState object.
const GraphDef& original_graph_def() { return original_graph_def_; }
// Returns the map of stateful placements as a map of
@@ -166,8 +165,8 @@ class SimpleGraphExecutionState {
}
private:
- SimpleGraphExecutionState(GraphDef* graph_def,
- const SimpleGraphExecutionStateOptions& options);
+ GraphExecutionState(GraphDef* graph_def,
+ const GraphExecutionStateOptions& options);
Status InitBaseGraph(const BuildGraphOptions& options);
@@ -194,16 +193,16 @@ class SimpleGraphExecutionState {
// and may be updated by a graph optimization pass.
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
- // `rewrite_metadata_` is only set for SimpleGraphExecutionState
+ // `rewrite_metadata_` is only set for GraphExecutionState
// objects created by `MakeForPrunedGraph()`.
std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
// The dataflow graph owned by this object.
Graph* graph_;
- TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState);
+ TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState);
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/placer.cc
index 663e62a765..73fdf60fd5 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/simple_placer.h"
+#include "tensorflow/core/common_runtime/placer.h"
#include <memory>
#include <set>
@@ -636,20 +636,20 @@ bool IsGeneratorNode(const Node* node) {
} // namespace
-SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices,
- const SessionOptions* options)
+Placer::Placer(Graph* graph, const DeviceSet* devices,
+ const SessionOptions* options)
: graph_(graph),
devices_(devices),
options_(options),
log_device_placement_(options != nullptr &&
options->config.log_device_placement()) {}
-SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices)
- : SimplePlacer(graph, devices, nullptr) {}
+Placer::Placer(Graph* graph, const DeviceSet* devices)
+ : Placer(graph, devices, nullptr) {}
-SimplePlacer::~SimplePlacer() {}
+Placer::~Placer() {}
-Status SimplePlacer::Run() {
+Status Placer::Run() {
if (devices_->devices().empty()) {
return errors::FailedPrecondition("No devices are registered");
}
@@ -771,7 +771,7 @@ Status SimplePlacer::Run() {
// choose the same device.
//
// TODO(vrv): Factor this assignment out into a pluggable
- // algorithm, so that SimplePlacer is responsible for enforcing
+ // algorithm, so that Placer is responsible for enforcing
// preconditions and we can experiment with other algorithms when
// given a choice of devices. Once we have a better idea of the
// types of heuristics we want to use and the information needed
@@ -844,9 +844,8 @@ Status SimplePlacer::Run() {
return Status::OK();
}
-bool SimplePlacer::CanAssignToDevice(
- const string& candidate_device_name,
- const std::vector<Device*>& devices) const {
+bool Placer::CanAssignToDevice(const string& candidate_device_name,
+ const std::vector<Device*>& devices) const {
if (!candidate_device_name.empty()) {
// 'devices' lists the set of devices that the placer or the user has
// constrained the operation to. "candidate_device_name" must
@@ -862,12 +861,12 @@ bool SimplePlacer::CanAssignToDevice(
return false;
}
-void SimplePlacer::AssignAndLog(int assigned_device, Node* node) const {
+void Placer::AssignAndLog(int assigned_device, Node* node) const {
node->set_assigned_device_name_index(assigned_device);
LogDeviceAssignment(node);
}
-void SimplePlacer::LogDeviceAssignment(const Node* node) const {
+void Placer::LogDeviceAssignment(const Node* node) const {
// Log placement if log_device_placement is set.
if (log_device_placement_) {
printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(),
diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/placer.h
index 9c63cef40b..c5b76592e1 100644
--- a/tensorflow/core/common_runtime/simple_placer.h
+++ b/tensorflow/core/common_runtime/placer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_
-#define TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_
+#ifndef TENSORFLOW_COMMON_RUNTIME_PLACER_H_
+#define TENSORFLOW_COMMON_RUNTIME_PLACER_H_
#include <string>
#include <unordered_map>
@@ -53,25 +53,24 @@ namespace tensorflow {
// TODO(mrry): Create a common interface for this and the other
// placement algorithms so that they may be injected into the graph
// builder.
-class SimplePlacer {
+class Placer {
public:
// A map from graph node names to numerical IDs (in a Graph object).
typedef std::unordered_map<string, int> NodeNameToIdMap;
- // Creates an instance of the SimplePlacer algorithm for the given
+ // Creates an instance of the Placer algorithm for the given
// Graph "graph" (nodes in which may or may not be assigned) on the
// given DeviceSet "devices".
//
// The "graph", and "devices" pointer arguments
- // are borrowed by this SimplePlacer, and must outlive it.
- SimplePlacer(Graph* graph, const DeviceSet* devices,
- const SessionOptions* options);
+ // are borrowed by this Placer, and must outlive it.
+ Placer(Graph* graph, const DeviceSet* devices, const SessionOptions* options);
- SimplePlacer(Graph* graph, const DeviceSet* devices);
+ Placer(Graph* graph, const DeviceSet* devices);
- ~SimplePlacer();
+ ~Placer();
- // Assigns each node in this SimplePlacer's graph to a device in its
+ // Assigns each node in this Placer's graph to a device in its
// set of devices.
//
// This method is not thread-safe.
@@ -94,9 +93,9 @@ class SimplePlacer {
const SessionOptions* options_; // Not owned.
const bool log_device_placement_;
- TF_DISALLOW_COPY_AND_ASSIGN(SimplePlacer);
+ TF_DISALLOW_COPY_AND_ASSIGN(Placer);
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_
+#endif // TENSORFLOW_COMMON_RUNTIME_PLACER_H_
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 967bee63a1..5d87b1e279 100644
--- a/tensorflow/core/common_runtime/simple_placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/simple_placer.h"
+#include "tensorflow/core/common_runtime/placer.h"
#include <memory>
#include <string>
@@ -44,7 +44,7 @@ namespace {
//
// Op, kernel, and device registrations to set up the environment.
//
-// The SimplePlacer uses information about the op (input types),
+// The Placer uses information about the op (input types),
// kernel (device constraints), and available devices to make
// placement decisions. To avoid depending on the full runtime, we
// define dummy implementations of these, and register them with the
@@ -164,17 +164,17 @@ REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp);
////////////////////////////////////////////////////////////////////////////////
//
-// A SimplePlacerTest method has three phases:
+// A PlacerTest method has three phases:
//
// 1. Build a TensorFlow graph, with no (or partial) device assignments.
-// 2. Attempt to compute a placement using the SimplePlacer.
+// 2. Attempt to compute a placement using the Placer.
// 3. EITHER: test that the constraints implied by the graph are respected;
// or that an appropriate error was reported.
//
////////////////////////////////////////////////////////////////////////////////
-class SimplePlacerTest : public ::testing::Test {
+class PlacerTest : public ::testing::Test {
protected:
- SimplePlacerTest() {
+ PlacerTest() {
// Build a set of 10 GPU and 10 CPU devices.
// NOTE: this->local_devices_ owns the device objects;
// this->devices_ contains borrowed pointers to the device
@@ -201,12 +201,12 @@ class SimplePlacerTest : public ::testing::Test {
return Status::OK();
}
- // Invokes the SimplePlacer on "graph". If no DeviceSet is specified, the
+ // Invokes the Placer on "graph". If no DeviceSet is specified, the
// placement will use the default DeviceSet (of 10 CPU and 10 GPU devices).
//
// REQUIRES: "*graph" was produced by the most recent call to BuildGraph.
Status Place(Graph* graph, DeviceSet* devices, SessionOptions* options) {
- SimplePlacer placer(graph, devices, options);
+ Placer placer(graph, devices, options);
return placer.Run();
}
@@ -232,7 +232,7 @@ class SimplePlacerTest : public ::testing::Test {
protected:
std::vector<std::unique_ptr<Device>> local_devices_;
DeviceSet devices_;
- SimplePlacer::NodeNameToIdMap nodes_by_name_;
+ Placer::NodeNameToIdMap nodes_by_name_;
Status ReferenceTestHelper(const string& variable_op_type,
const string& assign_op_type,
@@ -267,7 +267,7 @@ class SimplePlacerTest : public ::testing::Test {
// Test that a graph with no constraints will successfully assign nodes to the
// "best available" device (i.e. prefer GPU over CPU).
-TEST_F(SimplePlacerTest, TestNoConstraints) {
+TEST_F(PlacerTest, TestNoConstraints) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -286,7 +286,7 @@ TEST_F(SimplePlacerTest, TestNoConstraints) {
// Test that a graph with device type and reference constraints on
// some of the ops will successfully assign nodes to the constrained
// device, and colocate nodes with reference connections.
-TEST_F(SimplePlacerTest, TestDeviceTypeConstraints) {
+TEST_F(PlacerTest, TestDeviceTypeConstraints) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -308,7 +308,7 @@ TEST_F(SimplePlacerTest, TestDeviceTypeConstraints) {
EXPECT_COLOCATED(g, "var_gpu", "assign_gpu");
}
-TEST_F(SimplePlacerTest, TestMetadataColocatedWithInput) {
+TEST_F(PlacerTest, TestMetadataColocatedWithInput) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -330,7 +330,7 @@ TEST_F(SimplePlacerTest, TestMetadataColocatedWithInput) {
// Heuristic A implements "Island fusing": if a node only generates
// an output and it has only one consumer, we place the node
// with its consumer.
-TEST_F(SimplePlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) {
+TEST_F(PlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -355,7 +355,7 @@ TEST_F(SimplePlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) {
EXPECT_COLOCATED(g, "assign", "in");
}
-TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) {
+TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -382,7 +382,7 @@ TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) {
EXPECT_COLOCATED(g, "var_cpu", "assign");
}
-TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) {
+TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -416,7 +416,7 @@ TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) {
// Test that a graph with partial device specifications on the ops
// will successfully
-TEST_F(SimplePlacerTest, TestPartialSpec) {
+TEST_F(PlacerTest, TestPartialSpec) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -434,7 +434,7 @@ TEST_F(SimplePlacerTest, TestPartialSpec) {
}
// Test that a node with a pre-assigned device is not relocated.
-TEST_F(SimplePlacerTest, TestAssignedDevicePreserved) {
+TEST_F(PlacerTest, TestAssignedDevicePreserved) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -452,7 +452,7 @@ TEST_F(SimplePlacerTest, TestAssignedDevicePreserved) {
// Test that a graph with partial device specifications for CPU-only ops
// will be relocated to CPU.
-TEST_F(SimplePlacerTest, TestPartialSpecGpuToCpu) {
+TEST_F(PlacerTest, TestPartialSpecGpuToCpu) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -474,7 +474,7 @@ TEST_F(SimplePlacerTest, TestPartialSpecGpuToCpu) {
// Test that a node with an assigned GPU device but has not registered
// OpKernel will fail.
-TEST_F(SimplePlacerTest, TestAssignedGpuDeviceToCpuDevice) {
+TEST_F(PlacerTest, TestAssignedGpuDeviceToCpuDevice) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -499,9 +499,9 @@ TEST_F(SimplePlacerTest, TestAssignedGpuDeviceToCpuDevice) {
// Build a graph containing a Variable op of "variable_op_type" and an
// Assign op of "assign_op_type", and expect all of the ops to be
// placed on a device of type "expected_device_type".
-Status SimplePlacerTest::ReferenceTestHelper(
- const string& variable_op_type, const string& assign_op_type,
- const DeviceType& expected_device_type) {
+Status PlacerTest::ReferenceTestHelper(const string& variable_op_type,
+ const string& assign_op_type,
+ const DeviceType& expected_device_type) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -530,7 +530,7 @@ Status SimplePlacerTest::ReferenceTestHelper(
// Test all 2^3 combinations of Variable and Assignment op types
// (unconstrained, CPU-only, and GPU-only).
-TEST_F(SimplePlacerTest, TestReferenceConnection) {
+TEST_F(PlacerTest, TestReferenceConnection) {
Status s;
TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "TestAssign", "FakeGPU"));
TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignCPU", "FakeCPU"));
@@ -575,7 +575,7 @@ REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float");
REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device("FakeGPU"), DummyOp);
// Tests all combinations of resource handles and ops using them.
-TEST_F(SimplePlacerTest, TestResourceHandle) {
+TEST_F(PlacerTest, TestResourceHandle) {
auto handle_test = [this](const string& var_op_name,
const string& use_op_name, DeviceType device) {
Graph g(OpRegistry::Global());
@@ -611,7 +611,7 @@ TEST_F(SimplePlacerTest, TestResourceHandle) {
// Test that an assignment of an operator to the wrong device
// is ignored when it could never be satisfied (due to reference
// edges, for example).
-TEST_F(SimplePlacerTest, TestReferenceConnectionIgnoreInfeasible) {
+TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) {
Status s;
Graph g(OpRegistry::Global());
{
@@ -640,8 +640,7 @@ TEST_F(SimplePlacerTest, TestReferenceConnectionIgnoreInfeasible) {
// Test that an assignment of an operator to the a more specified device
// causes the device to maintain its more specific placement.
-TEST_F(SimplePlacerTest,
- TestReferenceConnectionMoreSpecificDestinationSourceWins) {
+TEST_F(PlacerTest, TestReferenceConnectionMoreSpecificDestinationSourceWins) {
Status s;
Graph g(OpRegistry::Global());
{
@@ -675,7 +674,7 @@ TEST_F(SimplePlacerTest,
// where the assign has a device but the variable does not. In this
// case, the variable gets placed on the location of the assign
// operation.
-TEST_F(SimplePlacerTest, TestReferenceConnectionNoSourceDevice) {
+TEST_F(PlacerTest, TestReferenceConnectionNoSourceDevice) {
Status s;
Graph g(OpRegistry::Global());
{
@@ -697,7 +696,7 @@ TEST_F(SimplePlacerTest, TestReferenceConnectionNoSourceDevice) {
EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU");
}
-TEST_F(SimplePlacerTest, TestColocationGroup) {
+TEST_F(PlacerTest, TestColocationGroup) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -720,7 +719,7 @@ TEST_F(SimplePlacerTest, TestColocationGroup) {
EXPECT_NOT_COLOCATED(g, "in", "foo");
}
-TEST_F(SimplePlacerTest, TestMultipleColocationGroups) {
+TEST_F(PlacerTest, TestMultipleColocationGroups) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -742,7 +741,7 @@ TEST_F(SimplePlacerTest, TestMultipleColocationGroups) {
EXPECT_COLOCATED(g, "in", "foo");
}
-TEST_F(SimplePlacerTest, TestInvalidMultipleColocationGroups) {
+TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -766,7 +765,7 @@ TEST_F(SimplePlacerTest, TestInvalidMultipleColocationGroups) {
"other nodes colocated with them"));
}
-TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) {
+TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -793,8 +792,7 @@ TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) {
EXPECT_COLOCATED(g, "var2", "assign1");
}
-TEST_F(SimplePlacerTest,
- TestColocationGroupWithUnsatisfiableReferenceConnections) {
+TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -832,7 +830,7 @@ TEST_F(SimplePlacerTest,
"nodes colocated with them."));
}
-TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
+TEST_F(PlacerTest, TestColocationAndReferenceConnections) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -877,7 +875,7 @@ TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
}
// Test that placement fails when no devices are registered.
-TEST_F(SimplePlacerTest, TestEmptyDeviceSet) {
+TEST_F(PlacerTest, TestEmptyDeviceSet) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -894,7 +892,7 @@ TEST_F(SimplePlacerTest, TestEmptyDeviceSet) {
// Test that placement fails when the requested device forces an
// indirect constraint to be violated.
-TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) {
+TEST_F(PlacerTest, TestHeterogeneousDeviceSetFailure) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -928,7 +926,7 @@ TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) {
}
// Test that placement fails when an unknown device is requested.
-TEST_F(SimplePlacerTest, TestUnknownDevice) {
+TEST_F(PlacerTest, TestUnknownDevice) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -943,7 +941,7 @@ TEST_F(SimplePlacerTest, TestUnknownDevice) {
// Test that placement fails when the combination of partial
// constraints leads to an unknown device.
-TEST_F(SimplePlacerTest, TestUnknownMergedDevice) {
+TEST_F(PlacerTest, TestUnknownMergedDevice) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -958,7 +956,7 @@ TEST_F(SimplePlacerTest, TestUnknownMergedDevice) {
// Test that placement fails when the previously-assigned device for a
// node is unknown.
-TEST_F(SimplePlacerTest, TestUnknownAssignedDevice) {
+TEST_F(PlacerTest, TestUnknownAssignedDevice) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -977,7 +975,7 @@ TEST_F(SimplePlacerTest, TestUnknownAssignedDevice) {
// Test that placement fails when an op with no registered kernels is
// requested.
-TEST_F(SimplePlacerTest, TestNoKernelsRegistered) {
+TEST_F(PlacerTest, TestNoKernelsRegistered) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -997,7 +995,7 @@ TEST_F(SimplePlacerTest, TestNoKernelsRegistered) {
// Test that placement fails when a kernel is registered but no known
// device supports it.
-TEST_F(SimplePlacerTest, TestNoDevicesRegistered) {
+TEST_F(PlacerTest, TestNoDevicesRegistered) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1019,7 +1017,7 @@ TEST_F(SimplePlacerTest, TestNoDevicesRegistered) {
}
// Test that placement fails when a requested device is malformed.
-TEST_F(SimplePlacerTest, TestMalformedDeviceSpecification) {
+TEST_F(PlacerTest, TestMalformedDeviceSpecification) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1034,7 +1032,7 @@ TEST_F(SimplePlacerTest, TestMalformedDeviceSpecification) {
}
// Test that placement fails when a previously-assigned device is malformed.
-TEST_F(SimplePlacerTest, TestMalformedAssignedDevice) {
+TEST_F(PlacerTest, TestMalformedAssignedDevice) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1052,7 +1050,7 @@ TEST_F(SimplePlacerTest, TestMalformedAssignedDevice) {
// Test that placement fails when a device was previously assigned to
// a node, but it does not uniquely identify a particular device.
-TEST_F(SimplePlacerTest, TestNonUniqueAssignedDevice) {
+TEST_F(PlacerTest, TestNonUniqueAssignedDevice) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1071,7 +1069,7 @@ TEST_F(SimplePlacerTest, TestNonUniqueAssignedDevice) {
// Test that ops request to be placed on non-existent devices will be relocated
// to existing device of the same type if allow_soft_placement is set.
-TEST_F(SimplePlacerTest, TestNonexistentGpuAllowSoftPlacement) {
+TEST_F(PlacerTest, TestNonexistentGpuAllowSoftPlacement) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1088,7 +1086,7 @@ TEST_F(SimplePlacerTest, TestNonexistentGpuAllowSoftPlacement) {
// Test that ops request to be placed on non-existent devices will fail if
// allow_soft_placement is not set.
-TEST_F(SimplePlacerTest, TestNonexistentGpuNoAllowSoftPlacement) {
+TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacement) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1105,7 +1103,7 @@ TEST_F(SimplePlacerTest, TestNonexistentGpuNoAllowSoftPlacement) {
// Test that placement fails when a node requests an explicit device that is not
// supported by the registered kernels if allow_soft_placement is no set.
-TEST_F(SimplePlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) {
+TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1125,7 +1123,7 @@ TEST_F(SimplePlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) {
// Test that placement fails when a node requests an explicit device that is not
// supported by the registered kernels if allow_soft_placement is no set.
-TEST_F(SimplePlacerTest, TestNonExistentDevice) {
+TEST_F(PlacerTest, TestNonExistentDevice) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1143,7 +1141,7 @@ TEST_F(SimplePlacerTest, TestNonExistentDevice) {
"but available devices"));
}
-TEST_F(SimplePlacerTest, TestUnsupportedDeviceAllowSoftPlacement) {
+TEST_F(PlacerTest, TestUnsupportedDeviceAllowSoftPlacement) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1160,7 +1158,7 @@ TEST_F(SimplePlacerTest, TestUnsupportedDeviceAllowSoftPlacement) {
// Test that a graph with device type and reference constraints on
// some of the ops will successfully assign nodes to the constrained
// device, and colocate nodes with reference connections.
-TEST_F(SimplePlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) {
+TEST_F(PlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1194,7 +1192,7 @@ TEST_F(SimplePlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) {
// Test that placement fails when two nodes have a reference connection
// constraint, and each node requires a mutually incompatible device.
-TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
+TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1212,7 +1210,7 @@ TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
// Test that a generator node follows its consumers (where there are several
// consumer nodes on the same devices).
-TEST_F(SimplePlacerTest, TestGeneratorNodeFollowsConsumerNode) {
+TEST_F(PlacerTest, TestGeneratorNodeFollowsConsumerNode) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -1245,7 +1243,7 @@ TEST_F(SimplePlacerTest, TestGeneratorNodeFollowsConsumerNode) {
// Test that a generator node does not follow its consumers (where there are
// several consumers on different devices).
-TEST_F(SimplePlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) {
+TEST_F(PlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index caa60e58e6..bfdf967333 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -55,7 +55,7 @@ limitations under the License.
namespace tensorflow {
-// MasterSession wraps SimpleClientGraph in a reference counted object.
+// MasterSession wraps ClientGraph in a reference counted object.
// This way, MasterSession can clear up the cache mapping Run requests to
// compiled graphs while the compiled graph is still being used.
//
@@ -63,10 +63,10 @@ namespace tensorflow {
class MasterSession::ReffedClientGraph : public core::RefCounted {
public:
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
- std::unique_ptr<SimpleClientGraph> cg,
+ std::unique_ptr<ClientGraph> cg,
const SessionOptions& session_opts,
const StatsPublisherFactory& stats_publisher_factory,
- SimpleGraphExecutionState* execution_state, bool is_partial,
+ GraphExecutionState* execution_state, bool is_partial,
WorkerCacheInterface* worker_cache)
: session_handle_(handle),
client_graph_(std::move(cg)),
@@ -87,7 +87,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
~ReffedClientGraph() override { DeregisterPartitions(); }
- const SimpleClientGraph* client_graph() { return client_graph_.get(); }
+ const ClientGraph* client_graph() { return client_graph_.get(); }
std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
int64 execution_count,
@@ -186,7 +186,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
// Checks that the requested fetches can be computed from the provided feeds.
Status CheckFetches(const RunStepRequestWrapper& req,
const RunState* run_state,
- SimpleGraphExecutionState* execution_state);
+ GraphExecutionState* execution_state);
string DetailText(const Node& node, const NodeExecStats& ns) {
int64 tot = 0;
@@ -203,7 +203,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
private:
const string session_handle_;
- const std::unique_ptr<SimpleClientGraph> client_graph_;
+ const std::unique_ptr<ClientGraph> client_graph_;
const SessionOptions session_opts_;
const bool is_partial_;
const DebugOptions& debug_opts_;
@@ -813,7 +813,7 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
// contention.
Status MasterSession::ReffedClientGraph::CheckFetches(
const RunStepRequestWrapper& req, const RunState* run_state,
- SimpleGraphExecutionState* execution_state) {
+ GraphExecutionState* execution_state) {
// Build the set of pending feeds that we haven't seen.
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
for (const auto& input : run_state->pending_inputs) {
@@ -1028,12 +1028,12 @@ Status MasterSession::Create(GraphDef* graph_def,
session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
}
- SimpleGraphExecutionStateOptions execution_options;
+ GraphExecutionStateOptions execution_options;
execution_options.device_set = devices_.get();
execution_options.session_options = &session_opts_;
{
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
+ TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
graph_def, execution_options, &execution_state_));
}
if (options.cluster_def != nullptr) {
@@ -1142,7 +1142,7 @@ Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) {
UpdateLastAccessTime();
- std::unique_ptr<SimpleGraphExecutionState> extended_execution_state;
+ std::unique_ptr<GraphExecutionState> extended_execution_state;
{
mutex_lock l(mu_);
if (closed_) {
@@ -1195,7 +1195,7 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
VLOG(1) << "Unseen hash " << hash << " for "
<< BuildGraphOptionsString(opts) << " is_partial = " << is_partial
<< "\n";
- std::unique_ptr<SimpleClientGraph> client_graph;
+ std::unique_ptr<ClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph(
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index 33b9bfe631..51ea92da68 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device_set.h"
-#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
+#include "tensorflow/core/common_runtime/graph_execution_state.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
@@ -128,7 +128,7 @@ class MasterSession : public core::RefCounted {
std::atomic<int64> partial_run_handle_counter_ = {0};
mutex mu_;
- std::unique_ptr<SimpleGraphExecutionState> execution_state_ GUARDED_BY(mu_);
+ std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_);
int64 graph_version_;
// We keep a map from a signature of a run request to the