aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/datasets.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/datasets.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/datasets.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py
index 2e472a2805..d879170b68 100644
--- a/tensorflow/contrib/tpu/python/tpu/datasets.py
+++ b/tensorflow/contrib/tpu/python/tpu/datasets.py
@@ -166,11 +166,21 @@ def StreamingFilesDataset(files,
return remote_iterator.get_next()
def MapFn(unused_input):
- return functional_ops.remote_call(
+ if isinstance(source_dataset.output_types, dtypes.DType):
+ output_types = [source_dataset.output_types]
+ elif isinstance(source_dataset.output_types, (list, tuple)):
+ output_types = source_dataset.output_types
+ else:
+ raise ValueError('source dataset has invalid output types')
+ remote_calls = functional_ops.remote_call(
args=[source_handle],
- Tout=[dtypes.string],
+ Tout=output_types,
f=LoadingFunc,
- target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)[0]
+ target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
+ if len(remote_calls) == 1:
+ return remote_calls[0]
+ else:
+ return remote_calls
with ops.device('/job:%s' % worker_job):
output_dataset = dataset_ops.Dataset.range(2).repeat().map(