aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py497
1 files changed, 329 insertions, 168 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
index 2ce11e813f..f665ff7644 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
@@ -36,27 +36,49 @@ from tensorflow.python.platform import tf_logging as logging
# pylint: disable=g-multiple-import,g-bad-import-order
from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels
+
+
# pylint: enable=g-multiple-import,g-bad-import-order
def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
"""Returns shape for input and output of the data feeder."""
+ x_is_dict, y_is_dict = isinstance(x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
+ if y_is_dict and n_classes is not None:
+ assert (isinstance(n_classes, dict))
+
if batch_size is None:
- batch_size = x_shape[0]
+ batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0]
elif batch_size <= 0:
raise ValueError('Invalid batch_size %d.' % batch_size)
- x_shape = list(x_shape[1:]) if len(x_shape) > 1 else [1]
- input_shape = [batch_size] + x_shape
+
+ if x_is_dict:
+ input_shape = {}
+ for k, v in list(x_shape.items()):
+ input_shape[k] = [batch_size] + (list(v[1:]) if len(v) > 1 else [1])
+ else:
+ x_shape = list(x_shape[1:]) if len(x_shape) > 1 else [1]
+ input_shape = [batch_size] + x_shape
+
if y_shape is None:
return input_shape, None, batch_size
- y_shape = list(y_shape[1:]) if len(y_shape) > 1 else []
- # Skip first dimension if it is 1.
- if y_shape and y_shape[0] == 1:
- y_shape = y_shape[1:]
- if n_classes is not None and n_classes > 1:
- output_shape = [batch_size] + y_shape + [n_classes]
+
+ def out_el_shape(out_shape, num_classes):
+ out_shape = list(out_shape[1:]) if len(out_shape) > 1 else []
+ # Skip first dimension if it is 1.
+ if out_shape and out_shape[0] == 1:
+ out_shape = out_shape[1:]
+ if num_classes is not None and num_classes > 1:
+ return [batch_size] + out_shape + [num_classes]
+ else:
+ return [batch_size] + out_shape
+
+ if not y_is_dict:
+ output_shape = out_el_shape(y_shape, n_classes)
else:
- output_shape = [batch_size] + y_shape
+ output_shape = dict([(k, out_el_shape(v, n_classes[k] if n_classes is not None and k in n_classes else None))
+ for k, v in list(y_shape.items())])
+
return input_shape, output_shape, batch_size
@@ -78,15 +100,18 @@ def _is_iterable(x):
def setup_train_data_feeder(
- x, y, n_classes, batch_size=None, shuffle=True, epochs=None):
+ x, y, n_classes, batch_size=None, shuffle=True, epochs=None):
"""Create data feeder, to sample inputs from dataset.
If `x` and `y` are iterators, use `StreamingDataFeeder`.
Args:
- x: numpy, pandas or Dask matrix or iterable.
- y: numpy, pandas or Dask array or iterable.
- n_classes: number of classes.
+ x: numpy, pandas or Dask matrix or dictionary of aforementioned. Also
+ supports iterables.
+ y: numpy, pandas or Dask array or dictionary of aforementioned. Also supports
+ iterables.
+ n_classes: number of classes. Must be None or same type as y. In case, `y` is `dict`
+ (or iterable which returns dict) such that `n_classes[key] = n_classes for y[key]`
batch_size: size to split data into parts. Must be >= 1.
shuffle: Whether to shuffle the inputs.
epochs: Number of epochs to run.
@@ -102,7 +127,7 @@ def setup_train_data_feeder(
# pylint: disable=g-import-not-at-top
import dask.dataframe as dd
if (isinstance(x, (dd.Series, dd.DataFrame)) and
- (y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
+ (y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
data_feeder_cls = DaskDataFeeder
else:
data_feeder_cls = DataFeeder
@@ -115,31 +140,54 @@ def setup_train_data_feeder(
'streaming learning to work.')
return StreamingDataFeeder(x, y, n_classes, batch_size)
return data_feeder_cls(
- x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
+ x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
def _batch_data(x, batch_size=None):
if (batch_size is not None) and (batch_size <= 0):
raise ValueError('Invalid batch_size %d.' % batch_size)
- chunk = []
+
+ x_first_el = six.next(x)
+ x = itertools.chain([x_first_el], x)
+
+ chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(x_first_el, dict) else []
+ chunk_filled = False
for data in x:
- chunk.append(data)
- if (batch_size is not None) and (len(chunk) >= batch_size):
- yield np.matrix(chunk)
- chunk = []
- yield np.matrix(chunk)
+ if isinstance(data, dict):
+ for k, v in list(data.items()):
+ chunk[k].append(v)
+ if (batch_size is not None) and (len(chunk[k]) >= batch_size):
+ chunk[k] = np.matrix(chunk[k])
+ chunk_filled = True
+ if chunk_filled:
+ yield chunk
+ chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(x_first_el, dict) else []
+ chunk_filled = False
+ else:
+ chunk.append(data)
+ if (batch_size is not None) and (len(chunk) >= batch_size):
+ yield np.matrix(chunk)
+ chunk = []
+
+ if isinstance(x_first_el, dict):
+ for k, v in list(data.items()):
+ chunk[k] = np.matrix(chunk[k])
+ yield chunk
+ else:
+ yield np.matrix(chunk)
def setup_predict_data_feeder(x, batch_size=None):
"""Returns an iterable for feeding into predict step.
Args:
- x: numpy, pandas, Dask array or iterable.
- batch_size: Size of batches to split data into.
- If `None`, returns one batch of full size.
+ x: numpy, pandas, Dask array or dictionary of aforementioned. Also supports
+ iterable.
+ batch_size: Size of batches to split data into. If `None`, returns one
+ batch of full size.
Returns:
- List or iterator of parts of data to predict on.
+ List or iterator (or dictionary thereof) of parts of data to predict on.
Raises:
ValueError: if `batch_size` <= 0.
@@ -211,7 +259,7 @@ def _access(data, iloc):
def _check_dtype(dtype):
if dtypes.as_dtype(dtype) == dtypes.float64:
logging.warn(
- 'float64 is not supported by many models, consider casting to float32.')
+ 'float64 is not supported by many models, consider casting to float32.')
return dtype
@@ -219,63 +267,85 @@ class DataFeeder(object):
"""Data feeder is an example class to sample data for TF trainer."""
def __init__(
- self, x, y, n_classes, batch_size=None, shuffle=True, random_state=None,
- epochs=None):
+ self, x, y, n_classes, batch_size=None, shuffle=True, random_state=None,
+ epochs=None):
"""Initializes a DataFeeder instance.
Args:
- x: Feature Nd numpy matrix of shape `[n_samples, n_features, ...]`.
- y: Label vector, either floats for regression or class id for
- classification. If matrix, will consider as a sequence
- of labels. Can be `None` for unsupervised setting.
+ x: One feature sample which can either Nd numpy matrix of shape
+ `[n_samples, n_features, ...]` or dictionary of Nd numpy matrix.
+ y: label vector, either floats for regression or class id for
+ classification. If matrix, will consider as a sequence of labels.
+ Can be `None` for unsupervised setting. Also supports dictionary of
+ labels.
n_classes: Number of classes, 0 and 1 are considered regression, `None`
- will pass through the input labels without one-hot conversion.
- batch_size: Mini-batch size to accumulate.
+ will pass through the input labels without one-hot conversion. Also, if
+ `y` is `dict`, then `n_classes` must be `dict` such that
+ `n_classes[key] = n_classes for label y[key]`, `None` otherwise.
+ batch_size: Mini-batch size to accumulate samples in one mini batch.
shuffle: Whether to shuffle `x`.
random_state: Numpy `RandomState` object to reproduce sampling.
epochs: Number of times to iterate over input data before raising
`StopIteration` exception.
Attributes:
- x: Input features.
- y: Input label.
+ x: Input features (ndarray or dictionary of ndarrays).
+ y: Input label (ndarray or dictionary of ndarrays).
n_classes: Number of classes (if `None`, pass through indices without
one-hot conversion).
batch_size: Mini-batch size to accumulate.
- input_shape: Shape of the input.
- output_shape: Shape of the output.
- input_dtype: DType of input.
- output_dtype: DType of output.
+ input_shape: Shape of the input (or dictionary of shapes).
+ output_shape: Shape of the output (or dictionary of shapes).
+ input_dtype: DType of input (or dictionary of shapes).
+ output_dtype: DType of output (or dictionary of shapes.
"""
- self._x = check_array(x, dtype=x.dtype)
- # self.n_classes is None means we're passing in raw label indices.
- y_dtype = (
- np.int64 if n_classes is not None and n_classes > 1 else np.float32)
+ x_is_dict, y_is_dict = isinstance(x, dict), y is not None and isinstance(y, dict)
+ if isinstance(y, list):
+ y = np.array(y)
+
+ self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())]) if x_is_dict else check_array(x, x.dtype)
+ self._y = None if y is None else \
+ dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype)
+
+ # self.n_classes is not None means we're converting raw target indices to one-hot.
if n_classes is not None:
- self._y = (None if y is None else check_array(y, dtype=y_dtype))
- elif isinstance(y, list):
- self._y = np.array(y)
- else:
- self._y = y
+ if not y_is_dict:
+ y_dtype = (np.int64 if n_classes is not None and n_classes > 1 else np.float32)
+ self._y = (None if y is None else check_array(y, dtype=y_dtype))
+
self.n_classes = n_classes
self.max_epochs = epochs
+
+ x_shape = dict([(k, v.shape) for k, v in list(self._x.items())]) if x_is_dict else self._x.shape
+ y_shape = dict(
+ [(k, v.shape) for k, v in list(self._y.items())]) if y_is_dict else None if y is None else self._y.shape
+
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
- self._x.shape, None if self._y is None else self._y.shape, n_classes,
- batch_size)
+ x_shape, y_shape, n_classes, batch_size)
+
# Input dtype matches dtype of x.
- self._input_dtype = _check_dtype(self._x.dtype)
- # self.n_classes is None means we're passing in raw label indices
- if n_classes is not None or self._y is None:
- self._output_dtype = np.float32
- else:
- self._output_dtype = _check_dtype(self._y.dtype)
+ self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) if x_is_dict \
+ else _check_dtype(self._x.dtype)
+
+ # note: self._output_dtype = np.float32 when y is None
+ self._output_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) if y_is_dict \
+ else _check_dtype(self._y.dtype) if y is not None else np.float32
+
+ # self.n_classes is None means we're passing in raw target indices
+ if n_classes is not None and y_is_dict:
+ for key in list(n_classes.keys()):
+ if key in self._output_dtype:
+ self._output_dtype[key] = np.float32
+
self._shuffle = shuffle
self.random_state = np.random.RandomState(
- 42) if random_state is None else random_state
+ 42) if random_state is None else random_state
+
+ num_samples = list(self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0]
if self._shuffle:
- self.indices = self.random_state.permutation(self._x.shape[0])
+ self.indices = self.random_state.permutation(num_samples)
else:
- self.indices = np.array(range(self._x.shape[0]))
+ self.indices = np.array(range(num_samples))
self.offset = 0
self.epoch = 0
self._epoch_placeholder = None
@@ -320,19 +390,27 @@ class DataFeeder(object):
Returns:
Two placeholders for inputs and outputs.
"""
- input_shape = [None] + self.input_shape[1:]
- self._input_placeholder = array_ops.placeholder(
- dtypes.as_dtype(self._input_dtype),
- input_shape,
- name='input')
- if self.output_shape is None:
- self._output_placeholder = None
- else:
- output_shape = [None] + self.output_shape[1:]
- self._output_placeholder = array_ops.placeholder(
- dtypes.as_dtype(self._output_dtype),
- output_shape,
- name='output')
+
+ def get_placeholder(shape, dtype, name_prepend):
+ if shape is None:
+ return None
+ if isinstance(shape, dict):
+ placeholder = {}
+ for key in list(shape.keys()):
+ placeholder[key] = array_ops.placeholder(
+ dtypes.as_dtype(dtype[key]),
+ [None] + shape[key][1:],
+ name=name_prepend + '_' + key
+ )
+ else:
+ placeholder = array_ops.placeholder(
+ dtypes.as_dtype(dtype),
+ [None] + shape[1:],
+ name=name_prepend)
+ return placeholder
+
+ self._input_placeholder = get_placeholder(self.input_shape, self._input_dtype, 'input')
+ self._output_placeholder = get_placeholder(self.output_shape, self._output_dtype, 'output')
return self._input_placeholder, self._output_placeholder
def set_placeholders(self, input_placeholder, output_placeholder):
@@ -342,21 +420,21 @@ class DataFeeder(object):
input_placeholder: Placeholder for `x` variable. Should match shape
of the examples in the x dataset.
output_placeholder: Placeholder for `y` variable. Should match
- shape of the examples in the y dataset. Can be None.
+ shape of the examples in the y dataset. Can be `None`.
"""
self._input_placeholder = input_placeholder
self._output_placeholder = output_placeholder
def get_feed_params(self):
- """Function returns a dict with data feed params while training.
+ """Function returns a `dict` with data feed params while training.
Returns:
- A dict with data feed params while training.
+ A `dict` with data feed params while training.
"""
return {
- 'epoch': self.epoch,
- 'offset': self.offset,
- 'batch_size': self._batch_size
+ 'epoch': self.epoch,
+ 'offset': self.offset,
+ 'batch_size': self._batch_size
}
def get_feed_dict_fn(self):
@@ -364,8 +442,35 @@ class DataFeeder(object):
Returns:
A function that when called samples a random subset of batch size
- from x and y.
+ from `x` and `y`.
"""
+ x_is_dict, y_is_dict = isinstance(self._x, dict), self._y is not None and isinstance(self._y, dict)
+
+ # Assign input features from random indices.
+ def extract(data, indices):
+ return (np.array(_access(data, indices)).reshape((indices.shape[0], 1))
+ if len(data.shape) == 1 else _access(data, indices))
+
+ # assign labels from random indices
+ def assign_label(data, shape, dtype, n_classes, indices):
+ shape[0] = indices.shape[0]
+ out = np.zeros(shape, dtype=dtype)
+ for i in xrange(out.shape[0]):
+ sample = indices[i]
+ # self.n_classes is None means we're passing in raw target indices
+ if n_classes is None:
+ out[i] = _access(data, sample)
+ else:
+ if n_classes > 1:
+ if len(shape) == 2:
+ out.itemset((i, int(_access(data, sample))), 1.0)
+ else:
+ for idx, value in enumerate(_access(data, sample)):
+ out.itemset(tuple([i, idx, value]), 1.0)
+ else:
+ out[i] = _access(data, sample)
+ return out
+
def _feed_dict_fn():
"""Function that samples data into given placeholders."""
if self.max_epochs is not None and self.epoch + 1 > self.max_epochs:
@@ -376,20 +481,19 @@ class DataFeeder(object):
feed_dict[self._epoch_placeholder.name] = [self.epoch]
# Take next batch of indices.
- end = min(self._x.shape[0], self.offset + self._batch_size)
+ x_len = list(self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0]
+ end = min(x_len, self.offset + self._batch_size)
batch_indices = self.indices[self.offset:end]
- # Assign input features from random indices.
- inp = (
- np.array(_access(self._x, batch_indices)).reshape(
- (batch_indices.shape[0], 1))
- if len(self._x.shape) == 1 else _access(self._x, batch_indices))
- feed_dict[self._input_placeholder.name] = inp
+ # adding input placeholder
+ feed_dict.update(
+ dict([(self._input_placeholder[k].name, extract(v, batch_indices)) for k, v in list(self._x.items())])
+ if x_is_dict else {self._input_placeholder.name: extract(self._x, batch_indices)})
# move offset and reset it if necessary
self.offset += self._batch_size
- if self.offset >= self._x.shape[0]:
- self.indices = self.random_state.permutation(self._x.shape[0])
+ if self.offset >= x_len:
+ self.indices = self.random_state.permutation(x_len) if self._shuffle else np.array(range(x_len))
self.offset = 0
self.epoch += 1
@@ -397,24 +501,18 @@ class DataFeeder(object):
if self._output_placeholder is None:
return feed_dict
- # assign labels from random indices
- self.output_shape[0] = batch_indices.shape[0]
- out = np.zeros(self.output_shape, dtype=self._output_dtype)
- for i in xrange(out.shape[0]):
- sample = batch_indices[i]
- # self.n_classes is None means we're passing in raw label indices
- if self.n_classes is None:
- out[i] = _access(self._y, sample)
- else:
- if self.n_classes > 1:
- if len(self.output_shape) == 2:
- out.itemset((i, int(_access(self._y, sample))), 1.0)
- else:
- for idx, value in enumerate(_access(self._y, sample)):
- out.itemset(tuple([i, idx, value]), 1.0)
- else:
- out[i] = _access(self._y, sample)
- feed_dict[self._output_placeholder.name] = out
+ # adding output placeholders
+ if y_is_dict:
+ for k, v in list(self._y.items()):
+ n_classes = (
+ self.n_classes[k] if k in self.n_classes else None) if self.n_classes is not None else None
+ shape, dtype = self.output_shape[k], self._output_dtype[k]
+ feed_dict.update(
+ {self._output_placeholder[k].name: assign_label(v, shape, dtype, n_classes, batch_indices)})
+ else:
+ shape, dtype, n_classes = self.output_shape, self._output_dtype, self.n_classes
+ feed_dict.update(
+ {self._output_placeholder.name: assign_label(self._y, shape, dtype, n_classes, batch_indices)})
return feed_dict
@@ -433,21 +531,29 @@ class StreamingDataFeeder(DataFeeder):
"""Initializes a StreamingDataFeeder instance.
Args:
- x: iterator that returns for each element, returns features.
- y: iterator that returns for each element, returns 1 or many classes /
- regression values.
- n_classes: indicator of how many classes the label has.
- batch_size: Mini batch size to accumulate.
+ x: iterator each element of which returns one feature sample. Sample can
+ be a Nd numpy matrix or dictionary of Nd numpy matrices.
+ y: iterator each element of which returns one label sample. Sample can be
+ a Nd numpy matrix or dictionary of Nd numpy matrices with 1 or many
+ classes regression values.
+ n_classes: indicator of how many classes the corresponding label sample
+ has for the purposes of one-hot conversion of label. In case where `y`
+ is a dictionary, `n_classes` must be dictionary (with same keys as `y`)
+ of how many classes there are in each label in `y`. If key is
+ present in `y` and missing in `n_classes`, the value is assumed `None`
+ and no one-hot conversion will be applied to the label with that key.
+ batch_size: Mini batch size to accumulate samples in one batch. If set
+ `None`, then assumes that iterator to return already batched element.
Attributes:
- x: input features.
- y: input label.
+ x: input features (or dictionary of input features).
+ y: input label (or dictionary of output features).
n_classes: number of classes.
batch_size: mini batch size to accumulate.
- input_shape: shape of the input.
- output_shape: shape of the output.
- input_dtype: dtype of input.
- output_dtype: dtype of output.
+ input_shape: shape of the input (can be dictionary depending on `x`).
+ output_shape: shape of the output (can be dictionary depending on `y`).
+ input_dtype: dtype of input (can be dictionary depending on `x`).
+ output_dtype: dtype of output (can be dictionary depending on `y`).
"""
# pylint: disable=invalid-name,super-init-not-called
x_first_el = six.next(x)
@@ -459,25 +565,48 @@ class StreamingDataFeeder(DataFeeder):
y_first_el = None
self._y = None
self.n_classes = n_classes
- x_first_el = ops.convert_to_tensor(x_first_el)
- y_first_el = ops.convert_to_tensor(y_first_el) if y is not None else None
- self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
- [1] + list(x_first_el.get_shape()),
- [1] + list(y_first_el.get_shape()) if y is not None else None,
- n_classes,
- batch_size)
- self._input_dtype = _check_dtype(x_first_el.dtype).as_numpy_dtype
+
+ x_is_dict, y_is_dict = isinstance(x_first_el, dict), y is not None and isinstance(y_first_el, dict)
+ if y_is_dict and n_classes is not None:
+ assert (isinstance(n_classes, dict))
+
+ # extract shapes for first_elements
+ x_first_el_shape = dict([(k, [1] + list(v.shape)) for k, v in list(x_first_el.items())]) if x_is_dict \
+ else [1] + list(x_first_el.shape)
+
+ y_first_el_shape = dict([(k, [1] + list(v.shape)) for k, v in list(y_first_el.items())]) if y_is_dict \
+ else ([1] + list(y_first_el[0].shape if isinstance(y_first_el, list) else y_first_el.shape)
+ if y is not None else None)
+
+ self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(x_first_el_shape, y_first_el_shape,
+ n_classes, batch_size)
+
+ # Input dtype of x_first_el.
+ self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(x_first_el.items())]) if x_is_dict \
+ else _check_dtype(x_first_el.dtype)
+
+ # Output dtype of y_first_el.
+ def check_y_dtype(el):
+ if isinstance(el, list) or isinstance(el, np.ndarray):
+ if isinstance(el, np.ndarray) and el.ndim == 0:
+ return el.dtype
+ else:
+ return _check_dtype(np.dtype(type(el[0])))
+ else:
+ return _check_dtype(np.dtype(type(el)))
+
# Output types are floats, due to both softmaxes and regression req.
- if n_classes is not None and n_classes > 0:
+ if n_classes is not None and (y is None or not y_is_dict) and n_classes > 0:
self._output_dtype = np.float32
- elif y is not None:
- self._output_dtype = _check_dtype(y_first_el.dtype).as_numpy_dtype
+ else:
+ self._output_dtype = dict([(k, check_y_dtype(v)) for k, v in list(y_first_el.items())]) if y_is_dict \
+ else (check_y_dtype(y_first_el) if y is not None else None)
def get_feed_params(self):
- """Function returns a dict with data feed params while training.
+ """Function returns a `dict` with data feed params while training.
Returns:
- A dict with data feed params while training.
+ A `dict` with data feed params while training.
"""
return {'batch_size': self._batch_size}
@@ -494,50 +623,76 @@ class StreamingDataFeeder(DataFeeder):
"""Samples data and provides it to placeholders.
Returns:
- Dict of input and output tensors.
+ `dict` of input and output tensors.
"""
+
+ def init_array(shape, dtype):
+ if shape is None:
+ return None
+ else:
+ return dict([(k, np.zeros(shape[k], dtype[k])) for k in list(shape.keys())]) if isinstance(shape, dict) else \
+ np.zeros(shape, dtype=dtype)
+
+ def put_data_array(dest, index, source=None, n_classes=None):
+ if source is None:
+ dest = dest[:index, :]
+ elif n_classes is not None and n_classes > 1:
+ if len(self.output_shape) == 2:
+ dest.itemset((index, source), 1.0)
+ else:
+ for idx, value in enumerate(source):
+ dest.itemset(tuple([index, idx, value]), 1.0)
+ else:
+ if len(dest.shape) > 1:
+ dest[index, :] = source
+ else:
+ dest[index] = source[0] if isinstance(source, list) else source
+ return dest
+
+ def put_data_array_or_dict(holder, index, data=None, n_classes=None):
+ if holder is None:
+ return None
+ if isinstance(holder, dict):
+ assert (isinstance(data, dict))
+ for k, v in list(holder.items()):
+ num_classes = n_classes[k] if (n_classes is not None and k in n_classes) else None
+ holder[k] = put_data_array(holder[k], index, data[k], num_classes)
+ else:
+ holder = put_data_array(holder, index, data, n_classes)
+ return holder
+
if self.stopped:
raise StopIteration
- try:
- inp = np.zeros(self.input_shape, dtype=self._input_dtype)
- except TypeError as exc:
- raise TypeError('Unrecognized dtype: {}. {}'.format(
- self._input_dtype, exc))
- if self._y is not None:
- out = np.zeros(self.output_shape, dtype=self._output_dtype)
+
+ inp = init_array(self.input_shape, self._input_dtype)
+ out = init_array(self.output_shape, self._output_dtype)
+
for i in xrange(self._batch_size):
# Add handling when queue ends.
try:
- inp[i, :] = six.next(self._x)
+ next_inp = six.next(self._x)
+ inp = put_data_array_or_dict(inp, i, next_inp, None)
except StopIteration:
self.stopped = True
if i == 0:
raise
- inp = inp[:i, :]
- if self._y is not None:
- out = out[:i]
+ inp = put_data_array_or_dict(inp, i, None, None)
+ out = put_data_array_or_dict(out, i, None, None)
break
if self._y is not None:
- y = six.next(self._y)
- if self.n_classes is not None and self.n_classes > 1:
- if len(self.output_shape) == 2:
- out.itemset((i, y), 1.0)
- else:
- for idx, value in enumerate(y):
- out.itemset(tuple([i, idx, value]), 1.0)
- else:
- # The y itertor can sometimes return scalars or singleton lists.
- try:
- out[i] = y
- except ValueError as _:
- assert len(y) == 1, ('Expected singleton label, got {}'
- .format(repr(y)))
- out[i] = y[0]
- if self._y is None:
- return {self._input_placeholder.name: inp}
- return {self._input_placeholder.name: inp,
- self._output_placeholder.name: out}
+ next_out = six.next(self._y)
+ out = put_data_array_or_dict(out, i, next_out, self.n_classes)
+
+ # creating feed_dict
+ feed_dict = dict([(self._input_placeholder[k].name, inp[k]) for k in list(self._input_placeholder.keys())]) if \
+ isinstance(inp, dict) else {self._input_placeholder.name: inp}
+ if self._y is not None:
+ feed_dict.update(
+ dict([(self._output_placeholder[k].name, out[k]) for k in list(self._output_placeholder.keys())]) \
+ if isinstance(out, dict) else {self._output_placeholder.name: out})
+
+ return feed_dict
return _feed_dict_fn
@@ -575,6 +730,10 @@ class DaskDataFeeder(object):
input_dtype: dtype of input.
output_dtype: dtype of output.
"""
+
+ if isinstance(x, dict) or isinstance(y, dict):
+ raise ValueError("DaskDataFeeder does not support dictionaries at the moment.")
+
# pylint: disable=invalid-name,super-init-not-called
import dask.dataframe as dd # pylint: disable=g-import-not-at-top
# TODO(terrytangyuan): check x and y dtypes in dask_io like pandas
@@ -601,7 +760,7 @@ class DaskDataFeeder(object):
self._shuffle = shuffle
self.epochs = epochs
self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
- x_shape, y_shape, n_classes, batch_size)
+ x_shape, y_shape, n_classes, batch_size)
self.sample_fraction = self._batch_size / float(x_count)
self._input_dtype = _check_dtype(self._x.dtypes[0])
self._output_dtype = _check_dtype(self._y.dtypes[self._y_columns])
@@ -611,10 +770,10 @@ class DaskDataFeeder(object):
self.random_state = random_state
def get_feed_params(self):
- """Function returns a dict with data feed params while training.
+ """Function returns a `dict` with data feed params while training.
Returns:
- A dict with data feed params while training.
+ A `dict` with data feed params while training.
"""
return {'batch_size': self._batch_size}
@@ -629,13 +788,14 @@ class DaskDataFeeder(object):
A function that when called samples a random subset of batch size
from x and y.
"""
+
def _feed_dict_fn():
"""Samples data and provides it to placeholders."""
# TODO(ipolosukhin): option for with/without replacement (dev version of
# dask)
sample = self.df.random_split(
- [self.sample_fraction, 1 - self.sample_fraction],
- random_state=self.random_state)
+ [self.sample_fraction, 1 - self.sample_fraction],
+ random_state=self.random_state)
inp = extract_pandas_matrix(sample[0][self._x_columns].compute()).tolist()
out = extract_pandas_matrix(sample[0][self._y_columns].compute())
# convert to correct dtype
@@ -650,4 +810,5 @@ class DaskDataFeeder(object):
encoded_out[np.arange(out.size), out] = 1
return {input_placeholder.name: inp,
output_placeholder.name: encoded_out}
+
return _feed_dict_fn