diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 48d79ecbbf..4c50d40aaa 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -28,7 +28,6 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -44,7 +43,7 @@ def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None): x_is_dict, y_is_dict = isinstance( x_shape, dict), y_shape is not None and isinstance(y_shape, dict) if y_is_dict and n_classes is not None: - assert (isinstance(n_classes, dict)) + assert isinstance(n_classes, dict) if batch_size is None: batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0] @@ -322,10 +321,12 @@ class DataFeeder(object): self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items()) ]) if x_is_dict else check_array(x, x.dtype) - self._y = None if y is None else \ - dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype) + self._y = None if y is None else ( + dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) + if y_is_dict else check_array(y, y.dtype)) - # self.n_classes is not None means we're converting raw target indices to one-hot. + # self.n_classes is not None means we're converting raw target indices + # to one-hot. if n_classes is not None: if not y_is_dict: y_dtype = (np.int64 @@ -344,12 +345,15 @@ class DataFeeder(object): x_shape, y_shape, n_classes, batch_size) # Input dtype matches dtype of x. - self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) if x_is_dict \ - else _check_dtype(self._x.dtype) - - # note: self._output_dtype = np.float32 when y is None - self._output_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) if y_is_dict \ - else _check_dtype(self._y.dtype) if y is not None else np.float32 + self._input_dtype = ( + dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) + if x_is_dict else _check_dtype(self._x.dtype)) + + # self._output_dtype == np.float32 when y is None + self._output_dtype = ( + dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) + if y_is_dict else ( + _check_dtype(self._y.dtype) if y is not None else np.float32)) # self.n_classes is None means we're passing in raw target indices if n_classes is not None and y_is_dict: |