From aa695d3ce8bc54c5d6b10d8a32a55f78337d145c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 12 May 2016 11:23:22 -0800 Subject: Provide epoch with the shape it's placeholder variable declared it as: i.e., as a shape (1) tensor rather than as a scalar. Change: 122184518 --- tensorflow/contrib/learn/python/learn/io/data_feeder.py | 2 +- tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py index e1c0c726c1..6768f9091b 100644 --- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py @@ -284,7 +284,7 @@ class DataFeeder(object): assert self._input_placeholder != None feed_dict = {} if self._epoch_placeholder is not None: - feed_dict[self._epoch_placeholder.name] = self.epoch + feed_dict[self._epoch_placeholder.name] = [self.epoch] # take random indices if self.batch_size < 0: diff --git a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py index 300d5f9309..ea361a0d68 100644 --- a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py @@ -56,16 +56,16 @@ class DataFeederTest(tf.test.TestCase): feed_dict_fn = feeder.get_feed_dict_fn() # First input feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 0) + self.assertAllClose(feed_dict[epoch.name], [0]) # Second input feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 0) + self.assertAllClose(feed_dict[epoch.name], [0]) # Third input feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 0) + self.assertAllClose(feed_dict[epoch.name], [0]) # Back to the first input again, so new epoch. feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 1) + self.assertAllClose(feed_dict[epoch.name], [1]) def test_data_feeder_multioutput_regression(self): X = np.matrix([[1, 2], [3, 4]]) -- cgit v1.2.3