diff options
author | 2017-07-07 15:19:40 -0700 | |
---|---|---|
committer | 2017-07-07 15:24:16 -0700 | |
commit | e9bea40511b1fe5d2e1e7761f640943d0e17a7df (patch) | |
tree | df7d28b102f3a0fd2c20d2fcf2325d18f4681218 /tensorflow/contrib/keras | |
parent | 204c367ab1a38cee71dac2a64164e96abaeffbf2 (diff) |
Refactor Keras Sequence utility.
Correct Keras version number.
PiperOrigin-RevId: 161252947
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r-- | tensorflow/contrib/keras/python/keras/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/keras/python/keras/utils/data_utils.py | 33 |
2 files changed, 2 insertions, 33 deletions
diff --git a/tensorflow/contrib/keras/python/keras/__init__.py b/tensorflow/contrib/keras/python/keras/__init__.py index 6e0e03d7f7..19380bc8c5 100644 --- a/tensorflow/contrib/keras/python/keras/__init__.py +++ b/tensorflow/contrib/keras/python/keras/__init__.py @@ -37,4 +37,4 @@ from tensorflow.contrib.keras.python.keras import utils from tensorflow.contrib.keras.python.keras import wrappers from tensorflow.contrib.keras.python.keras.layers import Input -__version__ = '2.0.5-tf' +__version__ = '2.0.6-tf' diff --git a/tensorflow/contrib/keras/python/keras/utils/data_utils.py b/tensorflow/contrib/keras/python/keras/utils/data_utils.py index 9aa477d522..853625e7c4 100644 --- a/tensorflow/contrib/keras/python/keras/utils/data_utils.py +++ b/tensorflow/contrib/keras/python/keras/utils/data_utils.py @@ -20,7 +20,6 @@ from __future__ import print_function from abc import abstractmethod import hashlib import multiprocessing -import multiprocessing.managers from multiprocessing.pool import ThreadPool import os import random @@ -315,34 +314,6 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): return False -class HolderManager(multiprocessing.managers.BaseManager): - """Custom manager to share a Holder object.""" - pass - - -class Holder(object): - """Object to encapsulate a Sequence. - - This allows the Sequence to be shared across multiple workers. - - Arguments: - seq: Sequence object to be shared. - """ - - def __init__(self, seq): - self.seq = seq - - def __getitem__(self, idx): - return self.seq[idx] - - def __len__(self): - return len(self.seq) - - -# Register the Holder class using the ListProxy (allows __len__ and __getitem__) -HolderManager.register('Holder', Holder, multiprocessing.managers.ListProxy) - - class Sequence(object): """Base object for fitting to a sequence of data, such as a dataset. @@ -488,9 +459,7 @@ class OrderedEnqueuer(SequenceEnqueuer): sequence, use_multiprocessing=False, scheduling='sequential'): - self.manager = HolderManager() - self.manager.start() - self.sequence = self.manager.Holder(sequence) + self.sequence = sequence self.use_multiprocessing = use_multiprocessing self.scheduling = scheduling self.workers = 0 |