aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-12 11:23:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-12 12:31:24 -0700
commitaa695d3ce8bc54c5d6b10d8a32a55f78337d145c (patch)
tree2988bc8206eb30f390a1c9a2ec27e800b63496cd
parent0558011a44291fd717036be1ea230d1e872d88e0 (diff)
Provide epoch with the shape it's placeholder variable declared it as:
i.e., as a shape (1) tensor rather than as a scalar. Change: 122184518
-rw-r--r--tensorflow/contrib/learn/python/learn/io/data_feeder.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py8
2 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
index e1c0c726c1..6768f9091b 100644
--- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
@@ -284,7 +284,7 @@ class DataFeeder(object):
assert self._input_placeholder != None
feed_dict = {}
if self._epoch_placeholder is not None:
- feed_dict[self._epoch_placeholder.name] = self.epoch
+ feed_dict[self._epoch_placeholder.name] = [self.epoch]
# take random indices
if self.batch_size < 0:
diff --git a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py
index 300d5f9309..ea361a0d68 100644
--- a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py
@@ -56,16 +56,16 @@ class DataFeederTest(tf.test.TestCase):
feed_dict_fn = feeder.get_feed_dict_fn()
# First input
feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[epoch.name], 0)
+ self.assertAllClose(feed_dict[epoch.name], [0])
# Second input
feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[epoch.name], 0)
+ self.assertAllClose(feed_dict[epoch.name], [0])
# Third input
feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[epoch.name], 0)
+ self.assertAllClose(feed_dict[epoch.name], [0])
# Back to the first input again, so new epoch.
feed_dict = feed_dict_fn()
- self.assertAllClose(feed_dict[epoch.name], 1)
+ self.assertAllClose(feed_dict[epoch.name], [1])
def test_data_feeder_multioutput_regression(self):
X = np.matrix([[1, 2], [3, 4]])