diff options
author | 2018-08-18 18:28:09 +0800 | |
---|---|---|
committer | 2018-08-18 18:35:41 +0800 | |
commit | 8fbafe6c7e75e1d931eca7202ea3a4c5ac8fc2dd (patch) | |
tree | b86bc57ad76ce970eb8bca92c218031778d7ee84 /tensorflow/contrib/data | |
parent | 6e0f1120fd7a6df805a8b712d2d4a38042576b46 (diff) |
CLN: fix code style
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py | 20 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/batching.py | 12 |
2 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index ebc5160408..9d8e955245 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -759,8 +759,8 @@ class RestructuredDatasetTest(test.TestCase): def create_unknown_shape_dataset(x): return script_ops.py_func( lambda _: ( # pylint: disable=g-long-lambda - np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), [x], [dtypes.float32, dtypes.int32]) @@ -789,8 +789,8 @@ class RestructuredDatasetTest(test.TestCase): def create_unknown_shape_dataset(x): return script_ops.py_func( lambda _: ( # pylint: disable=g-long-lambda - np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), [x], [dtypes.float32, dtypes.int32]) @@ -802,7 +802,7 @@ class RestructuredDatasetTest(test.TestCase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) + dataset.apply(batching.assert_element_shape(wrong_shapes)) .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -854,8 +854,8 @@ class RestructuredDatasetTest(test.TestCase): def create_unknown_shape_dataset(x): return script_ops.py_func( lambda _: ( # pylint: disable=g-long-lambda - np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), [x], [dtypes.float32, dtypes.int32]) @@ -884,8 +884,8 @@ class RestructuredDatasetTest(test.TestCase): def create_unknown_shape_dataset(x): return script_ops.py_func( lambda _: ( # pylint: disable=g-long-lambda - np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), [x], [dtypes.float32, dtypes.int32]) @@ -897,7 +897,7 @@ class RestructuredDatasetTest(test.TestCase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) + dataset.apply(batching.assert_element_shape(wrong_shapes)) .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 3cad83fcb1..9c2001c34f 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -666,13 +666,13 @@ def assert_element_shape(expected_shapes): `tf.data.Dataset.apply` """ - def _merge_output_shape(original_shapes, expected_shapes): + def _merge_output_shapes(original_shapes, expected_shapes): flat_original_shapes = nest.flatten(original_shapes) flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes) flat_merged_output_shapes = [ - original_shape.merge_with(new_shape) - for original_shape, new_shape in zip(flat_original_shapes, - flat_new_shapes)] + original_shape.merge_with(new_shape) + for original_shape, new_shape in zip(flat_original_shapes, + flat_new_shapes)] return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes) def _check_shape(*elements): @@ -685,8 +685,8 @@ def assert_element_shape(expected_shapes): return nest.pack_sequence_as(elements, checked_tensors) def _apply_fn(dataset): - output_shapes = _merge_output_shape(dataset.output_shapes, - expected_shapes) + output_shapes = _merge_output_shapes(dataset.output_shapes, + expected_shapes) return _RestructuredDataset( dataset.map(_check_shape), dataset.output_types, |