diff options
author | 2017-01-31 16:41:50 -0800 | |
---|---|---|
committer | 2017-01-31 17:06:15 -0800 | |
commit | 4062c3a74fa5069b44cb67a9f2005cfa90e54ca1 (patch) | |
tree | 827875679cdaa339e771f03fcc5fb3028876cb42 | |
parent | 8b60a8c131f535b098e72c9a05620e6d52703aa0 (diff) |
Automated rollback of change 145619741
Change: 146183030
-rw-r--r-- | tensorflow/core/kernels/matching_files_op.cc | 47 | ||||
-rw-r--r-- | tensorflow/core/ops/io_ops.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/ops/io_ops_test.cc | 3 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/io_ops_test.py | 19 | ||||
-rw-r--r-- | tensorflow/python/training/input.py | 4 |
5 files changed, 55 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc index a35b5889d3..5eb060f664 100644 --- a/tensorflow/core/kernels/matching_files_op.cc +++ b/tensorflow/core/kernels/matching_files_op.cc @@ -29,22 +29,37 @@ class MatchingFilesOp : public OpKernel { public: using OpKernel::OpKernel; void Compute(OpKernelContext* context) override { - const Tensor* pattern; - OP_REQUIRES_OK(context, context->input("pattern", &pattern)); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(pattern->shape()), - errors::InvalidArgument( - "Input pattern tensor must be scalar, but had shape: ", - pattern->shape().DebugString())); - std::vector<string> fnames; - OP_REQUIRES_OK(context, context->env()->GetMatchingPaths( - pattern->scalar<string>()(), &fnames)); - const int num_out = fnames.size(); - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - "filenames", TensorShape({num_out}), &output)); - auto output_vec = output->vec<string>(); - for (int i = 0; i < num_out; ++i) { - output_vec(i) = fnames[i]; + const Tensor* patterns_t; + // NOTE(ringwalt): Changing the input name "pattern" to "patterns" would + // break existing graphs. + OP_REQUIRES_OK(context, context->input("pattern", &patterns_t)); + OP_REQUIRES( + context, + TensorShapeUtils::IsScalar(patterns_t->shape()) || + TensorShapeUtils::IsVector(patterns_t->shape()), + errors::InvalidArgument( + "Input patterns tensor must be scalar or vector, but had shape: ", + patterns_t->shape().DebugString())); + const auto patterns = patterns_t->flat<string>(); + int num_patterns = patterns.size(); + int num_files = 0; + std::vector<std::vector<string>> all_fnames(num_patterns); + for (int i = 0; i < num_patterns; i++) { + OP_REQUIRES_OK( + context, + context->env()->GetMatchingPaths(patterns(i), &all_fnames[i])); + num_files += all_fnames[i].size(); + } + Tensor* output_t = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + "filenames", TensorShape({num_files}), &output_t)); + auto output = output_t->vec<string>(); + int index = 0; + for (int i = 0; i < num_patterns; ++i) { + for (int j = 0; j < all_fnames[i].size(); j++) { + output(index++) = all_fnames[i][j]; + } } } }; diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index 1412aeffc5..864f131caa 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -810,17 +810,17 @@ REGISTER_OP("MatchingFiles") .Output("filenames: string") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); return Status::OK(); }) .Doc(R"doc( -Returns the set of files matching a pattern. +Returns the set of files matching one or more glob patterns. Note that this routine only supports wildcard characters in the basename portion of the pattern, not in the directory portion. -pattern: A (scalar) shell wildcard pattern. +pattern: Shell wildcard pattern(s). Scalar or vector of type string. filenames: A vector of matching filenames. )doc"); diff --git a/tensorflow/core/ops/io_ops_test.cc b/tensorflow/core/ops/io_ops_test.cc index 99b091bcee..9d98cb9048 100644 --- a/tensorflow/core/ops/io_ops_test.cc +++ b/tensorflow/core/ops/io_ops_test.cc @@ -185,7 +185,8 @@ TEST(IoOpsTest, MatchingFiles_ShapeFn) { INFER_OK(op, "?", "[?]"); INFER_OK(op, "[]", "[?]"); - INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?]"); + INFER_OK(op, "[42]", "[?]"); + INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[?,?]"); } } // end namespace tensorflow diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py index 0e5ca21c48..472487ccfb 100644 --- a/tensorflow/python/kernel_tests/io_ops_test.py +++ b/tensorflow/python/kernel_tests/io_ops_test.py @@ -80,8 +80,8 @@ class IoOpsTest(test.TestCase): io_ops.matching_files(f.name).eval(), compat.as_bytes(f.name)) # We will look for files matching "ABxDEF.GH*" where "x" is some wildcard. - pos = files[0].name.find(cases[0]) - pattern = files[0].name[:pos] + 'AB%sDEF.GH*' + directory_path = files[0].name[:files[0].name.find(cases[0])] + pattern = directory_path + 'AB%sDEF.GH*' self.assertEqual( set(io_ops.matching_files(pattern % 'z').eval()), @@ -102,6 +102,21 @@ class IoOpsTest(test.TestCase): set(io_ops.matching_files(pattern % '[0-9]').eval()), self._subset(files, [3, 4])) + # Test an empty list input. + self.assertItemsEqual(io_ops.matching_files([]).eval(), []) + + # Test multiple exact filenames. + self.assertItemsEqual( + io_ops.matching_files([ + files[0].name, files[1].name, files[2].name]).eval(), + self._subset(files, [0, 1, 2])) + + # Test multiple globs. + self.assertItemsEqual( + io_ops.matching_files([ + pattern % '?', directory_path + 'X?Z*']).eval(), + self._subset(files, [0, 1, 3, 4, 6])) + for f in files: f.close() diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 557e96cf5d..f535c692a6 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -56,11 +56,11 @@ def match_filenames_once(pattern, name=None): """Save the list of files matching pattern, so it is only computed once. Args: - pattern: A file pattern (glob). + pattern: A file pattern (glob), or 1D tensor of file patterns. name: A name for the operations (optional). Returns: - A variable that is initialized to the list of files matching pattern. + A variable that is initialized to the list of files matching the pattern(s). """ with ops.name_scope(name, "matching_filenames", [pattern]) as name: return variables.Variable(io_ops.matching_files(pattern), trainable=False, |