diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-05-12 11:23:22 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-05-12 12:31:24 -0700 |
commit | aa695d3ce8bc54c5d6b10d8a32a55f78337d145c (patch) | |
tree | 2988bc8206eb30f390a1c9a2ec27e800b63496cd | |
parent | 0558011a44291fd717036be1ea230d1e872d88e0 (diff) |
Provide epoch with the shape it's placeholder variable declared it as:
i.e., as a shape (1) tensor rather than as a scalar.
Change: 122184518
-rw-r--r-- | tensorflow/contrib/learn/python/learn/io/data_feeder.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py | 8 |
2 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py index e1c0c726c1..6768f9091b 100644 --- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py @@ -284,7 +284,7 @@ class DataFeeder(object): assert self._input_placeholder != None feed_dict = {} if self._epoch_placeholder is not None: - feed_dict[self._epoch_placeholder.name] = self.epoch + feed_dict[self._epoch_placeholder.name] = [self.epoch] # take random indices if self.batch_size < 0: diff --git a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py index 300d5f9309..ea361a0d68 100644 --- a/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/tests/test_data_feeder.py @@ -56,16 +56,16 @@ class DataFeederTest(tf.test.TestCase): feed_dict_fn = feeder.get_feed_dict_fn() # First input feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 0) + self.assertAllClose(feed_dict[epoch.name], [0]) # Second input feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 0) + self.assertAllClose(feed_dict[epoch.name], [0]) # Third input feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 0) + self.assertAllClose(feed_dict[epoch.name], [0]) # Back to the first input again, so new epoch. feed_dict = feed_dict_fn() - self.assertAllClose(feed_dict[epoch.name], 1) + self.assertAllClose(feed_dict[epoch.name], [1]) def test_data_feeder_multioutput_regression(self): X = np.matrix([[1, 2], [3, 4]]) |