1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the private `_RestructuredDataset` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class RestructuredDatasetTest(test_base.DatasetTestBase):
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),
(array_ops.placeholder(dtypes.int32, shape=[None]),
array_ops.placeholder(dtypes.int32, shape=[20, 30])))
dataset = dataset_ops.Dataset.from_tensors(components)
i32 = dtypes.int32
test_cases = [((i32, i32, i32), None),
(((i32, i32), i32), None),
((i32, i32, i32), (None, None, None)),
((i32, i32, i32), ([17], [17], [20, 30]))]
for new_types, new_shape_lists in test_cases:
# pylint: disable=protected-access
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
# pylint: enable=protected-access
self.assertEqual(new_types, new.output_types)
if new_shape_lists is not None:
for expected_shape_list, shape in zip(
nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)):
if expected_shape_list is None:
self.assertIs(None, shape.ndims)
else:
self.assertEqual(expected_shape_list, shape.as_list())
fail_cases = [((i32, dtypes.int64, i32), None),
((i32, i32, i32, i32), None),
((i32, i32, i32), ((None, None), None)),
((i32, i32, i32), (None, None, None, None)),
((i32, i32, i32), (None, [None], [21, 30]))]
for new_types, new_shape_lists in fail_cases:
with self.assertRaises(ValueError):
# pylint: disable=protected-access
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
# pylint: enable=protected-access
if __name__ == "__main__":
test.main()
|