diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py | 12 |
1 files changed, 3 insertions, 9 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index db18ebf05d..4c50d40aaa 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -28,14 +28,13 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels + # pylint: enable=g-multiple-import,g-bad-import-order @@ -366,13 +365,8 @@ class DataFeeder(object): self.random_state = np.random.RandomState( 42) if random_state is None else random_state - if x_is_dict: - num_samples = list(self._x.values())[0].shape[0] - elif tensor_util.is_tensor(self._x): - num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int - else: - num_samples = self._x.shape[0] - + num_samples = list(self._x.values())[0].shape[ + 0] if x_is_dict else self._x.shape[0] if self._shuffle: self.indices = self.random_state.permutation(num_samples) else: |