aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py43
1 files changed, 16 insertions, 27 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index 0848c5f62f..7005a647db 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -34,12 +34,10 @@ from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import device as framework_device
from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients as gradients_lib
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
@@ -185,17 +183,10 @@ def _split_batch(features, labels, number_of_shards, device):
"""Split input features and labes into batches."""
def split_dictionary(dictionary):
- """Split a dictionary into shards."""
shards = [{} for _ in range(number_of_shards)]
for name, tensor in six.iteritems(dictionary):
- if isinstance(tensor, sparse_tensor.SparseTensor):
- for i, shard in enumerate(
- sparse_ops.sparse_split(
- sp_input=tensor, num_split=number_of_shards, axis=0)):
- shards[i][name] = shard
- else:
- for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
- shards[i][name] = shard
+ for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
+ shards[i][name] = shard
return shards
with ops_lib.name_scope('split_inputs'):
@@ -322,17 +313,7 @@ def _call_optimizer_fn(optimizer_fn, params):
def _compute_sum_on_device(values, device, name=None):
with ops_lib.device(device):
- if isinstance(values[0], ops_lib.IndexedSlices):
- if name:
- raise ValueError('The name {} is not expected to be given to '
- 'IndexedSlices {}'.format(name, values))
-
- values_concat = array_ops.concat([v.values for v in values], axis=0)
- indices_concat = array_ops.concat([v.indices for v in values], axis=0)
- return ops_lib.IndexedSlices(values_concat, indices_concat,
- values[0].dense_shape)
- else:
- return math_ops.add_n(values, name=name)
+ return math_ops.add_n(values, name=name)
def _train_spec(tower_specs,
@@ -357,17 +338,25 @@ def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'):
[spec.loss for spec in tower_specs], aggregation_device,
aggregated_loss_name)
- update_ops = []
+ eval_metric_ops_lists = {}
for tower_spec in tower_specs:
- for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops):
+ metrics = tower_spec.eval_metric_ops or {}
+ for name, (_, update_op) in six.iteritems(metrics):
+ update_ops = eval_metric_ops_lists.setdefault(name, ([]))
update_ops.append(update_op)
- with ops_lib.control_dependencies(update_ops):
- reduced_update_op = _reduce_metric_variables(len(tower_specs))
-
eval_metric_ops = {}
for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops):
+ with ops_lib.control_dependencies(eval_metric_ops_lists[name]):
+ # This operation reduces local variables across all metrics, yet is
+ # called for every metric. This is redundant and it's done because
+ # it is hard to know what local variables correspond to what metric.
+ # Estimator is going to execute all `reduced_update_op`s as part of
+ # a group inside a single `Session.run()` call, which will avoid duplicate
+ # computation.
+ reduced_update_op = _reduce_metric_variables(len(tower_specs))
eval_metric_ops[name] = (metric_tensor, reduced_update_op)
+
estimator_spec['eval_metric_ops'] = eval_metric_ops
return model_fn_lib.EstimatorSpec(**estimator_spec)