aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Alexander Gorban <gorban@google.com>2018-05-29 17:51:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 17:54:03 -0700
commit7a4d278a3dbb71c0d707e2c5e99423489099f441 (patch)
treebd252479e10bb5f1a1ef0134191c912d3697df1d /tensorflow/core
parentce88b47799caa472509a34c6c2e4265e2d16ceb9 (diff)
Convenience functions to create TensorProto directly from data (std::vector).
PiperOrigin-RevId: 198486802
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/framework/tensor_util.cc9
-rw-r--r--tensorflow/core/framework/tensor_util.h103
-rw-r--r--tensorflow/core/framework/tensor_util_test.cc140
3 files changed, 252 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc
index 8e3ac25512..65f6dc1c00 100644
--- a/tensorflow/core/framework/tensor_util.cc
+++ b/tensorflow/core/framework/tensor_util.cc
@@ -168,5 +168,14 @@ Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
return Status::OK();
}
+namespace internal {
+void SetTensorProtoShape(std::vector<size_t> shape,
+ TensorShapeProto* shape_proto) {
+ for (auto dim : shape) {
+ shape_proto->mutable_dim()->Add()->set_size(dim);
+ }
+}
+} // namespace internal
+
} // namespace tensor
} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
index 6c218b69e0..43d2d95311 100644
--- a/tensorflow/core/framework/tensor_util.h
+++ b/tensorflow/core/framework/tensor_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
#include <vector>
namespace tensorflow {
@@ -54,6 +55,108 @@ Status Concat(const gtl::ArraySlice<Tensor>& tensors,
Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
std::vector<Tensor>* result) TF_MUST_USE_RESULT;
+namespace internal {
+void SetTensorProtoShape(std::vector<size_t> shape,
+ TensorShapeProto* shape_proto);
+
+// Defines value type dependent methods to manipulate `TensorProto`.
+// Class specializations has to define following methods:
+// static DataType GetDataType()
+// static void AddValue(Type value, TensorProto* proto)
+template <typename Type>
+class TensorProtoHelper : public std::false_type {};
+
+template <>
+class TensorProtoHelper<string> : public std::true_type {
+ public:
+ static DataType GetDataType() { return DataType::DT_STRING; }
+ static void AddValue(const string& value, TensorProto* proto) {
+ *proto->mutable_string_val()->Add() = value;
+ }
+};
+
+template <>
+class TensorProtoHelper<int32> : public std::true_type {
+ public:
+ static DataType GetDataType() { return DataType::DT_INT32; }
+ static void AddValue(int32 value, TensorProto* proto) {
+ proto->mutable_int_val()->Add(value);
+ }
+};
+
+template <>
+class TensorProtoHelper<int64> : public std::true_type {
+ public:
+ static DataType GetDataType() { return DataType::DT_INT64; }
+ static void AddValue(int64 value, TensorProto* proto) {
+ proto->mutable_int64_val()->Add(value);
+ }
+};
+
+template <>
+class TensorProtoHelper<uint32> : public std::true_type {
+ public:
+ static DataType GetDataType() { return DataType::DT_UINT32; }
+ static void AddValue(uint32 value, TensorProto* proto) {
+ proto->mutable_uint32_val()->Add(value);
+ }
+};
+
+template <>
+class TensorProtoHelper<uint64> : public std::true_type {
+ public:
+ static DataType GetDataType() { return DataType::DT_UINT64; }
+ static void AddValue(uint64 value, TensorProto* proto) {
+ proto->mutable_uint64_val()->Add(value);
+ }
+};
+
+template <>
+class TensorProtoHelper<float> : public std::true_type {
+ public:
+ static DataType GetDataType() { return DataType::DT_FLOAT; }
+ static void AddValue(float value, TensorProto* proto) {
+ proto->mutable_float_val()->Add(value);
+ }
+};
+
+template <>
+class TensorProtoHelper<double> : public std::true_type {
+ public:
+ static DataType GetDataType() { return DataType::DT_DOUBLE; }
+ static void AddValue(double value, TensorProto* proto) {
+ proto->mutable_double_val()->Add(value);
+ }
+};
+
+template <>
+class TensorProtoHelper<bool> : public std::true_type {
+ public:
+ static DataType GetDataType() { return DataType::DT_BOOL; }
+ static void AddValue(bool value, TensorProto* proto) {
+ proto->mutable_bool_val()->Add(value);
+ }
+};
+} // namespace internal
+
+// Creates a 'TensorProto' with specified shape and values.
+// The dtype and a field to represent data values of the returned 'TensorProto'
+// are determined based on type of the 'values' parameter.
+template <typename Type>
+typename std::enable_if<internal::TensorProtoHelper<Type>::value,
+ TensorProto>::type
+CreateTensorProto(const std::vector<Type>& values,
+ const std::vector<size_t>& shape) {
+ TensorProto tensor;
+ using TypeHelper = internal::TensorProtoHelper<Type>;
+ tensor.set_dtype(TypeHelper::GetDataType());
+ internal::SetTensorProtoShape(shape, tensor.mutable_tensor_shape());
+ for (const auto& value : values) {
+ TypeHelper::AddValue(value, &tensor);
+ }
+ return tensor;
+}
+
} // namespace tensor
} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc
index 69eb8363b2..2b4e1cad2f 100644
--- a/tensorflow/core/framework/tensor_util_test.cc
+++ b/tensorflow/core/framework/tensor_util_test.cc
@@ -226,5 +226,145 @@ TEST(TensorUtil, ConcatSplitStrings) {
}
}
+TEST(TensorProtoUtil, CreatesStringTensorProto) {
+ std::vector<string> values{"a", "b", "c"};
+ std::vector<size_t> shape{1, 3};
+
+ auto proto = tensor::CreateTensorProto(values, shape);
+
+ EXPECT_EQ(proto.DebugString(),
+ "dtype: DT_STRING\n"
+ "tensor_shape {\n"
+ " dim {\n"
+ " size: 1\n"
+ " }\n"
+ " dim {\n"
+ " size: 3\n"
+ " }\n"
+ "}\n"
+ "string_val: \"a\"\n"
+ "string_val: \"b\"\n"
+ "string_val: \"c\"\n");
+}
+
+TEST(TensorProtoUtil, CreatesInt32TensorProto) {
+ std::vector<int32> values{1, 2};
+ std::vector<size_t> shape{2};
+
+ auto proto = tensor::CreateTensorProto(values, shape);
+
+ EXPECT_EQ(proto.DebugString(),
+ "dtype: DT_INT32\n"
+ "tensor_shape {\n"
+ " dim {\n"
+ " size: 2\n"
+ " }\n"
+ "}\n"
+ "int_val: 1\n"
+ "int_val: 2\n");
+}
+
+TEST(TensorProtoUtil, CreatesInt64TensorProto) {
+ std::vector<int64> values{1, 2};
+ std::vector<size_t> shape{2};
+
+ auto proto = tensor::CreateTensorProto(values, shape);
+
+ EXPECT_EQ(proto.DebugString(),
+ "dtype: DT_INT64\n"
+ "tensor_shape {\n"
+ " dim {\n"
+ " size: 2\n"
+ " }\n"
+ "}\n"
+ "int64_val: 1\n"
+ "int64_val: 2\n");
+}
+
+TEST(TensorProtoUtil, CreatesUInt32TensorProto) {
+ std::vector<uint32> values{1, 2};
+ std::vector<size_t> shape{2};
+
+ auto proto = tensor::CreateTensorProto(values, shape);
+
+ EXPECT_EQ(proto.DebugString(),
+ "dtype: DT_UINT32\n"
+ "tensor_shape {\n"
+ " dim {\n"
+ " size: 2\n"
+ " }\n"
+ "}\n"
+ "uint32_val: 1\n"
+ "uint32_val: 2\n");
+}
+
+TEST(TensorProtoUtil, CreatesUInt64TensorProto) {
+ std::vector<uint64> values{1, 2};
+ std::vector<size_t> shape{2};
+
+ auto proto = tensor::CreateTensorProto(values, shape);
+
+ EXPECT_EQ(proto.DebugString(),
+ "dtype: DT_UINT64\n"
+ "tensor_shape {\n"
+ " dim {\n"
+ " size: 2\n"
+ " }\n"
+ "}\n"
+ "uint64_val: 1\n"
+ "uint64_val: 2\n");
+}
+
+TEST(TensorProtoUtil, CreatesFloatTensorProto) {
+ std::vector<float> values{1.1, 2.2};
+ std::vector<size_t> shape{2};
+
+ auto proto = tensor::CreateTensorProto(values, shape);
+
+ EXPECT_EQ(proto.DebugString(),
+ "dtype: DT_FLOAT\n"
+ "tensor_shape {\n"
+ " dim {\n"
+ " size: 2\n"
+ " }\n"
+ "}\n"
+ "float_val: 1.1\n"
+ "float_val: 2.2\n");
+}
+
+TEST(TensorProtoUtil, CreatesDoubleTensorProto) {
+ std::vector<double> values{1.1, 2.2};
+ std::vector<size_t> shape{2};
+
+ auto proto = tensor::CreateTensorProto(values, shape);
+
+ EXPECT_EQ(proto.DebugString(),
+ "dtype: DT_DOUBLE\n"
+ "tensor_shape {\n"
+ " dim {\n"
+ " size: 2\n"
+ " }\n"
+ "}\n"
+ "double_val: 1.1\n"
+ "double_val: 2.2\n");
+}
+
+TEST(TensorProtoUtil, CreatesBoolTensorProto) {
+ std::vector<bool> values{true, false};
+ std::vector<size_t> shape{2};
+
+ auto proto = tensor::CreateTensorProto(values, shape);
+
+ EXPECT_EQ(proto.DebugString(),
+ "dtype: DT_BOOL\n"
+ "tensor_shape {\n"
+ " dim {\n"
+ " size: 2\n"
+ " }\n"
+ "}\n"
+ "bool_val: true\n"
+ "bool_val: false\n");
+}
+
} // namespace
} // namespace tensorflow