aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2017-07-07 15:19:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-07 15:24:16 -0700
commite9bea40511b1fe5d2e1e7761f640943d0e17a7df (patch)
treedf7d28b102f3a0fd2c20d2fcf2325d18f4681218 /tensorflow/contrib/keras
parent204c367ab1a38cee71dac2a64164e96abaeffbf2 (diff)
Refactor Keras Sequence utility.
Correct Keras version number. PiperOrigin-RevId: 161252947
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r--tensorflow/contrib/keras/python/keras/__init__.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/data_utils.py33
2 files changed, 2 insertions, 33 deletions
diff --git a/tensorflow/contrib/keras/python/keras/__init__.py b/tensorflow/contrib/keras/python/keras/__init__.py
index 6e0e03d7f7..19380bc8c5 100644
--- a/tensorflow/contrib/keras/python/keras/__init__.py
+++ b/tensorflow/contrib/keras/python/keras/__init__.py
@@ -37,4 +37,4 @@ from tensorflow.contrib.keras.python.keras import utils
from tensorflow.contrib.keras.python.keras import wrappers
from tensorflow.contrib.keras.python.keras.layers import Input
-__version__ = '2.0.5-tf'
+__version__ = '2.0.6-tf'
diff --git a/tensorflow/contrib/keras/python/keras/utils/data_utils.py b/tensorflow/contrib/keras/python/keras/utils/data_utils.py
index 9aa477d522..853625e7c4 100644
--- a/tensorflow/contrib/keras/python/keras/utils/data_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/data_utils.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from abc import abstractmethod
import hashlib
import multiprocessing
-import multiprocessing.managers
from multiprocessing.pool import ThreadPool
import os
import random
@@ -315,34 +314,6 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
return False
-class HolderManager(multiprocessing.managers.BaseManager):
- """Custom manager to share a Holder object."""
- pass
-
-
-class Holder(object):
- """Object to encapsulate a Sequence.
-
- This allows the Sequence to be shared across multiple workers.
-
- Arguments:
- seq: Sequence object to be shared.
- """
-
- def __init__(self, seq):
- self.seq = seq
-
- def __getitem__(self, idx):
- return self.seq[idx]
-
- def __len__(self):
- return len(self.seq)
-
-
-# Register the Holder class using the ListProxy (allows __len__ and __getitem__)
-HolderManager.register('Holder', Holder, multiprocessing.managers.ListProxy)
-
-
class Sequence(object):
"""Base object for fitting to a sequence of data, such as a dataset.
@@ -488,9 +459,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
sequence,
use_multiprocessing=False,
scheduling='sequential'):
- self.manager = HolderManager()
- self.manager.start()
- self.sequence = self.manager.Holder(sequence)
+ self.sequence = sequence
self.use_multiprocessing = use_multiprocessing
self.scheduling = scheduling
self.workers = 0