diff options
author | Alexander Gorban <gorban@google.com> | 2018-05-29 17:51:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-29 17:54:03 -0700 |
commit | 7a4d278a3dbb71c0d707e2c5e99423489099f441 (patch) | |
tree | bd252479e10bb5f1a1ef0134191c912d3697df1d /tensorflow/core | |
parent | ce88b47799caa472509a34c6c2e4265e2d16ceb9 (diff) |
Convenience functions to create TensorProto directly from data (std::vector).
PiperOrigin-RevId: 198486802
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/framework/tensor_util.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_util.h | 103 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_util_test.cc | 140 |
3 files changed, 252 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc index 8e3ac25512..65f6dc1c00 100644 --- a/tensorflow/core/framework/tensor_util.cc +++ b/tensorflow/core/framework/tensor_util.cc @@ -168,5 +168,14 @@ Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes, return Status::OK(); } +namespace internal { +void SetTensorProtoShape(std::vector<size_t> shape, + TensorShapeProto* shape_proto) { + for (auto dim : shape) { + shape_proto->mutable_dim()->Add()->set_size(dim); + } +} +} // namespace internal + } // namespace tensor } // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index 6c218b69e0..43d2d95311 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include <vector> namespace tensorflow { @@ -54,6 +55,108 @@ Status Concat(const gtl::ArraySlice<Tensor>& tensors, Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes, std::vector<Tensor>* result) TF_MUST_USE_RESULT; +namespace internal { +void SetTensorProtoShape(std::vector<size_t> shape, + TensorShapeProto* shape_proto); + +// Defines value type dependent methods to manipulate `TensorProto`. +// Class specializations has to define following methods: +// static DataType GetDataType() +// static void AddValue(Type value, TensorProto* proto) +template <typename Type> +class TensorProtoHelper : public std::false_type {}; + +template <> +class TensorProtoHelper<string> : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_STRING; } + static void AddValue(const string& value, TensorProto* proto) { + *proto->mutable_string_val()->Add() = value; + } +}; + +template <> +class TensorProtoHelper<int32> : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_INT32; } + static void AddValue(int32 value, TensorProto* proto) { + proto->mutable_int_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper<int64> : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_INT64; } + static void AddValue(int64 value, TensorProto* proto) { + proto->mutable_int64_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper<uint32> : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_UINT32; } + static void AddValue(uint32 value, TensorProto* proto) { + proto->mutable_uint32_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper<uint64> : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_UINT64; } + static void AddValue(uint64 value, TensorProto* proto) { + proto->mutable_uint64_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper<float> : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_FLOAT; } + static void AddValue(float value, TensorProto* proto) { + proto->mutable_float_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper<double> : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_DOUBLE; } + static void AddValue(double value, TensorProto* proto) { + proto->mutable_double_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper<bool> : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_BOOL; } + static void AddValue(bool value, TensorProto* proto) { + proto->mutable_bool_val()->Add(value); + } +}; +} // namespace internal + +// Creates a 'TensorProto' with specified shape and values. +// The dtype and a field to represent data values of the returned 'TensorProto' +// are determined based on type of the 'values' parameter. +template <typename Type> +typename std::enable_if<internal::TensorProtoHelper<Type>::value, + TensorProto>::type +CreateTensorProto(const std::vector<Type>& values, + const std::vector<size_t>& shape) { + TensorProto tensor; + using TypeHelper = internal::TensorProtoHelper<Type>; + tensor.set_dtype(TypeHelper::GetDataType()); + internal::SetTensorProtoShape(shape, tensor.mutable_tensor_shape()); + for (const auto& value : values) { + TypeHelper::AddValue(value, &tensor); + } + return tensor; +} + } // namespace tensor } // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc index 69eb8363b2..2b4e1cad2f 100644 --- a/tensorflow/core/framework/tensor_util_test.cc +++ b/tensorflow/core/framework/tensor_util_test.cc @@ -226,5 +226,145 @@ TEST(TensorUtil, ConcatSplitStrings) { } } +TEST(TensorProtoUtil, CreatesStringTensorProto) { + std::vector<string> values{"a", "b", "c"}; + std::vector<size_t> shape{1, 3}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_STRING\n" + "tensor_shape {\n" + " dim {\n" + " size: 1\n" + " }\n" + " dim {\n" + " size: 3\n" + " }\n" + "}\n" + "string_val: \"a\"\n" + "string_val: \"b\"\n" + "string_val: \"c\"\n"); +} + +TEST(TensorProtoUtil, CreatesInt32TensorProto) { + std::vector<int32> values{1, 2}; + std::vector<size_t> shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_INT32\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "int_val: 1\n" + "int_val: 2\n"); +} + +TEST(TensorProtoUtil, CreatesInt64TensorProto) { + std::vector<int64> values{1, 2}; + std::vector<size_t> shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_INT64\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "int64_val: 1\n" + "int64_val: 2\n"); +} + +TEST(TensorProtoUtil, CreatesUInt32TensorProto) { + std::vector<uint32> values{1, 2}; + std::vector<size_t> shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_UINT32\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "uint32_val: 1\n" + "uint32_val: 2\n"); +} + +TEST(TensorProtoUtil, CreatesUInt64TensorProto) { + std::vector<uint64> values{1, 2}; + std::vector<size_t> shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_UINT64\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "uint64_val: 1\n" + "uint64_val: 2\n"); +} + +TEST(TensorProtoUtil, CreatesFloatTensorProto) { + std::vector<float> values{1.1, 2.2}; + std::vector<size_t> shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_FLOAT\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "float_val: 1.1\n" + "float_val: 2.2\n"); +} + +TEST(TensorProtoUtil, CreatesDoubleTensorProto) { + std::vector<double> values{1.1, 2.2}; + std::vector<size_t> shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_DOUBLE\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "double_val: 1.1\n" + "double_val: 2.2\n"); +} + +TEST(TensorProtoUtil, CreatesBoolTensorProto) { + std::vector<bool> values{true, false}; + std::vector<size_t> shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_BOOL\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "bool_val: true\n" + "bool_val: false\n"); +} + } // namespace } // namespace tensorflow |