aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-10-02 14:30:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 14:35:06 -0700
commitc921e45bccac86ce0becc71cedc3da2c702d5c38 (patch)
tree0a460ab691dd66600bdfee5ecfd68c0666bb7095 /tensorflow/python/keras
parente45c90f0e4d17ac22048a73f1e81bd9c7a7a5145 (diff)
Add support for multiple input/output numpy arrays when using Keras APIs.
PiperOrigin-RevId: 215459075
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py134
-rw-r--r--tensorflow/python/keras/engine/training.py48
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py30
-rw-r--r--tensorflow/python/keras/models.py5
4 files changed, 162 insertions, 55 deletions
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index 39341a931b..050602868a 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -17,12 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python.client import session as session_module
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
@@ -304,23 +310,19 @@ def validate_inputs(x, y, distribution_strategy):
compiled.
Raises:
- ValueError: if input is not a Dataset or a numpy array.
+ ValueError: if input is not a Dataset or a numpy array(when we use
+ MirroredStrategy).
"""
- if isinstance(x, list) or isinstance(y, list):
- raise ValueError('DistributionStrategy does not support lists of numpy'
- 'arrays. You must pass a Dataset object or a numpy array '
- 'as input.')
-
if isinstance(x, dict) or isinstance(y, dict):
- raise ValueError('DistributionStrategy does not support inputs of type '
- 'dict. You must pass a Dataset object or a numpy array as '
- 'input.')
+ raise ValueError('`DistributionStrategy` does not support inputs of type '
+ 'dict. You must pass a `tf.data.Dataset` object or a '
+ 'numpy array as input.')
- if isinstance(x, iterator_ops.Iterator) or \
- isinstance(y, iterator_ops.Iterator):
- raise ValueError('DistributionStrategy does not support inputs of type '
- 'Iterator. You must pass a Dataset object or a numpy '
- 'array as input.')
+ if (isinstance(x, iterator_ops.Iterator) or
+ isinstance(y, iterator_ops.Iterator)):
+ raise ValueError('`DistributionStrategy` does not support inputs of type '
+ 'Iterator. You must pass a `tf.data.Dataset` object or a '
+ 'numpy array as input.')
if distribution_strategy.__class__.__name__ == 'TPUStrategy':
for i in [x, y]:
@@ -334,14 +336,14 @@ def validate_inputs(x, y, distribution_strategy):
'Found unknown shape {} in input {}.'.format(s, i))
-def get_input_batch_params(first_x_value, batch_size, current_strategy):
+def get_input_batch_params(first_x_value, batch_size, distribution_strategy):
"""Calculate the number of batches and steps/steps_per_epoch.
Args:
first_x_value: This is the first input numpy array that is passed in as the
model input.
batch_size: The specified batch_size or the default batch_size of 32.
- current_strategy: The current DistributionStrategy used to compile the
+ distribution_strategy: The current DistributionStrategy used to compile the
model.
Returns:
@@ -359,14 +361,14 @@ def get_input_batch_params(first_x_value, batch_size, current_strategy):
# TODO(anjalisridhar): TPU currently supports using the num_towers property.
# We might want to look into implementing worker_devices. In multi worker
# strategy, perhaps num_towers works better?
- steps = num_batches // current_strategy.num_towers
+ steps = num_batches // distribution_strategy.num_towers
if not steps:
# TODO(anjalisridhar): Number of towers in the error message may not convey
# what we want to the user. Is there another terminology that we can use
# that is consistent across different strategies.
raise ValueError('The number of batches %d is smaller than the number '
'of towers %d used for DistributionStrategy. ' %
- num_batches, current_strategy.num_towers)
+ (num_batches, distribution_strategy.num_towers))
return steps
@@ -376,3 +378,99 @@ def get_batch_dimension(iterator):
# all.
dims = shapes[0].dims
return dims[0] if dims else None
+
+
+def get_cpu_device(distribution_strategy):
+ """Returns the CPU device of the TPU host or the default CPU device string.
+
+ Args:
+ distribution_strategy: The DistributionStrategy used to compile the model.
+
+ Returns:
+ A device string which is the TPU host's CPU device in case of
+ TPUDistributionStrategy or the default CPU device string in all other
+ cases.
+
+ Raises:
+ NotImplementedError: We currently don't support copying numpy data to
+ multiple hosts in the case of Cloud TPU pods.
+ """
+ if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+ if distribution_strategy.num_hosts > 1:
+ raise NotImplementedError('TPUDistributionStrategy does not '
+ 'support numpy inputs when running on Cloud'
+ 'TPU pods.')
+ return distribution_strategy.get_host_cpu_device(0)
+ else:
+ # For all strategies except TPUDistributionStrategy
+ # TODO(anjalisridhar): We may need to modify this when we add support for
+ # multi-worker strategy.
+ return '/CPU:0'
+
+
+def get_var_for_numpy(distribution_strategy, x):
+ if isinstance(x, list):
+ var_x = tuple([_get_var_for_numpy(distribution_strategy, single_input)
+ for single_input in x])
+ else:
+ var_x = _get_var_for_numpy(distribution_strategy, x)
+ return var_x
+
+
+def _get_var_for_numpy(distribution_strategy, input_array):
+ """Creates a variable and assigns the value of the numpy array to it.
+
+ Args:
+ distribution_strategy: The DistributionStrategy used to compile the model.
+ input_array: The input numpy array whose value will be assigned to the
+ variable we create.
+
+ Returns:
+ The variable to which we will copy the value of the input numpy array.
+
+ """
+ with ops.device(get_cpu_device(distribution_strategy)):
+ # Create and initialize a variable on the CPU device. This is the CPU
+ # device of the host in the case of TPUDistributionStrategy.
+ input_var = variables.VariableV1(array_ops.zeros(input_array.shape,
+ input_array.dtype),
+ trainable=False, use_resource=True)
+ K.get_session().run(input_var.initializer)
+
+ # Create a placeholder for the numpy array input slices. We copy the value
+ # of the input numpy array to the variable in slices of size 64 MB to avoid
+ # running into memory issues or RPC message limits.
+ start_placeholder = array_ops.placeholder(dtypes.int64, ())
+ end_placeholder = array_ops.placeholder(dtypes.int64, ())
+ slice_placeholder = array_ops.placeholder(input_var.dtype)
+ assign_slice_op = input_var[start_placeholder:end_placeholder].assign(
+ slice_placeholder)
+
+ # If each batch element is > 64 MB, then we copy each batch element
+ # individually. Otherwise, the slices will be < 128 MB. There might be padding
+ # which might mean that the slices are 128 MB even if the size of the
+ # tensor allocated is less than 128 MB.
+ # This formula gives slices with size:
+ # ceil(64 MB / byte size per batch element) bytes.
+ # Using ceil() guarantees we get a number >= 1.
+
+ # Calculate the size of each batch element.
+ byte_size_per_batch_element = np.prod(input_array.shape[1:]) * \
+ input_var.dtype.size
+
+ # Calculate number of elements we want to copy per slice.
+ batch_size_per_slice = np.ceil((64 << 20) / byte_size_per_batch_element)
+
+ # Copy slices of the above size starting at 0, except the last slice will be
+ # smaller.
+ start = 0
+ limit = input_array.shape[0]
+ while start < limit:
+ end = min(start + batch_size_per_slice, limit)
+ K.get_session().run(assign_slice_op, feed_dict={
+ start_placeholder: start,
+ end_placeholder: end,
+ slice_placeholder: input_array[start:end]})
+ start = end
+
+ return input_var
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 5091cac836..c842b8192e 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -20,11 +20,9 @@ from __future__ import print_function
import weakref
import numpy as np
-import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -814,19 +812,21 @@ class Model(Network):
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
x_shape = first_x_value.shape
- x_dtype = first_x_value.dtype
if batch_size is None:
batch_size = x_shape[0] // steps
if y is not None:
- first_y_value = nest.flatten(y)[0]
- x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y),
- output_types=(x_dtype, first_y_value.dtype),
- output_shapes=(x_shape[1:],
- first_y_value.shape[1:]))
+ var_x = distributed_training_utils.get_var_for_numpy(
+ self._distribution_strategy, x)
+ var_y = distributed_training_utils.get_var_for_numpy(
+ self._distribution_strategy, y)
+
+ x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y))
# TODO(anjalisridhar): What should the buffer size be?
x = x.shuffle(10000)
x = x.repeat()
- x = x.batch(batch_size)
+ # We need to use the drop_remainder argument to allow for a static
+ # input shape which is required for TPUs.
+ x = x.batch(batch_size, drop_remainder=True)
y = None
else:
# This case is for the predict call where the dataset only contains
@@ -834,11 +834,13 @@ class Model(Network):
# TODO(anjalisridhar): Raise an error if we are not able to process
# all the predict samples. This can happen if the number of batches is
# not evenly divisible by the number of worker devices.
- x = Dataset.from_generator(lambda x=x: x,
- output_types=x_dtype,
- output_shapes=x_shape[1:])
+ var_x = distributed_training_utils.get_var_for_numpy(
+ self._distribution_strategy, x)
+ x = dataset_ops.Dataset.from_tensor_slices(var_x)
x = x.repeat()
- x = x.batch(batch_size)
+ # We need to use the drop_remainder argument to allow for a static
+ # input shape which is required for TPUs.
+ x = x.batch(batch_size, drop_remainder=True)
# TODO(anjalisridhar): Can we use the iterator and getnext op cache?
# We require users to pass Datasets since we distribute the dataset across
@@ -978,16 +980,18 @@ class Model(Network):
'Make sure that your dataset can generate '
'required number of samples.')
- if (not isinstance(next_element, (list, tuple)) or
- len(next_element) not in [2, 3]):
- raise ValueError(
- 'Please provide model inputs as a list or tuple of 2 or 3'
- 'elements: (input, target) or (input, target, sample_weights)'
- 'Received %s' % next_element)
- if len(next_element) == 2:
- x, y = next_element
+ if isinstance(next_element, (list, tuple)):
+ if len(next_element) not in [2, 3]:
+ raise ValueError(
+ 'Please provide model inputs as a list or tuple of 2 or 3'
+ 'elements: (input, target) or (input, target, sample_weights)'
+ 'Received %s' % next_element)
+ if len(next_element) == 2:
+ x, y = next_element
+ else:
+ x, y, sample_weight = next_element
else:
- x, y, sample_weight = next_element
+ x = next_element
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
class_weight, batch_size)
return x, y, sample_weights
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index a6470458d2..04e8d079c0 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.util import nest
# TODO(priyag, sourabhbajaj): Refactor this file to address code duplication.
@@ -296,15 +297,16 @@ def _experimental_fit_loop(
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
if steps_per_epoch is None:
- raise ValueError('steps_per_epoch should be specified in the fit call.')
- steps_per_run_var = K.variable(
+ raise ValueError('`steps_per_epoch` should be specified when calling '
+ '`fit` on the model.')
+ steps_per_run = K.variable(
value=min(steps_per_epoch, current_strategy.steps_per_run),
dtype='int32',
- name='steps_per_run_var')
+ name='steps_per_run')
with current_strategy.scope():
ctx = current_strategy.run_steps_on_dataset(
- step_fn, iterator, iterations=steps_per_run_var,
+ step_fn, iterator, iterations=steps_per_run,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
@@ -344,7 +346,7 @@ def _experimental_fit_loop(
batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
callbacks.on_batch_begin(step_index, batch_logs)
if prev_step_count is None or step_count != prev_step_count:
- steps_per_run_var.load(step_count, K.get_session())
+ steps_per_run.load(step_count, K.get_session())
prev_step_count = step_count
try:
_, outputs = K.get_session().run([train_op, output_tensors])
@@ -720,13 +722,9 @@ def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
model.predict_function.updates_op,
model.predict_function.session_kwargs)
- def step_fn(ctx, inputs, targets):
+ def step_fn(ctx, *inputs):
"""Clones the model and calls make_predict_function."""
- # TODO(anjalisridhar): Support predict input correctly as it will not
- # contain targets, only inputs.
- del targets
-
# TODO(priyag, sourabhbajaj): The model gets cloned every time
# fit/test/predict is called. We should look into caching this keyed on
# input shapes.
@@ -824,9 +822,10 @@ def _clone_and_build_model(model, inputs=None, targets=None):
# TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
# single tensor should be OK but it throws an error in that case.
- if (targets is not None and not isinstance(targets, list) and
- not isinstance(targets, dict)):
+ if targets is not None and not isinstance(targets, (list, dict, tuple)):
targets = [targets]
+ if isinstance(targets, tuple):
+ targets = nest.flatten(targets)
cloned_model.compile(
optimizer,
model.loss,
@@ -891,11 +890,12 @@ def _get_input_from_iterator(iterator, model):
"""Get elements from the iterator and verify the input shape and type."""
next_element = iterator.get_next()
- if isinstance(next_element, tuple):
- x, y = next_element
- else:
+ if len(nest.flatten(next_element)) == len(model.inputs):
x = next_element
y = None
+ else:
+ x, y = next_element
+
# Validate that all the elements in x and y are of the same type and shape.
# We can then pass the first element of x and y to `_standardize_weights`
# below and be confident of the output.
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index b04b4df257..2883c9ad74 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -96,6 +96,8 @@ def _clone_functional_model(model, input_tensors=None):
else:
# Make sure that all input tensors come from a Keras layer.
# If tensor comes from an input layer: cache the input layer.
+ if isinstance(input_tensors, tuple):
+ input_tensors = list(input_tensors)
input_tensors = generic_utils.to_list(input_tensors)
input_tensors_ = []
for i, x in enumerate(input_tensors):
@@ -212,6 +214,9 @@ def _clone_sequential_model(model, input_tensors=None):
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '
'as part of `input_tensors`.')
+
+ if isinstance(input_tensors, tuple):
+ input_tensors = list(input_tensors)
x = generic_utils.to_list(input_tensors)[0]
if K.is_keras_tensor(x):
origin_layer = x._keras_history[0]