aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/BUILD8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt19
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt38
-rw-r--r--tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt4
-rw-r--r--tensorflow/core/framework/tensor.cc112
-rw-r--r--tensorflow/core/framework/tensor.h2
-rw-r--r--tensorflow/core/framework/tensor_test.cc57
-rw-r--r--tensorflow/core/kernels/BUILD27
-rw-r--r--tensorflow/core/kernels/logging_ops.cc57
-rw-r--r--tensorflow/core/kernels/logging_ops_test.cc22
-rw-r--r--tensorflow/core/kernels/string_format_op.cc65
-rw-r--r--tensorflow/core/kernels/string_format_op_test.cc66
-rw-r--r--tensorflow/core/ops/logging_ops.cc19
-rw-r--r--tensorflow/core/ops/string_ops.cc27
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/framework/test_util.py60
-rw-r--r--tensorflow/python/kernel_tests/BUILD13
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py313
-rw-r--r--tensorflow/python/kernel_tests/string_format_op_test.py384
-rw-r--r--tensorflow/python/ops/logging_ops.py260
-rw-r--r--tensorflow/python/ops/string_ops.py84
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt4
26 files changed, 1635 insertions, 28 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 4b2589aaeb..e82dd13b31 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1067,7 +1067,6 @@ tf_gen_op_libs(
"spectral_ops",
"state_ops",
"stateless_random_ops",
- "string_ops",
"summary_ops",
"training_ops",
],
@@ -1075,6 +1074,13 @@ tf_gen_op_libs(
tf_gen_op_libs(
op_lib_names = [
+ "string_ops",
+ ],
+ deps = ["@com_google_absl//absl/strings"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
"array_ops",
],
deps = [":protos_all_cc"],
diff --git a/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..4cb8955dcb
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,19 @@
+op {
+ graph_op_name: "PrintV2"
+ in_arg {
+ name: "input"
+ description: <<END
+The string scalar to print.
+END
+ }
+ attr {
+ name: "output_stream"
+ description: <<END
+A string specifying the output stream or logging level to print to.
+END
+ }
+ summary: "Prints a string scalar."
+ description: <<END
+Prints a string scalar to the desired output_stream.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..a82dae9e48
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,38 @@
+op {
+ graph_op_name: "StringFormat"
+ in_arg {
+ name: "inputs"
+ description: <<END
+The list of tensors to format into the placeholder string.
+END
+ }
+
+ out_arg {
+ name: "output"
+ description: <<END
+= The resulting string scalar.
+END
+ }
+ attr {
+ name: "template"
+ description: <<END
+A string, the template to format tensor summaries into.
+END
+ }
+ attr {
+ name: "placeholder"
+ description: <<END
+A string, at each placeholder in the template a subsequent tensor summary will be inserted.
+END
+ }
+ attr {
+ name: "summarize"
+ description: <<END
+When formatting the tensor summaries print the first and last summarize entries of each tensor dimension.
+END
+ }
+ summary: "Formats a string template using a list of tensors."
+ description: <<END
+Formats a string template using a list of tensors, pretty-printing tensor summaries.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..e22d980424
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "PrintV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..8f0b1db45d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StringFormat"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 516afa517d..eb9c79ff2d 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -948,9 +948,69 @@ void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
}
}
+// Appends the spacing between elements for a given dim onto a result string
+void PrintDimSpacing(int dim_index, int num_dims, string* result) {
+ if (dim_index == num_dims - 1) {
+ strings::StrAppend(result, " ");
+ return;
+ }
+ for (int j = 0; j < num_dims - dim_index - 1; j++) {
+ strings::StrAppend(result, "\n");
+ }
+ for (int j = 0; j <= dim_index; j++) {
+ strings::StrAppend(result, " ");
+ }
+}
+
+// Print from left dim to right dim recursively.
+template <typename T>
+void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
+ int64 num_elts_at_ends, int num_dims, const T* data,
+ int64 data_index, string* result) {
+ // We have recursed beyond all the dimensions into a single element
+ // of the tensor.
+ if (dim_index == num_dims) {
+ strings::StrAppend(result, PrintOneElement(data[data_index]));
+ return;
+ }
+
+ strings::StrAppend(result, "[");
+ int64 element_count = shape[dim_index];
+ int64 start_of_end =
+ std::max(num_elts_at_ends, element_count - num_elts_at_ends);
+
+ // Loop every element of one dim.
+ int64 elements_per_iter = 1;
+ for (int i = dim_index + 1; i < num_dims; i++) {
+ elements_per_iter *= shape[i];
+ }
+ for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
+ if (i > 0) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ }
+
+ // As for each element, print the sub-dim.
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+ if (element_count > 2 * num_elts_at_ends) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ strings::StrAppend(result, "...");
+ }
+ for (int64 i = start_of_end; i < element_count; i++) {
+ // As for each element, print the sub-dim.
+ PrintDimSpacing(dim_index, num_dims, result);
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+
+ strings::StrAppend(result, "]");
+}
+
template <typename T>
string SummarizeArray(int64 limit, int64 num_elts,
- const TensorShape& tensor_shape, const char* data) {
+ const TensorShape& tensor_shape, const char* data,
+ const bool print_v2) {
string ret;
const T* array = reinterpret_cast<const T*>(data);
@@ -963,17 +1023,26 @@ string SummarizeArray(int64 limit, int64 num_elts,
if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
- int64 data_index = 0;
- const int shape_size = tensor_shape.dims();
- PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+ if (print_v2) {
+ const int num_dims = tensor_shape.dims();
+ PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
+ } else {
+ int64 data_index = 0;
+ const int shape_size = tensor_shape.dims();
+ PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+
+ if (num_elts > limit) strings::StrAppend(&ret, "...");
+ }
- if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
} // namespace
-string Tensor::SummarizeValue(int64 max_entries) const {
+string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
const int64 num_elts = NumElements();
+ if (max_entries < 0) {
+ max_entries = num_elts;
+ }
size_t limit = std::min(max_entries, num_elts);
if ((limit > 0) && (buf_ == nullptr)) {
return strings::StrCat("uninitialized Tensor of ", num_elts,
@@ -982,50 +1051,54 @@ string Tensor::SummarizeValue(int64 max_entries) const {
const char* data = limit > 0 ? tensor_data().data() : nullptr;
switch (dtype()) {
case DT_HALF:
- return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
+ return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
+ print_v2);
break;
case DT_FLOAT:
- return SummarizeArray<float>(limit, num_elts, shape_, data);
+ return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
break;
case DT_DOUBLE:
- return SummarizeArray<double>(limit, num_elts, shape_, data);
+ return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT32:
- return SummarizeArray<uint32>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT32:
- return SummarizeArray<int32>(limit, num_elts, shape_, data);
+ return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT8:
case DT_QUINT8:
- return SummarizeArray<uint8>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT16:
case DT_QUINT16:
- return SummarizeArray<uint16>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT16:
case DT_QINT16:
- return SummarizeArray<int16>(limit, num_elts, shape_, data);
+ return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT8:
case DT_QINT8:
- return SummarizeArray<int8>(limit, num_elts, shape_, data);
+ return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT64:
- return SummarizeArray<uint64>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT64:
- return SummarizeArray<int64>(limit, num_elts, shape_, data);
+ return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_BOOL:
// TODO(tucker): Is it better to emit "True False..."? This
// will emit "1 0..." which is more compact.
- return SummarizeArray<bool>(limit, num_elts, shape_, data);
+ return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
break;
default: {
// All irregular cases
string ret;
+ if (print_v2) {
+ strings::StrAppend(&ret, "[");
+ }
// TODO(irving): Don't call flat every time around this
// loop.
for (size_t i = 0; i < limit; ++i) {
@@ -1045,6 +1118,9 @@ string Tensor::SummarizeValue(int64 max_entries) const {
}
}
if (max_entries < num_elts) strings::StrAppend(&ret, "...");
+ if (print_v2) {
+ strings::StrAppend(&ret, "]");
+ }
return ret;
}
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 696fd277cd..5f5d2021a4 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -430,7 +430,7 @@ class Tensor {
int64 begin) const;
/// Render the first `max_entries` values in `*this` into a string.
- string SummarizeValue(int64 max_entries) const;
+ string SummarizeValue(int64 max_entries, bool print_v2 = false) const;
/// A human-readable summary of the tensor suitable for debugging.
string DebugString() const;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 9a78cdc91e..fc05c86990 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -1295,6 +1295,63 @@ TEST(SummarizeValue, STRING) {
EXPECT_EQ("one two three four five one...", x.SummarizeValue(6));
}
+TEST(SummarizeValue, INT32_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, INT32Dims_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({3, 4}),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ EXPECT_EQ("[[1 ... 4]\n ...\n [9 ... 12]]", x.SummarizeValue(1, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(10, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(-1, true));
+}
+
+TEST(SummarizeValue, FLOAT_PRINT_V2) {
+ Tensor x = MkTensor<float>(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, BOOL_PRINT_V2) {
+ Tensor x = MkTensor<bool>(DT_BOOL, TensorShape({5}), {false, true, true});
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[0 1 ... 0 1]", x.SummarizeValue(2, true));
+}
+
+TEST(SummarizeValue, STRING_PRINT_V2) {
+ Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(-1, true));
+ x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five one...]", x.SummarizeValue(6, true));
+}
+
void BM_CreateAndDestroy(int iters) {
TensorShape shape({10, 20});
while (--iters) {
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 7aa1169061..b0d04a7213 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2707,6 +2707,7 @@ cc_library(
)
LOGGING_DEPS = [
+ "@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -2764,6 +2765,7 @@ tf_cc_tests(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -4401,6 +4403,7 @@ cc_library(
":reduce_join_op",
":regex_full_match_op",
":regex_replace_op",
+ ":string_format_op",
":string_join_op",
":string_length_op",
":string_split_op",
@@ -4432,6 +4435,30 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "string_format_op",
+ prefix = "string_format_op",
+ deps = STRING_DEPS + ["@com_google_absl//absl/strings"],
+)
+
+tf_cc_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.cc"],
+ deps = [
+ ":string_format_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "string_join_op",
prefix = "string_join_op",
deps = STRING_DEPS,
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 6b6a14e9a7..8bafd5739d 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -74,8 +75,7 @@ class PrintOp : public OpKernel {
string msg;
strings::StrAppend(&msg, message_);
for (int i = 1; i < ctx->num_inputs(); ++i) {
- strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
- "]");
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_));
}
std::cerr << msg << std::endl;
}
@@ -90,6 +90,59 @@ class PrintOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
+class PrintV2Op : public OpKernel {
+ public:
+ explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
+
+ auto output_stream_index =
+ std::find(std::begin(valid_output_streams_),
+ std::end(valid_output_streams_), output_stream_);
+
+ if (output_stream_index == std::end(valid_output_streams_)) {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
+ const string& msg = input_->scalar<string>()();
+
+ if (output_stream_ == "stdout") {
+ std::cout << msg << std::endl;
+ } else if (output_stream_ == "stderr") {
+ std::cerr << msg << std::endl;
+ } else if (output_stream_ == "log(info)") {
+ LOG(INFO) << msg << std::endl;
+ } else if (output_stream_ == "log(warning)") {
+ LOG(WARNING) << msg << std::endl;
+ } else if (output_stream_ == "log(error)") {
+ LOG(ERROR) << msg << std::endl;
+ } else {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ const char* valid_output_streams_[6] = {"stdout", "stderr", "log(info)",
+ "log(warning)", "log(error)"};
+
+ private:
+ string output_stream_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
+
class TimestampOp : public OpKernel {
public:
explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc
index 5e6958f364..a259d995fa 100644
--- a/tensorflow/core/kernels/logging_ops_test.cc
+++ b/tensorflow/core/kernels/logging_ops_test.cc
@@ -23,11 +23,33 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
+class PrintingV2GraphTest : public OpsTestBase {
+ protected:
+ Status Init(const string& output_stream = "log(warning)") {
+ TF_CHECK_OK(NodeDefBuilder("op", "PrintV2")
+ .Input(FakeInput(DT_STRING))
+ .Attr("output_stream", output_stream)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(PrintingV2GraphTest, StringSuccess) {
+ TF_ASSERT_OK(Init());
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
+ ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
+}
+
class PrintingGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type1, DataType input_type2, string msg = "",
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
new file mode 100644
index 0000000000..e4a1887f8d
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class StringFormatOp : public OpKernel {
+ public:
+ explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string template_;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+
+ split_template_ = absl::StrSplit(template_, placeholder_);
+ int64 num_placeholders = split_template_.size() - 1;
+ OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders,
+ errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", ctx->num_inputs())));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* formatted_string = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &formatted_string));
+
+ string msg;
+ strings::StrAppend(&msg, split_template_[0].c_str());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true));
+ strings::StrAppend(&msg, split_template_[i + 1].c_str());
+ }
+
+ formatted_string->scalar<string>()() = msg;
+ }
+
+ private:
+ int32 summarize_ = 0;
+ string placeholder_;
+ std::vector<std::string> split_template_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU),
+ StringFormatOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc
new file mode 100644
index 0000000000..13130a5797
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+class StringFormatGraphTest : public OpsTestBase {
+ protected:
+ Status Init(int num_inputs, DataType input_type,
+ const string& template_ = "%s", const string& placeholder = "%s",
+ int summarize = 3) {
+ TF_CHECK_OK(NodeDefBuilder("op", "StringFormat")
+ .Input(FakeInput(num_inputs, input_type))
+ .Attr("template", template_)
+ .Attr("placeholder", placeholder)
+ .Attr("summarize", summarize)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(StringFormatGraphTest, Int32Success_7) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s"));
+
+ AddInputFromArray<int32>(TensorShape({7}), {1, 2, 3, 4, 5, 6, 7});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [1 2 3 ... 5 6 7]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(StringFormatGraphTest, Int32Success_3_3) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s", "%s", 1));
+
+ AddInputFromArray<int32>(TensorShape({3, 3}), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [[1 ... 3]\n ..."
+ "\n [7 ... 9]]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+} // end namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index 639d211767..2034d3601b 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -20,6 +20,8 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("Assert")
.Input("condition: bool")
.Input("data: T")
@@ -44,6 +46,23 @@ REGISTER_OP("Print")
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print");
+REGISTER_OP("PrintV2")
+ .Input("input: string")
+ .SetIsStateful()
+ .Attr(
+ "output_stream: {'stdout', 'stderr', 'log(info)', "
+ "'log(warning)', 'log(error)'} = 'stderr'")
+ .SetShapeFn([](InferenceContext* c) {
+ // Make sure that the input is a scalar.
+ if (c->Rank(c->input(0)) != 0) {
+ return errors::InvalidArgument("input must be a scalar, but has rank: ",
+ c->Rank(c->input(0)));
+ }
+ return Status::OK();
+ });
+
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("PrintV2");
+
// ----------------------------------------------------------------------------
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
// inputs or outputs in various ways.
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index ef8b15dc8a..99159839d0 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -102,6 +103,32 @@ REGISTER_OP("AsString")
.Attr("fill: string = ''")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("StringFormat")
+ .Input("inputs: T")
+ .Output("output: string")
+ .Attr("T: list(type) >= 0")
+ .Attr("template: string = '%s'")
+ .Attr("placeholder: string = '%s'")
+ .Attr("summarize: int = 3")
+ .SetShapeFn([](InferenceContext* c) {
+ string template_;
+ string placeholder;
+ TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
+ TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
+
+ std::vector<std::string> split_template;
+ split_template = absl::StrSplit(template_, placeholder);
+ int64 num_placeholders = split_template.size() - 1;
+ if (c->num_inputs() != num_placeholders) {
+ return errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", c->num_inputs()));
+ }
+
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("StringJoin")
.Input("inputs: N * string")
.Attr("N: int")
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index d70e9c5798..9730e9933a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2324,6 +2324,8 @@ py_library(
deps = [
":framework_for_generated_wrappers",
":logging_ops_gen",
+ ":platform",
+ ":string_ops",
":util",
],
)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b7398238f5..c302072aa1 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,6 +24,7 @@ from collections import OrderedDict
import contextlib
import gc
import itertools
+import os
import math
import random
import re
@@ -868,6 +869,19 @@ def device(use_gpu):
yield
+class CapturedWrites(object):
+ """A utility class to load the captured writes made to a stream."""
+
+ def __init__(self, capture_location):
+ self.capture_location = capture_location
+
+ def contents(self):
+ """Get the captured writes as a single string."""
+ with open(self.capture_location) as tmp_file:
+ output_data = "".join(tmp_file.readlines())
+ return output_data
+
+
class ErrorLoggingSession(session.Session):
"""Wrapper around a Session that logs errors in run().
"""
@@ -934,6 +948,52 @@ class TensorFlowTestCase(googletest.TestCase):
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
+ @contextlib.contextmanager
+ def captureWritesToStream(self, stream):
+ """A context manager that captures the writes to a given stream.
+
+ This context manager captures all writes to a given stream inside of a
+ `CapturedWrites` object. When this context manager is created, it yields
+ the `CapturedWrites` object. The captured contents can be accessed by
+ calling `.contents()` on the `CapturedWrites`.
+
+ For this function to work, the stream must have a file descriptor that
+ can be modified using `os.dup` and `os.dup2`, and the stream must support
+ a `.flush()` method. The default python sys.stdout and sys.stderr are
+ examples of this. Note that this does not work in Colab or Jupyter
+ notebooks, because those use alternate stdout streams.
+
+ Example:
+ ```python
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ with self.captureWritesToStream(sys.stdout) as captured:
+ result = MyOperator(input).eval()
+ self.assertStartsWith(captured.contents(), "This was printed.")
+ ```
+
+ Args:
+ stream: The stream whose writes should be captured. This
+ stream must have a file descriptor, support writing via using that
+ file descriptor, and must have a `.flush()` method.
+
+ Yields:
+ A `CapturedWrites` object that contains all writes to the specified stream
+ made during this context.
+ """
+ stream.flush()
+ fd = stream.fileno()
+ tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
+ tmp_file = open(tmp_file_path, "w")
+ orig_fd = os.dup(fd)
+ os.dup2(tmp_file.fileno(), fd)
+ try:
+ yield CapturedWrites(tmp_file_path)
+ finally:
+ tmp_file.close()
+ os.dup2(orig_fd, fd)
+
def _AssertProtoEquals(self, a, b, msg=None):
"""Asserts that a and b are the same proto.
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index a048eaa69f..9dc6df77f1 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -961,6 +961,19 @@ tf_py_test(
)
tf_py_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+tf_py_test(
name = "string_join_op_test",
size = "small",
srcs = ["string_join_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 82729b9e27..79fe9de62f 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -18,14 +18,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
class LoggingOpsTest(test.TestCase):
@@ -57,6 +66,305 @@ class LoggingOpsTest(test.TestCase):
out.eval()
+class PrintV2Test(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensor(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorVarySummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=1)
+ self.evaluate(print_op)
+
+ expected = "[0 ... 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=2)
+ self.evaluate(print_op)
+
+ expected = "[0 1 ... 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=3)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=-1)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 3 4 5 6 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneVariable(self):
+ with self.test_session():
+ var = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(var)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoVariablesInStructWithAssignAdd(self):
+ with self.test_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ self.evaluate(plus_one)
+ print_op = logging_ops.print_v2(var_one, {"second": var_two})
+ self.evaluate(print_op)
+ expected = "3.14 {'second': [0 1 2 ... 7 8 9]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoTensors(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, tensor * 10)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintPlaceholderGeneration(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
+ self.evaluate(print_op)
+ expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintNoTensors(self):
+ with self.test_session():
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
+ self.evaluate(print_op)
+ expected = "23 [23, 5] {'6': 12}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintFloatScalar(self):
+ with self.test_session():
+ tensor = ops.convert_to_tensor(434.43)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "434.43"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintStringScalar(self):
+ with self.test_session():
+ tensor = ops.convert_to_tensor("scalar")
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "scalar"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintComplexTensorStruct(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ small_tensor = constant_op.constant([0.3, 12.4, -16.1])
+ big_tensor = math_ops.mul(tensor, 10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ "first:", tensor, "middle:",
+ {"small": small_tensor, "Big": big_tensor}, 10,
+ [tensor * 2, tensor])
+ self.evaluate(print_op)
+ # Note that the keys in the dict will always be sorted,
+ # so 'Big' comes before 'small'
+ expected = ("first: [0 1 2 ... 7 8 9] "
+ "middle: {'Big': [0 10 20 ... 70 80 90], "
+ "'small': [0.3 12.4 -16.1]} "
+ "10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensor(self):
+ with self.test_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(sparse)
+ self.evaluate(print_op)
+ expected = ("'SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensorInDataStruct(self):
+ with self.test_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2([sparse])
+ self.evaluate(print_op)
+ expected = ("['SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorStdout(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stdout) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=sys.stdout)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogInfo(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.info)
+ self.evaluate(print_op)
+ self.assertTrue("I" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogWarning(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.warning)
+ self.evaluate(print_op)
+ self.assertTrue("W" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogError(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.error)
+ self.evaluate(print_op)
+ self.assertTrue("E" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testInvalidOutputStreamRaisesError(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ with self.assertRaises(ValueError):
+ print_op = logging_ops.print_v2(
+ tensor, output_stream="unknown")
+ self.evaluate(print_op)
+
+ def testPrintOpName(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ print_op = logging_ops.print_v2(tensor, name="print_name")
+ self.assertEqual(print_op.name, "print_name")
+
+ def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ formatted_string = string_ops.string_format("{}", tensor)
+ print_op = logging_ops.print_v2(formatted_string)
+ self.evaluate(print_op)
+ graph_ops = ops.get_default_graph().get_operations()
+ format_ops = [op for op in graph_ops if op.type == "StringFormat"]
+ # Should be only 1 format_op for graph mode.
+ self.assertEqual(len(format_ops), 1)
+
+ def testPrintOneTensorEagerOnOpCreate(self):
+ with self.test_session():
+ with context.eager_mode():
+ tensor = math_ops.range(10)
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed:
+ logging_ops.print_v2(tensor)
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintInDefunWithoutExplicitEvalOfPrint(self):
+ @function.defun
+ def f():
+ tensor = math_ops.range(10)
+ logging_ops.print_v2(tensor)
+ return tensor
+
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed_one:
+ x = f()
+ self.evaluate(x)
+ self.assertTrue((expected + "\n") in printed_one.contents())
+
+ # We execute the function again to make sure it doesn't only print on the
+ # first call.
+ with self.captureWritesToStream(sys.stderr) as printed_two:
+ y = f()
+ self.evaluate(y)
+ self.assertTrue((expected + "\n") in printed_two.contents())
+
+
class PrintGradientTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@@ -65,6 +373,11 @@ class PrintGradientTest(test.TestCase):
inp_printed = logging_ops.Print(inp, [inp])
self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+ def testPrintString(self):
+ inp = constant_op.constant(2.0, shape=[100, 32])
+ inp_printed = logging_ops.Print(inp, ["hello"])
+ self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+
def testPrintGradient(self):
with self.cached_session():
inp = constant_op.constant(2.0, shape=[100, 32], name="in")
diff --git a/tensorflow/python/kernel_tests/string_format_op_test.py b/tensorflow/python/kernel_tests/string_format_op_test.py
new file mode 100644
index 0000000000..afa71db909
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_format_op_test.py
@@ -0,0 +1,384 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class StringFormatOpTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDim(self):
+ with self.test_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", [tensor])
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableScalar(self):
+ with self.test_session():
+ var = variables.Variable(3.34)
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "3.34"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableOneDim(self):
+ with self.test_session():
+ var = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatTwoVariablesWithAssignAdd(self):
+ with self.test_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}, {}", [var_one, var_two])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ self.evaluate(plus_one)
+ out = self.evaluate(format_output)
+ expected = "3.14, [0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimFloat(self):
+ with self.test_session():
+ tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 0.1 0.2 ... 0.5 0.6 0.7]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimMatchesSummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimVarySummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=-1)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=1)
+ out = self.evaluate(format_output)
+ expected = "[0 ... 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = "[0 1 ... 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.test_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=10)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimAlmostSummarize(self):
+ with self.test_session():
+ tensor = math_ops.range(5)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimLessThanSummarize(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(4), [2, 2])
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1]\n"
+ " [2 3]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDim(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimSummarizeTwo(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorThreeDim(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]\n"
+ "\n"
+ " [[100 101 102 ... 107 108 109]\n"
+ " [110 111 112 ... 117 118 119]\n"
+ " [120 121 122 ... 127 128 129]\n"
+ " ...\n [170 171 172 ... 177 178 179]\n"
+ " [180 181 182 ... 187 188 189]\n"
+ " [190 191 192 ... 197 198 199]]\n"
+ "\n"
+ " [[200 201 202 ... 207 208 209]\n"
+ " [210 211 212 ... 217 218 219]\n"
+ " [220 221 222 ... 227 228 229]\n"
+ " ...\n"
+ " [270 271 272 ... 277 278 279]\n"
+ " [280 281 282 ... 287 288 289]\n"
+ " [290 291 292 ... 297 298 299]]\n"
+ "\n"
+ " ...\n"
+ "\n"
+ " [[700 701 702 ... 707 708 709]\n"
+ " [710 711 712 ... 717 718 719]\n"
+ " [720 721 722 ... 727 728 729]\n"
+ " ...\n"
+ " [770 771 772 ... 777 778 779]\n"
+ " [780 781 782 ... 787 788 789]\n"
+ " [790 791 792 ... 797 798 799]]\n"
+ "\n"
+ " [[800 801 802 ... 807 808 809]\n"
+ " [810 811 812 ... 817 818 819]\n"
+ " [820 821 822 ... 827 828 829]\n"
+ " ...\n"
+ " [870 871 872 ... 877 878 879]\n"
+ " [880 881 882 ... 887 888 889]\n"
+ " [890 891 892 ... 897 898 899]]\n"
+ "\n"
+ " [[900 901 902 ... 907 908 909]\n"
+ " [910 911 912 ... 917 918 919]\n"
+ " [920 921 922 ... 927 928 929]\n"
+ " ...\n"
+ " [970 971 972 ... 977 978 979]\n"
+ " [980 981 982 ... 987 988 989]\n"
+ " [990 991 992 ... 997 998 999]]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefix(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefixAndSuffix(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}, suffix",
+ tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplateSuffix(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}, suffix", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatNoTensor(self):
+ with self.test_session():
+ format_output = string_ops.string_format("No tensor.", ())
+ out = self.evaluate(format_output)
+ expected = "No tensor."
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatMultiTensor(self):
+ with self.test_session():
+ tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
+ tensor_two = tensor_one * 10
+ format_output = string_ops.string_format("One: {},\nTwo: {}",
+ (tensor_one, tensor_two))
+ out = self.evaluate(format_output)
+ expected = ("One: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]],\n"
+ "Two: [[0 10 20 ... 70 80 90]\n"
+ " [100 110 120 ... 170 180 190]\n"
+ " [200 210 220 ... 270 280 290]\n"
+ " ...\n"
+ " [700 710 720 ... 770 780 790]\n"
+ " [800 810 820 ... 870 880 890]\n"
+ " [900 910 920 ... 970 980 990]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeOne(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=1)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 ... 9]\n"
+ " ...\n"
+ " [90 ... 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeTwo(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatPlaceholder(self):
+ with self.test_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: %t%", tensor,
+ placeholder="%t%")
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testTensorCountMustMatchPlaceholderCount(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", tensor)
+ self.evaluate(format_output)
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", [tensor])
+ self.evaluate(format_output)
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"1 placeholder\(s\) in template does not match 2 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", (tensor, tensor))
+ self.evaluate(format_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index df41933f8a..4c53f33af1 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -19,13 +19,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import pprint
+import random
+import sys
+
+import six
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import gen_logging_ops
+from tensorflow.python.ops import string_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_logging_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -40,7 +51,32 @@ from tensorflow.python.util.tf_export import tf_export
# For users with Python 3 or Python 2.7
# with `from __future__ import print_function`, we could also allow lowercase.
# See https://github.com/tensorflow/tensorflow/issues/18053
-@tf_export("Print")
+
+
+# pylint: disable=invalid-name
+@deprecated("2018-08-20", "Use tf.print instead of tf.Print. Note that "
+ "tf.print returns a no-output operator that directly "
+ "prints the output. Outside of defuns or eager mode, "
+ "this operator will not be executed unless it is "
+ "directly specified in session.run or used as a "
+ "control dependency for other operators. This is "
+ "only a concern in graph mode. Below is an example "
+ "of how to ensure tf.print executes in graph mode:\n"
+ """```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print(tensor)
+ with tf.control_dependencies([print_op]):
+ out = tf.add(tensor, tensor)
+ sess.run(out)
+ ```
+Additionally, to use tf.print in python 2.7, users must make sure to import
+the following:
+
+ `from __future__ import print_function`
+""")
+@tf_export(v1=["Print"])
def Print(input_, data, message=None, first_n=None, summarize=None,
name=None):
"""Prints a list of tensors.
@@ -66,6 +102,228 @@ def Print(input_, data, message=None, first_n=None, summarize=None,
A `Tensor`. Has the same type and contents as `input_`.
"""
return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+# pylint: enable=invalid-name
+
+
+def _generate_placeholder_string(x, default_placeholder="{}"):
+ """Generate and return a string that does not appear in `x`."""
+ placeholder = default_placeholder
+ rng = random.Random(5)
+ while placeholder in x:
+ placeholder = placeholder + str(rng.randint(0, 9))
+ return placeholder
+
+
+# Temporarily disable pylint g-doc-args error to allow giving more context
+# about what the kwargs are.
+# Because we are using arbitrary-length positional arguments, python 2
+# does not support explicitly specifying the keyword arguments in the
+# function definition.
+# pylint: disable=g-doc-args
+@tf_export("print")
+def print_v2(*inputs, **kwargs):
+ """Print the specified inputs.
+
+ Returns an operator that prints the specified inputs to a desired
+ output stream or logging level. The inputs may be dense or sparse Tensors,
+ primitive python objects, data structures that contain Tensors, and printable
+ python objects. Printed tensors will recursively show the first and last
+ `summarize` elements of each dimension.
+
+ With eager execution enabled and/or inside a `tf.contrib.eager.defun` this
+ operator will automatically execute, and users only need to call `tf.print`
+ without using the return value. When constructing graphs outside of a
+ `tf.contrib.eager.defun`, one must either include the returned op
+ in the input to `session.run`, or use the operator as a control dependency for
+ executed ops by specifying `with tf.control_dependencies([print_op])`.
+
+ @compatibility(python2)
+ In python 2.7, make sure to import the following:
+ `from __future__ import print_function`
+ @end_compatibility
+
+ Example:
+ Single-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Multi-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Usage in a defun:
+ ```python
+ tf.enable_eager_execution()
+
+ @tf.contrib.eager.defun
+ def f():
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ return tensor
+
+ range_tensor = f()
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Usage when constructing graphs:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print("tensors:", tensor, {2: tensor * 2},
+ output_stream=sys.stdout)
+ with tf.control_dependencies([print_op]):
+ tripled_tensor = tensor * 3
+ sess.run(tripled_tensor)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Note: This op is only partially compatible with Jupyter notebooks and colabs.
+ Because it prints to the C++ standard out / standard error, this will go
+ in the notebook kernel's console output, not in the notebook cell output.
+
+ Args:
+ *inputs: Positional arguments that are the inputs to print. Inputs in the
+ printed output will be separated by spaces. Inputs may be python
+ primitives, tensors, data structures such as dicts and lists that
+ may contain tensors (with the data structures possibly nested in
+ arbitrary ways), and printable python objects.
+ output_stream: The output stream or logging level to print to. Defaults to
+ sys.stderr, but sys.stdout, tf.logging.info, tf.logging.warning, and
+ tf.logging.error are also supported.
+ summarize: The first and last `summarize` elements within each dimension are
+ recursively printed per Tensor. If None, then the first 3 and last 3
+ elements of each dimension are printed for each tensor. If set to -1, it
+ will print all elements of every tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ A print operator that prints the specified inputs in the specified output
+ stream or logging level.
+
+ Raises:
+ ValueError: If an unsupported output stream is specified.
+ """
+ # Because we are using arbitrary-length positional arguments, python 2
+ # does not support explicitly specifying the keyword arguments in the
+ # function definition. So, we manually get the keyword arguments w/ default
+ # values here.
+ output_stream = kwargs.pop("output_stream", sys.stderr)
+ name = kwargs.pop("name", None)
+ summarize = kwargs.pop("summarize", 3)
+ if kwargs:
+ raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs)
+ format_name = None
+ if name:
+ format_name = name + "_format"
+
+ # Match the C++ string constants representing the different output streams.
+ # Keep this updated!
+ output_stream_to_constant = {
+ sys.stdout: "stdout",
+ sys.stderr: "stderr",
+ tf_logging.INFO: "log(info)",
+ tf_logging.info: "log(info)",
+ tf_logging.WARN: "log(warning)",
+ tf_logging.warning: "log(warning)",
+ tf_logging.warn: "log(warning)",
+ tf_logging.ERROR: "log(error)",
+ tf_logging.error: "log(error)",
+ }
+
+ output_stream_string = output_stream_to_constant.get(output_stream)
+ if not output_stream_string:
+ raise ValueError(
+ "Unsupported output stream or logging level " +
+ str(output_stream) + ". Supported streams are sys.stdout, "
+ "sys.stderr, tf.logging.info, "
+ "tf.logging.warning, tf.logging.error")
+
+ # If we are only printing a single string scalar, there is no need to format
+ if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0])
+ and (not isinstance(inputs[0], sparse_tensor.SparseTensor))
+ and inputs[0].shape and (inputs[0].dtype == dtypes.string)):
+ formatted_string = inputs[0]
+ # Otherwise, we construct an appropriate template for the tensors we are
+ # printing, and format the template using those tensors.
+ else:
+ # For each input to this print function, we extract any nested tensors,
+ # and construct an appropriate template to format representing the
+ # printed input.
+ templates = []
+ tensors = []
+ tensor_free_structure = nest.map_structure(
+ lambda x: "" if tensor_util.is_tensor(x) else x,
+ inputs)
+ tensor_free_template = " ".join(pprint.pformat(x)
+ for x in tensor_free_structure)
+ placeholder = _generate_placeholder_string(tensor_free_template)
+
+ for input_ in inputs:
+ placeholders = []
+ # Use the nest utilities to flatten & process any nested elements in this
+ # input. The placeholder for a tensor in the template should be the
+ # placeholder string, and the placeholder for a non-tensor can just be
+ # the printed value of the non-tensor itself.
+ for x in nest.flatten(input_):
+ # support sparse tensors
+ if isinstance(x, sparse_tensor.SparseTensor):
+ tensors.extend([x.indices, x.values, x.dense_shape])
+ placeholders.append(
+ "SparseTensor(indices={}, values={}, shape={})".format(
+ placeholder, placeholder, placeholder)
+ )
+ elif tensor_util.is_tensor(x):
+ tensors.append(x)
+ placeholders.append(placeholder)
+ else:
+ placeholders.append(x)
+
+ if isinstance(input_, six.string_types):
+ # If the current input to format/print is a normal string, that string
+ # can act as the template.
+ cur_template = input_
+ else:
+ # We pack the placeholders into a data structure that matches the
+ # input data structure format, then format that data structure
+ # into a string template.
+ #
+ # NOTE: We must use pprint.pformat here for building the template for
+ # unordered data structures such as `dict`, because `str` doesn't
+ # guarantee orderings, while pprint prints in sorted order. pprint
+ # will match the ordering of `nest.flatten`.
+ # This even works when nest.flatten reorders OrderedDicts, because
+ # pprint is printing *after* the OrderedDicts have been reordered.
+ cur_template = pprint.pformat(
+ nest.pack_sequence_as(input_, placeholders))
+ templates.append(cur_template)
+
+ # We join the templates for the various inputs into a single larger
+ # template. We also remove all quotes surrounding the placeholders, so that
+ # the formatted/printed output will not contain quotes around tensors.
+ # (example of where these quotes might appear: if we have added a
+ # placeholder string into a list, then pretty-formatted that list)
+ template = " ".join(templates)
+ template = template.replace("'" + placeholder + "'", placeholder)
+ formatted_string = string_ops.string_format(
+ inputs=tensors, template=template, placeholder=placeholder,
+ summarize=summarize,
+ name=format_name)
+
+ return gen_logging_ops.print_v2(formatted_string,
+ output_stream=output_stream_string,
+ name=name)
+# pylint: enable=g-doc-args
@ops.RegisterGradient("Print")
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index b2c6937368..5d949467fd 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -29,14 +29,15 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.util import compat as util_compat
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_string_ops import *
+from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -103,6 +104,87 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
rewrite=rewrite, replace_global=replace_global)
+@tf_export("strings.format")
+def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
+ r"""Formats a string template using a list of tensors.
+
+ Formats a string template using a list of tensors, abbreviating tensors by
+ only printing the first and last `summarize` elements of each dimension
+ (recursively). If formatting only one tensor into a template, the tensor does
+ not have to be wrapped in a list.
+
+ Example:
+ Formatting a single-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ formatted = tf.strings.format("tensor: {}, suffix", tensor)
+ out = sess.run(formatted)
+ expected = "tensor: [0 1 2 ... 7 8 9], suffix"
+
+ assert(out.decode() == expected)
+ ```
+
+ Formatting a multi-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor_one = tf.reshape(tf.range(100), [10, 10])
+ tensor_two = tf.range(10)
+ formatted = tf.strings.format("first: {}, second: {}, suffix",
+ (tensor_one, tensor_two))
+
+ out = sess.run(formatted)
+ expected = ("first: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
+
+ assert(out.decode() == expected)
+ ```
+
+ Args:
+ template: A string template to format tensor values into.
+ inputs: A list of `Tensor` objects, or a single Tensor.
+ The list of tensors to format into the template string. If a solitary
+ tensor is passed in, the input tensor will automatically be wrapped as a
+ list.
+ placeholder: An optional `string`. Defaults to `{}`.
+ At each placeholder occurring in the template, a subsequent tensor
+ will be inserted.
+ summarize: An optional `int`. Defaults to `3`.
+ When formatting the tensors, show the first and last `summarize`
+ entries of each tensor dimension (recursively). If set to -1, all
+ elements of the tensor will be shown.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`.
+
+ Raises:
+ ValueError: if the number of placeholders does not match the number of
+ inputs.
+ """
+ # If there is only one tensor to format, we will automatically wrap it in a
+ # list to simplify the user experience
+ if tensor_util.is_tensor(inputs):
+ inputs = [inputs]
+ if template.count(placeholder) != len(inputs):
+ raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
+ " provided as input" % (template.count(placeholder),
+ len(inputs)))
+
+ return gen_string_ops.string_format(inputs,
+ template=template,
+ placeholder=placeholder,
+ summarize=summarize,
+ name=name)
+
+
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 14ab885c91..6ff4343e9e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1593,6 +1593,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 018be7b9f9..c81c156518 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 323d2fc519..db90c007d4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -581,10 +581,6 @@ tf_module {
argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "Print"
- argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "abs"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1541,6 +1537,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 018be7b9f9..c81c156518 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}