aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py26
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
index 48d79ecbbf..4c50d40aaa 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
@@ -28,7 +28,6 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
@@ -44,7 +43,7 @@ def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
x_is_dict, y_is_dict = isinstance(
x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
if y_is_dict and n_classes is not None:
- assert (isinstance(n_classes, dict))
+ assert isinstance(n_classes, dict)
if batch_size is None:
batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0]
@@ -322,10 +321,12 @@ class DataFeeder(object):
self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())
]) if x_is_dict else check_array(x, x.dtype)
- self._y = None if y is None else \
- dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype)
+ self._y = None if y is None else (
+ dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())])
+ if y_is_dict else check_array(y, y.dtype))
- # self.n_classes is not None means we're converting raw target indices to one-hot.
+ # self.n_classes is not None means we're converting raw target indices
+ # to one-hot.
if n_classes is not None:
if not y_is_dict:
y_dtype = (np.int64
@@ -344,12 +345,15 @@ class DataFeeder(object):
x_shape, y_shape, n_classes, batch_size)
# Input dtype matches dtype of x.
- self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) if x_is_dict \
- else _check_dtype(self._x.dtype)
-
- # note: self._output_dtype = np.float32 when y is None
- self._output_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) if y_is_dict \
- else _check_dtype(self._y.dtype) if y is not None else np.float32
+ self._input_dtype = (
+ dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())])
+ if x_is_dict else _check_dtype(self._x.dtype))
+
+ # self._output_dtype == np.float32 when y is None
+ self._output_dtype = (
+ dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())])
+ if y_is_dict else (
+ _check_dtype(self._y.dtype) if y is not None else np.float32))
# self.n_classes is None means we're passing in raw target indices
if n_classes is not None and y_is_dict: