aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training_utils.py')
-rw-r--r--tensorflow/python/keras/engine/training_utils.py138
1 files changed, 136 insertions, 2 deletions
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 728a2b493b..dbbc87daf9 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -19,9 +19,11 @@ from __future__ import division
from __future__ import print_function
import copy
+import math
import numpy as np
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_util
@@ -31,6 +33,135 @@ from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.ops import math_ops
+def _map_nested(data, func):
+ """Maps each nested element using func."""
+ if isinstance(data, list):
+ return [_map_nested(nested_data, func) for nested_data in data]
+ elif isinstance(data, tuple):
+ return tuple(_map_nested(nested_data, func) for nested_data in data)
+ elif isinstance(data, dict):
+ return {
+ k: _map_nested(nested_data, func) for k, nested_data in data.items()
+ }
+ else:
+ return func(data)
+
+
+def _nested_all(data, cond_func):
+ """Checks if all elements in a nested structure satisfy cond_func."""
+ if isinstance(data, (tuple, list)):
+ return all([_nested_all(nested_data, cond_func) for nested_data in data])
+ elif isinstance(data, dict):
+ return all(
+ [_nested_all(nested_data, cond_func) for nested_data in data.values()])
+ else:
+ return cond_func(data)
+
+
+def _nested_any(data, cond_func):
+ """Checks if any nested_elements in a nested structure satisfy cond_func."""
+ if isinstance(data, (tuple, list)):
+ return any([_nested_any(nested_data, cond_func) for nested_data in data])
+ elif isinstance(data, dict):
+ return any(
+ [_nested_any(nested_data, cond_func) for nested_data in data.values()])
+ else:
+ return cond_func(data)
+
+
+def _convert_lists_to_tuples(data):
+ """Converts all lists to tuples, since Datasets expect tuples."""
+ if isinstance(data, (tuple, list)):
+ return tuple(_convert_lists_to_tuples(nested_data) for nested_data in data)
+ elif isinstance(data, dict):
+ return {
+ k: _convert_lists_to_tuples(nested_data)
+ for k, nested_data in data.items()
+ }
+ else:
+ return data
+
+
+def _get_batch_axis_size(data):
+ """Returns batch axis shape for nested data."""
+ if isinstance(data, (tuple, list)):
+ return _get_batch_axis_size(data[0])
+ elif isinstance(data, dict):
+ return _get_batch_axis_size(list(data.values()))
+ else:
+ return int(data.shape[0])
+
+
+def convert_to_iterator(x=None,
+ y=None,
+ sample_weights=None,
+ batch_size=None,
+ steps_per_epoch=None,
+ epochs=1,
+ shuffle=False):
+ """Converts NumPy arrays or EagerTensors to an EagerIterator.
+
+ Combines all provided data into a single EagerIterator.
+
+ Arguments:
+ x: NumPy array or EagerTensor, or list of Numpy arrays or EagerTensors
+ representing inputs to a model.
+ y: Optional. NumPy array or EagerTensor, or list of Numpy arrays or
+ EagerTensors representing targets of a model.
+ sample_weights: Optional NumPy array or EagerTensor representing sample
+ weights.
+ batch_size: Used to batch data and calculate how many steps EagerIterator
+ should take per epoch.
+ steps_per_epoch: If provided, how many steps EagerIterator should take per
+ epoch.
+ epochs: Epochs to repeat iterator for.
+ shuffle: Whether to shuffle data after each epoch.
+
+ Raises:
+ ValueError: if steps_per_epoch cannot be calculated from the data
+ provided.
+
+ Returns:
+ (Iterator, steps_per_epoch).
+
+ """
+ if isinstance(x, iterator_ops.EagerIterator):
+ return x, steps_per_epoch
+
+ if not _nested_any(sample_weights, lambda x: x is None):
+ data = (x, y, sample_weights)
+ elif not _nested_any(y, lambda x: x is None):
+ data = (x, y)
+ else:
+ # always wrap in a tuple, so we know y, sample_weights weren't set
+ # even when x has multiple elements
+ data = (x,)
+
+ data = _convert_lists_to_tuples(data)
+ if steps_per_epoch is None and batch_size is not None:
+ num_samples = _get_batch_axis_size(data)
+ steps_per_epoch = int(math.ceil(num_samples / batch_size))
+
+ if steps_per_epoch is None:
+ raise ValueError('Could not determine steps_per_epoch.'
+ 'Please provide either batch_size or'
+ 'steps_per_epoch.')
+
+ # TODO(omalleyt) for NumPy arrays in graph mode
+ # placeholder ops should be used
+ # this is only ideal for eager mode
+ dataset = dataset_ops.Dataset.from_tensor_slices(data)
+
+ if batch_size is not None:
+ dataset = dataset.batch(batch_size)
+ if shuffle:
+ dataset = dataset.shuffle(buffer_size=10000)
+ dataset = dataset.repeat(epochs)
+ iterator = dataset.make_one_shot_iterator()
+
+ return iterator, steps_per_epoch
+
+
def check_num_samples(ins,
batch_size=None,
steps=None,
@@ -128,8 +259,8 @@ def standardize_input_data(data,
except KeyError as e:
raise ValueError('No data provided for "' + e.args[0] + '". Need data '
'for each key in: ' + str(names))
- elif isinstance(data, list):
- if isinstance(data[0], list):
+ elif isinstance(data, (list, tuple)):
+ if isinstance(data[0], (list, tuple)):
data = [np.asarray(d) for d in data]
elif len(names) == 1 and isinstance(data[0], (float, int)):
data = [np.asarray(data)]
@@ -482,6 +613,9 @@ def standardize_weights(y,
Raises:
ValueError: In case of invalid user-provided arguments.
"""
+ # Iterator may return sample_weight as 1-tuple
+ if isinstance(sample_weight, tuple):
+ sample_weight = sample_weight[0]
if sample_weight_mode is not None:
if sample_weight_mode != 'temporal':
raise ValueError('"sample_weight_mode '