aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/string_format_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-19 15:40:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 15:45:12 -0700
commitc96841dbd199d3c1a15a89e8c44c7c1d164968b9 (patch)
tree392511a31e3565f862c7257da6a35c31b6587968 /tensorflow/core/kernels/string_format_op.cc
parentb2b98a5ad1b647b77cb42761671cd9b3cf0e87b6 (diff)
This CL adds a new `tf.print` operator that more closely aligns with the standard python `print` method, and deprecates the old `tf.Print` operator (to be removed in in v2.0).
It follows the design doc specified in https://github.com/tensorflow/community/pull/14 and additionally incorporates the community feedback and design review decisions. This CL adds two new internal graph operators: a StringFormat operator that formats a template string with a list of input tensors to insert into the string and outputs a string scalar containing the result, and a PrintV2 operator that prints a string scalar to a specified output stream or logging level. The formatting op is exposed at `tf.strings.Format`. A new python method is exposed at `tf.print` that takes a list of inputs that may be nested structures and may contain tensors, formats them nicely using the formatting op, and returns a PrintV2 operator that prints them. In Eager mode and inside defuns this PrintV2 operator will automatically be executed, but in graph mode it will need to be either added to `sess.run`, or used as a control dependency for other operators being executed. As compared to the previous print function, the new print function: - Has an API that more closely aligns with the standard python3 print - Supports changing the print logging level/output stream - allows printing arbitrary (optionally nested) data structures as opposed to just flat lists of tensors - support printing sparse tensors - changes printed tensor format to show more meaningful summary (recursively print the first and last elements of each tensor dimension, instead of just the first few elements of the tensor irregardless of dimension). PiperOrigin-RevId: 213709924
Diffstat (limited to 'tensorflow/core/kernels/string_format_op.cc')
-rw-r--r--tensorflow/core/kernels/string_format_op.cc65
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
new file mode 100644
index 0000000000..e4a1887f8d
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class StringFormatOp : public OpKernel {
+ public:
+ explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string template_;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+
+ split_template_ = absl::StrSplit(template_, placeholder_);
+ int64 num_placeholders = split_template_.size() - 1;
+ OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders,
+ errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", ctx->num_inputs())));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* formatted_string = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &formatted_string));
+
+ string msg;
+ strings::StrAppend(&msg, split_template_[0].c_str());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true));
+ strings::StrAppend(&msg, split_template_[i + 1].c_str());
+ }
+
+ formatted_string->scalar<string>()() = msg;
+ }
+
+ private:
+ int32 summarize_ = 0;
+ string placeholder_;
+ std::vector<std::string> split_template_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU),
+ StringFormatOp);
+
+} // end namespace tensorflow