diff options
-rw-r--r-- | tensorflow/contrib/learn/python/learn/io/data_feeder.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py | 8 |
2 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py index e1c0c726c1..6768f9091b 100644 --- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py @@ -284,7 +284,7 @@ class DataFeeder(object): assert self._input_placeholder != None feed_dict = {} if self._epoch_placeholder is not None: - feed_dict[self._epoch_placeholder.name] = self.epoch + feed_dict[self._epoch_placeholder.name] = [self.epoch] # take random indices if self.batch_size < 0: diff --git a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py index 300d5f9309..ea361a0d68 100644 --- a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py @@ -56,16 +56,16 @@ class DataFeederTest(tf.test.TestCase): feed_dict_fn = feeder.get_feed_dict_fn() # First input feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 0) + self.assertAllClose(feed_dict[epoch.name], [0]) # Second input feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 0) + self.assertAllClose(feed_dict[epoch.name], [0]) # Third input feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 0) + self.assertAllClose(feed_dict[epoch.name], [0]) # Back to the first input again, so new epoch. feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 1) + self.assertAllClose(feed_dict[epoch.name], [1]) def test_data_feeder_multioutput_regression(self): X = np.matrix([[1, 2], [3, 4]]) |