diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-10-02 14:30:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 14:35:06 -0700 |
commit | c921e45bccac86ce0becc71cedc3da2c702d5c38 (patch) | |
tree | 0a460ab691dd66600bdfee5ecfd68c0666bb7095 /tensorflow/python/keras | |
parent | e45c90f0e4d17ac22048a73f1e81bd9c7a7a5145 (diff) |
Add support for multiple input/output numpy arrays when using Keras APIs.
PiperOrigin-RevId: 215459075
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/engine/distributed_training_utils.py | 134 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 48 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_distributed.py | 30 | ||||
-rw-r--r-- | tensorflow/python/keras/models.py | 5 |
4 files changed, 162 insertions, 55 deletions
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py index 39341a931b..050602868a 100644 --- a/tensorflow/python/keras/engine/distributed_training_utils.py +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -17,12 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.client import session as session_module from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -304,23 +310,19 @@ def validate_inputs(x, y, distribution_strategy): compiled. Raises: - ValueError: if input is not a Dataset or a numpy array. + ValueError: if input is not a Dataset or a numpy array(when we use + MirroredStrategy). """ - if isinstance(x, list) or isinstance(y, list): - raise ValueError('DistributionStrategy does not support lists of numpy' - 'arrays. You must pass a Dataset object or a numpy array ' - 'as input.') - if isinstance(x, dict) or isinstance(y, dict): - raise ValueError('DistributionStrategy does not support inputs of type ' - 'dict. You must pass a Dataset object or a numpy array as ' - 'input.') + raise ValueError('`DistributionStrategy` does not support inputs of type ' + 'dict. You must pass a `tf.data.Dataset` object or a ' + 'numpy array as input.') - if isinstance(x, iterator_ops.Iterator) or \ - isinstance(y, iterator_ops.Iterator): - raise ValueError('DistributionStrategy does not support inputs of type ' - 'Iterator. You must pass a Dataset object or a numpy ' - 'array as input.') + if (isinstance(x, iterator_ops.Iterator) or + isinstance(y, iterator_ops.Iterator)): + raise ValueError('`DistributionStrategy` does not support inputs of type ' + 'Iterator. You must pass a `tf.data.Dataset` object or a ' + 'numpy array as input.') if distribution_strategy.__class__.__name__ == 'TPUStrategy': for i in [x, y]: @@ -334,14 +336,14 @@ def validate_inputs(x, y, distribution_strategy): 'Found unknown shape {} in input {}.'.format(s, i)) -def get_input_batch_params(first_x_value, batch_size, current_strategy): +def get_input_batch_params(first_x_value, batch_size, distribution_strategy): """Calculate the number of batches and steps/steps_per_epoch. Args: first_x_value: This is the first input numpy array that is passed in as the model input. batch_size: The specified batch_size or the default batch_size of 32. - current_strategy: The current DistributionStrategy used to compile the + distribution_strategy: The current DistributionStrategy used to compile the model. Returns: @@ -359,14 +361,14 @@ def get_input_batch_params(first_x_value, batch_size, current_strategy): # TODO(anjalisridhar): TPU currently supports using the num_towers property. # We might want to look into implementing worker_devices. In multi worker # strategy, perhaps num_towers works better? - steps = num_batches // current_strategy.num_towers + steps = num_batches // distribution_strategy.num_towers if not steps: # TODO(anjalisridhar): Number of towers in the error message may not convey # what we want to the user. Is there another terminology that we can use # that is consistent across different strategies. raise ValueError('The number of batches %d is smaller than the number ' 'of towers %d used for DistributionStrategy. ' % - num_batches, current_strategy.num_towers) + (num_batches, distribution_strategy.num_towers)) return steps @@ -376,3 +378,99 @@ def get_batch_dimension(iterator): # all. dims = shapes[0].dims return dims[0] if dims else None + + +def get_cpu_device(distribution_strategy): + """Returns the CPU device of the TPU host or the default CPU device string. + + Args: + distribution_strategy: The DistributionStrategy used to compile the model. + + Returns: + A device string which is the TPU host's CPU device in case of + TPUDistributionStrategy or the default CPU device string in all other + cases. + + Raises: + NotImplementedError: We currently don't support copying numpy data to + multiple hosts in the case of Cloud TPU pods. + """ + if distribution_strategy.__class__.__name__ == 'TPUStrategy': + if distribution_strategy.num_hosts > 1: + raise NotImplementedError('TPUDistributionStrategy does not ' + 'support numpy inputs when running on Cloud' + 'TPU pods.') + return distribution_strategy.get_host_cpu_device(0) + else: + # For all strategies except TPUDistributionStrategy + # TODO(anjalisridhar): We may need to modify this when we add support for + # multi-worker strategy. + return '/CPU:0' + + +def get_var_for_numpy(distribution_strategy, x): + if isinstance(x, list): + var_x = tuple([_get_var_for_numpy(distribution_strategy, single_input) + for single_input in x]) + else: + var_x = _get_var_for_numpy(distribution_strategy, x) + return var_x + + +def _get_var_for_numpy(distribution_strategy, input_array): + """Creates a variable and assigns the value of the numpy array to it. + + Args: + distribution_strategy: The DistributionStrategy used to compile the model. + input_array: The input numpy array whose value will be assigned to the + variable we create. + + Returns: + The variable to which we will copy the value of the input numpy array. + + """ + with ops.device(get_cpu_device(distribution_strategy)): + # Create and initialize a variable on the CPU device. This is the CPU + # device of the host in the case of TPUDistributionStrategy. + input_var = variables.VariableV1(array_ops.zeros(input_array.shape, + input_array.dtype), + trainable=False, use_resource=True) + K.get_session().run(input_var.initializer) + + # Create a placeholder for the numpy array input slices. We copy the value + # of the input numpy array to the variable in slices of size 64 MB to avoid + # running into memory issues or RPC message limits. + start_placeholder = array_ops.placeholder(dtypes.int64, ()) + end_placeholder = array_ops.placeholder(dtypes.int64, ()) + slice_placeholder = array_ops.placeholder(input_var.dtype) + assign_slice_op = input_var[start_placeholder:end_placeholder].assign( + slice_placeholder) + + # If each batch element is > 64 MB, then we copy each batch element + # individually. Otherwise, the slices will be < 128 MB. There might be padding + # which might mean that the slices are 128 MB even if the size of the + # tensor allocated is less than 128 MB. + # This formula gives slices with size: + # ceil(64 MB / byte size per batch element) bytes. + # Using ceil() guarantees we get a number >= 1. + + # Calculate the size of each batch element. + byte_size_per_batch_element = np.prod(input_array.shape[1:]) * \ + input_var.dtype.size + + # Calculate number of elements we want to copy per slice. + batch_size_per_slice = np.ceil((64 << 20) / byte_size_per_batch_element) + + # Copy slices of the above size starting at 0, except the last slice will be + # smaller. + start = 0 + limit = input_array.shape[0] + while start < limit: + end = min(start + batch_size_per_slice, limit) + K.get_session().run(assign_slice_op, feed_dict={ + start_placeholder: start, + end_placeholder: end, + slice_placeholder: input_array[start:end]}) + start = end + + return input_var diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 5091cac836..c842b8192e 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -20,11 +20,9 @@ from __future__ import print_function import weakref import numpy as np -import six from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.ops.dataset_ops import Dataset from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -814,19 +812,21 @@ class Model(Network): first_x_value = nest.flatten(x)[0] if isinstance(first_x_value, np.ndarray): x_shape = first_x_value.shape - x_dtype = first_x_value.dtype if batch_size is None: batch_size = x_shape[0] // steps if y is not None: - first_y_value = nest.flatten(y)[0] - x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y), - output_types=(x_dtype, first_y_value.dtype), - output_shapes=(x_shape[1:], - first_y_value.shape[1:])) + var_x = distributed_training_utils.get_var_for_numpy( + self._distribution_strategy, x) + var_y = distributed_training_utils.get_var_for_numpy( + self._distribution_strategy, y) + + x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y)) # TODO(anjalisridhar): What should the buffer size be? x = x.shuffle(10000) x = x.repeat() - x = x.batch(batch_size) + # We need to use the drop_remainder argument to allow for a static + # input shape which is required for TPUs. + x = x.batch(batch_size, drop_remainder=True) y = None else: # This case is for the predict call where the dataset only contains @@ -834,11 +834,13 @@ class Model(Network): # TODO(anjalisridhar): Raise an error if we are not able to process # all the predict samples. This can happen if the number of batches is # not evenly divisible by the number of worker devices. - x = Dataset.from_generator(lambda x=x: x, - output_types=x_dtype, - output_shapes=x_shape[1:]) + var_x = distributed_training_utils.get_var_for_numpy( + self._distribution_strategy, x) + x = dataset_ops.Dataset.from_tensor_slices(var_x) x = x.repeat() - x = x.batch(batch_size) + # We need to use the drop_remainder argument to allow for a static + # input shape which is required for TPUs. + x = x.batch(batch_size, drop_remainder=True) # TODO(anjalisridhar): Can we use the iterator and getnext op cache? # We require users to pass Datasets since we distribute the dataset across @@ -978,16 +980,18 @@ class Model(Network): 'Make sure that your dataset can generate ' 'required number of samples.') - if (not isinstance(next_element, (list, tuple)) or - len(next_element) not in [2, 3]): - raise ValueError( - 'Please provide model inputs as a list or tuple of 2 or 3' - 'elements: (input, target) or (input, target, sample_weights)' - 'Received %s' % next_element) - if len(next_element) == 2: - x, y = next_element + if isinstance(next_element, (list, tuple)): + if len(next_element) not in [2, 3]: + raise ValueError( + 'Please provide model inputs as a list or tuple of 2 or 3' + 'elements: (input, target) or (input, target, sample_weights)' + 'Received %s' % next_element) + if len(next_element) == 2: + x, y = next_element + else: + x, y, sample_weight = next_element else: - x, y, sample_weight = next_element + x = next_element x, y, sample_weights = self._standardize_weights(x, y, sample_weight, class_weight, batch_size) return x, y, sample_weights diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index a6470458d2..04e8d079c0 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest # TODO(priyag, sourabhbajaj): Refactor this file to address code duplication. @@ -296,15 +297,16 @@ def _experimental_fit_loop( initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) if steps_per_epoch is None: - raise ValueError('steps_per_epoch should be specified in the fit call.') - steps_per_run_var = K.variable( + raise ValueError('`steps_per_epoch` should be specified when calling ' + '`fit` on the model.') + steps_per_run = K.variable( value=min(steps_per_epoch, current_strategy.steps_per_run), dtype='int32', - name='steps_per_run_var') + name='steps_per_run') with current_strategy.scope(): ctx = current_strategy.run_steps_on_dataset( - step_fn, iterator, iterations=steps_per_run_var, + step_fn, iterator, iterations=steps_per_run, initial_loop_values=initial_loop_values) train_op = ctx.run_op @@ -344,7 +346,7 @@ def _experimental_fit_loop( batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count} callbacks.on_batch_begin(step_index, batch_logs) if prev_step_count is None or step_count != prev_step_count: - steps_per_run_var.load(step_count, K.get_session()) + steps_per_run.load(step_count, K.get_session()) prev_step_count = step_count try: _, outputs = K.get_session().run([train_op, output_tensors]) @@ -720,13 +722,9 @@ def _experimental_predict_loop(model, iterator, verbose=0, steps=None): model.predict_function.updates_op, model.predict_function.session_kwargs) - def step_fn(ctx, inputs, targets): + def step_fn(ctx, *inputs): """Clones the model and calls make_predict_function.""" - # TODO(anjalisridhar): Support predict input correctly as it will not - # contain targets, only inputs. - del targets - # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. @@ -824,9 +822,10 @@ def _clone_and_build_model(model, inputs=None, targets=None): # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a # single tensor should be OK but it throws an error in that case. - if (targets is not None and not isinstance(targets, list) and - not isinstance(targets, dict)): + if targets is not None and not isinstance(targets, (list, dict, tuple)): targets = [targets] + if isinstance(targets, tuple): + targets = nest.flatten(targets) cloned_model.compile( optimizer, model.loss, @@ -891,11 +890,12 @@ def _get_input_from_iterator(iterator, model): """Get elements from the iterator and verify the input shape and type.""" next_element = iterator.get_next() - if isinstance(next_element, tuple): - x, y = next_element - else: + if len(nest.flatten(next_element)) == len(model.inputs): x = next_element y = None + else: + x, y = next_element + # Validate that all the elements in x and y are of the same type and shape. # We can then pass the first element of x and y to `_standardize_weights` # below and be confident of the output. diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index b04b4df257..2883c9ad74 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -96,6 +96,8 @@ def _clone_functional_model(model, input_tensors=None): else: # Make sure that all input tensors come from a Keras layer. # If tensor comes from an input layer: cache the input layer. + if isinstance(input_tensors, tuple): + input_tensors = list(input_tensors) input_tensors = generic_utils.to_list(input_tensors) input_tensors_ = [] for i, x in enumerate(input_tensors): @@ -212,6 +214,9 @@ def _clone_sequential_model(model, input_tensors=None): raise ValueError('To clone a `Sequential` model, we expect ' ' at most one tensor ' 'as part of `input_tensors`.') + + if isinstance(input_tensors, tuple): + input_tensors = list(input_tensors) x = generic_utils.to_list(input_tensors)[0] if K.is_keras_tensor(x): origin_layer = x._keras_history[0] |