aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-18 18:28:09 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-18 18:35:41 +0800
commit8fbafe6c7e75e1d931eca7202ea3a4c5ac8fc2dd (patch)
treeb86bc57ad76ce970eb8bca92c218031778d7ee84 /tensorflow/contrib/data
parent6e0f1120fd7a6df805a8b712d2d4a38042576b46 (diff)
CLN: fix code style
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py20
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py12
2 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index ebc5160408..9d8e955245 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -759,8 +759,8 @@ class RestructuredDatasetTest(test.TestCase):
def create_unknown_shape_dataset(x):
return script_ops.py_func(
lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
[x],
[dtypes.float32, dtypes.int32])
@@ -789,8 +789,8 @@ class RestructuredDatasetTest(test.TestCase):
def create_unknown_shape_dataset(x):
return script_ops.py_func(
lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
[x],
[dtypes.float32, dtypes.int32])
@@ -802,7 +802,7 @@ class RestructuredDatasetTest(test.TestCase):
wrong_shapes = (tensor_shape.TensorShape(2),
tensor_shape.TensorShape((3, 10)))
iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -854,8 +854,8 @@ class RestructuredDatasetTest(test.TestCase):
def create_unknown_shape_dataset(x):
return script_ops.py_func(
lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
[x],
[dtypes.float32, dtypes.int32])
@@ -884,8 +884,8 @@ class RestructuredDatasetTest(test.TestCase):
def create_unknown_shape_dataset(x):
return script_ops.py_func(
lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
[x],
[dtypes.float32, dtypes.int32])
@@ -897,7 +897,7 @@ class RestructuredDatasetTest(test.TestCase):
wrong_shapes = (tensor_shape.TensorShape(2),
tensor_shape.TensorShape((None, 10)))
iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 3cad83fcb1..9c2001c34f 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -666,13 +666,13 @@ def assert_element_shape(expected_shapes):
`tf.data.Dataset.apply`
"""
- def _merge_output_shape(original_shapes, expected_shapes):
+ def _merge_output_shapes(original_shapes, expected_shapes):
flat_original_shapes = nest.flatten(original_shapes)
flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes)
flat_merged_output_shapes = [
- original_shape.merge_with(new_shape)
- for original_shape, new_shape in zip(flat_original_shapes,
- flat_new_shapes)]
+ original_shape.merge_with(new_shape)
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes)]
return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes)
def _check_shape(*elements):
@@ -685,8 +685,8 @@ def assert_element_shape(expected_shapes):
return nest.pack_sequence_as(elements, checked_tensors)
def _apply_fn(dataset):
- output_shapes = _merge_output_shape(dataset.output_shapes,
- expected_shapes)
+ output_shapes = _merge_output_shapes(dataset.output_shapes,
+ expected_shapes)
return _RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,