diff options
author | Gunhan Gulsoy <gunan@google.com> | 2017-01-25 17:32:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-25 17:50:01 -0800 |
commit | 30eab385e0c6b08b7afdb3793c7d405f848fc753 (patch) | |
tree | 9fcc35dfb465a0b58545f800398023eeed52b5e9 /tensorflow/core/kernels/matching_files_op.cc | |
parent | de5073445899fbee08dd48ea45fae78fb81066ef (diff) |
Automated rollback of change 145580313
Change: 145619741
Diffstat (limited to 'tensorflow/core/kernels/matching_files_op.cc')
-rw-r--r-- | tensorflow/core/kernels/matching_files_op.cc | 47 |
1 files changed, 16 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc index 5eb060f664..a35b5889d3 100644 --- a/tensorflow/core/kernels/matching_files_op.cc +++ b/tensorflow/core/kernels/matching_files_op.cc @@ -29,37 +29,22 @@ class MatchingFilesOp : public OpKernel { public: using OpKernel::OpKernel; void Compute(OpKernelContext* context) override { - const Tensor* patterns_t; - // NOTE(ringwalt): Changing the input name "pattern" to "patterns" would - // break existing graphs. - OP_REQUIRES_OK(context, context->input("pattern", &patterns_t)); - OP_REQUIRES( - context, - TensorShapeUtils::IsScalar(patterns_t->shape()) || - TensorShapeUtils::IsVector(patterns_t->shape()), - errors::InvalidArgument( - "Input patterns tensor must be scalar or vector, but had shape: ", - patterns_t->shape().DebugString())); - const auto patterns = patterns_t->flat<string>(); - int num_patterns = patterns.size(); - int num_files = 0; - std::vector<std::vector<string>> all_fnames(num_patterns); - for (int i = 0; i < num_patterns; i++) { - OP_REQUIRES_OK( - context, - context->env()->GetMatchingPaths(patterns(i), &all_fnames[i])); - num_files += all_fnames[i].size(); - } - Tensor* output_t = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output( - "filenames", TensorShape({num_files}), &output_t)); - auto output = output_t->vec<string>(); - int index = 0; - for (int i = 0; i < num_patterns; ++i) { - for (int j = 0; j < all_fnames[i].size(); j++) { - output(index++) = all_fnames[i][j]; - } + const Tensor* pattern; + OP_REQUIRES_OK(context, context->input("pattern", &pattern)); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(pattern->shape()), + errors::InvalidArgument( + "Input pattern tensor must be scalar, but had shape: ", + pattern->shape().DebugString())); + std::vector<string> fnames; + OP_REQUIRES_OK(context, context->env()->GetMatchingPaths( + pattern->scalar<string>()(), &fnames)); + const int num_out = fnames.size(); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + "filenames", TensorShape({num_out}), &output)); + auto output_vec = output->vec<string>(); + for (int i = 0; i < num_out; ++i) { + output_vec(i) = fnames[i]; } } }; |