aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
index 4c50d40aaa..db18ebf05d 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
@@ -28,13 +28,14 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
# pylint: disable=g-multiple-import,g-bad-import-order
from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels
-
# pylint: enable=g-multiple-import,g-bad-import-order
@@ -365,8 +366,13 @@ class DataFeeder(object):
self.random_state = np.random.RandomState(
42) if random_state is None else random_state
- num_samples = list(self._x.values())[0].shape[
- 0] if x_is_dict else self._x.shape[0]
+ if x_is_dict:
+ num_samples = list(self._x.values())[0].shape[0]
+ elif tensor_util.is_tensor(self._x):
+ num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int
+ else:
+ num_samples = self._x.shape[0]
+
if self._shuffle:
self.indices = self.random_state.permutation(num_samples)
else: