aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matching_files_op.cc
diff options
context:
space:
mode:
authorGravatar Gunhan Gulsoy <gunan@google.com>2017-01-25 17:32:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-25 17:50:01 -0800
commit30eab385e0c6b08b7afdb3793c7d405f848fc753 (patch)
tree9fcc35dfb465a0b58545f800398023eeed52b5e9 /tensorflow/core/kernels/matching_files_op.cc
parentde5073445899fbee08dd48ea45fae78fb81066ef (diff)
Automated rollback of change 145580313
Change: 145619741
Diffstat (limited to 'tensorflow/core/kernels/matching_files_op.cc')
-rw-r--r--tensorflow/core/kernels/matching_files_op.cc47
1 files changed, 16 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc
index 5eb060f664..a35b5889d3 100644
--- a/tensorflow/core/kernels/matching_files_op.cc
+++ b/tensorflow/core/kernels/matching_files_op.cc
@@ -29,37 +29,22 @@ class MatchingFilesOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* context) override {
- const Tensor* patterns_t;
- // NOTE(ringwalt): Changing the input name "pattern" to "patterns" would
- // break existing graphs.
- OP_REQUIRES_OK(context, context->input("pattern", &patterns_t));
- OP_REQUIRES(
- context,
- TensorShapeUtils::IsScalar(patterns_t->shape()) ||
- TensorShapeUtils::IsVector(patterns_t->shape()),
- errors::InvalidArgument(
- "Input patterns tensor must be scalar or vector, but had shape: ",
- patterns_t->shape().DebugString()));
- const auto patterns = patterns_t->flat<string>();
- int num_patterns = patterns.size();
- int num_files = 0;
- std::vector<std::vector<string>> all_fnames(num_patterns);
- for (int i = 0; i < num_patterns; i++) {
- OP_REQUIRES_OK(
- context,
- context->env()->GetMatchingPaths(patterns(i), &all_fnames[i]));
- num_files += all_fnames[i].size();
- }
- Tensor* output_t = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(
- "filenames", TensorShape({num_files}), &output_t));
- auto output = output_t->vec<string>();
- int index = 0;
- for (int i = 0; i < num_patterns; ++i) {
- for (int j = 0; j < all_fnames[i].size(); j++) {
- output(index++) = all_fnames[i][j];
- }
+ const Tensor* pattern;
+ OP_REQUIRES_OK(context, context->input("pattern", &pattern));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(pattern->shape()),
+ errors::InvalidArgument(
+ "Input pattern tensor must be scalar, but had shape: ",
+ pattern->shape().DebugString()));
+ std::vector<string> fnames;
+ OP_REQUIRES_OK(context, context->env()->GetMatchingPaths(
+ pattern->scalar<string>()(), &fnames));
+ const int num_out = fnames.size();
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "filenames", TensorShape({num_out}), &output));
+ auto output_vec = output->vec<string>();
+ for (int i = 0; i < num_out; ++i) {
+ output_vec(i) = fnames[i];
}
}
};