diff options
Diffstat (limited to 'tensorflow/core/framework/tensor.cc')
-rw-r--r-- | tensorflow/core/framework/tensor.cc | 570 |
1 files changed, 570 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc new file mode 100644 index 0000000000..4a1b65db97 --- /dev/null +++ b/tensorflow/core/framework/tensor.cc @@ -0,0 +1,570 @@ +// Implementation notes: +// +// Tensor.cc uses a few templated classes and structs to facilitate +// implementation of the Tensor class. +// +// * Buffer<T>: provides the implementation for a typed array T[n]. +// The array is allocated by the given allocator. It runs T's +// default constructors and destructors when T is not a simple type +// (e.g., string.), and skips them otherwise. +// +// * Helper<T>: provides various routines given type T. The routines +// includes running the constructor and destructor of T[], encoding +// an decoding T[] into/from a Cord, etc. + +#include "tensorflow/core/public/tensor.h" + +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/tensor_coding.h" + +namespace tensorflow { +namespace { + +// Typed ref-counted buffer: T[n]. +template <typename T> +class Buffer : public TensorBuffer { + public: + Buffer(Allocator* a, int64 n); + + void* data() const override { return data_; } + size_t size() const override { return sizeof(T) * elem_; } + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription(AllocationDescription* proto) const override { + int64 rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(alloc_->Name()); + if (alloc_->TracksAllocationSizes()) { + int64 ab = alloc_->AllocatedSize(data_); + proto->set_allocated_bytes(ab); + } + } + + private: + Allocator* alloc_; + T* data_; + int64 elem_; + + ~Buffer() override; + + TF_DISALLOW_COPY_AND_ASSIGN(Buffer); +}; + +// is_simple<T>::value if T[] can be safely constructed and destructed +// without running T() and ~T(). We do not use std::is_trivial<T> +// directly because std::complex<float> is not trival but its array +// can be constructed and destructed without running its default ctor +// and dtor. +template <typename T> +struct is_simple { + static const bool value = std::is_trivial<T>::value || + std::is_same<T, complex64>::value || + is_quantized<T>::value; +}; + +template <> +struct is_simple<bfloat16> { + static const bool value = true; +}; + +// A set of helper functions depending on T. +template <typename T> +struct Helper { + // By default, we assume T is a simple type (float, int32, etc.) + static_assert(is_simple<T>::value, "T is not a simple type."); + typedef protobuf::RepeatedField<T> RepeatedFieldType; + + // No constructor to run. + static void RunCtor(T* p, int n) {} + + // No destructor to run. + static void RunDtor(T* p, int n) {} + + // Encoder of simple type T to a string. We do a copy. + template <typename Destination> + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + DCHECK_EQ(in->size(), sizeof(T) * n); + port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in, + out); + } + + // Decoder of simple type T. Copy the bytes from "in" into the + // tensor buffer. + template <typename Source> + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + if (in.size() != sizeof(T) * n) { + LOG(ERROR) << "Input size was " << in.size() << " and expected " + << sizeof(T) * n; + return nullptr; + } + Buffer<T>* buf = new Buffer<T>(a, n); + port::CopyToArray(in, buf->template base<char>()); + return buf; + } + + // Memory usage. + static int64 TotalBytes(TensorBuffer* in, int64 n) { + DCHECK_EQ(in->size(), sizeof(T) * n); + return in->size(); + } +}; + +// Helper specialization for string (the only non-simple type we +// support). +template <> +struct Helper<string> { + // Proto message uses RepeatedFieldType to hold repeated T. + typedef protobuf::RepeatedPtrField<string> RepeatedFieldType; + + // Runs string's default constructor for p[0], p[1], ..., p[n-1]. + static void RunCtor(string* p, int n) { + for (int i = 0; i < n; ++p, ++i) new (p) string(); + } + + // Runs T's default destructor for p[0], p[1], ..., p[n-1]. + static void RunDtor(string* p, int n) { + for (int i = 0; i < n; ++p, ++i) p->~string(); + } + + // Encodes "n" elements of type string stored in "in" into Cord + // "out", which is usually the TensorProto::tensor_content. + template <typename Destination> + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + port::EncodeStringList(in->base<const string>(), n, out); + } + + // Decodes "n" elements of type string from "in" and constructs a + // buffer out of it. Returns nullptr if the decoding fails. "in" is + // usually the TensorProto::tensor_content. + template <typename Source> + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + Buffer<string>* buf = new Buffer<string>(a, n); + string* strings = buf->template base<string>(); + if (port::DecodeStringList(in, strings, n)) { + return buf; + } else { + buf->Unref(); + return nullptr; + } + } + + // Returns the estimated memory usage of "n" elements of type T + // stored in buffer "in". + static int64 TotalBytes(TensorBuffer* in, int n) { + int64 tot = in->size(); + DCHECK_EQ(tot, sizeof(string) * n); + const string* p = in->base<const string>(); + for (int i = 0; i < n; ++i, ++p) tot += p->size(); + return tot; + } +}; + +template <typename T> +struct ProtoHelper {}; + +// For a C++ type "T" (float, double, int32, etc.), the repeated field +// "N"_val (float_val, int_val, label_val, etc.) of type "F" (float, +// int32, string, etc) in the TensorProto is used for serializing the +// tensor of type "T". +#define PROTO_TRAITS(T, F, N) \ + template <> \ + struct ProtoHelper<T> { \ + typedef Helper<F>::RepeatedFieldType FieldType; \ + static FieldType::const_iterator Begin(const TensorProto& proto) { \ + return proto.N##_val().begin(); \ + } \ + static size_t NumElements(const TensorProto& proto) { \ + return proto.N##_val().size(); \ + } \ + static void Fill(const T* data, size_t n, TensorProto* proto) { \ + typename ProtoHelper<T>::FieldType copy(data, data + n); \ + proto->mutable_##N##_val()->Swap(©); \ + } \ + }; +PROTO_TRAITS(float, float, float); +PROTO_TRAITS(double, double, double); +PROTO_TRAITS(int32, int32, int); +PROTO_TRAITS(uint8, int32, int); +PROTO_TRAITS(int16, int32, int); +PROTO_TRAITS(int8, int32, int); +PROTO_TRAITS(int64, int64, int64); +PROTO_TRAITS(bool, bool, bool); +PROTO_TRAITS(string, string, string); +PROTO_TRAITS(qint8, int32, int); +PROTO_TRAITS(quint8, int32, int); +#undef PROTO_TRAITS + +template <> +struct ProtoHelper<complex64> { + typedef Helper<float>::RepeatedFieldType FieldType; + static const complex64* Begin(const TensorProto& proto) { + return reinterpret_cast<const complex64*>(proto.scomplex_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.scomplex_val().size() / 2; + } + static void Fill(const complex64* data, size_t n, TensorProto* proto) { + const float* p = reinterpret_cast<const float*>(data); + FieldType copy(p, p + n * 2); + proto->mutable_scomplex_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper<qint32> { + typedef Helper<int32>::RepeatedFieldType FieldType; + static const qint32* Begin(const TensorProto& proto) { + return reinterpret_cast<const qint32*>(proto.int_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.int_val().size(); + } + static void Fill(const qint32* data, size_t n, TensorProto* proto) { + const int32* p = reinterpret_cast<const int32*>(data); + FieldType copy(p, p + n); + proto->mutable_int_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper<bfloat16> { + typedef Helper<float>::RepeatedFieldType FieldType; + static const bfloat16* Begin(const TensorProto& proto) { + return reinterpret_cast<const bfloat16*>(proto.int_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.int_val().size(); + } + static void Fill(const bfloat16* data, size_t n, TensorProto* proto) { + proto->mutable_int_val()->Reserve(n); + for (size_t i = 0; i < n; ++i) { + proto->mutable_int_val()->AddAlreadyReserved(data[i].value); + } + } +}; + +template <typename T> +Buffer<T>::Buffer(Allocator* a, int64 n) + : alloc_(a), data_(a->Allocate<T>(n)), elem_(n) { + if (data_) Helper<T>::RunCtor(data_, elem_); +} + +template <typename T> +Buffer<T>::~Buffer() { + if (data_) { + Helper<T>::RunDtor(data_, elem_); + alloc_->Deallocate<T>(data_); + } +} + +// Allocates a T[n] buffer. Fills in the buffer with repeated values +// in "in". If "in" has less values than "n", fills the rest of T[n] +// with the last value. If "in" has no values, fills T[n] with the +// default value for T. +// +// This routine is using the typed fields (float_val, etc.) in the +// tenor proto as opposed to the untyped binary representation +// (tensor_content). This is used when we expect the TensorProto is +// used by a client program which may not know how to encode a tensor +// in the compact binary representation. +template <typename T> +TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) { + CHECK_GT(n, 0); + Buffer<T>* buf = new Buffer<T>(a, n); + T* data = buf->template base<T>(); + const int64 in_n = ProtoHelper<T>::NumElements(in); + auto begin = ProtoHelper<T>::Begin(in); + if (n <= in_n) { + std::copy_n(begin, n, data); + } else if (in_n > 0) { + std::copy_n(begin, in_n, data); + const T& last = *(data + in_n - 1); + std::fill_n(data + in_n, n - in_n, last); + } else { + std::fill_n(data, n, T()); + } + return buf; +} + +// Copies T[n] stored in the buffer "in" into the repeated field in +// "out" corresponding to type T. +template <typename T> +void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) { + const T* data = in.base<const T>(); + // NOTE: T may not the same as + // ProtoHelper<T>::FieldType::value_type. E.g., T==int16, + // ProtoHelper<T>::FieldType::value_type==int32. If performance is + // critical, we can specialize T=float and do memcpy directly. + ProtoHelper<T>::Fill(data, n, out); +} + +void RefIfNonNull(core::RefCounted* buf) { + if (buf) buf->Ref(); +} + +void UnrefIfNonNull(core::RefCounted* buf) { + if (buf) buf->Unref(); +} + +} // end namespace + +Tensor::Tensor() : Tensor(DT_FLOAT) {} + +Tensor::Tensor(DataType type) : type_(type), shape_({0}), buf_(nullptr) {} + +Tensor::Tensor(const Tensor& other) + : type_(other.dtype()), shape_(other.shape()), buf_(other.buf_) { + RefIfNonNull(buf_); +} + +Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf) + : type_(type), shape_(shape), buf_(buf) { + RefIfNonNull(buf); +} + +bool Tensor::IsInitialized() const { + return buf_ != nullptr && buf_->data() != nullptr; +} + +Tensor::~Tensor() { UnrefIfNonNull(buf_); } + +void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) { + CHECK_EQ(shape.num_elements(), other.NumElements()); + type_ = other.dtype(); + shape_ = shape; + if (buf_ != other.buf_) { + UnrefIfNonNull(buf_); + buf_ = other.buf_; + RefIfNonNull(buf_); + } +} + +// The macro CASES() expands to a switch statement conditioned on +// TYPE_ENUM. Each case expands the STMTS after a typedef for T. +#define SINGLE_ARG(...) __VA_ARGS__ +#define CASE(TYPE, STMTS) \ + case DataTypeToEnum<TYPE>::value: { \ + typedef TYPE T; \ + STMTS; \ + break; \ + } +#define CASES(TYPE_ENUM, STMTS) \ + switch (TYPE_ENUM) { \ + CASE(float, SINGLE_ARG(STMTS)) \ + CASE(double, SINGLE_ARG(STMTS)) \ + CASE(int32, SINGLE_ARG(STMTS)) \ + CASE(uint8, SINGLE_ARG(STMTS)) \ + CASE(int16, SINGLE_ARG(STMTS)) \ + CASE(int8, SINGLE_ARG(STMTS)) \ + CASE(string, SINGLE_ARG(STMTS)) \ + CASE(complex64, SINGLE_ARG(STMTS)) \ + CASE(int64, SINGLE_ARG(STMTS)) \ + CASE(bool, SINGLE_ARG(STMTS)) \ + CASE(qint32, SINGLE_ARG(STMTS)) \ + CASE(quint8, SINGLE_ARG(STMTS)) \ + CASE(qint8, SINGLE_ARG(STMTS)) \ + CASE(bfloat16, SINGLE_ARG(STMTS)) \ + case DT_INVALID: \ + LOG(FATAL) << "Type not set"; \ + break; \ + default: \ + LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \ + break; \ + } + +Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape) + : type_(type), shape_(shape), buf_(nullptr) { + CHECK_NOTNULL(a); + if (shape_.num_elements() > 0) { + CASES(type, buf_ = new Buffer<T>(a, shape.num_elements())); + } +} + +Tensor::Tensor(DataType type, const TensorShape& shape) + : Tensor(cpu_allocator(), type, shape) {} + +template <typename T> +class SubBuffer : public TensorBuffer { + public: + // This buffer is an alias to buf[delta, delta + n). + SubBuffer(TensorBuffer* buf, int64 delta, int64 n) + : root_(buf->root_buffer()), data_(buf->base<T>() + delta), elem_(n) { + // Sanity check. The caller should ensure the sub buffer is valid. + CHECK_LE(root_->base<T>(), this->base<T>()); + T* root_limit = root_->base<T>() + root_->size() / sizeof(T); + CHECK_LE(this->base<T>(), root_limit); + CHECK_LE(this->base<T>() + n, root_limit); + // Hold a ref of the underlying root buffer. + // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer. + root_->Ref(); + } + + void* data() const override { return data_; } + size_t size() const override { return sizeof(T) * elem_; } + TensorBuffer* root_buffer() override { return root_; } + void FillAllocationDescription(AllocationDescription* proto) const override { + root_->FillAllocationDescription(proto); + } + + private: + TensorBuffer* root_; + T* data_; + int64 elem_; + + ~SubBuffer() override { root_->Unref(); } + + TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer); +}; + +Tensor Tensor::Slice(int64 start, int64 limit) const { + CHECK_GE(dims(), 1); + CHECK_LE(0, start); + CHECK_LE(start, limit); + int64 dim0_size = shape_.dim_size(0); + CHECK_LE(limit, dim0_size); + if ((start == 0) && (limit == dim0_size)) { + return *this; + } + Tensor ret; + ret.type_ = type_; + ret.shape_ = shape_; + ret.buf_ = nullptr; + if (dim0_size > 0) { + const int64 elems_per_dim0 = NumElements() / dim0_size; + const int64 delta = start * elems_per_dim0; + dim0_size = limit - start; + ret.shape_.set_dim(0, dim0_size); + const int64 num_elems = dim0_size * elems_per_dim0; + if (buf_) { + CASES(type_, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems)); + } + } + return ret; +} + +bool Tensor::FromProto(const TensorProto& proto) { + return FromProto(cpu_allocator(), proto); +} + +bool Tensor::FromProto(Allocator* a, const TensorProto& proto) { + CHECK_NOTNULL(a); + TensorBuffer* p = nullptr; + if (!TensorShape::IsValid(proto.tensor_shape())) return false; + if (proto.dtype() == DT_INVALID) return false; + TensorShape shape(proto.tensor_shape()); + const int64 N = shape.num_elements(); + if (N > 0 && proto.dtype()) { + if (!proto.tensor_content().empty()) { + const auto& content = proto.tensor_content(); + CASES(proto.dtype(), p = Helper<T>::Decode(a, content, N)); + } else { + CASES(proto.dtype(), p = FromProtoField<T>(a, proto, N)); + } + if (p == nullptr) return false; + } + type_ = proto.dtype(); + shape_ = shape; + UnrefIfNonNull(buf_); + buf_ = p; + return true; +} + +void Tensor::AsProtoField(TensorProto* proto) const { + proto->Clear(); + proto->set_dtype(dtype()); + shape_.AsProto(proto->mutable_tensor_shape()); + if (buf_) { + CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto)); + } +} + +void Tensor::AsProtoTensorContent(TensorProto* proto) const { + proto->Clear(); + proto->set_dtype(type_); + shape_.AsProto(proto->mutable_tensor_shape()); + if (buf_) { + CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(), + proto->mutable_tensor_content())); + } +} + +size_t Tensor::TotalBytes() const { + if (shape_.num_elements() == 0) return 0; + CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements(); + CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements())); + return 0; // Makes compiler happy. +} + +bool Tensor::CanUseDMA() const { + CASES(dtype(), return is_simple<T>::value); + return false; // Makes compiler happy. +} + +#undef CASES +#undef CASE + +string Tensor::SummarizeValue(int64 max_entries) const { + string ret; + for (int64 i = 0; i < std::min(max_entries, NumElements()); ++i) { + if (i > 0) strings::StrAppend(&ret, " "); + switch (dtype()) { + case DT_STRING: + strings::StrAppend(&ret, str_util::CEscape(flat<string>()(i))); + break; + case DT_BOOL: + strings::StrAppend(&ret, flat<bool>()(i) ? "True" : "False"); + break; + +#define CASE(DT_ENUM) \ + case DT_ENUM: \ + strings::StrAppend(&ret, flat<EnumToDataType<DT_ENUM>::Type>()(i)); \ + break + + CASE(DT_FLOAT); + CASE(DT_DOUBLE); + CASE(DT_INT32); + CASE(DT_UINT8); + CASE(DT_INT16); + CASE(DT_INT8); + CASE(DT_INT64); + +#undef CASE + default: + // TODO(zhifengc, josh11b): Pretty-print other types (bool, + // complex64, quantized, bfloat16). + strings::StrAppend(&ret, " ?"); + } + } + if (max_entries < NumElements()) strings::StrAppend(&ret, "..."); + + return ret; +} + +StringPiece Tensor::tensor_data() const { + if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors + return StringPiece(static_cast<char*>(buf_->data()), TotalBytes()); +} + +string Tensor::DebugString() const { + return strings::StrCat("Tensor<type: ", DataTypeString(dtype()), " shape: ", + shape().ShortDebugString(), " values: ", + SummarizeValue(3), ">"); +} + +void Tensor::FillDescription(TensorDescription* description) const { + description->set_dtype(dtype()); + shape().AsProto(description->mutable_shape()); + buf_->FillAllocationDescription( + description->mutable_allocation_description()); +} + +} // namespace tensorflow |