aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2017-01-31 16:41:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-31 17:06:15 -0800
commit4062c3a74fa5069b44cb67a9f2005cfa90e54ca1 (patch)
tree827875679cdaa339e771f03fcc5fb3028876cb42
parent8b60a8c131f535b098e72c9a05620e6d52703aa0 (diff)
Automated rollback of change 145619741
Change: 146183030
-rw-r--r--tensorflow/core/kernels/matching_files_op.cc47
-rw-r--r--tensorflow/core/ops/io_ops.cc6
-rw-r--r--tensorflow/core/ops/io_ops_test.cc3
-rw-r--r--tensorflow/python/kernel_tests/io_ops_test.py19
-rw-r--r--tensorflow/python/training/input.py4
5 files changed, 55 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc
index a35b5889d3..5eb060f664 100644
--- a/tensorflow/core/kernels/matching_files_op.cc
+++ b/tensorflow/core/kernels/matching_files_op.cc
@@ -29,22 +29,37 @@ class MatchingFilesOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* context) override {
- const Tensor* pattern;
- OP_REQUIRES_OK(context, context->input("pattern", &pattern));
- OP_REQUIRES(context, TensorShapeUtils::IsScalar(pattern->shape()),
- errors::InvalidArgument(
- "Input pattern tensor must be scalar, but had shape: ",
- pattern->shape().DebugString()));
- std::vector<string> fnames;
- OP_REQUIRES_OK(context, context->env()->GetMatchingPaths(
- pattern->scalar<string>()(), &fnames));
- const int num_out = fnames.size();
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "filenames", TensorShape({num_out}), &output));
- auto output_vec = output->vec<string>();
- for (int i = 0; i < num_out; ++i) {
- output_vec(i) = fnames[i];
+ const Tensor* patterns_t;
+ // NOTE(ringwalt): Changing the input name "pattern" to "patterns" would
+ // break existing graphs.
+ OP_REQUIRES_OK(context, context->input("pattern", &patterns_t));
+ OP_REQUIRES(
+ context,
+ TensorShapeUtils::IsScalar(patterns_t->shape()) ||
+ TensorShapeUtils::IsVector(patterns_t->shape()),
+ errors::InvalidArgument(
+ "Input patterns tensor must be scalar or vector, but had shape: ",
+ patterns_t->shape().DebugString()));
+ const auto patterns = patterns_t->flat<string>();
+ int num_patterns = patterns.size();
+ int num_files = 0;
+ std::vector<std::vector<string>> all_fnames(num_patterns);
+ for (int i = 0; i < num_patterns; i++) {
+ OP_REQUIRES_OK(
+ context,
+ context->env()->GetMatchingPaths(patterns(i), &all_fnames[i]));
+ num_files += all_fnames[i].size();
+ }
+ Tensor* output_t = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ "filenames", TensorShape({num_files}), &output_t));
+ auto output = output_t->vec<string>();
+ int index = 0;
+ for (int i = 0; i < num_patterns; ++i) {
+ for (int j = 0; j < all_fnames[i].size(); j++) {
+ output(index++) = all_fnames[i][j];
+ }
}
}
};
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index 1412aeffc5..864f131caa 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -810,17 +810,17 @@ REGISTER_OP("MatchingFiles")
.Output("filenames: string")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
})
.Doc(R"doc(
-Returns the set of files matching a pattern.
+Returns the set of files matching one or more glob patterns.
Note that this routine only supports wildcard characters in the
basename portion of the pattern, not in the directory portion.
-pattern: A (scalar) shell wildcard pattern.
+pattern: Shell wildcard pattern(s). Scalar or vector of type string.
filenames: A vector of matching filenames.
)doc");
diff --git a/tensorflow/core/ops/io_ops_test.cc b/tensorflow/core/ops/io_ops_test.cc
index 99b091bcee..9d98cb9048 100644
--- a/tensorflow/core/ops/io_ops_test.cc
+++ b/tensorflow/core/ops/io_ops_test.cc
@@ -185,7 +185,8 @@ TEST(IoOpsTest, MatchingFiles_ShapeFn) {
INFER_OK(op, "?", "[?]");
INFER_OK(op, "[]", "[?]");
- INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?]");
+ INFER_OK(op, "[42]", "[?]");
+ INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[?,?]");
}
} // end namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py
index 0e5ca21c48..472487ccfb 100644
--- a/tensorflow/python/kernel_tests/io_ops_test.py
+++ b/tensorflow/python/kernel_tests/io_ops_test.py
@@ -80,8 +80,8 @@ class IoOpsTest(test.TestCase):
io_ops.matching_files(f.name).eval(), compat.as_bytes(f.name))
# We will look for files matching "ABxDEF.GH*" where "x" is some wildcard.
- pos = files[0].name.find(cases[0])
- pattern = files[0].name[:pos] + 'AB%sDEF.GH*'
+ directory_path = files[0].name[:files[0].name.find(cases[0])]
+ pattern = directory_path + 'AB%sDEF.GH*'
self.assertEqual(
set(io_ops.matching_files(pattern % 'z').eval()),
@@ -102,6 +102,21 @@ class IoOpsTest(test.TestCase):
set(io_ops.matching_files(pattern % '[0-9]').eval()),
self._subset(files, [3, 4]))
+ # Test an empty list input.
+ self.assertItemsEqual(io_ops.matching_files([]).eval(), [])
+
+ # Test multiple exact filenames.
+ self.assertItemsEqual(
+ io_ops.matching_files([
+ files[0].name, files[1].name, files[2].name]).eval(),
+ self._subset(files, [0, 1, 2]))
+
+ # Test multiple globs.
+ self.assertItemsEqual(
+ io_ops.matching_files([
+ pattern % '?', directory_path + 'X?Z*']).eval(),
+ self._subset(files, [0, 1, 3, 4, 6]))
+
for f in files:
f.close()
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 557e96cf5d..f535c692a6 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -56,11 +56,11 @@ def match_filenames_once(pattern, name=None):
"""Save the list of files matching pattern, so it is only computed once.
Args:
- pattern: A file pattern (glob).
+ pattern: A file pattern (glob), or 1D tensor of file patterns.
name: A name for the operations (optional).
Returns:
- A variable that is initialized to the list of files matching pattern.
+ A variable that is initialized to the list of files matching the pattern(s).
"""
with ops.name_scope(name, "matching_filenames", [pattern]) as name:
return variables.Variable(io_ops.matching_files(pattern), trainable=False,