aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/tensor.cc')
-rw-r--r--tensorflow/core/framework/tensor.cc72
1 files changed, 58 insertions, 14 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index ce4ae2d3d3..05db5a32b5 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -721,14 +721,58 @@ bool Tensor::CanUseDMA() const {
#undef CASE
namespace {
+// Print from left dim to right dim recursively.
template <typename T>
-string SummarizeArray(int64 limit, int64 num_elts, const char* data) {
+void PrintOneDim(int dim_index, gtl::InlinedVector<int64, 4> shape, int64 limit,
+ int shape_size, T* data, int64* data_index, string* result) {
+ if (*data_index >= limit) return;
+ int64 element_count = shape[dim_index];
+ // We have reached the right-most dimension of the tensor.
+ if (dim_index == shape_size - 1) {
+ for (int64 i = 0; i < element_count; i++) {
+ if (*data_index >= limit) return;
+ if (i > 0) strings::StrAppend(result, " ");
+ strings::StrAppend(result, data[(*data_index)++]);
+ }
+ return;
+ }
+ // Loop every element of one dim.
+ for (int64 i = 0; i < element_count; i++) {
+ bool flag = false;
+ if (*data_index < limit) {
+ strings::StrAppend(result, "[");
+ flag = true;
+ }
+ // As for each element, print the sub-dim.
+ PrintOneDim(dim_index + 1, shape, limit, shape_size,
+ data, data_index, result);
+ if (*data_index < limit || flag) {
+ strings::StrAppend(result, "]");
+ flag = false;
+ }
+ }
+}
+
+template <typename T>
+string SummarizeArray(int64 limit, int64 num_elts, const TensorShape& tensor_shape,
+ const char* data) {
string ret;
const T* array = reinterpret_cast<const T*>(data);
- for (int64 i = 0; i < limit; ++i) {
- if (i > 0) strings::StrAppend(&ret, " ");
- strings::StrAppend(&ret, array[i]);
+
+ const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
+ if(shape.empty()) {
+ for (int64 i = 0; i < limit; ++i) {
+ if (i > 0) strings::StrAppend(&ret, " ");
+ strings::StrAppend(&ret, array[i]);
+ }
+ if (num_elts > limit) strings::StrAppend(&ret, "...");
+ return ret;
}
+ int64 data_index = 0;
+ const int shape_size = tensor_shape.dims();
+ PrintOneDim(0, shape, limit, shape_size,
+ array, &data_index, &ret);
+
if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
@@ -744,40 +788,40 @@ string Tensor::SummarizeValue(int64 max_entries) const {
const char* data = limit > 0 ? tensor_data().data() : nullptr;
switch (dtype()) {
case DT_HALF:
- return SummarizeArray<Eigen::half>(limit, num_elts, data);
+ return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
break;
case DT_FLOAT:
- return SummarizeArray<float>(limit, num_elts, data);
+ return SummarizeArray<float>(limit, num_elts, shape_, data);
break;
case DT_DOUBLE:
- return SummarizeArray<double>(limit, num_elts, data);
+ return SummarizeArray<double>(limit, num_elts, shape_, data);
break;
case DT_INT32:
- return SummarizeArray<int32>(limit, num_elts, data);
+ return SummarizeArray<int32>(limit, num_elts, shape_, data);
break;
case DT_UINT8:
case DT_QUINT8:
- return SummarizeArray<uint8>(limit, num_elts, data);
+ return SummarizeArray<uint8>(limit, num_elts, shape_, data);
break;
case DT_UINT16:
case DT_QUINT16:
- return SummarizeArray<uint16>(limit, num_elts, data);
+ return SummarizeArray<uint16>(limit, num_elts, shape_, data);
break;
case DT_INT16:
case DT_QINT16:
- return SummarizeArray<int16>(limit, num_elts, data);
+ return SummarizeArray<int16>(limit, num_elts, shape_, data);
break;
case DT_INT8:
case DT_QINT8:
- return SummarizeArray<int8>(limit, num_elts, data);
+ return SummarizeArray<int8>(limit, num_elts, shape_, data);
break;
case DT_INT64:
- return SummarizeArray<int64>(limit, num_elts, data);
+ return SummarizeArray<int64>(limit, num_elts, shape_, data);
break;
case DT_BOOL:
// TODO(tucker): Is it better to emit "True False..."? This
// will emit "1 0..." which is more compact.
- return SummarizeArray<bool>(limit, num_elts, data);
+ return SummarizeArray<bool>(limit, num_elts, shape_, data);
break;
default: {
// All irregular cases