diff options
author | Alexandre Passos <apassos@google.com> | 2018-09-21 10:37:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 10:41:09 -0700 |
commit | eafd43b7e47508fd0eddfd389ea206be79f5dbe6 (patch) | |
tree | 63cc4f06c238f0e9e8a49d9b5e8a97647db9f723 /tensorflow/contrib/eager | |
parent | 010e8ed731d0e10c82fccbf6c119180ca1a36efd (diff) |
Simple scaffold for parameter-server training with eager execution
PiperOrigin-RevId: 214007470
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/BUILD | 13 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/parameter_server.py | 289 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/remote_test.py | 20 |
3 files changed, 322 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 84517b57c7..9c3676629d 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -97,6 +97,18 @@ py_library( ], ) +py_library( + name = "parameter_server", + srcs = ["parameter_server.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + ], +) + cuda_py_test( name = "saver_test", srcs = ["saver_test.py"], @@ -241,6 +253,7 @@ py_test( srcs = ["remote_test.py"], srcs_version = "PY2AND3", deps = [ + ":parameter_server", ":remote", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py new file mode 100644 index 0000000000..3a9e7b027e --- /dev/null +++ b/tensorflow/contrib/eager/python/parameter_server.py @@ -0,0 +1,289 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""EXPERIMENTAL utilities for parameter server training with eager execution. + +Note: this should eventually be merged with the distribution strategy for +ParameterServer. +""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import time + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training.checkpointable import base as checkpointable + + +def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): + """Creates a variable handle with information to do shape inference.""" + container = ops.get_default_graph()._container # pylint: disable=protected-access + if container is None: + container = "" + handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, + shared_name=shared_name, + name=name, + container=container) + if graph_mode: + return handle + + with context.graph_mode(), ops.Graph().as_default() as graph: + h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, + shared_name=shared_name, + name=name, + container=container) + + # Tensor._handle_data contains information for the shape-inference code to + # know the shape and dtype of the variable pointed to by a handle. Since + # shape inference doesn't run in eager mode we copy this data here for when + # the handle is captured by an eager mode function. + # pylint: disable=protected-access + if ops._USE_C_SHAPES: + handle._handle_data = resource_variable_ops.get_resource_handle_data(h) + else: + if h._handle_data is None: + ops.set_shape_and_handle_data_for_outputs(h.op) + handle._handle_data = h._handle_data + # pylint: enable=protected-access + # Clean up op->graph->op reference cycles. + ops.dismantle_graph(graph) + return handle + + +class SharedVariable(resource_variable_ops.ResourceVariable): + """Experimental Variable designed for parameter server training. + + A SharedVariable has a name and two instances of SharedVariable with the + same name will have the same value, even if they are in different Sessions, + as long as they are placed on the same device. + + The storage associated with SharedVariables is also not deleted when they go + out of scope. + """ + + def __init__(self, # pylint: disable=super-init-not-called + initial_value=None, + trainable=True, + name=None, + dtype=None, + constraint=None, + initialize=True, + **unused_kwargs): + """Creates a variable. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. + (Note that initializer functions from init_ops.py must first be bound + to a shape before being used here.) + trainable: If `True`, automatically watches this variable on GradientTape + whenever it's used. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. + If None, either the datatype will be kept (if initial_value is + a Tensor) or float32 will be used (if it is a Python object convertible + to a Tensor). + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value + (which must have the same shape). Constraints are not safe to + use when doing asynchronous distributed training. + initialize: if True, runs initialization in eager execution; leaves the + variable uninitialized otherwise. + + Raises: + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + """ + if initial_value is None: + raise ValueError("initial_value must be specified.") + init_from_fn = callable(initial_value) + + if isinstance(initial_value, ops.Tensor) and hasattr( + initial_value, "graph") and initial_value.graph.building_function: + raise ValueError("Tensor-typed variable initializers must either be " + "wrapped in an init_scope or callable " + "(e.g., `tf.Variable(lambda : " + "tf.truncated_normal([10, 40]))`) when building " + "functions. Please file a feature request if this " + "restriction inconveniences you.") + + if constraint is not None and not callable(constraint): + raise ValueError("The `constraint` argument must be a callable.") + + if isinstance(initial_value, checkpointable.CheckpointInitialValue): + self._maybe_initialize_checkpointable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + + self._trainable = trainable + self._save_slice_info = None + # Store the graph key so optimizers know how to only retrieve variables from + # this graph. + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + with ops.init_scope(): + self._in_graph_mode = not context.executing_eagerly() + with ops.name_scope(name, "Variable", [] + if init_from_fn else [initial_value]) as name: + # pylint: disable=protected-access + handle_name = ops._name_from_scope_name(name) + shared_name = handle_name + if init_from_fn: + # Use attr_scope and device(None) to simulate the behavior of + # colocate_with when the variable we want to colocate with doesn't + # yet exist. + if self._in_graph_mode: + with ops.name_scope("Initializer"), ops.device(None): + initial_value = ops.convert_to_tensor( + initial_value(), name="initial_value", dtype=dtype) + self._handle = _eager_safe_variable_handle( + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=shared_name, + name=name, + graph_mode=self._in_graph_mode) + self._shape = initial_value.get_shape() + else: + initial_value = initial_value() + with ops.name_scope("Initializer"): + initial_value = ops.convert_to_tensor( + initial_value, name="initial_value", dtype=dtype) + self._handle = _eager_safe_variable_handle( + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=shared_name, + name=name, + graph_mode=False) + self._shape = initial_value.get_shape() + # pylint: enable=protected-access + + # Or get the initial value from a Tensor or Python object. + else: + with ops.name_scope("Initializer"): + initial_value = ops.convert_to_tensor( + initial_value, name="initial_value", dtype=dtype) + # pylint: disable=protected-access + if (self._in_graph_mode and initial_value is not None and + initial_value.op._get_control_flow_context() is not None): + raise ValueError( + "Initializer for variable %s is from inside a control-flow " + "construct, such as a loop or conditional. When creating a " + "variable inside a loop or conditional, use a lambda as the " + "initializer." % name) + # pylint: enable=protected-access + self._handle = _eager_safe_variable_handle( + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=shared_name, + name=name, + graph_mode=self._in_graph_mode) + self._shape = initial_value.get_shape() + + self._unique_id = shared_name + self._initial_value = initial_value if self._in_graph_mode else None + self._handle_name = handle_name + ":0" + self._dtype = initial_value.dtype.base_dtype + self._constraint = constraint + + if self._in_graph_mode: + with ops.name_scope("IsInitialized"): + self._is_initialized_op = ( + resource_variable_ops.var_is_initialized_op(self._handle)) + if initial_value is not None: + with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): + self._initializer_op = ( + resource_variable_ops.assign_variable_op( + self._handle, + self._try_guard_against_uninitialized_dependencies( + initial_value), + name=n)) + with ops.name_scope("Read"), ops.colocate_with(self._handle): + # Manually assign reads to the handle's device to avoid log + # messages. + with ops.device(self._handle.device): + value = self._read_variable_op() + self._graph_element = value + self._cached_value = None + else: + if initialize: + resource_variable_ops.assign_variable_op(self._handle, + initial_value) + self._is_initialized_op = None + self._initializer_op = None + self._graph_element = None + self._cached_value = None + + self._handle_deleter = None + self._cached_shape_as_list = None + + +@contextlib.contextmanager +def parameter_server_scope(is_chief, ps_job_name, num_ps_tasks): + """Strategy to use parameter servers in eager. + + Creates SharedVariable objects for variables created in this scope. These + SharedVariable objects will be placed round-robin on the parameter servers + specified by the ps_job_name and num_ps_tasks arguments. + + To use parameter servers you need only to wrap your model initialization in + this scope: + + ``` + with tf.contrib.eager.parameter_server_scope( + is_chief, ps_job_name, num_ps_tasks): + my_model = tf.keras.Sequential([...]) # Or + input = tf.keras.Input(...) + .... + my_model = tf.keras.Model(input, output) + my_model.compile(...) + # or other usages of the model. + ``` + + Args: + is_chief: Boolean. Whether this worker is responsible for initializing + variables. + ps_job_name: The name of the ps job in this cluster. + num_ps_tasks: The number of ps tasks to use. + + Yields: + a context manager. + """ + # Note: capturing in a list to allow assignment. + ps_index = [0] + + def variable_creator_scope(unused_next_creator, **kwargs): + kwargs["initialize"] = is_chief + with ops.device( + "/job:%s/task:%s" % (ps_job_name, ps_index[0] % num_ps_tasks)): + ps_index[0] += 1 + v = SharedVariable(**kwargs) + if not is_chief: + while not resource_variable_ops.var_is_initialized_op(v.handle): + time.sleep(10) + return v + + with variable_scope.variable_creator_scope(variable_creator_scope): + yield diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index 13029db975..ba6fe9701d 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -23,6 +23,7 @@ import os import numpy as np +from tensorflow.contrib.eager.python import parameter_server from tensorflow.contrib.eager.python import remote from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 @@ -33,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -120,6 +122,24 @@ class RemoteExecutionTest(test.TestCase): y = math_ops.matmul(x1, x2) np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + def testParameterServer(self): + with parameter_server.parameter_server_scope( + is_chief=True, ps_job_name=JOB_NAME, num_ps_tasks=3): + v0 = variables.Variable([1.0], name="v0") + v1 = variables.Variable([2.0], name="v1") + v0.assign(v0 * v1) + self.assertAllEqual(v0.read_value(), [2.0]) + self.assertAllEqual(v0.device, + "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME) + self.assertAllEqual(v1.device, + "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME) + v1.assign_add(v1) + # Simulate aliasing another variable of the same name as v1 + with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + v1_replica = parameter_server.SharedVariable( + [1.0], name="v1", initialize=False) + self.assertAllEqual(v1_replica.read_value(), [4.0]) + @run_sync_and_async def testSimpleWeightRead(self): """Basic remote eager weight read.""" |