diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 4c50d40aaa..db18ebf05d 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -28,13 +28,14 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels - # pylint: enable=g-multiple-import,g-bad-import-order @@ -365,8 +366,13 @@ class DataFeeder(object): self.random_state = np.random.RandomState( 42) if random_state is None else random_state - num_samples = list(self._x.values())[0].shape[ - 0] if x_is_dict else self._x.shape[0] + if x_is_dict: + num_samples = list(self._x.values())[0].shape[0] + elif tensor_util.is_tensor(self._x): + num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int + else: + num_samples = self._x.shape[0] + if self._shuffle: self.indices = self.random_state.permutation(num_samples) else: |