aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-11 15:04:30 -0700
committerGravatar Akshay Modi <nareshmodi@google.com>2018-06-11 15:04:30 -0700
commit6d9c0ba224f5903375ae26f582ef233740477e29 (patch)
treebfcaef48a56c6d3900d1db5a248e9d9f155c48e1 /tensorflow/contrib/training
parenta4b390bffbcb01d8f57f25c007277d457f752a69 (diff)
parentab51450c817674c8ff08a7ae4f8ac50cdc4bed8b (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
index 409aba817c..a2444934bc 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
@@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
- # pylint: disable=protected-access
if padded_shapes is None:
self._padded_shapes = nest.map_structure(
- dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes)
+ convert.partial_shape_to_tensor, input_dataset.output_shapes)
else:
self._padded_shapes = nest.map_structure_up_to(
- input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor,
+ input_dataset.output_shapes, convert.partial_shape_to_tensor,
padded_shapes)
+ # pylint: disable=protected-access
padding_values = (
padding_values if padding_values is not None else
dataset_ops._default_padding(input_dataset))