aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-11-06 07:43:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-06 07:48:39 -0800
commit584de66eb968592b00d9f46d6f56a9de23ae1295 (patch)
treefac511d1589c3c2e680c31233a2b3d0af29bf70c
parentc217140933e36f4f2ac18a2add51e27b1d146b24 (diff)
Replicate `Estimator.model_fn` across available GPUs.
def replicate_model_fn(model_fn, optimizer_fn, devices=None): """Replicate `Estimator.model_fn` over GPUs. ... I tested that it seems to give the right result on cnn_mnist.py on 1 CPU, 1 real GPU, 4 allow_soft_placement=True GPUs. Some measurements on CNN MNIST across steps 19300-20000: 1) no replicate_model_fn call: global_step/sec: 156.254 global_step/sec: 155.074 global_step/sec: 155.74 global_step/sec: 153.636 global_step/sec: 157.218 global_step/sec: 159.644 2) replicate across one hardware GPU: global_step/sec: 158.171 global_step/sec: 165.618 global_step/sec: 162.773 global_step/sec: 159.204 global_step/sec: 162.289 global_step/sec: 167.173 3) replicate across 4 software GPUs on one hardware GPU (soft placement): global_step/sec: 75.47 global_step/sec: 76.16 global_step/sec: 75.18 Loss numbers didn't change across the three configurations. PiperOrigin-RevId: 174704385
-rw-r--r--tensorflow/contrib/estimator/BUILD64
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py470
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py901
3 files changed, 1434 insertions, 1 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index a0f83ac105..6eb2cfdaca 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -7,6 +7,7 @@ package(
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
filegroup(
name = "all_files",
@@ -30,6 +31,7 @@ py_library(
":head",
":logit_fns",
":multi_head",
+ ":replicate_model_fn",
"//tensorflow/python:util",
],
)
@@ -227,9 +229,69 @@ py_test(
"//tensorflow/python:string_ops",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
- "//tensorflow/python/estimator:prediction_keys",
+ "//tensorflow/python/ops/losses",
"//tensorflow/python/saved_model:signature_constants",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
+
+py_library(
+ name = "replicate_model_fn",
+ srcs = [
+ "python/estimator/replicate_model_fn.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:device",
+ "//tensorflow/python:device_lib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/estimator:export_output",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:util",
+ "@six_archive//:six",
+ ],
+)
+
+cuda_py_test(
+ name = "replicate_model_fn_test",
+ size = "small",
+ srcs = ["python/estimator/replicate_model_fn_test.py"],
+ additional_deps = [
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:dnn",
+ "//tensorflow/python/estimator:export_export",
+ "//tensorflow/python/estimator:export_output",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/estimator:prediction_keys",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ":replicate_model_fn",
+ ],
+ tags = ["requires-gpu-sm35"],
+)
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
new file mode 100644
index 0000000000..7005a647db
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -0,0 +1,470 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to replicate model_fn's over local GPUs.
+
+This file contains util that allow to replicate `Estimator.model_fn` over
+GPUs. Replicated version of a `model_fn` is returned that can subsequently
+be used with `Estimator`.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import six
+
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.client import device_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator import util
+from tensorflow.python.estimator.export import export_output as export_output_lib
+from tensorflow.python.framework import device as framework_device
+from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients as gradients_lib
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.training import training_util
+
+
+def replicate_model_fn(model_fn, optimizer_fn, devices=None):
+ """Replicate `Estimator.model_fn` over GPUs within a single host.
+
+ The given `model_fn` specifies a single forward pass of a model. To replicate
+ such a model over GPUs, each GPU gets its own instance of the forward pass
+ (a.k.a. a tower). The input features and labels get sharded into the chunks
+ that correspond to the number of GPUs. Each tower computes its own loss based
+ on its input. For each such loss, gradients are computed. After that, the
+ available losses are summed to form aggregated loss. The available
+ gradients are summed too. Then, they update weights using the specified
+ optimizer.
+
+ If `devices` are `None`, then all available GPUs are going to be used for
+ replication. If no GPUs are available, then the model is going to be
+ placed on the CPU.
+
+ Two modes of local replication over available GPUs are supported:
+ 1) If exactly 1 GPU is detected, then variables and operations are placed
+ onto GPU.
+ 2) If more than 1 GPU is detected, then variables are going to be placed on
+ the CPU. Replicas of operations are placed on each individual GPU.
+
+ Here is an example of how one might use their `model_fn` to run over GPUs:
+ ```python
+ def optimizer_fn():
+ return tf.train.GradientDescentOptimizer(learning_rate=0.001)
+ ...
+ def model_fn(...): # See `model_fn` in `Estimator`.
+ loss = ...
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ # See the section below on `EstimatorSpec.train_op`.
+ return EstimatorSpec(mode=mode, loss=loss, train_op=tf.noop())
+
+ # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`.
+ return EstimatorSpec(...)
+ ...
+ classifier = tf.estimator.Estimator(
+ model_fn=replicate_model_fn.replicate_model_fn(model_fn, optimizer_fn))
+ ```
+
+ On `EstimatorSpec.train_op`:
+ `model_fn` returns `EstimatorSpec.train_op` for
+ `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer.
+ `replicate_model_fn` ignores the returned `EstimatorSpec.train_op`, so there
+ is no need to use an optimizer inside the user's `model_fn`. The
+ `EstimatorSpec.loss` subgraph is going to be executed, while
+ `EstimatorSpec.train_op` isn't going to be executed. One could pass
+ `train_op=tf.noop()` to `EstimatorSpec`.
+
+ On sharding input features and labels:
+ Input features and labels are split for consumption by each tower. They are
+ split across the dimension 0. Features and labels need to be batch major.
+
+ On reduction algorithms:
+ Certain algorithms were chosen for aggregating results of computations on
+ multiple towers:
+ - Losses from all towers are reduced using sum.
+ - Gradients are reduced using sum for each trainable variable.
+ - `eval_metrics_ops` are reduced per metric using `reduce_mean`.
+ - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are
+ reduced using concatenation.
+ - For all other fields of `EstimatorSpec` the values of the first tower
+ are taken.
+
+ On replication of variables:
+ Variables are not duplicated between towers. Instead, they are placed on a
+ single device as defined above and shared across towers.
+
+ Other current limitations:
+ - `predictions` are not supported for `ModeKeys.EVAL`. That is required for
+ `tf.contrib.estimator.add_metrics`.
+
+ Args:
+ model_fn: `model_fn` as defined in `Estimator`. See the section above about
+ the train_op argument of `EstimatorSpec`.
+ optimizer_fn: a function that returns an optimizer instance. The function
+ may accept one `params` argument. This is the `params` argument as
+ defined by `Estimator`. See the `Estimator` documentation for details.
+ devices: Optional list of devices to replicate the model across. This
+ argument can be used to replice only on the subset of available GPUs.
+ If `None`, then all available GPUs are going to be used for replication.
+ If no GPUs are available, then the model is going to be placed on the CPU.
+
+ Returns:
+ A replicated version of the supplied `model_fn`. Returned function that
+ conforms to the requirements of `Estimator`'s `model_fn` and can be used
+ instead of the supplied `model_fn`.
+ """
+ if not devices:
+ devices = _get_local_devices('GPU') or _get_local_devices('CPU')
+
+ is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0]
+ local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU')
+
+ tf_logging.info('Replicating the `model_fn` across {}. Local parameter '
+ 'server device is going to be {}.'.format(
+ devices, local_ps_device))
+
+ def replicated_model_fn(mode, features, labels, params=None, config=None):
+ """Replicated version of `model_fn` to be used instead."""
+ feature_shards, label_shards = _split_batch(
+ features, labels, len(devices), device=local_ps_device)
+ tower_specs = _get_loss_towers(
+ model_fn=model_fn,
+ mode=mode,
+ features=feature_shards,
+ labels=label_shards,
+ params=params,
+ config=config,
+ devices=devices,
+ local_ps_device=local_ps_device)
+
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ train_op = _minimize_towers(tower_specs,
+ _call_optimizer_fn(optimizer_fn, params))
+ return _train_spec(
+ tower_specs, train_op, aggregation_device=local_ps_device)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ return _eval_spec(tower_specs, aggregation_device=local_ps_device)
+ elif mode == model_fn_lib.ModeKeys.PREDICT:
+ return _predict_spec(tower_specs, aggregation_device=local_ps_device)
+
+ return replicated_model_fn
+
+
+def _get_local_devices(device_type):
+ local_device_protos = device_lib.list_local_devices()
+ return [
+ device.name
+ for device in local_device_protos
+ if device.device_type == device_type
+ ]
+
+
+def _split_batch(features, labels, number_of_shards, device):
+ """Split input features and labes into batches."""
+
+ def split_dictionary(dictionary):
+ shards = [{} for _ in range(number_of_shards)]
+ for name, tensor in six.iteritems(dictionary):
+ for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
+ shards[i][name] = shard
+ return shards
+
+ with ops_lib.name_scope('split_inputs'):
+ with ops_lib.device(device):
+ if isinstance(features, dict):
+ feature_shards = split_dictionary(features)
+ else:
+ feature_shards = array_ops.split(features, number_of_shards)
+
+ if labels is None:
+ label_shards = None
+ elif isinstance(labels, dict):
+ label_shards = split_dictionary(labels)
+ else:
+ label_shards = array_ops.split(labels, number_of_shards)
+ return feature_shards, label_shards
+
+
+_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}'
+
+
+def _get_loss_towers(model_fn,
+ mode,
+ features,
+ labels,
+ params,
+ config,
+ devices,
+ local_ps_device,
+ name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
+ """Replicate the loss computation across devices."""
+ tower_specs = []
+
+ model_fn_args = util.fn_args(model_fn)
+ optional_params = {}
+ if 'params' in model_fn_args:
+ optional_params['params'] = copy.deepcopy(params)
+ if 'config' in model_fn_args:
+ optional_params['config'] = copy.deepcopy(config)
+
+ for i, device in enumerate(devices):
+ is_the_first_tower = (i == 0)
+
+ device_setter = _local_device_setter(
+ worker_device=device, ps_device=local_ps_device)
+
+ # We would like to preserve the names of the variables and ops that a user
+ # might be relying on. Names with prefix are going to resolve to variables
+ # and ops of the first tower.
+ name_scope = name_scope_pattern
+ if is_the_first_tower:
+ name_scope = ''
+
+ with variable_scope.variable_scope('', reuse=not is_the_first_tower):
+ with ops_lib.name_scope(name_scope.format(i)):
+ with ops_lib.device(device_setter):
+ labels_shard = None
+ if labels:
+ labels_shard = labels[i]
+
+ tower_specs.append(
+ model_fn(
+ mode=mode,
+ features=features[i],
+ labels=labels_shard,
+ **optional_params))
+ return tower_specs
+
+
+def _local_device_setter(ps_device, worker_device):
+ """A device setter that puts distributes Var/Ops to PS/workers."""
+ ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
+
+ def local_device_chooser(op):
+ current_device = framework_device.DeviceSpec.from_string(op.device or '')
+
+ node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
+ if node_def.op in ps_ops:
+ ps_device_spec = framework_device.DeviceSpec.from_string(
+ '{}'.format(ps_device))
+
+ ps_device_spec.merge_from(current_device)
+ return ps_device_spec.to_string()
+ else:
+ worker_device_spec = framework_device.DeviceSpec.from_string(
+ worker_device or '')
+ worker_device_spec.merge_from(current_device)
+ return worker_device_spec.to_string()
+
+ return local_device_chooser
+
+
+def _minimize_towers(tower_specs, optimizer):
+ """Aggregate and apply gradients for computed losses."""
+ grad_lists = {}
+ for tower_spec in tower_specs:
+ with ops_lib.device(tower_spec.loss.device):
+ variables = variables_lib.trainable_variables()
+ gradients = gradients_lib.gradients(tower_spec.loss, variables)
+
+ for var, grad in zip(variables, gradients):
+ if grad is not None:
+ grad_lists.setdefault(var, []).append(grad)
+
+ aggregated_grads = []
+ with ops_lib.name_scope('gradient_aggregating'):
+ for var, grads in six.iteritems(grad_lists):
+ grad = _compute_sum_on_device(grads, var.device)
+ aggregated_grads.append((grad, var))
+
+ train_op = optimizer.apply_gradients(
+ aggregated_grads, global_step=training_util.get_global_step())
+
+ return train_op
+
+
+def _call_optimizer_fn(optimizer_fn, params):
+ arguments = {}
+ optimizer_fn_arguments = util.fn_args(optimizer_fn)
+ if 'params' in optimizer_fn_arguments:
+ arguments['params'] = params
+ return optimizer_fn(**arguments)
+
+
+def _compute_sum_on_device(values, device, name=None):
+ with ops_lib.device(device):
+ return math_ops.add_n(values, name=name)
+
+
+def _train_spec(tower_specs,
+ train_op,
+ aggregation_device,
+ aggregated_loss_name='loss'):
+ """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`."""
+ estimator_spec = tower_specs[0]._asdict()
+ estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN
+ estimator_spec['train_op'] = train_op
+ estimator_spec['loss'] = _compute_sum_on_device(
+ [spec.loss for spec in tower_specs], aggregation_device,
+ aggregated_loss_name)
+ return model_fn_lib.EstimatorSpec(**estimator_spec)
+
+
+def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'):
+ """Populate replicated EstimatorSpec for `GraphKeys.EVAL`."""
+ estimator_spec = tower_specs[0]._asdict()
+ estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL
+ estimator_spec['loss'] = _compute_sum_on_device(
+ [spec.loss for spec in tower_specs], aggregation_device,
+ aggregated_loss_name)
+
+ eval_metric_ops_lists = {}
+ for tower_spec in tower_specs:
+ metrics = tower_spec.eval_metric_ops or {}
+ for name, (_, update_op) in six.iteritems(metrics):
+ update_ops = eval_metric_ops_lists.setdefault(name, ([]))
+ update_ops.append(update_op)
+
+ eval_metric_ops = {}
+ for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops):
+ with ops_lib.control_dependencies(eval_metric_ops_lists[name]):
+ # This operation reduces local variables across all metrics, yet is
+ # called for every metric. This is redundant and it's done because
+ # it is hard to know what local variables correspond to what metric.
+ # Estimator is going to execute all `reduced_update_op`s as part of
+ # a group inside a single `Session.run()` call, which will avoid duplicate
+ # computation.
+ reduced_update_op = _reduce_metric_variables(len(tower_specs))
+ eval_metric_ops[name] = (metric_tensor, reduced_update_op)
+
+ estimator_spec['eval_metric_ops'] = eval_metric_ops
+ return model_fn_lib.EstimatorSpec(**estimator_spec)
+
+
+def _reduce_metric_variables(number_of_towers):
+ """Aggregate local variables used in metrics into the first tower."""
+ if number_of_towers == 1:
+ return control_flow_ops.no_op()
+
+ metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)
+ variables_per_tower = len(metric_variables) // number_of_towers
+
+ if len(metric_variables) % number_of_towers != 0:
+ raise ValueError(
+ 'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.'
+ ' Expected {} local variables, but got {} instead.'.format(
+ variables_per_tower * number_of_towers, len(metric_variables)))
+
+ # `metric_variables` has the size of `variables_per_tower` x
+ # number_of_towers. Each tower is produced by calling the same model_fn.
+ # First `variables_per_tower` correspond to the first tower. Each such
+ # variable has an replica at the `(variables_per_tower * i)` position, where
+ # `i` is `[1.. number_of_towers]`. We are going to add values from replicas
+ # to each variable of the first tower. We then zero out replica values, so
+ # that `_reduce_metric_variables` operation is idempotent. If a metric
+ # is then computed based on local variables from the first tower, then the
+ # resulting metric is an estimate for all `number_of_towers` towers.
+ ops = []
+ for i in range(0, variables_per_tower):
+ next_replica_id = i + variables_per_tower
+ replicas = [
+ metric_variables[replica_id]
+ for replica_id in range(next_replica_id, len(metric_variables),
+ variables_per_tower)
+ ] # `replicas` doesn't contain the first-tower variable.
+
+ reduce_op = state_ops.assign_add(metric_variables[i],
+ math_ops.add_n(replicas))
+
+ with ops_lib.control_dependencies([reduce_op]):
+ for replica in replicas:
+ zeros_for_replica = array_ops.zeros(
+ array_ops.shape(replica), dtype=replica.dtype)
+ zero_out_replica_op = state_ops.assign(replica, zeros_for_replica)
+ ops.append(zero_out_replica_op)
+
+ return control_flow_ops.group(*ops)
+
+
+def _predict_spec(tower_specs, aggregation_device):
+ """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`."""
+ estimator_spec = tower_specs[0]._asdict()
+ estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT
+
+ with ops_lib.device(aggregation_device):
+ estimator_spec['predictions'] = _concat_tensor_dicts(
+ *[tower_spec.predictions for tower_spec in tower_specs])
+
+ export_outputs_dict = _dict_concat(
+ *[tower_spec.export_outputs for tower_spec in tower_specs])
+
+ export_outputs = {}
+ for name, export_output_list in six.iteritems(export_outputs_dict):
+ if isinstance(export_output_list[0], export_output_lib.PredictOutput):
+ export_outputs[name] = export_output_lib.PredictOutput(
+ outputs=_concat_tensor_dicts(*[
+ export_output.outputs for export_output in export_output_list
+ ]))
+ elif isinstance(export_output_list[0],
+ export_output_lib.RegressionOutput):
+ export_outputs[name] = export_output_lib.RegressionOutput(
+ value=array_ops.concat(
+ [export_output.value for export_output in export_output_list],
+ axis=0))
+ elif isinstance(export_output_list[0],
+ export_output_lib.ClassificationOutput):
+ scores = None
+ if export_output_list[0].scores is not None:
+ scores = array_ops.concat(
+ [export_output.scores for export_output in export_output_list],
+ axis=0)
+
+ classes = None
+ if export_output_list[0].classes is not None:
+ classes = array_ops.stack(
+ [export_output.classes for export_output in export_output_list],
+ axis=0)
+
+ export_outputs[name] = export_output_lib.ClassificationOutput(
+ scores=scores, classes=classes)
+
+ estimator_spec['export_outputs'] = export_outputs
+ return model_fn_lib.EstimatorSpec(**estimator_spec)
+
+
+def _concat_tensor_dicts(*tensor_dicts):
+ return {
+ name: array_ops.concat(tensors, axis=0, name=name)
+ for name, tensors in six.iteritems(_dict_concat(*tensor_dicts))
+ }
+
+
+def _dict_concat(*dicts):
+ list_dict = {}
+ for d in dicts:
+ if d is None:
+ continue
+
+ for k, v in six.iteritems(d):
+ list_dict.setdefault(k, []).append(v)
+ return list_dict
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
new file mode 100644
index 0000000000..10b47fba5a
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -0,0 +1,901 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utilities that replicate `Estimator.model_fn` over GPUs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+import shutil
+import tempfile
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import replicate_model_fn
+from tensorflow.python.estimator import estimator as estimator_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.estimator.canned import optimizers
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.export import export_output
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import gradient_descent
+
+
+class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def test_complete_flow(self):
+ n_classes = 3
+ input_dimension = 2
+ batch_size = 12
+
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ x_data = data.reshape(batch_size, input_dimension)
+ y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data},
+ y=y_data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, batch_size=batch_size, shuffle=False)
+
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+
+ estimator = dnn.DNNClassifier(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ def optimizer_fn():
+ return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
+
+ # TODO(isaprykin): Switch Estimator to use allow_soft_placement=True
+ # during export_savedmodel and then switch this test to replicate over
+ # GPUs instead of CPUs.
+ estimator = estimator_lib.Estimator(
+ model_fn=replicate_model_fn.replicate_model_fn(
+ estimator.model_fn,
+ optimizer_fn,
+ devices=['/cpu:0', '/cpu:0', '/cpu:0']),
+ model_dir=estimator.model_dir,
+ config=estimator.config,
+ params=estimator.params)
+
+ num_steps = 10
+ estimator.train(train_input_fn, steps=num_steps)
+
+ scores = estimator.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ predicted_proba = np.array([
+ x[prediction_keys.PredictionKeys.PROBABILITIES]
+ for x in estimator.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
+
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = estimator.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def _as_label(self, data_in_float):
+ return np.rint(data_in_float).astype(np.int64)
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+
+class ReplicateModelTest(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.multiply(features, c)
+
+ loss = None
+ if mode is not model_fn_lib.ModeKeys.PREDICT:
+ loss = losses.absolute_difference(
+ labels=labels,
+ predictions=predictions,
+ reduction=losses.Reduction.SUM)
+ loss = math_ops.reduce_sum(loss)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=metrics,
+ predictions={'probabilities': predictions},
+ train_op=control_flow_ops.no_op()) # This train_op isn't actually used.
+
+ def optimizer_fn(self, params):
+ return gradient_descent.GradientDescentOptimizer(params['learning_rate'])
+
+ @property
+ def params(self):
+ params = {}
+ params['learning_rate'] = 1.0
+ return params
+
+ def test_train(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN,
+ features, labels, self.params)
+ session.run(variables.global_variables_initializer())
+
+ # loss = feature * c - label
+ total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
+ self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+ # loss' of c is 3.
+ # new value of c = 10 - learning rate * 3 = 7.0.
+ session.run(estimator_spec.train_op)
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(7.0, session.run(c))
+
+ def test_train_spec_with_optimizer_without_params(self):
+
+ def optimizer_fn_without_params():
+ return gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
+
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session() as session: # pylint: disable=unused-variable
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn,
+ optimizer_fn_without_params,
+ devices=['/gpu:0', '/gpu:1'])
+ # This call is going to fail if `replicated_model_fn` is still passing
+ # `params` inside `optimizer_fn`, even though the latter doesn't take any:
+ estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN,
+ features, labels, self.params)
+ del estimator_spec
+
+ def test_eval(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
+ labels, self.params)
+ session.run(variables.local_variables_initializer())
+ session.run(variables.global_variables_initializer())
+
+ accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+ auc, b = estimator_spec.eval_metric_ops['auc']
+
+ session.run([a, b])
+ accuracy = session.run(accuracy)
+ auc = session.run(auc)
+
+ # Accuracy is 0.0 (no match) in the first tower.
+ # Accuracy is 1.0 (match) in the second tower, since the feature
+ # times weight "c" happened to be equal to the label.
+ total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
+
+ self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
+ self.assertEqual(0, auc)
+ self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
+
+ def test_predict(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT,
+ features, labels, self.params)
+ session.run(variables.global_variables_initializer())
+
+ self.assertAllClose({
+ 'probabilities': np.array([[0.1], [0.02]])
+ }, session.run(estimator_spec.predictions))
+
+ def test_train_single_tower(self):
+ features = np.array([[1.0], [2.0]])
+ labels = np.array([[1.0], [2.0]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, self.optimizer_fn)
+ estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN,
+ features, labels, self.params)
+ session.run(variables.global_variables_initializer())
+
+ # loss = feature * c - label
+ total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
+ self.assertEqual(total_loss, session.run(estimator_spec.loss))
+
+ # loss' of c is 3.
+ # new value of c = 10 - learning rate * 3 = 7.0.
+ session.run(estimator_spec.train_op)
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(7.0, session.run(c))
+
+ def test_eval_single_tower(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
+ estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
+ labels, self.params)
+ session.run(variables.local_variables_initializer())
+ session.run(variables.global_variables_initializer())
+
+ accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+ auc, b = estimator_spec.eval_metric_ops['auc']
+
+ session.run([a, b])
+ accuracy = session.run(accuracy)
+ auc = session.run(auc)
+
+ # Accuracy is 0.0 (no match) in the first tower.
+ # Accuracy is 1.0 (match) in the second tower, since the feature
+ # times weight "c" happened to be equal to the label.
+ total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
+
+ self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
+ self.assertEqual(0, auc)
+ self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
+
+ def test_predict_single_tower(self):
+ features = np.array([[0.01], [0.002]])
+ labels = np.array([[0.01], [0.02]])
+
+ with self.test_session() as session:
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, self.optimizer_fn, devices=['/gpu:0'])
+ estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT,
+ features, labels, self.params)
+ session.run(variables.global_variables_initializer())
+
+ self.assertAllClose({
+ 'probabilities': np.array([[0.1], [0.02]])
+ }, session.run(estimator_spec.predictions))
+
+
+class GetLossTowersTest(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(0.25, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
+ labels = np.array([0.1, 0.2, 0.3, labels[0]])
+
+ loss = losses.absolute_difference(
+ labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
+
+ return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss))
+
+ def test_gradients_are_computed(self):
+ with self.test_session() as session:
+ tower_specs = replicate_model_fn._get_loss_towers(
+ self.model_fn,
+ mode=None,
+ features=[[0.6], [1.6]],
+ labels=[[0.6], [0.6]],
+ params=None,
+ config=None,
+ devices=['/gpu:0', '/gpu:1'],
+ local_ps_device='/gpu:0',
+ name_scope_pattern='test_tower_{}')
+ session.run(variables.global_variables_initializer())
+
+ self.assertEqual(len(tower_specs), 2)
+
+ self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
+ self.assertEqual('Sum:0', tower_specs[0].loss.name)
+ self.assertEqual(1.0, session.run(tower_specs[0].loss))
+
+ self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
+ self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name)
+ # The input batch for the second tower had a loss that is 1.0
+ # bigger: 0.6 vs 1.6.
+ self.assertEqual(2.0, session.run(tower_specs[1].loss))
+
+ self.assertEqual(1, len(variables.global_variables()))
+ self.assertEqual(1, len(variables.trainable_variables()))
+
+ with variable_scope.variable_scope('', reuse=True):
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual(0.25, session.run(c))
+
+
+class SplitBatchTest(test_util.TensorFlowTestCase):
+
+ def evaluate_shards(self, first_list, second_list):
+ evaluate_items = lambda x: x.eval()
+ return list(map(evaluate_items, first_list)), list(
+ map(evaluate_items, second_list))
+
+ def test_simple_half_split(self):
+ with self.test_session() as session: # pylint: disable=unused-variable
+ features = [0.0, 1.0, 2.0, 3.0]
+ labels = [10.0, 11.0, 12.0, 13.0]
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ feature_shards, label_shards = self.evaluate_shards(
+ feature_shards, label_shards)
+
+ self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards)
+ self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
+
+ def test_to_each_their_own(self):
+ with self.test_session() as session: # pylint: disable=unused-variable
+ features = [0.0, 1.0, 2.0, 3.0]
+ labels = [10.0, 11.0, 12.0, 13.0]
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 4, device='/gpu:0')
+
+ feature_shards, label_shards = self.evaluate_shards(
+ feature_shards, label_shards)
+
+ self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards)
+ self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
+
+ def test_one_batch(self):
+ with self.test_session() as session: # pylint: disable=unused-variable
+ features = [0.0, 1.0, 2.0, 3.0]
+ labels = [10.0, 11.0, 12.0, 13.0]
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 1, device='/gpu:0')
+
+ feature_shards, label_shards = self.evaluate_shards(
+ feature_shards, label_shards)
+
+ self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards)
+ self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
+
+ def test_half_split_in_dictionary(self):
+ with self.test_session() as session: # pylint: disable=unused-variable
+ features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
+ labels = [10.0, 11.0, 12.0, 13.0]
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
+ self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
+ self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
+ self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
+ self.assertAllEqual([10.0, 11.0], label_shards[0].eval())
+ self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
+
+ def test_one_batch_in_dictionary(self):
+ with self.test_session() as session: # pylint: disable=unused-variable
+ features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
+ labels = [10.0, 11.0, 12.0, 13.0]
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 1, device='/gpu:0')
+
+ self.assertAllEqual([0.0, 1.0, 2.0, 3.0],
+ feature_shards[0]['first'].eval())
+ self.assertAllEqual([4.0, 5.0, 6.0, 7.0],
+ feature_shards[0]['second'].eval())
+ self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval())
+
+ def test_feature_and_label_dictionaries(self):
+ with self.test_session() as session: # pylint: disable=unused-variable
+ features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
+ labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]}
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
+ self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
+ self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
+ self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
+ self.assertAllEqual([10.0], label_shards[0]['first'].eval())
+ self.assertAllEqual([12.0], label_shards[0]['second'].eval())
+ self.assertAllEqual([11], label_shards[1]['first'].eval())
+ self.assertAllEqual([13.0], label_shards[1]['second'].eval())
+
+
+class TrainSpecTest(test_util.TensorFlowTestCase):
+
+ expected_predictions = {}
+
+ def create_estimator_spec(self, loss):
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ loss=loss,
+ train_op=loss, # Not used; currently required.
+ predictions=self.expected_predictions)
+
+ def create_constant_loss(self, loss_value):
+ return constant_op.constant(loss_value, dtype=dtypes.float64)
+
+ def test_example(self):
+ with self.test_session() as session:
+ tower_losses = list(map(self.create_constant_loss, [2, 4, 6]))
+ tower_specs = list(map(self.create_estimator_spec, tower_losses))
+
+ expected_train_op = tower_losses[1]
+
+ estimator_spec = replicate_model_fn._train_spec(
+ tower_specs, expected_train_op, aggregation_device='/gpu:0')
+
+ self.assertEqual(expected_train_op, estimator_spec.train_op)
+ self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
+ self.assertEqual(self.expected_predictions, estimator_spec.predictions)
+
+
+class EvalSpecTest(test_util.TensorFlowTestCase):
+
+ def create_estimator_spec(self, loss, metrics):
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics)
+
+ def create_constant_loss(self, loss_value):
+ return constant_op.constant(loss_value, dtype=dtypes.float64)
+
+ def create_eval_metrics(self, noise):
+ predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise])
+ labels = np.array([0.1, 0.2, 0.3, 0.6])
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions),
+ 'auc': metrics_lib.auc(labels, predictions)
+ }
+ return metrics
+
+ def test_example(self):
+ with self.test_session() as session:
+ tower_losses = map(self.create_constant_loss, [2, 4, 6])
+ tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
+ tower_specs = [
+ self.create_estimator_spec(l, m)
+ for l, m in zip(tower_losses, tower_metrics)
+ ]
+ session.run(variables.local_variables_initializer())
+
+ estimator_spec = replicate_model_fn._eval_spec(
+ tower_specs, aggregation_device='/device:GPU:0')
+
+ accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+ auc, b = estimator_spec.eval_metric_ops['auc']
+
+ self.assertEqual('/device:CPU:0', accuracy.device)
+ self.assertEqual('/device:CPU:0', auc.device)
+
+ session.run([a, b])
+ accuracy = session.run(accuracy)
+ auc = session.run(auc)
+
+ self.assertNear((12 - 2) / 12, accuracy, 0.01)
+ self.assertEqual(0, auc)
+ self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
+
+ def test_handles_single_tower(self):
+ with self.test_session() as session:
+ tower_losses = map(self.create_constant_loss, [5])
+ tower_metrics = map(self.create_eval_metrics, [0.2])
+ tower_specs = [
+ self.create_estimator_spec(l, m)
+ for l, m in zip(tower_losses, tower_metrics)
+ ]
+ session.run(variables.local_variables_initializer())
+
+ estimator_spec = replicate_model_fn._eval_spec(
+ tower_specs, aggregation_device='/device:GPU:0')
+
+ accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+ auc, b = estimator_spec.eval_metric_ops['auc']
+
+ self.assertEqual('/device:CPU:0', accuracy.device)
+ self.assertEqual('/device:CPU:0', auc.device)
+
+ session.run([a, b])
+ accuracy = session.run(accuracy)
+ auc = session.run(auc)
+
+ self.assertNear((4 - 1) / 4, accuracy, 0.01)
+ self.assertEqual(0, auc)
+ self.assertEqual(5, session.run(estimator_spec.loss))
+
+
+class PredictSpecTest(test_util.TensorFlowTestCase):
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(0.25, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.add(np.array([features[0], features[0]]), c)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.PREDICT,
+ predictions={
+ 'probabilities': predictions
+ })
+
+ def test_example(self):
+ with self.test_session() as session:
+ tower_specs = replicate_model_fn._get_loss_towers(
+ self.model_fn,
+ mode=None,
+ features=[[0.1], [0.2]],
+ labels=[[], []],
+ params=None,
+ config=None,
+ devices=['/gpu:0', '/gpu:1'],
+ local_ps_device='/gpu:0',
+ )
+ session.run(variables.global_variables_initializer())
+
+ estimator_spec = replicate_model_fn._predict_spec(
+ tower_specs, aggregation_device='/gpu:0')
+
+ self.assertEqual('/device:GPU:0',
+ estimator_spec.predictions['probabilities'].device)
+ self.assertAllClose({
+ 'probabilities': np.array([0.35, 0.35, 0.45, 0.45])
+ }, session.run(estimator_spec.predictions))
+
+
+class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
+
+ def create_metric_variable(self, initial_value, name):
+ return variable_scope.variable(
+ initial_value,
+ trainable=False,
+ collections=[ops_lib.GraphKeys.METRIC_VARIABLES],
+ validate_shape=True,
+ name=name)
+
+ def create_tower_metrics(self, tower_id):
+ with variable_scope.variable_scope('', reuse=(tower_id != 0)):
+ self.create_metric_variable(1.3 * (tower_id + 1), 'total')
+ self.create_metric_variable(2.3 * (tower_id + 1), 'count')
+ self.create_metric_variable(
+ np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total')
+
+ def test_example(self):
+ with self.test_session() as session:
+ for tower_id in range(3):
+ self.create_tower_metrics(tower_id)
+
+ session.run(
+ variables.variables_initializer(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
+
+ session.run(
+ replicate_model_fn._reduce_metric_variables(number_of_towers=3))
+
+ # 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7]
+ # 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4]
+ # 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1]
+ # Reduced = 7.8, 13.8, [19.8, 21.0, 22.2]
+ # Towers are accumulated in the first tower.
+ local_metrics = session.run(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
+
+ self.assertNear(7.8, local_metrics[0], 0.01)
+ self.assertNear(13.8, local_metrics[1], 0.01)
+ self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
+ self.assertNear(0.0, local_metrics[3], 0.01)
+ self.assertNear(0.0, local_metrics[4], 0.01)
+ self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
+ self.assertNear(0.0, local_metrics[6], 0.01)
+ self.assertNear(0.0, local_metrics[7], 0.01)
+ self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
+
+ def test_reduce_is_idempotent(self):
+ with self.test_session() as session:
+ for tower_id in range(3):
+ self.create_tower_metrics(tower_id)
+
+ session.run(
+ variables.variables_initializer(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
+
+ for _ in range(20):
+ session.run(
+ replicate_model_fn._reduce_metric_variables(number_of_towers=3))
+
+ local_metrics = session.run(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
+
+ self.assertNear(7.8, local_metrics[0], 0.01)
+ self.assertNear(13.8, local_metrics[1], 0.01)
+ self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
+ self.assertNear(0.0, local_metrics[3], 0.01)
+ self.assertNear(0.0, local_metrics[4], 0.01)
+ self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
+ self.assertNear(0.0, local_metrics[6], 0.01)
+ self.assertNear(0.0, local_metrics[7], 0.01)
+ self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
+
+ def test_handles_single_tower(self):
+ with self.test_session() as session:
+ self.create_tower_metrics(0)
+ session.run(
+ variables.variables_initializer(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
+
+ session.run(
+ replicate_model_fn._reduce_metric_variables(number_of_towers=1))
+
+ local_metrics = session.run(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
+
+ self.assertNear(1.3, local_metrics[0], 0.01)
+ self.assertNear(2.3, local_metrics[1], 0.01)
+ self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)
+
+ def test_doesnt_accept_uneven_number_of_variables(self):
+ with self.test_session() as session:
+ for tower_id in range(3):
+ self.create_tower_metrics(tower_id)
+ self.create_metric_variable(-1.0, 'oddball')
+
+ session.run(
+ variables.variables_initializer(
+ ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
+
+ with self.assertRaisesRegexp(ValueError, ''):
+ session.run(
+ replicate_model_fn._reduce_metric_variables(number_of_towers=3))
+
+
+class MergeExportOutputsTest(test_util.TensorFlowTestCase):
+
+ def optimizer_fn(self):
+ return gradient_descent.GradientDescentOptimizer(1.0)
+
+ def model_fn(self, mode, features, labels, params):
+ c = variable_scope.get_variable(
+ 'c',
+ initializer=constant_op.constant(10, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = {'probabilities': math_ops.multiply(features, c)}
+ loss = losses.absolute_difference(
+ labels=labels,
+ predictions=predictions['probabilities'],
+ reduction=losses.Reduction.SUM)
+
+ metrics = {
+ 'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']),
+ 'auc': metrics_lib.auc(labels, predictions['probabilities'])
+ }
+ tensor_string_repr = str(features)
+ classes = constant_op.constant(
+ re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1),
+ dtype=dtypes.string)
+
+ export_outputs = {
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ export_output.PredictOutput(predictions),
+ 'classification_output':
+ export_output.ClassificationOutput(predictions['probabilities'],
+ classes),
+ 'classification_scores':
+ export_output.ClassificationOutput(
+ scores=predictions['probabilities']),
+ 'classification_classes':
+ export_output.ClassificationOutput(classes=classes),
+ 'regression_output':
+ export_output.RegressionOutput(predictions['probabilities']),
+ }
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=math_ops.reduce_sum(loss),
+ eval_metric_ops=metrics,
+ predictions=predictions,
+ train_op=loss, # This train_op isn't actually used.
+ export_outputs=export_outputs)
+
+ def replicate_estimator_spec(self, session):
+ features = np.array([0.01, 0.002])
+ labels = np.array([0.01, 0.02])
+
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+ estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT,
+ features, labels, {})
+ session.run(variables.global_variables_initializer())
+ return estimator_spec
+
+ def test_merde_predict_output(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllClose(
+ {
+ 'probabilities': np.array([0.1, 0.02])
+ },
+ session.run(estimator_spec.export_outputs[
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs))
+
+ def test_merge_classification_output_scores_classes(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllClose(
+ [0.1, 0.02],
+ session.run(
+ estimator_spec.export_outputs['classification_output'].scores))
+ self.assertAllEqual(
+ [b'split_inputs/split:0', b'split_inputs/split:1'],
+ session.run(
+ estimator_spec.export_outputs['classification_output'].classes))
+
+ def test_merge_classification_output_scores(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllClose(
+ [0.1, 0.02],
+ session.run(
+ estimator_spec.export_outputs['classification_scores'].scores))
+ self.assertEqual(
+ None, estimator_spec.export_outputs['classification_scores'].classes)
+
+ def test_merge_classification_output_classes(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllEqual(
+ [b'split_inputs/split:0', b'split_inputs/split:1'],
+ session.run(
+ estimator_spec.export_outputs['classification_classes'].classes))
+ self.assertEqual(
+ None, estimator_spec.export_outputs['classification_classes'].scores)
+
+ def test_merge_regression_output(self):
+ with self.test_session() as session:
+ estimator_spec = self.replicate_estimator_spec(session)
+ self.assertAllClose(
+ [0.1, 0.02],
+ session.run(estimator_spec.export_outputs['regression_output'].value))
+
+
+class GetLocalDevicesTest(test_util.TensorFlowTestCase):
+
+ def test_there_is_at_least_a_cpu(self):
+ self.assertTrue(replicate_model_fn._get_local_devices('CPU'))
+
+ def test_there_is_no_xpu(self):
+ self.assertFalse(
+ replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist.
+
+ def test_whether_there_is_a_gpu(self):
+ self.assertEqual(
+ len(replicate_model_fn._get_local_devices('GPU')),
+ test.is_gpu_available())
+
+
+class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
+
+ def test_vars_are_on_ps_but_ops_are_on_workers(self):
+ local_device_setter = replicate_model_fn._local_device_setter(
+ ps_device='/device:GPU:3', worker_device='/device:GPU:2')
+
+ with ops_lib.device(local_device_setter):
+ c = variables.Variable(0.01)
+ self.assertEqual('/device:GPU:3', c.device)
+
+ cc = variables.Variable(0.02)
+ self.assertEqual('/device:GPU:3', cc.device)
+
+ ccc = variables.Variable(0.03)
+ self.assertEqual('/device:GPU:3', ccc.device)
+
+ c_op = array_ops.concat(c, axis=0)
+ self.assertEqual('/device:GPU:2', c_op.device)
+
+ cc_op = array_ops.concat(cc, axis=0)
+ self.assertEqual('/device:GPU:2', cc_op.device)
+
+
+class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
+
+ def test_example(self):
+ with self.test_session() as session:
+ total = replicate_model_fn._compute_sum_on_device(
+ [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
+
+ self.assertEqual('/device:GPU:0', total.device)
+ self.assertEqual('test_sum', total.op.name)
+ self.assertEqual(10.0, session.run(total))
+
+
+class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
+
+ def test_example(self):
+ tensor_dicts = [
+ {
+ 'a': np.array([1.0, 2.0]),
+ 'b': np.array([11.0]),
+ 'c': np.array([21.0]),
+ },
+ {
+ 'a': np.array([3.0]),
+ 'b': np.array([12.0, 13.0]),
+ },
+ {
+ 'b': np.array([14.0]),
+ },
+ ]
+
+ with self.test_session() as session:
+ self.assertAllClose({
+ 'a': np.array([1.0, 2.0, 3.0]),
+ 'b': np.array([11.0, 12.0, 13.0, 14.0]),
+ 'c': np.array([21.0]),
+ }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts)))
+
+
+if __name__ == '__main__':
+ test.main()