diff options
Diffstat (limited to 'tensorflow/core/framework/tensor.cc')
-rw-r--r-- | tensorflow/core/framework/tensor.cc | 72 |
1 files changed, 58 insertions, 14 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index ce4ae2d3d3..05db5a32b5 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -721,14 +721,58 @@ bool Tensor::CanUseDMA() const { #undef CASE namespace { +// Print from left dim to right dim recursively. template <typename T> -string SummarizeArray(int64 limit, int64 num_elts, const char* data) { +void PrintOneDim(int dim_index, gtl::InlinedVector<int64, 4> shape, int64 limit, + int shape_size, T* data, int64* data_index, string* result) { + if (*data_index >= limit) return; + int64 element_count = shape[dim_index]; + // We have reached the right-most dimension of the tensor. + if (dim_index == shape_size - 1) { + for (int64 i = 0; i < element_count; i++) { + if (*data_index >= limit) return; + if (i > 0) strings::StrAppend(result, " "); + strings::StrAppend(result, data[(*data_index)++]); + } + return; + } + // Loop every element of one dim. + for (int64 i = 0; i < element_count; i++) { + bool flag = false; + if (*data_index < limit) { + strings::StrAppend(result, "["); + flag = true; + } + // As for each element, print the sub-dim. + PrintOneDim(dim_index + 1, shape, limit, shape_size, + data, data_index, result); + if (*data_index < limit || flag) { + strings::StrAppend(result, "]"); + flag = false; + } + } +} + +template <typename T> +string SummarizeArray(int64 limit, int64 num_elts, const TensorShape& tensor_shape, + const char* data) { string ret; const T* array = reinterpret_cast<const T*>(data); - for (int64 i = 0; i < limit; ++i) { - if (i > 0) strings::StrAppend(&ret, " "); - strings::StrAppend(&ret, array[i]); + + const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes(); + if(shape.empty()) { + for (int64 i = 0; i < limit; ++i) { + if (i > 0) strings::StrAppend(&ret, " "); + strings::StrAppend(&ret, array[i]); + } + if (num_elts > limit) strings::StrAppend(&ret, "..."); + return ret; } + int64 data_index = 0; + const int shape_size = tensor_shape.dims(); + PrintOneDim(0, shape, limit, shape_size, + array, &data_index, &ret); + if (num_elts > limit) strings::StrAppend(&ret, "..."); return ret; } @@ -744,40 +788,40 @@ string Tensor::SummarizeValue(int64 max_entries) const { const char* data = limit > 0 ? tensor_data().data() : nullptr; switch (dtype()) { case DT_HALF: - return SummarizeArray<Eigen::half>(limit, num_elts, data); + return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data); break; case DT_FLOAT: - return SummarizeArray<float>(limit, num_elts, data); + return SummarizeArray<float>(limit, num_elts, shape_, data); break; case DT_DOUBLE: - return SummarizeArray<double>(limit, num_elts, data); + return SummarizeArray<double>(limit, num_elts, shape_, data); break; case DT_INT32: - return SummarizeArray<int32>(limit, num_elts, data); + return SummarizeArray<int32>(limit, num_elts, shape_, data); break; case DT_UINT8: case DT_QUINT8: - return SummarizeArray<uint8>(limit, num_elts, data); + return SummarizeArray<uint8>(limit, num_elts, shape_, data); break; case DT_UINT16: case DT_QUINT16: - return SummarizeArray<uint16>(limit, num_elts, data); + return SummarizeArray<uint16>(limit, num_elts, shape_, data); break; case DT_INT16: case DT_QINT16: - return SummarizeArray<int16>(limit, num_elts, data); + return SummarizeArray<int16>(limit, num_elts, shape_, data); break; case DT_INT8: case DT_QINT8: - return SummarizeArray<int8>(limit, num_elts, data); + return SummarizeArray<int8>(limit, num_elts, shape_, data); break; case DT_INT64: - return SummarizeArray<int64>(limit, num_elts, data); + return SummarizeArray<int64>(limit, num_elts, shape_, data); break; case DT_BOOL: // TODO(tucker): Is it better to emit "True False..."? This // will emit "1 0..." which is more compact. - return SummarizeArray<bool>(limit, num_elts, data); + return SummarizeArray<bool>(limit, num_elts, shape_, data); break; default: { // All irregular cases |