diff options
author | 2017-09-07 22:11:25 -0700 | |
---|---|---|
committer | 2017-09-07 22:15:42 -0700 | |
commit | b5c1d0f8977e0f05c9aeeb9e5105500bf83972bb (patch) | |
tree | e98407a18e1f7e8f3719ba58557be6be3c4deb59 | |
parent | d27db78cd0168f10b308f7508c11dfaa3c6707e9 (diff) |
SimpleGraphExecutionState -> GraphExecutionState
SimplePlacer -> Placer
And clean up a couple unneeded headers.
PiperOrigin-RevId: 167955883
-rw-r--r-- | tensorflow/core/BUILD | 10 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.h | 4 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/graph_execution_state.cc (renamed from tensorflow/core/common_runtime/simple_graph_execution_state.cc) | 67 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/graph_execution_state.h (renamed from tensorflow/core/common_runtime/simple_graph_execution_state.h) | 65 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/placer.cc (renamed from tensorflow/core/common_runtime/simple_placer.cc) | 25 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/placer.h (renamed from tensorflow/core/common_runtime/simple_placer.h) | 23 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/placer_test.cc (renamed from tensorflow/core/common_runtime/simple_placer_test.cc) | 106 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/master_session.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/master_session.h | 4 |
10 files changed, 171 insertions, 178 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 9319928307..4b80d2c543 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1773,12 +1773,14 @@ tf_cuda_library( "common_runtime/device_set.cc", "common_runtime/executor.cc", "common_runtime/function.cc", + "common_runtime/graph_execution_state.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", "common_runtime/local_device.cc", "common_runtime/memory_types.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", + "common_runtime/placer.cc", "common_runtime/process_function_library_runtime.cc", "common_runtime/process_util.cc", "common_runtime/renamed_device.cc", @@ -1788,8 +1790,6 @@ tf_cuda_library( "common_runtime/session_factory.cc", "common_runtime/session_options.cc", "common_runtime/session_state.cc", - "common_runtime/simple_graph_execution_state.cc", - "common_runtime/simple_placer.cc", "common_runtime/stats_publisher_interface.cc", "common_runtime/step_stats_collector.cc", "common_runtime/threadpool_device.cc", @@ -1829,8 +1829,8 @@ tf_cuda_library( "common_runtime/renamed_device.h", "common_runtime/rendezvous_mgr.h", "common_runtime/session_factory.h", - "common_runtime/simple_graph_execution_state.h", - "common_runtime/simple_placer.h", + "common_runtime/graph_execution_state.h", + "common_runtime/placer.h", "common_runtime/stats_publisher_interface.h", "common_runtime/step_stats_collector.h", "common_runtime/threadpool_device.h", @@ -2321,9 +2321,9 @@ tf_cc_tests( "common_runtime/device_set_test.cc", "common_runtime/optimization_registry_test.cc", "common_runtime/pending_counts_test.cc", + "common_runtime/placer_test.cc", "common_runtime/resource_variable_read_optimizer_test.cc", "common_runtime/session_test.cc", - "common_runtime/simple_placer_test.cc", "example/feature_util_test.cc", "framework/allocator_test.cc", "framework/attr_value_util_test.cc", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index a6630f38a5..d28857bb9f 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/common_runtime/simple_placer.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb_text.h" @@ -345,20 +344,20 @@ Status DirectSession::MaybeInitializeExecutionState( // all subsequent extensions of the graph. flib_def_.reset( new FunctionLibraryDefinition(OpRegistry::Global(), graph.library())); - SimpleGraphExecutionStateOptions options; + GraphExecutionStateOptions options; options.device_set = &device_set_; options.session_options = &options_; // TODO(mrry,suharshs): We explicitly copy `graph` so that // `MakeForBaseGraph()` can take ownership of its // contents. Previously this happened implicitly in calls to the - // `SimpleGraphExecutionState`. Other sessions call + // `GraphExecutionState`. Other sessions call // `MakeForBaseGraph` in such a way that we can destructively read // the passed-in `GraphDef`. In principle we could do the same here, // with a wider refactoring; we might revise the direct session so // that it copies the graph fewer times. GraphDef temp(graph); - TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( - &temp, options, &execution_state_)); + TF_RETURN_IF_ERROR( + GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_)); graph_created_ = true; *out_already_initialized = false; return Status::OK(); @@ -391,7 +390,7 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) { MaybeInitializeExecutionState(graph, &already_initialized)); if (already_initialized) { TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library())); - std::unique_ptr<SimpleGraphExecutionState> state; + std::unique_ptr<GraphExecutionState> state; TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); execution_state_.swap(state); } @@ -1280,19 +1279,19 @@ Status DirectSession::CreateGraphs( RunStateArgs* run_state_args, DataTypeVector* input_types, DataTypeVector* output_types) { mutex_lock l(graph_def_lock_); - std::unique_ptr<SimpleClientGraph> client_graph; + std::unique_ptr<ClientGraph> client_graph; - std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder; - SimpleGraphExecutionState* execution_state = nullptr; + std::unique_ptr<GraphExecutionState> temp_exec_state_holder; + GraphExecutionState* execution_state = nullptr; if (options_.config.graph_options().place_pruned_graph()) { // Because we are placing pruned graphs, we need to create a - // new SimpleGraphExecutionState for every new unseen graph, + // new GraphExecutionState for every new unseen graph, // and then place it. - SimpleGraphExecutionStateOptions prune_options; + GraphExecutionStateOptions prune_options; prune_options.device_set = &device_set_; prune_options.session_options = &options_; prune_options.stateful_placements = stateful_placements_; - TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph( + TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph( execution_state_->original_graph_def().library(), prune_options, execution_state_->original_graph_def(), subgraph_options, &temp_exec_state_holder, &client_graph)); diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 020831d6cc..7fbabf6d81 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -28,10 +28,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/graph_execution_state.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/session_factory.h" -#include "tensorflow/core/common_runtime/simple_graph_execution_state.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/session_state.h" @@ -310,7 +310,7 @@ class DirectSession : public Session { GUARDED_BY(graph_def_lock_); // Execution_state; used when placing the entire graph. - std::unique_ptr<SimpleGraphExecutionState> execution_state_ + std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(graph_def_lock_); // The function library, before any rewrites or optimizations have been diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index c66dc568f6..4bd40c7978 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/simple_graph_execution_state.h" +#include "tensorflow/core/common_runtime/graph_execution_state.h" #include <memory> #include <string> @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/common_runtime/simple_placer.h" +#include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -49,8 +49,8 @@ limitations under the License. namespace tensorflow { -SimpleGraphExecutionState::SimpleGraphExecutionState( - GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options) +GraphExecutionState::GraphExecutionState( + GraphDef* graph_def, const GraphExecutionStateOptions& options) : stateful_placements_(options.stateful_placements), device_set_(options.device_set), session_options_(options.session_options), @@ -65,16 +65,16 @@ SimpleGraphExecutionState::SimpleGraphExecutionState( // placement option. } -SimpleGraphExecutionState::~SimpleGraphExecutionState() { +GraphExecutionState::~GraphExecutionState() { node_name_to_cost_id_map_.clear(); delete graph_; } -/* static */ Status SimpleGraphExecutionState::MakeForBaseGraph( - GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options, - std::unique_ptr<SimpleGraphExecutionState>* out_state) { - std::unique_ptr<SimpleGraphExecutionState> ret( - new SimpleGraphExecutionState(graph_def, options)); +/* static */ Status GraphExecutionState::MakeForBaseGraph( + GraphDef* graph_def, const GraphExecutionStateOptions& options, + std::unique_ptr<GraphExecutionState>* out_state) { + std::unique_ptr<GraphExecutionState> ret( + new GraphExecutionState(graph_def, options)); TF_RETURN_IF_ERROR( AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0)); @@ -88,12 +88,12 @@ SimpleGraphExecutionState::~SimpleGraphExecutionState() { return Status::OK(); } -/* static */ Status SimpleGraphExecutionState::MakeForPrunedGraph( +/* static */ Status GraphExecutionState::MakeForPrunedGraph( const FunctionDefLibrary& func_def_lib, - const SimpleGraphExecutionStateOptions& options, const GraphDef& graph_def, + const GraphExecutionStateOptions& options, const GraphDef& graph_def, const BuildGraphOptions& subgraph_options, - std::unique_ptr<SimpleGraphExecutionState>* out_state, - std::unique_ptr<SimpleClientGraph>* out_client_graph) { + std::unique_ptr<GraphExecutionState>* out_state, + std::unique_ptr<ClientGraph>* out_client_graph) { DCHECK(options.session_options->config.graph_options().place_pruned_graph()); // NOTE(mrry): This makes a copy of `graph_def`, which is // regrettable. We could make `GraphDef` objects sharable between @@ -103,8 +103,8 @@ SimpleGraphExecutionState::~SimpleGraphExecutionState() { // also that the previous version used `Extend()`, which is strictly // more expensive than copying a `GraphDef`.) GraphDef temp(graph_def); - std::unique_ptr<SimpleGraphExecutionState> ret( - new SimpleGraphExecutionState(&temp, options)); + std::unique_ptr<GraphExecutionState> ret( + new GraphExecutionState(&temp, options)); TF_RETURN_IF_ERROR( AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0)); TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options)); @@ -113,9 +113,9 @@ SimpleGraphExecutionState::~SimpleGraphExecutionState() { return Status::OK(); } -Status SimpleGraphExecutionState::Extend( +Status GraphExecutionState::Extend( const GraphDef& extension_def, - std::unique_ptr<SimpleGraphExecutionState>* out) const { + std::unique_ptr<GraphExecutionState>* out) const { GraphDef gdef; // 1. Copy the function library. @@ -186,15 +186,15 @@ Status SimpleGraphExecutionState::Extend( } // 6. Add the extension. - SimpleGraphExecutionStateOptions combined_options; + GraphExecutionStateOptions combined_options; combined_options.device_set = device_set_; combined_options.session_options = session_options_; combined_options.stateful_placements = stateful_placements_; // NOTE(mrry): `gdef` is no longer valid after the constructor // executes. - std::unique_ptr<SimpleGraphExecutionState> new_execution_state( - new SimpleGraphExecutionState(&gdef, combined_options)); + std::unique_ptr<GraphExecutionState> new_execution_state( + new GraphExecutionState(&gdef, combined_options)); TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( &new_execution_state->original_graph_def_, *flib_def_, 0)); @@ -212,7 +212,7 @@ Status SimpleGraphExecutionState::Extend( return Status::OK(); } -void SimpleGraphExecutionState::SaveStatefulNodes(Graph* graph) { +void GraphExecutionState::SaveStatefulNodes(Graph* graph) { for (Node* n : graph->nodes()) { if (n->op_def().is_stateful()) { VLOG(2) << "Saving " << n->DebugString(); @@ -221,7 +221,7 @@ void SimpleGraphExecutionState::SaveStatefulNodes(Graph* graph) { } } -void SimpleGraphExecutionState::RestoreStatefulNodes(Graph* graph) { +void GraphExecutionState::RestoreStatefulNodes(Graph* graph) { for (Node* n : graph->nodes()) { if (n->op_def().is_stateful()) { auto iter = stateful_placements_.find(n->name()); @@ -233,8 +233,7 @@ void SimpleGraphExecutionState::RestoreStatefulNodes(Graph* graph) { } } -Status SimpleGraphExecutionState::InitBaseGraph( - const BuildGraphOptions& options) { +Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) { const GraphDef* graph_def = &original_graph_def_; std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global())); @@ -266,8 +265,8 @@ Status SimpleGraphExecutionState::InitBaseGraph( TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::PRE_PLACEMENT, optimization_options)); - SimplePlacer placer(new_graph.get(), device_set_, session_options_); - // TODO(mrry): Consider making the SimplePlacer cancelable. + Placer placer(new_graph.get(), device_set_, session_options_); + // TODO(mrry): Consider making the Placer cancelable. TF_RETURN_IF_ERROR(placer.Run()); TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( @@ -278,7 +277,7 @@ Status SimpleGraphExecutionState::InitBaseGraph( return Status::OK(); } -Status SimpleGraphExecutionState::OptimizeGraph( +Status GraphExecutionState::OptimizeGraph( const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph) { #ifndef IS_MOBILE_PLATFORM if (session_options_->config.graph_options().place_pruned_graph()) { @@ -378,8 +377,8 @@ Status SimpleGraphExecutionState::OptimizeGraph( #endif // IS_MOBILE_PLATFORM } -Status SimpleGraphExecutionState::BuildGraph( - const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) { +Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options, + std::unique_ptr<ClientGraph>* out) { VLOG(1) << "BuildGraph"; if (!graph_) { // It is only valid to call this method directly when the original graph @@ -406,7 +405,7 @@ Status SimpleGraphExecutionState::BuildGraph( options.target_nodes, device_set_->client_device()->attributes(), options.use_function_convention, &rewrite_metadata)); } else { - // This SimpleGraphExecutionState represents a graph that was + // This GraphExecutionState represents a graph that was // pruned when this was constructed, so we copy the metadata from // a member variable. CHECK(rewrite_metadata_); @@ -433,9 +432,9 @@ Status SimpleGraphExecutionState::BuildGraph( // Copy the extracted graph in order to make its node ids dense, // since the local CostModel used to record its stats is sized by // the largest node id. - std::unique_ptr<SimpleClientGraph> dense_copy( - new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types, - rewrite_metadata.fetch_types)); + std::unique_ptr<ClientGraph> dense_copy( + new ClientGraph(std::move(flib), rewrite_metadata.feed_types, + rewrite_metadata.fetch_types)); CopyGraph(*ng, &dense_copy->graph); // TODO(vrv): We should check invariants of the graph here. diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h index 53eef8a07d..db2686ce2c 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.h +++ b/tensorflow/core/common_runtime/graph_execution_state.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ #include <functional> #include <memory> @@ -38,7 +38,7 @@ namespace subgraph { struct RewriteGraphMetadata; } -struct SimpleGraphExecutionStateOptions { +struct GraphExecutionStateOptions { const DeviceSet* device_set = nullptr; const SessionOptions* session_options = nullptr; // A map from node name to device name, representing the unchangeable @@ -46,12 +46,11 @@ struct SimpleGraphExecutionStateOptions { std::unordered_map<string, string> stateful_placements; }; -// A SimpleClientGraph is simply a sub-graph of the full graph as induced by +// A ClientGraph is simply a sub-graph of the full graph as induced by // BuildGraphOptions. -struct SimpleClientGraph { - explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib, - DataTypeVector feed_types, - DataTypeVector fetch_types) +struct ClientGraph { + explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib, + DataTypeVector feed_types, DataTypeVector fetch_types) : flib_def(std::move(flib)), graph(flib_def.get()), feed_types(std::move(feed_types)), @@ -64,8 +63,8 @@ struct SimpleClientGraph { DataTypeVector fetch_types; }; -// SimpleGraphExecutionState is responsible for generating an -// executable SimpleClientGraph from the original GraphDef that specifies +// GraphExecutionState is responsible for generating an +// executable ClientGraph from the original GraphDef that specifies // the complete graph and from BuildGraphOptions which specifies // input/output nodes. // @@ -73,7 +72,7 @@ struct SimpleClientGraph { // meaning that each Node is assigned to a single Device in the // available set. // -// When SimpleGraphExecutionState is first constructed it instantiates +// When GraphExecutionState is first constructed it instantiates // a full Graph from the provided GraphDef, and places it, using only // the static device assignments from the GraphDef. Nodes without are // currently placed in a very naive way. Since stateful Nodes cannot @@ -81,18 +80,18 @@ struct SimpleClientGraph { // Nodes get sensible initial device assignments in the graph // definition. // -// Subsequently, SimpleGraphExecutionState generates a SimpleClientGraph on +// Subsequently, GraphExecutionState generates a SimpleClientGraph on // demand, which is a sub-graph of the latest placement of the full -// Graph. MasterSession uses such a SimpleClientGraph to execute one or +// Graph. MasterSession uses such a ClientGraph to execute one or // more similar client requests. // -// SimpleGraphExecutionState is thread-safe. +// GraphExecutionState is thread-safe. -class SimpleGraphExecutionState { +class GraphExecutionState { public: - virtual ~SimpleGraphExecutionState(); + virtual ~GraphExecutionState(); - // Creates a new `SimpleGraphExecutionState` for the given + // Creates a new `GraphExecutionState` for the given // `graph_def`, which represents the entire graph for a session. // // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def` @@ -100,21 +99,21 @@ class SimpleGraphExecutionState { // after this call, make an explicit copy of the graph before // calling this method. static Status MakeForBaseGraph( - GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options, - std::unique_ptr<SimpleGraphExecutionState>* out_state); + GraphDef* graph_def, const GraphExecutionStateOptions& options, + std::unique_ptr<GraphExecutionState>* out_state); - // Creates a new `SimpleGraphExecutionState` and `SimpleClientGraph` + // Creates a new `GraphExecutionState` and `SimpleClientGraph` // for the subgraph of `original_graph_def` defined by // `subgraph_options`. static Status MakeForPrunedGraph( const FunctionDefLibrary& func_def_lib, - const SimpleGraphExecutionStateOptions& options, + const GraphExecutionStateOptions& options, const GraphDef& original_graph_def, const BuildGraphOptions& subgraph_options, - std::unique_ptr<SimpleGraphExecutionState>* out_state, - std::unique_ptr<SimpleClientGraph>* out_client_graph); + std::unique_ptr<GraphExecutionState>* out_state, + std::unique_ptr<ClientGraph>* out_client_graph); - // Creates a new SimpleGraphExecutionState representing the + // Creates a new GraphExecutionState representing the // concatenation of this graph, and the graph defined by // "extension_def". The same name may not be used to define a node // in both this graph and "extension_def". @@ -129,14 +128,14 @@ class SimpleGraphExecutionState { // in *this, but currently does not transfer any other placement // or cost model information to the new graph. Status Extend(const GraphDef& extension_def, - std::unique_ptr<SimpleGraphExecutionState>* out) const; + std::unique_ptr<GraphExecutionState>* out) const; - // Builds a SimpleClientGraph (a sub-graph of the full graph as induced by + // Builds a ClientGraph (a sub-graph of the full graph as induced by // the Node set specified in "options"). If successful, returns OK // and the caller takes the ownership of "*out". Otherwise, returns // an error. Status BuildGraph(const BuildGraphOptions& options, - std::unique_ptr<SimpleClientGraph>* out); + std::unique_ptr<ClientGraph>* out); // The graph returned by BuildGraph may contain only the pruned // graph, whereas some clients may want access to the full graph. @@ -156,7 +155,7 @@ class SimpleGraphExecutionState { } // Returns a reference to the current graph_def. Use must - // not extend beyond lifetime of SimpleGrahExecutionState object. + // not extend beyond lifetime of GrahExecutionState object. const GraphDef& original_graph_def() { return original_graph_def_; } // Returns the map of stateful placements as a map of @@ -166,8 +165,8 @@ class SimpleGraphExecutionState { } private: - SimpleGraphExecutionState(GraphDef* graph_def, - const SimpleGraphExecutionStateOptions& options); + GraphExecutionState(GraphDef* graph_def, + const GraphExecutionStateOptions& options); Status InitBaseGraph(const BuildGraphOptions& options); @@ -194,16 +193,16 @@ class SimpleGraphExecutionState { // and may be updated by a graph optimization pass. std::unique_ptr<FunctionLibraryDefinition> flib_def_; - // `rewrite_metadata_` is only set for SimpleGraphExecutionState + // `rewrite_metadata_` is only set for GraphExecutionState // objects created by `MakeForPrunedGraph()`. std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_; // The dataflow graph owned by this object. Graph* graph_; - TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState); + TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState); }; } // namespace tensorflow -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_ +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/placer.cc index 663e62a765..73fdf60fd5 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/simple_placer.h" +#include "tensorflow/core/common_runtime/placer.h" #include <memory> #include <set> @@ -636,20 +636,20 @@ bool IsGeneratorNode(const Node* node) { } // namespace -SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices, - const SessionOptions* options) +Placer::Placer(Graph* graph, const DeviceSet* devices, + const SessionOptions* options) : graph_(graph), devices_(devices), options_(options), log_device_placement_(options != nullptr && options->config.log_device_placement()) {} -SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices) - : SimplePlacer(graph, devices, nullptr) {} +Placer::Placer(Graph* graph, const DeviceSet* devices) + : Placer(graph, devices, nullptr) {} -SimplePlacer::~SimplePlacer() {} +Placer::~Placer() {} -Status SimplePlacer::Run() { +Status Placer::Run() { if (devices_->devices().empty()) { return errors::FailedPrecondition("No devices are registered"); } @@ -771,7 +771,7 @@ Status SimplePlacer::Run() { // choose the same device. // // TODO(vrv): Factor this assignment out into a pluggable - // algorithm, so that SimplePlacer is responsible for enforcing + // algorithm, so that Placer is responsible for enforcing // preconditions and we can experiment with other algorithms when // given a choice of devices. Once we have a better idea of the // types of heuristics we want to use and the information needed @@ -844,9 +844,8 @@ Status SimplePlacer::Run() { return Status::OK(); } -bool SimplePlacer::CanAssignToDevice( - const string& candidate_device_name, - const std::vector<Device*>& devices) const { +bool Placer::CanAssignToDevice(const string& candidate_device_name, + const std::vector<Device*>& devices) const { if (!candidate_device_name.empty()) { // 'devices' lists the set of devices that the placer or the user has // constrained the operation to. "candidate_device_name" must @@ -862,12 +861,12 @@ bool SimplePlacer::CanAssignToDevice( return false; } -void SimplePlacer::AssignAndLog(int assigned_device, Node* node) const { +void Placer::AssignAndLog(int assigned_device, Node* node) const { node->set_assigned_device_name_index(assigned_device); LogDeviceAssignment(node); } -void SimplePlacer::LogDeviceAssignment(const Node* node) const { +void Placer::LogDeviceAssignment(const Node* node) const { // Log placement if log_device_placement is set. if (log_device_placement_) { printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(), diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/placer.h index 9c63cef40b..c5b76592e1 100644 --- a/tensorflow/core/common_runtime/simple_placer.h +++ b/tensorflow/core/common_runtime/placer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_ -#define TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_ +#ifndef TENSORFLOW_COMMON_RUNTIME_PLACER_H_ +#define TENSORFLOW_COMMON_RUNTIME_PLACER_H_ #include <string> #include <unordered_map> @@ -53,25 +53,24 @@ namespace tensorflow { // TODO(mrry): Create a common interface for this and the other // placement algorithms so that they may be injected into the graph // builder. -class SimplePlacer { +class Placer { public: // A map from graph node names to numerical IDs (in a Graph object). typedef std::unordered_map<string, int> NodeNameToIdMap; - // Creates an instance of the SimplePlacer algorithm for the given + // Creates an instance of the Placer algorithm for the given // Graph "graph" (nodes in which may or may not be assigned) on the // given DeviceSet "devices". // // The "graph", and "devices" pointer arguments - // are borrowed by this SimplePlacer, and must outlive it. - SimplePlacer(Graph* graph, const DeviceSet* devices, - const SessionOptions* options); + // are borrowed by this Placer, and must outlive it. + Placer(Graph* graph, const DeviceSet* devices, const SessionOptions* options); - SimplePlacer(Graph* graph, const DeviceSet* devices); + Placer(Graph* graph, const DeviceSet* devices); - ~SimplePlacer(); + ~Placer(); - // Assigns each node in this SimplePlacer's graph to a device in its + // Assigns each node in this Placer's graph to a device in its // set of devices. // // This method is not thread-safe. @@ -94,9 +93,9 @@ class SimplePlacer { const SessionOptions* options_; // Not owned. const bool log_device_placement_; - TF_DISALLOW_COPY_AND_ASSIGN(SimplePlacer); + TF_DISALLOW_COPY_AND_ASSIGN(Placer); }; } // namespace tensorflow -#endif // TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_ +#endif // TENSORFLOW_COMMON_RUNTIME_PLACER_H_ diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 967bee63a1..5d87b1e279 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/simple_placer.h" +#include "tensorflow/core/common_runtime/placer.h" #include <memory> #include <string> @@ -44,7 +44,7 @@ namespace { // // Op, kernel, and device registrations to set up the environment. // -// The SimplePlacer uses information about the op (input types), +// The Placer uses information about the op (input types), // kernel (device constraints), and available devices to make // placement decisions. To avoid depending on the full runtime, we // define dummy implementations of these, and register them with the @@ -164,17 +164,17 @@ REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp); //////////////////////////////////////////////////////////////////////////////// // -// A SimplePlacerTest method has three phases: +// A PlacerTest method has three phases: // // 1. Build a TensorFlow graph, with no (or partial) device assignments. -// 2. Attempt to compute a placement using the SimplePlacer. +// 2. Attempt to compute a placement using the Placer. // 3. EITHER: test that the constraints implied by the graph are respected; // or that an appropriate error was reported. // //////////////////////////////////////////////////////////////////////////////// -class SimplePlacerTest : public ::testing::Test { +class PlacerTest : public ::testing::Test { protected: - SimplePlacerTest() { + PlacerTest() { // Build a set of 10 GPU and 10 CPU devices. // NOTE: this->local_devices_ owns the device objects; // this->devices_ contains borrowed pointers to the device @@ -201,12 +201,12 @@ class SimplePlacerTest : public ::testing::Test { return Status::OK(); } - // Invokes the SimplePlacer on "graph". If no DeviceSet is specified, the + // Invokes the Placer on "graph". If no DeviceSet is specified, the // placement will use the default DeviceSet (of 10 CPU and 10 GPU devices). // // REQUIRES: "*graph" was produced by the most recent call to BuildGraph. Status Place(Graph* graph, DeviceSet* devices, SessionOptions* options) { - SimplePlacer placer(graph, devices, options); + Placer placer(graph, devices, options); return placer.Run(); } @@ -232,7 +232,7 @@ class SimplePlacerTest : public ::testing::Test { protected: std::vector<std::unique_ptr<Device>> local_devices_; DeviceSet devices_; - SimplePlacer::NodeNameToIdMap nodes_by_name_; + Placer::NodeNameToIdMap nodes_by_name_; Status ReferenceTestHelper(const string& variable_op_type, const string& assign_op_type, @@ -267,7 +267,7 @@ class SimplePlacerTest : public ::testing::Test { // Test that a graph with no constraints will successfully assign nodes to the // "best available" device (i.e. prefer GPU over CPU). -TEST_F(SimplePlacerTest, TestNoConstraints) { +TEST_F(PlacerTest, TestNoConstraints) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -286,7 +286,7 @@ TEST_F(SimplePlacerTest, TestNoConstraints) { // Test that a graph with device type and reference constraints on // some of the ops will successfully assign nodes to the constrained // device, and colocate nodes with reference connections. -TEST_F(SimplePlacerTest, TestDeviceTypeConstraints) { +TEST_F(PlacerTest, TestDeviceTypeConstraints) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -308,7 +308,7 @@ TEST_F(SimplePlacerTest, TestDeviceTypeConstraints) { EXPECT_COLOCATED(g, "var_gpu", "assign_gpu"); } -TEST_F(SimplePlacerTest, TestMetadataColocatedWithInput) { +TEST_F(PlacerTest, TestMetadataColocatedWithInput) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -330,7 +330,7 @@ TEST_F(SimplePlacerTest, TestMetadataColocatedWithInput) { // Heuristic A implements "Island fusing": if a node only generates // an output and it has only one consumer, we place the node // with its consumer. -TEST_F(SimplePlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) { +TEST_F(PlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -355,7 +355,7 @@ TEST_F(SimplePlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) { EXPECT_COLOCATED(g, "assign", "in"); } -TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) { +TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -382,7 +382,7 @@ TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) { EXPECT_COLOCATED(g, "var_cpu", "assign"); } -TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) { +TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -416,7 +416,7 @@ TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) { // Test that a graph with partial device specifications on the ops // will successfully -TEST_F(SimplePlacerTest, TestPartialSpec) { +TEST_F(PlacerTest, TestPartialSpec) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -434,7 +434,7 @@ TEST_F(SimplePlacerTest, TestPartialSpec) { } // Test that a node with a pre-assigned device is not relocated. -TEST_F(SimplePlacerTest, TestAssignedDevicePreserved) { +TEST_F(PlacerTest, TestAssignedDevicePreserved) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -452,7 +452,7 @@ TEST_F(SimplePlacerTest, TestAssignedDevicePreserved) { // Test that a graph with partial device specifications for CPU-only ops // will be relocated to CPU. -TEST_F(SimplePlacerTest, TestPartialSpecGpuToCpu) { +TEST_F(PlacerTest, TestPartialSpecGpuToCpu) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -474,7 +474,7 @@ TEST_F(SimplePlacerTest, TestPartialSpecGpuToCpu) { // Test that a node with an assigned GPU device but has not registered // OpKernel will fail. -TEST_F(SimplePlacerTest, TestAssignedGpuDeviceToCpuDevice) { +TEST_F(PlacerTest, TestAssignedGpuDeviceToCpuDevice) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -499,9 +499,9 @@ TEST_F(SimplePlacerTest, TestAssignedGpuDeviceToCpuDevice) { // Build a graph containing a Variable op of "variable_op_type" and an // Assign op of "assign_op_type", and expect all of the ops to be // placed on a device of type "expected_device_type". -Status SimplePlacerTest::ReferenceTestHelper( - const string& variable_op_type, const string& assign_op_type, - const DeviceType& expected_device_type) { +Status PlacerTest::ReferenceTestHelper(const string& variable_op_type, + const string& assign_op_type, + const DeviceType& expected_device_type) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -530,7 +530,7 @@ Status SimplePlacerTest::ReferenceTestHelper( // Test all 2^3 combinations of Variable and Assignment op types // (unconstrained, CPU-only, and GPU-only). -TEST_F(SimplePlacerTest, TestReferenceConnection) { +TEST_F(PlacerTest, TestReferenceConnection) { Status s; TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "TestAssign", "FakeGPU")); TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignCPU", "FakeCPU")); @@ -575,7 +575,7 @@ REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float"); REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device("FakeGPU"), DummyOp); // Tests all combinations of resource handles and ops using them. -TEST_F(SimplePlacerTest, TestResourceHandle) { +TEST_F(PlacerTest, TestResourceHandle) { auto handle_test = [this](const string& var_op_name, const string& use_op_name, DeviceType device) { Graph g(OpRegistry::Global()); @@ -611,7 +611,7 @@ TEST_F(SimplePlacerTest, TestResourceHandle) { // Test that an assignment of an operator to the wrong device // is ignored when it could never be satisfied (due to reference // edges, for example). -TEST_F(SimplePlacerTest, TestReferenceConnectionIgnoreInfeasible) { +TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) { Status s; Graph g(OpRegistry::Global()); { @@ -640,8 +640,7 @@ TEST_F(SimplePlacerTest, TestReferenceConnectionIgnoreInfeasible) { // Test that an assignment of an operator to the a more specified device // causes the device to maintain its more specific placement. -TEST_F(SimplePlacerTest, - TestReferenceConnectionMoreSpecificDestinationSourceWins) { +TEST_F(PlacerTest, TestReferenceConnectionMoreSpecificDestinationSourceWins) { Status s; Graph g(OpRegistry::Global()); { @@ -675,7 +674,7 @@ TEST_F(SimplePlacerTest, // where the assign has a device but the variable does not. In this // case, the variable gets placed on the location of the assign // operation. -TEST_F(SimplePlacerTest, TestReferenceConnectionNoSourceDevice) { +TEST_F(PlacerTest, TestReferenceConnectionNoSourceDevice) { Status s; Graph g(OpRegistry::Global()); { @@ -697,7 +696,7 @@ TEST_F(SimplePlacerTest, TestReferenceConnectionNoSourceDevice) { EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU"); } -TEST_F(SimplePlacerTest, TestColocationGroup) { +TEST_F(PlacerTest, TestColocationGroup) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -720,7 +719,7 @@ TEST_F(SimplePlacerTest, TestColocationGroup) { EXPECT_NOT_COLOCATED(g, "in", "foo"); } -TEST_F(SimplePlacerTest, TestMultipleColocationGroups) { +TEST_F(PlacerTest, TestMultipleColocationGroups) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -742,7 +741,7 @@ TEST_F(SimplePlacerTest, TestMultipleColocationGroups) { EXPECT_COLOCATED(g, "in", "foo"); } -TEST_F(SimplePlacerTest, TestInvalidMultipleColocationGroups) { +TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -766,7 +765,7 @@ TEST_F(SimplePlacerTest, TestInvalidMultipleColocationGroups) { "other nodes colocated with them")); } -TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) { +TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -793,8 +792,7 @@ TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) { EXPECT_COLOCATED(g, "var2", "assign1"); } -TEST_F(SimplePlacerTest, - TestColocationGroupWithUnsatisfiableReferenceConnections) { +TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -832,7 +830,7 @@ TEST_F(SimplePlacerTest, "nodes colocated with them.")); } -TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) { +TEST_F(PlacerTest, TestColocationAndReferenceConnections) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -877,7 +875,7 @@ TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) { } // Test that placement fails when no devices are registered. -TEST_F(SimplePlacerTest, TestEmptyDeviceSet) { +TEST_F(PlacerTest, TestEmptyDeviceSet) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -894,7 +892,7 @@ TEST_F(SimplePlacerTest, TestEmptyDeviceSet) { // Test that placement fails when the requested device forces an // indirect constraint to be violated. -TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) { +TEST_F(PlacerTest, TestHeterogeneousDeviceSetFailure) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -928,7 +926,7 @@ TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) { } // Test that placement fails when an unknown device is requested. -TEST_F(SimplePlacerTest, TestUnknownDevice) { +TEST_F(PlacerTest, TestUnknownDevice) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -943,7 +941,7 @@ TEST_F(SimplePlacerTest, TestUnknownDevice) { // Test that placement fails when the combination of partial // constraints leads to an unknown device. -TEST_F(SimplePlacerTest, TestUnknownMergedDevice) { +TEST_F(PlacerTest, TestUnknownMergedDevice) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -958,7 +956,7 @@ TEST_F(SimplePlacerTest, TestUnknownMergedDevice) { // Test that placement fails when the previously-assigned device for a // node is unknown. -TEST_F(SimplePlacerTest, TestUnknownAssignedDevice) { +TEST_F(PlacerTest, TestUnknownAssignedDevice) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -977,7 +975,7 @@ TEST_F(SimplePlacerTest, TestUnknownAssignedDevice) { // Test that placement fails when an op with no registered kernels is // requested. -TEST_F(SimplePlacerTest, TestNoKernelsRegistered) { +TEST_F(PlacerTest, TestNoKernelsRegistered) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -997,7 +995,7 @@ TEST_F(SimplePlacerTest, TestNoKernelsRegistered) { // Test that placement fails when a kernel is registered but no known // device supports it. -TEST_F(SimplePlacerTest, TestNoDevicesRegistered) { +TEST_F(PlacerTest, TestNoDevicesRegistered) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1019,7 +1017,7 @@ TEST_F(SimplePlacerTest, TestNoDevicesRegistered) { } // Test that placement fails when a requested device is malformed. -TEST_F(SimplePlacerTest, TestMalformedDeviceSpecification) { +TEST_F(PlacerTest, TestMalformedDeviceSpecification) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1034,7 +1032,7 @@ TEST_F(SimplePlacerTest, TestMalformedDeviceSpecification) { } // Test that placement fails when a previously-assigned device is malformed. -TEST_F(SimplePlacerTest, TestMalformedAssignedDevice) { +TEST_F(PlacerTest, TestMalformedAssignedDevice) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1052,7 +1050,7 @@ TEST_F(SimplePlacerTest, TestMalformedAssignedDevice) { // Test that placement fails when a device was previously assigned to // a node, but it does not uniquely identify a particular device. -TEST_F(SimplePlacerTest, TestNonUniqueAssignedDevice) { +TEST_F(PlacerTest, TestNonUniqueAssignedDevice) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1071,7 +1069,7 @@ TEST_F(SimplePlacerTest, TestNonUniqueAssignedDevice) { // Test that ops request to be placed on non-existent devices will be relocated // to existing device of the same type if allow_soft_placement is set. -TEST_F(SimplePlacerTest, TestNonexistentGpuAllowSoftPlacement) { +TEST_F(PlacerTest, TestNonexistentGpuAllowSoftPlacement) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1088,7 +1086,7 @@ TEST_F(SimplePlacerTest, TestNonexistentGpuAllowSoftPlacement) { // Test that ops request to be placed on non-existent devices will fail if // allow_soft_placement is not set. -TEST_F(SimplePlacerTest, TestNonexistentGpuNoAllowSoftPlacement) { +TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacement) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1105,7 +1103,7 @@ TEST_F(SimplePlacerTest, TestNonexistentGpuNoAllowSoftPlacement) { // Test that placement fails when a node requests an explicit device that is not // supported by the registered kernels if allow_soft_placement is no set. -TEST_F(SimplePlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) { +TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1125,7 +1123,7 @@ TEST_F(SimplePlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) { // Test that placement fails when a node requests an explicit device that is not // supported by the registered kernels if allow_soft_placement is no set. -TEST_F(SimplePlacerTest, TestNonExistentDevice) { +TEST_F(PlacerTest, TestNonExistentDevice) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1143,7 +1141,7 @@ TEST_F(SimplePlacerTest, TestNonExistentDevice) { "but available devices")); } -TEST_F(SimplePlacerTest, TestUnsupportedDeviceAllowSoftPlacement) { +TEST_F(PlacerTest, TestUnsupportedDeviceAllowSoftPlacement) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1160,7 +1158,7 @@ TEST_F(SimplePlacerTest, TestUnsupportedDeviceAllowSoftPlacement) { // Test that a graph with device type and reference constraints on // some of the ops will successfully assign nodes to the constrained // device, and colocate nodes with reference connections. -TEST_F(SimplePlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) { +TEST_F(PlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1194,7 +1192,7 @@ TEST_F(SimplePlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) { // Test that placement fails when two nodes have a reference connection // constraint, and each node requires a mutually incompatible device. -TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { +TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1212,7 +1210,7 @@ TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { // Test that a generator node follows its consumers (where there are several // consumer nodes on the same devices). -TEST_F(SimplePlacerTest, TestGeneratorNodeFollowsConsumerNode) { +TEST_F(PlacerTest, TestGeneratorNodeFollowsConsumerNode) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -1245,7 +1243,7 @@ TEST_F(SimplePlacerTest, TestGeneratorNodeFollowsConsumerNode) { // Test that a generator node does not follow its consumers (where there are // several consumers on different devices). -TEST_F(SimplePlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) { +TEST_F(PlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index caa60e58e6..bfdf967333 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -55,7 +55,7 @@ limitations under the License. namespace tensorflow { -// MasterSession wraps SimpleClientGraph in a reference counted object. +// MasterSession wraps ClientGraph in a reference counted object. // This way, MasterSession can clear up the cache mapping Run requests to // compiled graphs while the compiled graph is still being used. // @@ -63,10 +63,10 @@ namespace tensorflow { class MasterSession::ReffedClientGraph : public core::RefCounted { public: ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts, - std::unique_ptr<SimpleClientGraph> cg, + std::unique_ptr<ClientGraph> cg, const SessionOptions& session_opts, const StatsPublisherFactory& stats_publisher_factory, - SimpleGraphExecutionState* execution_state, bool is_partial, + GraphExecutionState* execution_state, bool is_partial, WorkerCacheInterface* worker_cache) : session_handle_(handle), client_graph_(std::move(cg)), @@ -87,7 +87,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { ~ReffedClientGraph() override { DeregisterPartitions(); } - const SimpleClientGraph* client_graph() { return client_graph_.get(); } + const ClientGraph* client_graph() { return client_graph_.get(); } std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step, int64 execution_count, @@ -186,7 +186,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { // Checks that the requested fetches can be computed from the provided feeds. Status CheckFetches(const RunStepRequestWrapper& req, const RunState* run_state, - SimpleGraphExecutionState* execution_state); + GraphExecutionState* execution_state); string DetailText(const Node& node, const NodeExecStats& ns) { int64 tot = 0; @@ -203,7 +203,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { private: const string session_handle_; - const std::unique_ptr<SimpleClientGraph> client_graph_; + const std::unique_ptr<ClientGraph> client_graph_; const SessionOptions session_opts_; const bool is_partial_; const DebugOptions& debug_opts_; @@ -813,7 +813,7 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats( // contention. Status MasterSession::ReffedClientGraph::CheckFetches( const RunStepRequestWrapper& req, const RunState* run_state, - SimpleGraphExecutionState* execution_state) { + GraphExecutionState* execution_state) { // Build the set of pending feeds that we haven't seen. std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; for (const auto& input : run_state->pending_inputs) { @@ -1028,12 +1028,12 @@ Status MasterSession::Create(GraphDef* graph_def, session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false); } - SimpleGraphExecutionStateOptions execution_options; + GraphExecutionStateOptions execution_options; execution_options.device_set = devices_.get(); execution_options.session_options = &session_opts_; { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( + TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph( graph_def, execution_options, &execution_state_)); } if (options.cluster_def != nullptr) { @@ -1142,7 +1142,7 @@ Status MasterSession::ListDevices(ListDevicesResponse* resp) const { Status MasterSession::Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp) { UpdateLastAccessTime(); - std::unique_ptr<SimpleGraphExecutionState> extended_execution_state; + std::unique_ptr<GraphExecutionState> extended_execution_state; { mutex_lock l(mu_); if (closed_) { @@ -1195,7 +1195,7 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, VLOG(1) << "Unseen hash " << hash << " for " << BuildGraphOptionsString(opts) << " is_partial = " << is_partial << "\n"; - std::unique_ptr<SimpleClientGraph> client_graph; + std::unique_ptr<ClientGraph> client_graph; TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); WorkerCacheInterface* worker_cache = get_worker_cache(); auto entry = new ReffedClientGraph( diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 33b9bfe631..51ea92da68 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/debugger_state_interface.h" #include "tensorflow/core/common_runtime/device_set.h" -#include "tensorflow/core/common_runtime/simple_graph_execution_state.h" +#include "tensorflow/core/common_runtime/graph_execution_state.h" #include "tensorflow/core/common_runtime/stats_publisher_interface.h" #include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/master_env.h" @@ -128,7 +128,7 @@ class MasterSession : public core::RefCounted { std::atomic<int64> partial_run_handle_counter_ = {0}; mutex mu_; - std::unique_ptr<SimpleGraphExecutionState> execution_state_ GUARDED_BY(mu_); + std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_); int64 graph_version_; // We keep a map from a signature of a run request to the |