diff options
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py | 43 |
1 files changed, 16 insertions, 27 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index 0848c5f62f..7005a647db 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -34,12 +34,10 @@ from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import ops as ops_lib -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients as gradients_lib from tensorflow.python.ops import math_ops -from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib @@ -185,17 +183,10 @@ def _split_batch(features, labels, number_of_shards, device): """Split input features and labes into batches.""" def split_dictionary(dictionary): - """Split a dictionary into shards.""" shards = [{} for _ in range(number_of_shards)] for name, tensor in six.iteritems(dictionary): - if isinstance(tensor, sparse_tensor.SparseTensor): - for i, shard in enumerate( - sparse_ops.sparse_split( - sp_input=tensor, num_split=number_of_shards, axis=0)): - shards[i][name] = shard - else: - for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): - shards[i][name] = shard + for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): + shards[i][name] = shard return shards with ops_lib.name_scope('split_inputs'): @@ -322,17 +313,7 @@ def _call_optimizer_fn(optimizer_fn, params): def _compute_sum_on_device(values, device, name=None): with ops_lib.device(device): - if isinstance(values[0], ops_lib.IndexedSlices): - if name: - raise ValueError('The name {} is not expected to be given to ' - 'IndexedSlices {}'.format(name, values)) - - values_concat = array_ops.concat([v.values for v in values], axis=0) - indices_concat = array_ops.concat([v.indices for v in values], axis=0) - return ops_lib.IndexedSlices(values_concat, indices_concat, - values[0].dense_shape) - else: - return math_ops.add_n(values, name=name) + return math_ops.add_n(values, name=name) def _train_spec(tower_specs, @@ -357,17 +338,25 @@ def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): [spec.loss for spec in tower_specs], aggregation_device, aggregated_loss_name) - update_ops = [] + eval_metric_ops_lists = {} for tower_spec in tower_specs: - for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops): + metrics = tower_spec.eval_metric_ops or {} + for name, (_, update_op) in six.iteritems(metrics): + update_ops = eval_metric_ops_lists.setdefault(name, ([])) update_ops.append(update_op) - with ops_lib.control_dependencies(update_ops): - reduced_update_op = _reduce_metric_variables(len(tower_specs)) - eval_metric_ops = {} for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops): + with ops_lib.control_dependencies(eval_metric_ops_lists[name]): + # This operation reduces local variables across all metrics, yet is + # called for every metric. This is redundant and it's done because + # it is hard to know what local variables correspond to what metric. + # Estimator is going to execute all `reduced_update_op`s as part of + # a group inside a single `Session.run()` call, which will avoid duplicate + # computation. + reduced_update_op = _reduce_metric_variables(len(tower_specs)) eval_metric_ops[name] = (metric_tensor, reduced_update_op) + estimator_spec['eval_metric_ops'] = eval_metric_ops return model_fn_lib.EstimatorSpec(**estimator_spec) |