aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2018-03-09 15:28:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 15:36:43 -0800
commitfaea16caaf84b065ecf5fd6706a597308984df71 (patch)
tree76d6e40aca002de0b812ce66f71a92c1aadf4aaf
parentbe51a9fac97d1497f59ecfc3a9aec4b5f84c9b76 (diff)
Copy `replicate_model_fn` to core.
PiperOrigin-RevId: 188547527
-rw-r--r--tensorflow/python/estimator/BUILD65
-rw-r--r--tensorflow/python/estimator/replicate_model_fn.py823
-rw-r--r--tensorflow/python/estimator/replicate_model_fn_test.py1709
3 files changed, 2597 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index e3a6708d67..04fcbb0e87 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -7,6 +7,7 @@ package(
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
filegroup(
name = "all_files",
@@ -35,6 +36,7 @@ py_library(
":linear",
":model_fn",
":parsing_utils",
+ ":replicate_model_fn",
":run_config",
":training",
"//tensorflow/python:util",
@@ -866,3 +868,66 @@ py_test(
"//tensorflow/python:training",
],
)
+
+py_library(
+ name = "replicate_model_fn",
+ srcs = [
+ "replicate_model_fn.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":export_output",
+ ":model_fn",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:device",
+ "//tensorflow/python:device_lib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/ops/losses",
+ "@six_archive//:six",
+ ],
+)
+
+cuda_py_test(
+ name = "replicate_model_fn_test",
+ size = "medium",
+ srcs = ["replicate_model_fn_test.py"],
+ additional_deps = [
+ "//tensorflow/python/estimator",
+ ":dnn",
+ ":export_export",
+ ":export_output",
+ ":model_fn",
+ ":numpy_io",
+ ":optimizers",
+ ":prediction_keys",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ":replicate_model_fn",
+ ],
+ tags = [
+ "multi_gpu",
+ ],
+)
diff --git a/tensorflow/python/estimator/replicate_model_fn.py b/tensorflow/python/estimator/replicate_model_fn.py
new file mode 100644
index 0000000000..7418852096
--- /dev/null
+++ b/tensorflow/python/estimator/replicate_model_fn.py
@@ -0,0 +1,823 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to replicate model_fn's over local GPUs.
+
+This file contains util that allow to replicate `Estimator.model_fn` over
+GPUs. Replicated version of a `model_fn` is returned that can subsequently
+be used with `Estimator`.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import defaultdict
+from contextlib import contextmanager
+import copy
+
+import six
+
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.client import device_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator import util
+from tensorflow.python.estimator.export import export_output as export_output_lib
+from tensorflow.python.framework import device as framework_device
+from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.training import device_setter as device_setter_lib
+from tensorflow.python.training import optimizer as optimizer_lib
+
+
+def _replicate_model_fn(model_fn,
+ loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
+ devices=None):
+ """Replicate `Estimator.model_fn` over GPUs.
+
+ The given `model_fn` specifies a single forward pass of a model. To replicate
+ such a model over GPUs, each GPU gets its own instance of the forward pass
+ (a.k.a. a tower). The input features and labels get sharded into the chunks
+ that correspond to the number of GPUs. Each tower computes a loss based
+ on its input. For each such loss, gradients are computed. After that, the
+ available losses are aggregated to form aggregated loss. Available
+ gradients are summed. Then, they update weights using the specified
+ optimizer.
+
+ If `devices` are `None`, then all available GPUs are going to be used for
+ replication. If no GPUs are available, then the model is going to be
+ placed on the CPU.
+
+ Two modes of local replication over available GPUs are supported:
+ 1) If exactly 1 GPU is detected, then variables and operations are placed
+ onto the GPU.
+ 2) If more than 1 GPU is detected, then variables are going to be placed on
+ the CPU. Replicas of operations are placed on each individual GPU.
+
+ Here is an example of how one might use their `model_fn` to run over GPUs:
+ ```python
+ ...
+ def model_fn(...): # See `model_fn` in `Estimator`.
+ loss = ...
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
+ optimizer = tf.contrib.estimator._TowerOptimizer(optimizer)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ # See the section below on `EstimatorSpec.train_op`.
+ return EstimatorSpec(mode=mode, loss=loss,
+ train_op=optimizer.minimize(loss))
+
+ # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`.
+ return EstimatorSpec(...)
+ ...
+ classifier = tf.estimator.Estimator(
+ model_fn=tf.contrib.estimator.replicate_model_fn(model_fn))
+ ```
+
+ Please see `DNNClassifierIntegrationTest` for an example with a canned
+ Estimator.
+
+ On `EstimatorSpec.train_op`:
+ `model_fn` returns `EstimatorSpec.train_op` for
+ `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer.
+ Towers are expected to populate it in the same way. Gradients from all towers
+ are reduced and applied in the last tower. To achieve that in the case of
+ multiple towers, `_TowerOptimizer` needs to be used. See `_TowerOptimizer`.
+
+ On sharding input features and labels:
+ Input features and labels are split for consumption by each tower. They are
+ split across the dimension 0. Features and labels need to be batch major.
+
+ On reduction algorithms:
+ Certain algorithms were chosen for aggregating results of computations on
+ multiple towers:
+ - Losses from all towers are reduced according to `loss_reduction`.
+ - Gradients from all towers are reduced according to `loss_reduction`
+ for each trainable variable.
+ - `eval_metrics_ops` are reduced per metric using `reduce_mean`.
+ - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are
+ reduced using concatenation.
+ - For all other fields of `EstimatorSpec` the values of the first tower
+ are taken.
+
+ On distribution of variables:
+ Variables are not duplicated between towers. Instead, they are placed on a
+ single device as defined above and shared across towers.
+
+ On overhead:
+ If only one device is specified, then aggregation of loss and gradients
+ doesn't happen. Replication consists of placing `model_fn` onto the
+ specified device.
+
+ On current limitations:
+ - `predictions` are not supported for `ModeKeys.EVAL`. They are required
+ for `tf.contrib.estimator.add_metrics`.
+
+ Args:
+ model_fn: `model_fn` as defined in `Estimator`. See the section above about
+ the train_op argument of `EstimatorSpec`.
+ loss_reduction: controls whether losses are summed or averaged.
+ devices: Optional list of devices to replicate the model across. This
+ argument can be used to replice only on the subset of available GPUs.
+ If `None`, then all available GPUs are going to be used for replication.
+ If no GPUs are available, then the model is going to be placed on the CPU.
+
+ Raises:
+ ValueError: if there is no `loss_reduction` or if _TowerOptimizer is
+ mis-used.
+
+ Returns:
+ A replicated version of the supplied `model_fn`. Returned function that
+ conforms to the requirements of `Estimator`'s `model_fn` and can be used
+ instead of the supplied `model_fn`.
+ """
+ return _replicate_model_fn_with_mode(
+ model_fn,
+ loss_reduction,
+ devices,
+ # TODO(isaprykin): Query the system configuration to choose modes other
+ # than `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often
+ # appropriate.
+ mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER)
+
+
+class _VariableDistributionMode(object):
+ """Modes for variable distribution used for forcing a particular one.
+
+ Forcing a mode is meant for performance experimentation purposes rather than
+ for general use cases.
+ """
+
+ SHARED_LOCAL_PARAMETER_SERVER = 1
+ """Variables are placed on a single device and shared across all devices.
+
+ Two ways to achieve this distribution over available GPUs are supported:
+ 1) If exactly 1 GPU is detected, then variables and operations are placed
+ onto GPU.
+ 2) If more than 1 GPU is detected, then variables are going to be placed on
+ the CPU. Replicas of operations are placed on each individual GPU.
+ """
+
+ SHARED_ROUND_ROBIN = 2
+ """Variables are placed on all devices in a round-robin fashion.
+
+ Every subsequent variable is placed on the next device. There is only one
+ copy of each variable that is shared across all devices.
+ """
+
+
+def _replicate_model_fn_with_mode(
+ model_fn,
+ loss_reduction,
+ devices=None,
+ mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER):
+ """A version of `replicate_model_fn` that allows to specify a `mode`."""
+ if loss_reduction == losses.Reduction.NONE:
+ raise ValueError('Tower losses need to be reduced in some way, yet {} '
+ 'reduction is specified.'.format(loss_reduction))
+ if not devices:
+ devices = _get_local_devices('GPU') or _get_local_devices('CPU')
+
+ is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0].upper()
+ consolidation_device = devices[0] if is_a_single_gpu_case else '/CPU:0'
+
+ ps_devices = [consolidation_device]
+ if mode == _VariableDistributionMode.SHARED_ROUND_ROBIN:
+ ps_devices = devices
+
+ tf_logging.info('Replicating the `model_fn` across {}. Variables are going '
+ 'to be placed on {}. Consolidation device is going to be {}.'
+ .format(devices, ps_devices, consolidation_device))
+
+ def single_device_model_fn(features, labels, mode, params=None, config=None):
+ """`model_fn` on a single device without reduction overhead."""
+ return _get_loss_towers(
+ model_fn=model_fn,
+ mode=mode,
+ features=[features],
+ labels=[labels],
+ params=params,
+ loss_reduction=loss_reduction,
+ config=config,
+ devices=devices,
+ local_ps_devices=ps_devices)[0] # One device, so one spec is out.
+
+ def replicated_model_fn(features, labels, mode, params=None, config=None):
+ """Replicated version of `model_fn` to be used instead."""
+ feature_shards, label_shards = _split_batch(
+ features, labels, len(devices), device=consolidation_device)
+ tower_specs = _get_loss_towers(
+ model_fn=model_fn,
+ mode=mode,
+ features=feature_shards,
+ labels=label_shards,
+ params=params,
+ loss_reduction=loss_reduction,
+ config=config,
+ devices=devices,
+ local_ps_devices=ps_devices)
+
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ train_op = _minimize_towers(tower_specs)
+ return _train_spec(
+ tower_specs, train_op, aggregation_device=consolidation_device)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ return _eval_spec(tower_specs, aggregation_device=consolidation_device)
+ elif mode == model_fn_lib.ModeKeys.PREDICT:
+ return _predict_spec(tower_specs, aggregation_device=consolidation_device)
+
+ if len(devices) == 1:
+ return single_device_model_fn
+ else:
+ return replicated_model_fn
+
+
+class _TowerOptimizer(optimizer_lib.Optimizer):
+ """Gathers gradients from all towers and reduces them in the last one."""
+
+ COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states'
+
+ def __init__(self, optimizer_or_optimizer_fn):
+ """Wrap an existing optimizer for gathering gradients across towers.
+
+ Each invocation of model_fn has to call the same optimizers in the same
+ order.
+
+ Multiple optimizers that use the same or different losses are supported.
+
+ If _TowerOptimizer is used but `replicate_model_fn` isn't, then no
+ aggregation will happen. All calls will simply be forwarded to the
+ underlying optimizer. The behavior is similar if there is only one tower.
+
+ If _TowerOptimizer is used together with SyncReplicasOptimizer that wraps
+ the user's optimizer, then it's the SyncReplicasOptimizer that needs to be
+ wrapped with _TowerOptimizer.
+
+ Args:
+ optimizer_or_optimizer_fn: an instance of optimizer to wrap. That
+ instance is going to be used for optimizer-specific logic. This can
+ also be a no-argument function that returns such an optimizer instance.
+ """
+ self._optimizer_or_optimizer_fn = optimizer_or_optimizer_fn
+
+ @staticmethod
+ def has_been_used():
+ return _TowerOptimizer._graph_state().has_tower_optimizer_been_used
+
+ def get_slot(self, *args, **kwargs):
+ return self._get_optimizer().get_slot(*args, **kwargs)
+
+ def get_slot_names(self, *args, **kwargs):
+ return self._get_optimizer().get_slot_names(*args, **kwargs)
+
+ def get_name(self, *args, **kwargs):
+ return self._get_optimizer().get_name(*args, **kwargs)
+
+ def variables(self, *args, **kwargs):
+ return self._get_optimizer().variables(*args, **kwargs)
+
+ def compute_gradients(self, loss, *args, **kwargs):
+ """Compute gradients, but first, if needed, scale the loss."""
+ loss = _scale_loss(loss,
+ self._graph_state().loss_reduction,
+ self._graph_state().number_of_towers)
+ return self._get_optimizer().compute_gradients(loss, *args, **kwargs)
+
+ def apply_gradients(self, grads_and_vars, global_step=None, **kwargs):
+ """Collect gradients updates to apply them with the last tower."""
+ if self._graph_state().number_of_towers == 1:
+ # Avoid the overhead of reduction if there's only one tower.
+ #
+ # There assumed to be only one tower if aggregation-related methods were
+ # not called by `_get_loss_towers`, for example if the model_fn uses
+ # TowerEstimator, but `replicate_model_fn` isn't used.
+ return self._get_optimizer().apply_gradients(grads_and_vars, global_step,
+ **kwargs)
+
+ self._graph_state().collect_gradients(grads_and_vars)
+
+ if not self._graph_state().is_the_last_tower:
+ with ops_lib.control_dependencies(_extract_tensors(grads_and_vars)):
+ return self._construct_no_op_train_op()
+ else:
+ # Gradients need to be gathered and applied in the scope of the first
+ # tower, so that the tensors are accessible via names without prefixes.
+ var_scope, name_scope = self._graph_state().scopes_of_the_first_tower
+ with variable_scope.variable_scope(var_scope):
+ with ops_lib.name_scope(name_scope):
+ return self._apply_gathered_gradients(global_step, **kwargs)
+
+ def _apply_gathered_gradients(self, global_step, **kwargs):
+ graph_state = self._graph_state()
+ optimizer = self._get_optimizer()
+
+ grad_lists = {}
+ for grad, var in graph_state.get_latest_gradients_from_all_towers():
+ if grad is not None:
+ grad_lists.setdefault(var, []).append(grad)
+
+ aggregated_grads = []
+ with ops_lib.name_scope('gradient_aggregating'):
+ for var, grads in six.iteritems(grad_lists):
+ grad = _compute_sum_on_device(grads, var.device)
+ aggregated_grads.append((grad, var))
+ return optimizer.apply_gradients(
+ aggregated_grads, global_step=global_step, **kwargs)
+
+ def _get_optimizer(self):
+ if callable(self._optimizer_or_optimizer_fn):
+ # If optimizer is given as a function then we need to wait till we are
+ # under the right graph context before constructing it. That's why the
+ # optimizer is constructed in _get_optimizer() rather than __init__().
+ self._optimizer_or_optimizer_fn = self._optimizer_or_optimizer_fn()
+ self._graph_state().has_tower_optimizer_been_used = True
+ return self._optimizer_or_optimizer_fn
+
+ def _construct_no_op_train_op(self):
+ return control_flow_ops.no_op(name='train_op_placeholder')
+
+ @staticmethod
+ def _graph_state():
+ graph_states = ops_lib.get_default_graph().get_collection_ref(
+ _TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)
+ if not graph_states:
+ graph_states.append(_TowerOptimizer._PerGraphState())
+ return graph_states[-1]
+
+ @staticmethod
+ def _did_towers_have_same_optimizer_calls():
+ graph_state = _TowerOptimizer._graph_state()
+ return graph_state.did_towers_have_same_optimizer_calls()
+
+ @staticmethod
+ def _clear_graph_state():
+ # Clearing the Graph collection will prevent _PerGraphState from being
+ # serialized.
+ ops_lib.get_default_graph().clear_collection(
+ _TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)
+
+ class _PerGraphState(object):
+ """Gradient reduction related state of a Tensorflow graph."""
+
+ def __init__(self):
+ self._collected_grads_and_vars = defaultdict(list)
+ self._current_tower_index = 0
+ self._number_of_towers = 1
+ self._loss_reduction = None
+ # Scopes of the first tower that don't have a prefix:
+ self._variable_scope = None
+ self._name_scope = None
+ # If needed, alert that _TowerOptimizer needs to be used with model_fn.
+ self._has_tower_optimizer_been_used = False
+
+ def collect_gradients(self, grads_and_vars):
+ self._collected_grads_and_vars[self._current_tower_index].append(
+ grads_and_vars)
+
+ def get_latest_gradients_from_all_towers(self):
+ """Get gradients across towers for the last called optimizer."""
+ grads_and_vars = []
+ index_of_last_gradients = len(
+ self._collected_grads_and_vars[self._current_tower_index]) - 1
+ for tower_id in range(self._current_tower_index + 1):
+ grads_and_vars.extend(
+ self._collected_grads_and_vars[tower_id][index_of_last_gradients])
+ return grads_and_vars
+
+ def set_reduction_across_towers(self, loss_reduction, number_of_towers):
+ self._loss_reduction = loss_reduction
+ self._number_of_towers = number_of_towers
+
+ @contextmanager
+ def tower(self, tower_id, var_scope, name_scope):
+ if tower_id == 0:
+ self._variable_scope = var_scope
+ self._name_scope = name_scope
+ self._current_tower_index = tower_id
+ yield
+
+ @property
+ def scopes_of_the_first_tower(self):
+ return self._variable_scope, self._name_scope
+
+ @property
+ def is_the_last_tower(self):
+ return self._current_tower_index == (self._number_of_towers - 1)
+
+ @property
+ def number_of_towers(self):
+ return self._number_of_towers
+
+ @property
+ def loss_reduction(self):
+ return self._loss_reduction
+
+ @property
+ def has_tower_optimizer_been_used(self):
+ return self._has_tower_optimizer_been_used
+
+ @has_tower_optimizer_been_used.setter
+ def has_tower_optimizer_been_used(self, value):
+ self._has_tower_optimizer_been_used = value
+
+ def did_towers_have_same_optimizer_calls(self):
+ total_number_of_grads = sum([
+ len(grads)
+ for _, grads in six.iteritems(self._collected_grads_and_vars)
+ ])
+ return total_number_of_grads % self._number_of_towers == 0
+
+
+def _get_local_devices(device_type):
+ local_device_protos = device_lib.list_local_devices()
+ return [
+ device.name
+ for device in local_device_protos
+ if device.device_type == device_type
+ ]
+
+
+def _split_batch(features, labels, number_of_shards, device):
+ """Split input features and labes into batches."""
+
+ def ensure_divisible_by_shards(sequence):
+ batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0]
+ if batch_size % number_of_shards != 0:
+ raise ValueError(
+ 'Batch size {} needs to be divisible by the number of GPUs, which '
+ 'is {}.'.format(batch_size, number_of_shards))
+
+ def split_dictionary(dictionary):
+ """Split a dictionary into shards."""
+ shards = [{} for _ in range(number_of_shards)]
+ for name, tensor in six.iteritems(dictionary):
+ if isinstance(tensor, sparse_tensor.SparseTensor):
+ for i, shard in enumerate(
+ sparse_ops.sparse_split(
+ sp_input=tensor, num_split=number_of_shards, axis=0)):
+ shards[i][name] = shard
+ else:
+ ensure_divisible_by_shards(tensor)
+ for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
+ shards[i][name] = shard
+ return shards
+
+ with ops_lib.name_scope('split_inputs'):
+ with ops_lib.device(device):
+ if isinstance(features, dict):
+ feature_shards = split_dictionary(features)
+ else:
+ ensure_divisible_by_shards(features)
+ feature_shards = array_ops.split(features, number_of_shards)
+
+ if labels is None:
+ label_shards = None
+ elif isinstance(labels, dict):
+ label_shards = split_dictionary(labels)
+ else:
+ ensure_divisible_by_shards(labels)
+ label_shards = array_ops.split(labels, number_of_shards)
+ return feature_shards, label_shards
+
+
+_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}'
+
+
+def _get_loss_towers(model_fn,
+ mode,
+ features,
+ labels,
+ params,
+ config,
+ devices,
+ local_ps_devices,
+ loss_reduction,
+ name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
+ """Replicate the loss computation across devices."""
+ tower_specs = []
+
+ model_fn_args = util.fn_args(model_fn)
+ optional_params = {}
+ if 'params' in model_fn_args:
+ optional_params['params'] = copy.deepcopy(params)
+ if 'config' in model_fn_args:
+ optional_params['config'] = copy.deepcopy(config)
+
+ # pylint: disable=protected-access
+ round_robin_strategy = device_setter_lib._RoundRobinStrategy(
+ num_tasks=len(local_ps_devices))
+ _TowerOptimizer._graph_state().set_reduction_across_towers(
+ loss_reduction, len(devices))
+
+ for i, device in enumerate(devices):
+ is_the_first_tower = (i == 0)
+
+ device_setter = _local_device_setter(
+ worker_device=device,
+ ps_devices=local_ps_devices,
+ ps_strategy=round_robin_strategy)
+
+ # We would like to preserve the names of the variables and ops that the user
+ # might be relying on. Names without a prefix are going to resolve to
+ # variables and ops of the first tower.
+ name_scope = name_scope_pattern
+ if is_the_first_tower:
+ name_scope = ''
+
+ with variable_scope.variable_scope(
+ '', reuse=not is_the_first_tower) as var_scope:
+ with ops_lib.name_scope(name_scope.format(i)) as name_scope:
+ with _TowerOptimizer._graph_state().tower(
+ tower_id=i, var_scope=var_scope, name_scope=name_scope):
+ with ops_lib.device(device_setter):
+ labels_shard = None
+ if labels:
+ labels_shard = labels[i]
+
+ tower_spec = model_fn(
+ mode=mode,
+ features=features[i],
+ labels=labels_shard,
+ **optional_params)
+
+ if (tower_spec.train_op is not None and len(devices) > 1 and
+ not _TowerOptimizer.has_been_used()):
+ raise ValueError('Please wrap optimizers with _TowerOptimizer'
+ ' in order to use replicate_model_fn with'
+ ' multiple `devices`.')
+
+ # Scaling the loss here doesn't actually affect gradients. Another
+ # instance of scaling happens inside the _TowerOptimizer.
+ tower_spec = _scale_tower_loss(
+ tower_spec, loss_reduction, number_of_towers=len(devices))
+ tower_specs.append(tower_spec)
+
+ if not _TowerOptimizer._did_towers_have_same_optimizer_calls():
+ raise ValueError('Each invocation of model_fn was supposed to make the same'
+ ' optimizer calls.')
+ _TowerOptimizer._clear_graph_state()
+ # pylint: enable=protected-access
+ return tower_specs
+
+
+def _local_device_setter(worker_device, ps_devices, ps_strategy):
+ """A device setter that puts distributes Var/Ops to PS/workers."""
+ ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
+
+ def local_device_chooser(op):
+ current_device = framework_device.DeviceSpec.from_string(op.device or '')
+
+ node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
+ if node_def.op in ps_ops:
+ ps_device_spec = framework_device.DeviceSpec.from_string(
+ '{}'.format(ps_devices[ps_strategy(op)]))
+
+ ps_device_spec.merge_from(current_device)
+ return ps_device_spec.to_string()
+ else:
+ worker_device_spec = framework_device.DeviceSpec.from_string(
+ worker_device or '')
+ worker_device_spec.merge_from(current_device)
+ return worker_device_spec.to_string()
+
+ return local_device_chooser
+
+
+def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers):
+ """Produce an EstimatorSpec with approproriately scaled loss."""
+ if tower_spec.loss is None:
+ return tower_spec
+
+ estimator_spec = _asdict(tower_spec)
+ estimator_spec['loss'] = _scale_loss(tower_spec.loss, loss_reduction,
+ number_of_towers)
+ return model_fn_lib.EstimatorSpec(**estimator_spec)
+
+
+def _scale_loss(loss, loss_reduction, number_of_towers):
+ """If needed, scale down the loss for averaging loss by summing."""
+ if loss is None:
+ return None
+ if number_of_towers == 1:
+ return loss
+
+ if loss_reduction != losses.Reduction.SUM:
+ return math_ops.div(loss, 1.0 * number_of_towers, name='averaged_loss')
+ else:
+ return loss
+
+
+def _minimize_towers(tower_specs):
+ """`train_op` of the last tower applies aggregated gradients."""
+ return tower_specs[-1].train_op
+
+
+def _compute_sum_on_device(values, device, name=None):
+ with ops_lib.device(device):
+ if isinstance(values[0], ops_lib.IndexedSlices):
+ if name:
+ raise ValueError('The name {} is not expected to be given to '
+ 'IndexedSlices {}'.format(name, values))
+
+ values_concat = array_ops.concat([v.values for v in values], axis=0)
+ indices_concat = array_ops.concat([v.indices for v in values], axis=0)
+ return ops_lib.IndexedSlices(values_concat, indices_concat,
+ values[0].dense_shape)
+ else:
+ return math_ops.add_n(values, name=name)
+
+
+def _train_spec(tower_specs,
+ train_op,
+ aggregation_device,
+ aggregated_loss_name='loss'):
+ """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`."""
+ # Spec of the last tower is used as the template for the final spec, because
+ # some `EstimatorSpec.training_hooks` rely on calls made in model_fn. For
+ # example, `SyncReplicasOptimizerHook` validates the
+ # `SyncReplicasOptimizer.apply_gradients` call. `TowerEstimator` makes that
+ # call only in the last tower.
+ estimator_spec = _asdict(tower_specs[-1])
+ estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN
+ estimator_spec['train_op'] = train_op
+ estimator_spec['loss'] = _compute_sum_on_device(
+ [spec.loss for spec in tower_specs], aggregation_device,
+ aggregated_loss_name)
+ return model_fn_lib.EstimatorSpec(**estimator_spec)
+
+
+def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'):
+ """Populate replicated EstimatorSpec for `GraphKeys.EVAL`."""
+ estimator_spec = _asdict(tower_specs[0])
+ estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL
+ estimator_spec['loss'] = _compute_sum_on_device(
+ [spec.loss for spec in tower_specs], aggregation_device,
+ aggregated_loss_name)
+
+ update_ops = []
+ for tower_spec in tower_specs:
+ for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops):
+ update_ops.append(update_op)
+
+ with ops_lib.control_dependencies(update_ops):
+ reduced_update_op = _reduce_metric_variables(len(tower_specs))
+
+ eval_metric_ops = {}
+ for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops):
+ eval_metric_ops[name] = (metric_tensor, reduced_update_op)
+ estimator_spec['eval_metric_ops'] = eval_metric_ops
+ return model_fn_lib.EstimatorSpec(**estimator_spec)
+
+
+def _reduce_metric_variables(number_of_towers):
+ """Aggregate local variables used in metrics into the first tower."""
+ if number_of_towers == 1:
+ return control_flow_ops.no_op(name='no_eval_metric_reduction')
+
+ metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)
+ variables_per_tower = len(metric_variables) // number_of_towers
+
+ if len(metric_variables) % number_of_towers != 0:
+ raise ValueError(
+ 'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.'
+ ' Expected {} local variables, but got {} instead.'.format(
+ variables_per_tower * number_of_towers, len(metric_variables)))
+
+ # `metric_variables` has the size of `variables_per_tower` x
+ # number_of_towers. Each tower is produced by calling the same model_fn.
+ # First `variables_per_tower` correspond to the first tower. Each such
+ # variable has an replica at the `(variables_per_tower * i)` position, where
+ # `i` is `[1.. number_of_towers]`. We are going to add values from replicas
+ # to each variable of the first tower. We then zero out replica values, so
+ # that `_reduce_metric_variables` operation is idempotent. If a metric
+ # is then computed based on local variables from the first tower, then the
+ # resulting metric is an estimate for all `number_of_towers` towers.
+ ops = []
+ for i in range(0, variables_per_tower):
+ next_replica_id = i + variables_per_tower
+ replicas = [
+ metric_variables[replica_id]
+ for replica_id in range(next_replica_id, len(metric_variables),
+ variables_per_tower)
+ ] # `replicas` doesn't contain the first-tower variable.
+
+ reduce_op = state_ops.assign_add(metric_variables[i],
+ math_ops.add_n(replicas))
+
+ with ops_lib.control_dependencies([reduce_op]):
+ for replica in replicas:
+ zeros_for_replica = array_ops.zeros(
+ array_ops.shape(replica), dtype=replica.dtype)
+ zero_out_replica_op = state_ops.assign(replica, zeros_for_replica)
+ ops.append(zero_out_replica_op)
+
+ return control_flow_ops.group(*ops)
+
+
+def _predict_spec(tower_specs, aggregation_device):
+ """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`."""
+ estimator_spec = _asdict(tower_specs[0])
+ estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT
+
+ with ops_lib.device(aggregation_device):
+ estimator_spec['predictions'] = _concat_tensor_dicts(
+ *[tower_spec.predictions for tower_spec in tower_specs])
+
+ export_outputs_dict = _dict_concat(
+ *[tower_spec.export_outputs for tower_spec in tower_specs])
+
+ export_outputs = {}
+ for name, export_output_list in six.iteritems(export_outputs_dict):
+ if isinstance(export_output_list[0], export_output_lib.PredictOutput):
+ export_outputs[name] = export_output_lib.PredictOutput(
+ outputs=_concat_tensor_dicts(*[
+ export_output.outputs for export_output in export_output_list
+ ]))
+ elif isinstance(export_output_list[0],
+ export_output_lib.RegressionOutput):
+ export_outputs[name] = export_output_lib.RegressionOutput(
+ value=array_ops.concat(
+ [export_output.value for export_output in export_output_list],
+ axis=0))
+ elif isinstance(export_output_list[0],
+ export_output_lib.ClassificationOutput):
+ scores = None
+ if export_output_list[0].scores is not None:
+ scores = array_ops.concat(
+ [export_output.scores for export_output in export_output_list],
+ axis=0)
+
+ classes = None
+ if export_output_list[0].classes is not None:
+ classes = array_ops.stack(
+ [export_output.classes for export_output in export_output_list],
+ axis=0)
+
+ export_outputs[name] = export_output_lib.ClassificationOutput(
+ scores=scores, classes=classes)
+
+ estimator_spec['export_outputs'] = export_outputs
+ return model_fn_lib.EstimatorSpec(**estimator_spec)
+
+
+def _concat_tensor_dicts(*tensor_dicts):
+ return {
+ name: array_ops.concat(tensors, axis=0, name=name)
+ for name, tensors in six.iteritems(_dict_concat(*tensor_dicts))
+ }
+
+
+def _extract_tensors(tensors_and_vars):
+ tensors = []
+ for tensor_and_var in tensors_and_vars:
+ tensor, _ = tensor_and_var
+ if isinstance(tensor, ops_lib.IndexedSlices):
+ tensors.append(tensor.values)
+ elif tensor is not None:
+ tensors.append(tensor)
+ return tensors
+
+
+def _dict_concat(*dicts):
+ list_dict = {}
+ for d in dicts:
+ if d is None:
+ continue
+
+ for k, v in six.iteritems(d):
+ list_dict.setdefault(k, []).append(v)
+ return list_dict
+
+
+def _asdict(namedtuple):
+ """Returns a namedtuple as a dictionary.
+
+ This is required because `_asdict()` in Python 3.x.x is broken in classes
+ that inherit from `collections.namedtuple`. See
+ https://bugs.python.org/issue24931 for more details.
+
+ Args:
+ namedtuple: An object that inherits from `collections.namedtuple`.
+
+ Returns:
+ A dictionary version of the tuple.
+ """
+ return {k: getattr(namedtuple, k) for k in namedtuple._fields}
diff --git a/tensorflow/python/estimator/replicate_model_fn_test.py b/tensorflow/python/estimator/replicate_model_fn_test.py
new file mode 100644
index 0000000000..b6dd4e981f
--- /dev/null
+++ b/tensorflow/python/estimator/replicate_model_fn_test.py
@@ -0,0 +1,1709 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utilities that replicate `Estimator.model_fn` over GPUs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+import shutil
+import tempfile
+import numpy as np
+import six
+
+from tensorflow.python.estimator import estimator as estimator_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator import replicate_model_fn
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.estimator.canned import optimizers
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.export import export_output
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import losses
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import adam
+from tensorflow.python.training import device_setter
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import training
+
+
+# TODO(isaprykin): Parametrize all the tests on
+# replicate_model_fn._VariableDistributionMode when it's supported.
+class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def test_complete_flow_with_public_version(self):
+ return self._complete_flow_with_mode(mode=None)
+
+ def test_complete_flow_with_mode_local_ps_server(self):
+ return self._complete_flow_with_mode(
+ replicate_model_fn._VariableDistributionMode.
+ SHARED_LOCAL_PARAMETER_SERVER)
+
+ def test_complete_flow_with_mode_round_robin(self):
+ return self._complete_flow_with_mode(
+ replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN)
+
+ def _complete_flow_with_mode(self, mode):
+ n_classes = 3
+ input_dimension = 2
+ batch_size = 12
+
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ x_data = data.reshape(batch_size, input_dimension)
+ categorical_data = np.random.random_integers(
+ 0, len(x_data), size=len(x_data))
+ y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data,
+ 'categories': categorical_data},
+ y=y_data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data,
+ 'categories': categorical_data},
+ y=y_data,
+ batch_size=batch_size,
+ shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data,
+ 'categories': categorical_data},
+ batch_size=batch_size,
+ shuffle=False)
+
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,)),
+ feature_column.embedding_column(
+ feature_column.categorical_column_with_vocabulary_list(
+ 'categories',
+ vocabulary_list=np.linspace(
+ 0., len(x_data), len(x_data), dtype=np.int64)), 1)
+ ]
+
+ def optimizer_fn():
+ return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
+
+ estimator = dnn.DNNClassifier(
+ hidden_units=(2, 2),
+ # Adagrad is configured with `get_optimizer_instance`, so the function
+ # form of `_TowerOptimizer.__init__` is used.
+ optimizer=replicate_model_fn._TowerOptimizer(optimizer_fn),
+ feature_columns=feature_columns,
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ if not mode: # Use the public `replicate_model_fn`.
+ model_fn = replicate_model_fn._replicate_model_fn(
+ estimator.model_fn, devices=['/gpu:0', '/gpu:1', '/gpu:2'])
+ else:
+ model_fn = replicate_model_fn._replicate_model_fn_with_mode(
+ estimator.model_fn,
+ devices=['/gpu:0', '/gpu:1', '/gpu:2'],
+ loss_reduction=losses.Reduction.SUM,
+ mode=mode)
+
+ estimator = estimator_lib.Estimator(
+ model_fn=model_fn,
+ model_dir=estimator.model_dir,
+ config=estimator.config,
+ params=estimator.params)
+
+ num_steps = 10
+ estimator.train(train_input_fn, steps=num_steps)
+
+ scores = estimator.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ predicted_proba = np.array([
+ x[prediction_keys.PredictionKeys.PROBABILITIES]
+ for x in estimator.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
+
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = estimator.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ # Nothing should be left in the graph so that it doesn't get serialized.
+ self.assertFalse(ops_lib.get_default_graph().get_collection_ref(
+ replicate_model_fn._TowerOptimizer.COLLECTION_FOR_GRAPH_STATES))
+
+ def _as_label(self, data_in_float):
+ return np.rint(data_in_float).astype(np.int64)
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+
+class ReplicateModelTest(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.multiply(features, c)
+
+ loss = losses.absolute_difference(
+ labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+ loss = math_ops.reduce_sum(loss)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+
+ optimizer = replicate_model_fn._TowerOptimizer(
+ gradient_descent.GradientDescentOptimizer(params['learning_rate']))
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=metrics,
+ predictions={'probabilities': predictions},
+ train_op=optimizer.minimize(loss))
+
+ @property
+ def params(self):
+ params = {}
+ params['learning_rate'] = 1.0
+ return params
+
+ def test_train(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn,
+ loss_reduction=losses.Reduction.SUM,
+ devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+ session.run(variables.global_variables_initializer())
+
+ # loss = feature * c - label
+ total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
+ self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+ # derivative of loss = (1*c - 1) + (2*c - 2) is 3.
+ # new value of c = 10 - learning rate * 3 = 7.0.
+ session.run(estimator_spec.train_op)
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(7.0, session.run(c))
+
+ def test_train_with_mean_reduction(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session() as session:
+ # Add another trainable variable that doesn't produce a gradient to
+ # verify that None gradients are supported.
+ _ = variable_scope.get_variable(
+ 'another_variable',
+ initializer=constant_op.constant(1, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+ session.run(variables.global_variables_initializer())
+
+ # loss = feature * c - label
+ total_loss = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)) / 2.0
+ self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+ # derivative of loss = (1*c - 1)/2 + (2*c - 2)/2 is 1.5.
+ # It's the same computation as without mean reduction, but the
+ # loss from every tower is scaled by 1/<number of towers>.
+ # new value of c = 10 - learning rate * 1.5 = 8.5
+ session.run(estimator_spec.train_op)
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(8.5, session.run(c))
+
+ def test_train_two_steps_collected_gradients_are_reset_between_steps(self):
+ with ops_lib.Graph().as_default():
+ features = array_ops.placeholder(dtypes.float64)
+ labels = array_ops.placeholder(dtypes.float64)
+
+ feature_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
+ label_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
+
+ # loss = feature * c - label
+ expected_losses = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0),
+ (1.5 * 7.0 - 1.5) + (2.5 * 7.0 - 2.5))
+ # Derivative of the loss is 1.0 + 2.0 for the first step and 1.5 + 2.5
+ # for the second.
+ expected_c = 10.0 - 3.0, 7.0 - 4.0
+
+ with self.test_session() as session, variable_scope.variable_scope(
+ '', reuse=variable_scope.AUTO_REUSE):
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn,
+ loss_reduction=losses.Reduction.SUM,
+ devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+ session.run(variables.global_variables_initializer())
+
+ for feature_input, label_input, loss, weight in zip(
+ feature_inputs, label_inputs, expected_losses, expected_c):
+ feeds = {features: feature_input, labels: label_input}
+
+ self.assertEqual(loss, session.run(estimator_spec.loss, feeds))
+
+ session.run(estimator_spec.train_op, feeds)
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(weight, session.run(c, feeds))
+
+ def test_eval(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn,
+ loss_reduction=losses.Reduction.SUM,
+ devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
+ session.run(variables.local_variables_initializer())
+ session.run(variables.global_variables_initializer())
+
+ accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+ auc, b = estimator_spec.eval_metric_ops['auc']
+
+ session.run([a, b])
+ accuracy = session.run(accuracy)
+ auc = session.run(auc)
+
+ # loss[i] = features[i] * 10 - labels[i].
+ # Accuracy is 0.0 (no match) in the first tower.
+ # Accuracy is 1.0 (match) in the second tower, since the feature
+ # times weight "c" happened to be equal to the label.
+ total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
+
+ self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
+ self.assertEqual(0, auc)
+ self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
+
+ def test_eval_with_mean_reduction(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
+ session.run(variables.local_variables_initializer())
+ session.run(variables.global_variables_initializer())
+
+ accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+ auc, b = estimator_spec.eval_metric_ops['auc']
+
+ session.run([a, b])
+ accuracy = session.run(accuracy)
+ auc = session.run(auc)
+
+ # loss[i] = features[i] * 10 - labels[i].
+ # Accuracy is 0.0 (no match) in the first tower.
+ # Accuracy is 1.0 (match) in the second tower, since the feature
+ # times weight "c" happened to be equal to the label.
+ total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) / 2.0
+
+ self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
+ self.assertEqual(0, auc)
+ self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
+
+ def test_predict(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
+ session.run(variables.global_variables_initializer())
+
+ self.assertAllClose({
+ 'probabilities': np.array([[0.1], [0.02]])
+ }, session.run(estimator_spec.predictions))
+
+ def test_train_single_tower(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+ session.run(variables.global_variables_initializer())
+
+ # loss = feature * c - label
+ total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
+ self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+ # loss' of c is 3.
+ # new value of c = 10 - learning rate * 3 = 7.0.
+ session.run(estimator_spec.train_op)
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(7.0, session.run(c))
+
+ def test_eval_single_tower(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
+ session.run(variables.local_variables_initializer())
+ session.run(variables.global_variables_initializer())
+
+ accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+ auc, b = estimator_spec.eval_metric_ops['auc']
+
+ session.run([a, b])
+ accuracy = session.run(accuracy)
+ auc = session.run(auc)
+
+ # Accuracy is 0.0 (no match) in the first tower.
+ # Accuracy is 1.0 (match) in the second tower, since the feature
+ # times weight "c" happened to be equal to the label.
+ total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
+
+ self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
+ self.assertEqual(0, auc)
+ self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
+
+ def test_predict_single_tower(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
+ session.run(variables.global_variables_initializer())
+
+ self.assertAllClose({
+ 'probabilities': np.array([[0.1], [0.02]])
+ }, session.run(estimator_spec.predictions))
+
+ def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self):
+ features = np.array([[1.0], [2.0], [3.0]])
+ labels = np.array([[1.0], [2.0], [3.0]])
+
+ with self.assertRaisesRegexp(
+ ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'):
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0', '/gpu:1'])
+ _ = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+
+ def test_unsupported_loss_reduction(self):
+ with self.assertRaisesRegexp(ValueError,
+ '.+none.+reduction.+is.+specified.+'):
+ _ = replicate_model_fn._replicate_model_fn(self.model_fn,
+ losses.Reduction.NONE)
+
+ def test_places_on_gpu_with_upper_case_spelling(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session():
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/GPU:0'])
+ _ = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:0', c.device)
+
+ def test_places_on_gpu_with_lower_case_spelling(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session():
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0'])
+ _ = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:0', c.device)
+
+
+class ReplicateAcrossASingleDeviceWithoutTowerOptimizer(
+ test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.multiply(features, c)
+
+ loss = losses.absolute_difference(
+ labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+ loss = math_ops.reduce_sum(loss)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+
+ optimizer = gradient_descent.GradientDescentOptimizer(
+ params['learning_rate'])
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=metrics,
+ predictions={'probabilities': predictions},
+ train_op=optimizer.minimize(loss))
+
+ @property
+ def params(self):
+ params = {}
+ params['learning_rate'] = 1.0
+ return params
+
+ def test_train_single_tower(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0'])
+ estimator_spec = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+ session.run(variables.global_variables_initializer())
+
+ # loss = feature * c - label
+ total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
+ self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+ # loss' of c is 3.
+ # new value of c = 10 - learning rate * 3 = 7.0.
+ session.run(estimator_spec.train_op)
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(7.0, session.run(c))
+
+
+class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ features = features['features']
+ predictions = math_ops.multiply(features, c)
+
+ loss = losses.absolute_difference(
+ labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+ loss = math_ops.reduce_sum(loss)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+
+ optimizer = replicate_model_fn._TowerOptimizer(
+ gradient_descent.GradientDescentOptimizer(params['learning_rate']))
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=metrics,
+ predictions={'probabilities': predictions},
+ train_op=optimizer.minimize(loss))
+
+ @property
+ def params(self):
+ params = {}
+ params['learning_rate'] = 1.0
+ return params
+
+ def test_train_single_tower(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'features': features}, y=labels, batch_size=2, shuffle=False)
+
+ with self.test_session():
+ estimator = estimator_lib.Estimator(
+ model_fn=self.model_fn,
+ model_dir=tempfile.mkdtemp(),
+ params=self.params)
+ estimator.train(train_input_fn, steps=1)
+
+ self.assertEqual(7.0, estimator.get_variable_value('c'))
+
+
+class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ features = features['features']
+ predictions = math_ops.multiply(features, c)
+
+ loss = losses.absolute_difference(
+ labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+ loss = math_ops.reduce_sum(loss)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+
+ optimizer = gradient_descent.GradientDescentOptimizer(
+ params['learning_rate'])
+ optimizer = training.SyncReplicasOptimizer(
+ optimizer, replicas_to_aggregate=1)
+ sync_hook = optimizer.make_session_run_hook(True)
+ optimizer = replicate_model_fn._TowerOptimizer(optimizer)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=metrics,
+ training_hooks=[sync_hook],
+ predictions={'probabilities': predictions},
+ train_op=optimizer.minimize(
+ loss, global_step=training.get_global_step()))
+
+ @property
+ def params(self):
+ params = {}
+ params['learning_rate'] = 1.0
+ return params
+
+ def test_train_multiple_towers(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'features': features}, y=labels, batch_size=2, shuffle=False)
+
+ model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn,
+ loss_reduction=losses.Reduction.SUM,
+ devices=['/gpu:0', '/gpu:1'])
+
+ estimator = estimator_lib.Estimator(
+ model_fn=model_fn, model_dir=tempfile.mkdtemp(), params=self.params)
+ estimator.train(train_input_fn, steps=1)
+
+ self.assertEqual(7.0, estimator.get_variable_value('c'))
+
+
+class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ side_effects = variable_scope.get_variable(
+ 'side_effects',
+ initializer=constant_op.constant(0, dtype=dtypes.float64),
+ dtype=dtypes.float64,
+ use_resource=True,
+ trainable=False)
+
+ predictions = math_ops.multiply(features, c)
+
+ loss = losses.absolute_difference(
+ labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+ loss = math_ops.reduce_sum(loss)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+
+ first_optimizer = replicate_model_fn._TowerOptimizer(
+ gradient_descent.GradientDescentOptimizer(1.0))
+ second_optimizer = replicate_model_fn._TowerOptimizer(
+ adam.AdamOptimizer(1.0))
+
+ with ops_lib.control_dependencies([side_effects.assign_add(1.0)]):
+ first_grads_and_vars = first_optimizer.compute_gradients(loss)
+
+ train_op = control_flow_ops.group(
+ [first_optimizer.apply_gradients(first_grads_and_vars),
+ second_optimizer.minimize(loss)])
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=metrics,
+ predictions={'probabilities': predictions},
+ train_op=train_op)
+
+ def test_train(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn,
+ loss_reduction=losses.Reduction.SUM,
+ devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(features, labels,
+ model_fn_lib.ModeKeys.TRAIN, {})
+ session.run(variables.global_variables_initializer())
+
+ # loss = feature * c - label
+ total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
+ self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+ # loss' of c is 3.
+ # new value of c = 10 - learning rate * 3 = 7.0.
+ # Adam subtracts another ~1.
+ session.run(estimator_spec.train_op)
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertNear(6.0, session.run(c), 0.000001)
+
+ side_effects = variable_scope.get_variable(
+ 'side_effects', dtype=dtypes.float64)
+ self.assertNear(2.0, session.run(side_effects), 0.000001)
+
+
+class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._should_skip_optimizer = False
+ self._towers_left_before_skipping_optimizer = -1
+
+ def incorrectly_skip_optimizer_for_tower(self, tower_number):
+ self._should_skip_optimizer = True
+ self._towers_left_before_skipping_optimizer = tower_number
+
+ def should_skip_optimizer(self):
+ if not self._should_skip_optimizer:
+ return False
+ if self._towers_left_before_skipping_optimizer == 0:
+ return True
+ else:
+ self._towers_left_before_skipping_optimizer -= 1
+ return False
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+ d = variable_scope.get_variable(
+ 'd',
+ initializer=constant_op.constant(2, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.multiply(features, c)
+
+ loss = losses.absolute_difference(
+ labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+ loss = math_ops.reduce_sum(loss)
+
+ another_predictions = math_ops.multiply(features, d)
+ another_loss = losses.absolute_difference(
+ labels=labels,
+ predictions=another_predictions,
+ reduction=losses.Reduction.SUM)
+ another_loss = math_ops.reduce_sum(another_loss)
+
+ total_loss = math_ops.add(loss, another_loss)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+
+ train_ops = []
+
+ optimizer = replicate_model_fn._TowerOptimizer(
+ gradient_descent.GradientDescentOptimizer(1.0))
+ train_ops.append(optimizer.minimize(loss, var_list=[c]))
+ if not self.should_skip_optimizer():
+ another_optimizer = replicate_model_fn._TowerOptimizer(
+ gradient_descent.GradientDescentOptimizer(1.0))
+ train_ops.append(another_optimizer.minimize(another_loss, var_list=[d]))
+
+ train_op = control_flow_ops.group(train_ops)
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=total_loss,
+ eval_metric_ops=metrics,
+ predictions={'probabilities': predictions},
+ train_op=train_op)
+
+ def test_train(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn,
+ loss_reduction=losses.Reduction.SUM,
+ devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(features, labels,
+ model_fn_lib.ModeKeys.TRAIN, {})
+ session.run(variables.global_variables_initializer())
+
+ # For each tower, loss = (feature * c - label) + (feature * d - label).
+ total_loss = (1.0 * 10 - 1.0 + 1.0 * 2.0 - 1.0) + (
+ 2.0 * 10 - 2.0 + 2.0 * 2.0 - 2.0)
+ self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+ session.run(estimator_spec.train_op)
+
+ # loss' of c or loss' of d is 3.
+ # new value of c = 10 - learning rate * 3 = 7.0.
+ # new value of d = 2 - learning rate * 3 = -1.0.
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertNear(7.0, session.run(c), 0.000001)
+ d = variable_scope.get_variable('d', dtype=dtypes.float64)
+ self.assertNear(-1.0, session.run(d), 0.000001)
+
+ def test_different_optimizer_calls_within_towers(self):
+ self.incorrectly_skip_optimizer_for_tower(1)
+
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session(), ops_lib.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError, '.+was.+supposed.+to.+make.+same.+optimizer.+calls.+'):
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0', '/gpu:1'])
+ _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
+ {})
+
+
+class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.multiply(features, c)
+
+ loss = losses.absolute_difference(
+ labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+ loss = math_ops.reduce_sum(loss)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+
+ optimizer = gradient_descent.GradientDescentOptimizer(1.0)
+ train_op = optimizer.minimize(loss)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=metrics,
+ predictions={'probabilities': predictions},
+ train_op=train_op)
+
+ def test_train(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError,
+ 'Please.+wrap.+with.+_TowerOptimizer'):
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0', '/gpu:1'])
+ _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
+ {})
+
+
+class GetLossTowersTest(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(0.25, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
+ labels = np.array([0.1, 0.2, 0.3, labels[0]])
+
+ loss = losses.absolute_difference(
+ labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+
+ return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss))
+
+ def test_gradients_are_computed(self):
+ with self.test_session() as session:
+ tower_specs = replicate_model_fn._get_loss_towers(
+ self.model_fn,
+ mode=None,
+ features=[[0.6], [1.6]],
+ labels=[[0.6], [0.6]],
+ params=None,
+ config=None,
+ loss_reduction=losses.Reduction.SUM,
+ devices=['/gpu:0', '/gpu:1'],
+ local_ps_devices=['/gpu:0'],
+ name_scope_pattern='test_tower_{}')
+ session.run(variables.global_variables_initializer())
+
+ self.assertEqual(len(tower_specs), 2)
+
+ self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
+ self.assertEqual('Sum:0', tower_specs[0].loss.name)
+ self.assertEqual(1.0, session.run(tower_specs[0].loss))
+
+ self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
+ self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name)
+ # The input batch for the second tower had a loss that is 1.0
+ # bigger: 0.6 vs 1.6.
+ self.assertEqual(2.0, session.run(tower_specs[1].loss))
+
+ self.assertEqual(1, len(variables.global_variables()))
+ self.assertEqual(1, len(variables.trainable_variables()))
+
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(0.25, session.run(c))
+
+ def test_gradients_are_computed_with_mean_reduction(self):
+ with self.test_session() as session:
+ tower_specs = replicate_model_fn._get_loss_towers(
+ self.model_fn,
+ mode=model_fn_lib.ModeKeys.EVAL,
+ features=[[0.6], [1.6]],
+ labels=[[0.6], [0.6]],
+ params=None,
+ loss_reduction=losses.Reduction.MEAN,
+ config=None,
+ devices=['/gpu:0', '/gpu:1'],
+ local_ps_devices=['/gpu:0'],
+ name_scope_pattern='test_tower_{}')
+ session.run(variables.global_variables_initializer())
+
+ self.assertEqual(len(tower_specs), 2)
+
+ self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
+ self.assertEqual('averaged_loss:0', tower_specs[0].loss.name)
+ self.assertEqual(0.5, session.run(tower_specs[0].loss))
+
+ self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
+ self.assertEqual('test_tower_1/averaged_loss:0', tower_specs[1].loss.name)
+ # The input batch for the second tower had a loss that is 1.0
+ # bigger: 0.6 vs 1.6.
+ self.assertEqual(1.0, session.run(tower_specs[1].loss))
+
+ self.assertEqual(1, len(variables.global_variables()))
+ self.assertEqual(1, len(variables.trainable_variables()))
+
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(0.25, session.run(c))
+
+ def test_variables_are_round_robined_correctly(self):
+ """Test that creates multiple variables and tests round-robin placement."""
+
+ def model_fn(mode, features, labels, params):
+ del params
+ for variable_name in ['a', 'b', 'c', 'd']:
+ c = variable_scope.get_variable(
+ variable_name,
+ initializer=constant_op.constant(0.25, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
+ labels = np.array([0.1, 0.2, 0.3, labels[0]])
+ loss = losses.absolute_difference(
+ labels=labels,
+ predictions=predictions,
+ reduction=losses.Reduction.SUM)
+ return model_fn_lib.EstimatorSpec(
+ mode=mode, loss=math_ops.reduce_sum(loss))
+
+ with self.test_session() as session:
+ tower_specs = replicate_model_fn._get_loss_towers(
+ model_fn,
+ mode=None,
+ features=[[0.6], [1.6], [2.6]],
+ labels=[[0.6], [0.6], [2.6]],
+ params=None,
+ loss_reduction=losses.Reduction.SUM,
+ config=None,
+ devices=['/gpu:0', '/gpu:1', '/gpu:3'],
+ local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'],
+ name_scope_pattern='test_tower_{}')
+ session.run(variables.global_variables_initializer())
+
+ self.assertEqual(len(tower_specs), 3)
+ self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
+ self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
+ self.assertEqual('/device:GPU:3', tower_specs[2].loss.device)
+
+ with variable_scope.variable_scope('', reuse=True):
+ a = variable_scope.get_variable('a', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:0', a.device)
+ b = variable_scope.get_variable('b', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:1', b.device)
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:3', c.device)
+ d = variable_scope.get_variable('d', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:0', d.device)
+
+
+class SplitBatchTest(test_util.TensorFlowTestCase):
+
+ def evaluate_shards(self, first_list, second_list):
+ evaluate_items = lambda x: x.eval()
+ return list(map(evaluate_items, first_list)), list(
+ map(evaluate_items, second_list))
+
+ def assertSparseValuesEqual(self, a, b):
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+ def test_simple_half_split(self):
+ with self.test_session():
+ features = [0.0, 1.0, 2.0, 3.0]
+ labels = [10.0, 11.0, 12.0, 13.0]
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ feature_shards, label_shards = self.evaluate_shards(
+ feature_shards, label_shards)
+
+ self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards)
+ self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
+
+ def test_to_each_their_own(self):
+ with self.test_session():
+ features = [0.0, 1.0, 2.0, 3.0]
+ labels = [10.0, 11.0, 12.0, 13.0]
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 4, device='/gpu:0')
+
+ feature_shards, label_shards = self.evaluate_shards(
+ feature_shards, label_shards)
+
+ self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards)
+ self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
+
+ def test_one_batch(self):
+ with self.test_session():
+ features = [0.0, 1.0, 2.0, 3.0]
+ labels = [10.0, 11.0, 12.0, 13.0]
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 1, device='/gpu:0')
+
+ feature_shards, label_shards = self.evaluate_shards(
+ feature_shards, label_shards)
+
+ self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards)
+ self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
+
+ def test_half_split_in_dictionary(self):
+ with self.test_session():
+ features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
+ labels = [10.0, 11.0, 12.0, 13.0]
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
+ self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
+ self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
+ self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
+ self.assertAllEqual([10.0, 11.0], label_shards[0].eval())
+ self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
+
+ def test_sparse_tensor_can_be_split_unevenly(self):
+ with self.test_session():
+ features = {
+ 'x':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 2], [2, 2]],
+ values=[1.0, 2.0, 3.0],
+ dense_shape=[3, 4])
+ }
+ labels = np.array([[1.0], [2.0]])
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ self.assertSparseValuesEqual(
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]),
+ feature_shards[0]['x'].eval())
+ self.assertSparseValuesEqual(
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 2]], values=[3.], dense_shape=[1, 4]),
+ feature_shards[1]['x'].eval())
+ self.assertAllEqual([[1.0]], label_shards[0].eval())
+ self.assertAllEqual([[2.0]], label_shards[1].eval())
+
+ def test_sparse_tensor_can_be_split_unevenly_repeated_row(self):
+ with self.test_session():
+ features = {
+ 'x':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[1.0, 2.0, 3.0],
+ dense_shape=[3, 4])
+ }
+ labels = np.array([[1.0], [2.0]])
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ self.assertSparseValuesEqual(
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[1., 2., 3.],
+ dense_shape=[2, 4]), feature_shards[0]['x'].eval())
+
+ second_batch = feature_shards[1]['x'].eval()
+ self.assertFalse(len(second_batch.indices))
+ self.assertFalse(len(second_batch.values))
+ self.assertAllEqual([1, 4], second_batch.dense_shape)
+ self.assertAllEqual([[1.0]], label_shards[0].eval())
+ self.assertAllEqual([[2.0]], label_shards[1].eval())
+
+ def test_one_batch_in_dictionary(self):
+ with self.test_session() as session: # pylint: disable=unused-variable
+ features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
+ labels = [10.0, 11.0, 12.0, 13.0]
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 1, device='/gpu:0')
+
+ self.assertAllEqual([0.0, 1.0, 2.0, 3.0],
+ feature_shards[0]['first'].eval())
+ self.assertAllEqual([4.0, 5.0, 6.0, 7.0],
+ feature_shards[0]['second'].eval())
+ self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval())
+
+ def test_feature_and_label_dictionaries(self):
+ with self.test_session() as session: # pylint: disable=unused-variable
+ features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
+ labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]}
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
+ self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
+ self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
+ self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
+ self.assertAllEqual([10.0], label_shards[0]['first'].eval())
+ self.assertAllEqual([12.0], label_shards[0]['second'].eval())
+ self.assertAllEqual([11], label_shards[1]['first'].eval())
+ self.assertAllEqual([13.0], label_shards[1]['second'].eval())
+
+
+class TrainSpecTest(test_util.TensorFlowTestCase):
+
+ expected_predictions = {}
+
+ def create_estimator_spec(self, loss):
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ loss=loss,
+ train_op=loss, # Not used; currently required.
+ predictions=self.expected_predictions)
+
+ def create_constant_loss(self, loss_value):
+ return constant_op.constant(loss_value, dtype=dtypes.float64)
+
+ def test_example(self):
+ with self.test_session() as session:
+ tower_losses = list(map(self.create_constant_loss, [2, 4, 6]))
+ tower_specs = list(map(self.create_estimator_spec, tower_losses))
+
+ expected_train_op = tower_losses[1]
+
+ estimator_spec = replicate_model_fn._train_spec(
+ tower_specs, expected_train_op, aggregation_device='/gpu:0')
+
+ self.assertEqual(expected_train_op, estimator_spec.train_op)
+ self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
+ self.assertEqual(self.expected_predictions, estimator_spec.predictions)
+
+
+class EvalSpecTest(test_util.TensorFlowTestCase):
+
+ def create_estimator_spec(self, loss, metrics):
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics)
+
+ def create_constant_loss(self, loss_value):
+ return constant_op.constant(loss_value, dtype=dtypes.float64)
+
+ def create_eval_metrics(self, noise):
+ predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise])
+ labels = np.array([0.1, 0.2, 0.3, 0.6])
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+ return metrics
+
+ def test_example(self):
+ with self.test_session() as session:
+ tower_losses = map(self.create_constant_loss, [2, 4, 6])
+ tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
+ tower_specs = [
+ self.create_estimator_spec(l, m)
+ for l, m in zip(tower_losses, tower_metrics)
+ ]
+ session.run(variables.local_variables_initializer())
+
+ estimator_spec = replicate_model_fn._eval_spec(
+ tower_specs, aggregation_device='/device:GPU:0')
+
+ accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+ auc, b = estimator_spec.eval_metric_ops['auc']
+
+ self.assertEqual('/device:CPU:0', accuracy.device)
+ self.assertEqual('/device:CPU:0', auc.device)
+
+ session.run([a, b])
+ accuracy, auc = session.run([accuracy, auc])
+
+ self.assertNear((12 - 2) / 12, accuracy, 0.01)
+ self.assertEqual(0, auc)
+ self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
+
+ def test_handles_single_tower(self):
+ with self.test_session() as session:
+ tower_losses = map(self.create_constant_loss, [5])
+ tower_metrics = map(self.create_eval_metrics, [0.2])
+ tower_specs = [
+ self.create_estimator_spec(l, m)
+ for l, m in zip(tower_losses, tower_metrics)
+ ]
+ session.run(variables.local_variables_initializer())
+
+ estimator_spec = replicate_model_fn._eval_spec(
+ tower_specs, aggregation_device='/device:GPU:0')
+
+ accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+ auc, b = estimator_spec.eval_metric_ops['auc']
+
+ self.assertEqual('/device:CPU:0', accuracy.device)
+ self.assertEqual('/device:CPU:0', auc.device)
+
+ session.run([a, b])
+ accuracy = session.run(accuracy)
+ auc = session.run(auc)
+
+ self.assertNear((4 - 1) / 4, accuracy, 0.01)
+ self.assertEqual(0, auc)
+ self.assertEqual(5, session.run(estimator_spec.loss))
+
+
+class PredictSpecTest(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(0.25, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.add(np.array([features[0], features[0]]), c)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.PREDICT,
+ predictions={
+ 'probabilities': predictions
+ })
+
+ def test_example(self):
+ with self.test_session() as session:
+ tower_specs = replicate_model_fn._get_loss_towers(
+ self.model_fn,
+ mode=None,
+ features=[[0.1], [0.2]],
+ loss_reduction=losses.Reduction.SUM,
+ labels=[[], []],
+ params=None,
+ config=None,
+ devices=['/gpu:0', '/gpu:1'],
+ local_ps_devices=['/gpu:0'],
+ )
+ session.run(variables.global_variables_initializer())
+
+ estimator_spec = replicate_model_fn._predict_spec(
+ tower_specs, aggregation_device='/gpu:0')
+
+ self.assertEqual('/device:GPU:0',
+ estimator_spec.predictions['probabilities'].device)
+ self.assertAllClose({
+ 'probabilities': np.array([0.35, 0.35, 0.45, 0.45])
+ }, session.run(estimator_spec.predictions))
+
+
+class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
+
+ def create_metric_variable(self, initial_value, name):
+ return variable_scope.variable(
+ initial_value,
+ trainable=False,
+ collections=[ops_lib.GraphKeys.METRIC_VARIABLES],
+ validate_shape=True,
+ name=name)
+
+ def create_tower_metrics(self, tower_id):
+ with variable_scope.variable_scope('', reuse=(tower_id != 0)):
+ self.create_metric_variable(1.3 * (tower_id + 1), 'total')
+ self.create_metric_variable(2.3 * (tower_id + 1), 'count')
+ self.create_metric_variable(
+ np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total')
+
+ def test_example(self):
+ with self.test_session() as session:
+ for tower_id in range(3):
+ self.create_tower_metrics(tower_id)
+
+ session.run(
+ variables.variables_initializer(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
+
+ session.run(
+ replicate_model_fn._reduce_metric_variables(number_of_towers=3))
+
+ # 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7]
+ # 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4]
+ # 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1]
+ # Reduced = 7.8, 13.8, [19.8, 21.0, 22.2]
+ # Towers are accumulated in the first tower.
+ local_metrics = session.run(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
+
+ self.assertNear(7.8, local_metrics[0], 0.01)
+ self.assertNear(13.8, local_metrics[1], 0.01)
+ self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
+ self.assertNear(0.0, local_metrics[3], 0.01)
+ self.assertNear(0.0, local_metrics[4], 0.01)
+ self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
+ self.assertNear(0.0, local_metrics[6], 0.01)
+ self.assertNear(0.0, local_metrics[7], 0.01)
+ self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
+
+ def test_reduce_is_idempotent(self):
+ with self.test_session() as session:
+ for tower_id in range(3):
+ self.create_tower_metrics(tower_id)
+
+ session.run(
+ variables.variables_initializer(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
+
+ for _ in range(20):
+ session.run(
+ replicate_model_fn._reduce_metric_variables(number_of_towers=3))
+
+ local_metrics = session.run(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
+
+ self.assertNear(7.8, local_metrics[0], 0.01)
+ self.assertNear(13.8, local_metrics[1], 0.01)
+ self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
+ self.assertNear(0.0, local_metrics[3], 0.01)
+ self.assertNear(0.0, local_metrics[4], 0.01)
+ self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
+ self.assertNear(0.0, local_metrics[6], 0.01)
+ self.assertNear(0.0, local_metrics[7], 0.01)
+ self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
+
+ def test_handles_single_tower(self):
+ with self.test_session() as session:
+ self.create_tower_metrics(0)
+ session.run(
+ variables.variables_initializer(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
+
+ session.run(
+ replicate_model_fn._reduce_metric_variables(number_of_towers=1))
+
+ local_metrics = session.run(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
+
+ self.assertNear(1.3, local_metrics[0], 0.01)
+ self.assertNear(2.3, local_metrics[1], 0.01)
+ self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)
+
+ def test_doesnt_accept_uneven_number_of_variables(self):
+ with self.test_session() as session:
+ for tower_id in range(3):
+ self.create_tower_metrics(tower_id)
+ self.create_metric_variable(-1.0, 'oddball')
+
+ session.run(
+ variables.variables_initializer(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
+
+ with self.assertRaisesRegexp(
+ ValueError, '.+Expected.+local.+variables.+but.+got.+instead.+'):
+ session.run(
+ replicate_model_fn._reduce_metric_variables(number_of_towers=3))
+
+
+class MergeExportOutputsTest(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = {'probabilities': math_ops.multiply(features, c)}
+ loss = losses.absolute_difference(
+ labels=labels,
+ predictions=predictions['probabilities'],
+ reduction=losses.Reduction.SUM)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']),
+ 'auc': metrics_lib.auc(labels, predictions['probabilities'])
+ }
+ tensor_string_repr = str(features)
+ classes = constant_op.constant(
+ re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1),
+ dtype=dtypes.string)
+
+ export_outputs = {
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ export_output.PredictOutput(predictions),
+ 'classification_output':
+ export_output.ClassificationOutput(predictions['probabilities'],
+ classes),
+ 'classification_scores':
+ export_output.ClassificationOutput(
+ scores=predictions['probabilities']),
+ 'classification_classes':
+ export_output.ClassificationOutput(classes=classes),
+ 'regression_output':
+ export_output.RegressionOutput(predictions['probabilities']),
+ }
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=math_ops.reduce_sum(loss),
+ eval_metric_ops=metrics,
+ predictions=predictions,
+ export_outputs=export_outputs)
+
+ def replicate_estimator_spec(self, session):
+ features = np.array([0.01, 0.002])
+ labels = np.array([0.01, 0.02])
+
+ replicated_model_fn = replicate_model_fn._replicate_model_fn(
+ self.model_fn, devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(features, labels,
+ model_fn_lib.ModeKeys.PREDICT, {})
+ session.run(variables.global_variables_initializer())
+ return estimator_spec
+
+ def test_merge_predict_output(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllClose(
+ {
+ 'probabilities': np.array([0.1, 0.02])
+ },
+ session.run(estimator_spec.export_outputs[
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs))
+
+ def test_merge_classification_output_scores_classes(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllClose(
+ [0.1, 0.02],
+ session.run(
+ estimator_spec.export_outputs['classification_output'].scores))
+ self.assertAllEqual(
+ [b'split_inputs/split:0', b'split_inputs/split:1'],
+ session.run(
+ estimator_spec.export_outputs['classification_output'].classes))
+
+ def test_merge_classification_output_scores(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllClose(
+ [0.1, 0.02],
+ session.run(
+ estimator_spec.export_outputs['classification_scores'].scores))
+ self.assertEqual(
+ None, estimator_spec.export_outputs['classification_scores'].classes)
+
+ def test_merge_classification_output_classes(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllEqual(
+ [b'split_inputs/split:0', b'split_inputs/split:1'],
+ session.run(
+ estimator_spec.export_outputs['classification_classes'].classes))
+ self.assertEqual(
+ None, estimator_spec.export_outputs['classification_classes'].scores)
+
+ def test_merge_regression_output(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllClose(
+ [0.1, 0.02],
+ session.run(estimator_spec.export_outputs['regression_output'].value))
+
+
+class GetLocalDevicesTest(test_util.TensorFlowTestCase):
+
+ def test_there_is_at_least_a_cpu(self):
+ self.assertTrue(replicate_model_fn._get_local_devices('CPU'))
+
+ def test_there_is_no_xpu(self):
+ self.assertFalse(
+ replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist.
+
+ def test_whether_there_is_a_gpu(self):
+ if test.is_gpu_available():
+ self.assertTrue(len(replicate_model_fn._get_local_devices('GPU')))
+
+
+class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
+
+ def test_vars_are_on_ps_but_ops_are_on_workers(self):
+ ps_devices = ['/device:GPU:3']
+ round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
+
+ local_device_setter = replicate_model_fn._local_device_setter(
+ ps_devices=ps_devices,
+ ps_strategy=round_robin,
+ worker_device='/device:GPU:2')
+
+ with ops_lib.device(local_device_setter):
+ a = variables.Variable(0.01)
+ self.assertEqual('/device:GPU:3', a.device)
+
+ b = variables.Variable(0.02)
+ self.assertEqual('/device:GPU:3', b.device)
+
+ c = variables.Variable(0.03)
+ self.assertEqual('/device:GPU:3', c.device)
+
+ a_op = array_ops.concat(a, axis=0)
+ self.assertEqual('/device:GPU:2', a_op.device)
+
+ b_op = array_ops.concat(b, axis=0)
+ self.assertEqual('/device:GPU:2', b_op.device)
+
+ def test_round_robin_placement(self):
+ ps_devices = [
+ '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4'
+ ]
+ round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
+
+ local_device_setter = replicate_model_fn._local_device_setter(
+ ps_devices=ps_devices,
+ ps_strategy=round_robin,
+ worker_device='/device:GPU:2')
+
+ with ops_lib.device(local_device_setter):
+ a = variables.Variable(0.01)
+ self.assertEqual('/device:GPU:0', a.device)
+
+ b = variables.Variable(0.02)
+ self.assertEqual('/device:GPU:1', b.device)
+
+ c = variables.Variable(0.03)
+ self.assertEqual('/device:GPU:3', c.device)
+
+ a_op = array_ops.concat(a, axis=0)
+ self.assertEqual('/device:GPU:2', a_op.device)
+
+ b_op = array_ops.concat(b, axis=0)
+ self.assertEqual('/device:GPU:2', b_op.device)
+
+ c = variables.Variable(0.03)
+ self.assertEqual('/device:GPU:4', c.device)
+
+ d = variables.Variable(0.03)
+ self.assertEqual('/device:GPU:0', d.device)
+
+ c_op = array_ops.concat(c, axis=0)
+ self.assertEqual('/device:GPU:2', c_op.device)
+
+
+class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
+
+ def test_vectors(self):
+ with self.test_session() as session:
+ total = replicate_model_fn._compute_sum_on_device(
+ [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertEqual('test_sum', total.op.name)
+ self.assertEqual(10.0, session.run(total))
+
+ def test_tensors(self):
+ with self.test_session() as session:
+ total = replicate_model_fn._compute_sum_on_device(
+ [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertEqual('test_sum', total.op.name)
+ self.assertAllEqual([4.0, 6.0], session.run(total))
+
+ def test_indexedslices(self):
+ with self.test_session() as session:
+ a = ops_lib.IndexedSlices(
+ constant_op.constant([1.0, 2.0]), [0, 1],
+ dense_shape=constant_op.constant([2]))
+ b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
+
+ total = replicate_model_fn._compute_sum_on_device(
+ [a, b], device='/device:GPU:0')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertAllEqual([4.0, 6.0],
+ session.run(ops_lib.convert_to_tensor(total)))
+
+ def test_indexedslices_higher_dimensions(self):
+ with self.test_session() as session:
+ a = ops_lib.IndexedSlices(
+ constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
+ dense_shape=constant_op.constant([2, 4]))
+ b = ops_lib.IndexedSlices(
+ constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1])
+
+ total = replicate_model_fn._compute_sum_on_device(
+ [a, b], device='/device:GPU:0')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]],
+ session.run(ops_lib.convert_to_tensor(total)))
+
+ def test_indexedslices_some_dont_overlap(self):
+ with self.test_session() as session:
+ a = ops_lib.IndexedSlices(
+ constant_op.constant([1.0, 2.0]), [0, 3],
+ dense_shape=constant_op.constant([4]))
+ b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
+
+ total = replicate_model_fn._compute_sum_on_device(
+ [a, b], device='/device:GPU:0')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertAllEqual([4.0, 4.0, 0.0, 2.0],
+ session.run(ops_lib.convert_to_tensor(total)))
+
+ def test_no_name_for_indexslices(self):
+ a = ops_lib.IndexedSlices(
+ constant_op.constant([1.0, 2.0]), [0, 1],
+ dense_shape=constant_op.constant([2]))
+ b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
+
+ with self.assertRaisesRegexp(ValueError, '.+name.+not.+expected.+'):
+ _ = replicate_model_fn._compute_sum_on_device(
+ [a, b], device='/device:GPU:0', name='cant_name_indexslices')
+
+
+class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
+
+ def test_example(self):
+ tensor_dicts = [
+ {
+ 'a': np.array([1.0, 2.0]),
+ 'b': np.array([11.0]),
+ 'c': np.array([21.0]),
+ },
+ {
+ 'a': np.array([3.0]),
+ 'b': np.array([12.0, 13.0]),
+ },
+ {
+ 'b': np.array([14.0]),
+ },
+ ]
+
+ with self.test_session() as session:
+ self.assertAllClose({
+ 'a': np.array([1.0, 2.0, 3.0]),
+ 'b': np.array([11.0, 12.0, 13.0, 14.0]),
+ 'c': np.array([21.0]),
+ }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts)))
+
+
+if __name__ == '__main__':
+ test.main()