aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/save_op_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/save_op_test.cc')
-rw-r--r--tensorflow/core/kernels/save_op_test.cc443
1 files changed, 443 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/save_op_test.cc b/tensorflow/core/kernels/save_op_test.cc
new file mode 100644
index 0000000000..ee1ba492a6
--- /dev/null
+++ b/tensorflow/core/kernels/save_op_test.cc
@@ -0,0 +1,443 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class SaveOpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "Save")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput(
+ {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32}))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(SaveOpTest, Simple) {
+ const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
+ const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double",
+ "tensor_qint8", "tensor_qint32"};
+
+ MakeOp();
+ // Add a file name
+ AddInput<string>(TensorShape({}),
+ [&filename](int x) -> string { return filename; });
+
+ // Add the tensor names
+ AddInput<string>(TensorShape({5}),
+ [&tensornames](int x) -> string { return tensornames[x]; });
+
+ // Add a 1-d integer tensor
+ AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
+
+ // Add a 2-d float tensor
+ AddInput<float>(TensorShape({2, 4}),
+ [](int x) -> float { return static_cast<float>(x) / 10; });
+
+ // Add a 2-d double tensor
+ AddInput<double>(TensorShape({2, 4}),
+ [](int x) -> double { return static_cast<double>(x) / 20; });
+
+ // Add a 2-d qint8 tensor
+ AddInput<qint8>(TensorShape({3, 2}),
+ [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
+
+ // Add a 2-d qint32 tensor
+ AddInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 {
+ return *reinterpret_cast<qint32*>(&x) * qint8(2);
+ });
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check that the checkpoint file is properly written
+ checkpoint::TensorSliceReader reader(filename,
+ checkpoint::OpenTableTensorSliceReader);
+ EXPECT_OK(reader.status());
+
+ // We expect to find all saved tensors
+ {
+ // The 1-d integer tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
+ TensorShape expected({10});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_INT32, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-");
+ int data[10];
+ std::fill_n(data, 10, 0);
+ EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(i + 1, data[i]);
+ }
+ }
+
+ {
+ // The 2-d float tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
+ TensorShape expected({2, 4});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_FLOAT, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ float data[8];
+ std::fill_n(data, 8, 0);
+ EXPECT_TRUE(reader.CopySliceData("tensor_float", s, data));
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(static_cast<float>(i) / 10, data[i]);
+ }
+ }
+
+ {
+ // The 2-d double tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
+ TensorShape expected({2, 4});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_DOUBLE, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ double data[8];
+ std::fill_n(data, 8, 0);
+ EXPECT_TRUE(reader.CopySliceData("tensor_double", s, data));
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(static_cast<double>(i) / 20, data[i]);
+ }
+ }
+
+ {
+ // The 2-d qint8 tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
+ TensorShape expected({3, 2});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_QINT8, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ qint8 data[6];
+ EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(*reinterpret_cast<qint8*>(&i), data[i]);
+ }
+ }
+
+ {
+ // The 2-d qint32 tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
+ TensorShape expected({2, 3});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_QINT32, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ qint32 data[6];
+ EXPECT_TRUE(reader.CopySliceData("tensor_qint32", s, data));
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2), data[i]);
+ }
+ }
+}
+
+class SaveSlicesOpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput(
+ {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32}))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+// Here we save only slices. We restore them in a larger tensor and we check
+// that the right slice is restored. It is quite tricky to check that the
+// right slices are actually restored so instead we just check that
+// CopySliceData() return true/false depending on the slice we ask for.
+TEST_F(SaveSlicesOpTest, Slices) {
+ const string filename = io::JoinPath(testing::TmpDir(), "tensor_slices");
+ const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double",
+ "tensor_qint8", "tensor_qint32"};
+ // Specifies that the data we save are slices of larger tensors.
+ // See core/framework/tensor_slice.h for the slice syntax.
+ const string tensorshapes[] = {
+ "10 -", // Full contents of a 10 element vector.
+ "2 4 -:0,2", // A 2x2 slice of a 2x4 tensor.
+ "2 4 0,1:2,2", // A 1x2 slice of a 2x4 tensor.
+ "3 2 -:-", // Full contents of a 3x2 tensor.
+ "2 3 1,1:2,1" // Another 1x1 slice of a2x3 tensor.
+ };
+
+ MakeOp();
+ // Add a file name
+ AddInput<string>(TensorShape({}),
+ [&filename](int x) -> string { return filename; });
+
+ // Add the tensor names
+ AddInput<string>(TensorShape({5}),
+ [&tensornames](int x) -> string { return tensornames[x]; });
+
+ // Add the tensor shapes and slices
+ AddInput<string>(TensorShape({5}), [&tensorshapes](int x) -> string {
+ return tensorshapes[x];
+ });
+
+ // Add a 1-d integer tensor
+ AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
+
+ // Add a 2-d float tensor
+ AddInput<float>(TensorShape({2, 2}),
+ [](int x) -> float { return static_cast<float>(x) / 10; });
+
+ // Add a 2-d double tensor
+ AddInput<double>(TensorShape({1, 2}),
+ [](int x) -> double { return static_cast<double>(x) / 20; });
+
+ // Add a 2-d qint8 tensor
+ AddInput<qint8>(TensorShape({3, 2}),
+ [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
+
+ // Add a 2-d qint32 tensor
+ AddInput<qint32>(TensorShape({1, 1}), [](int x) -> qint32 {
+ return *reinterpret_cast<qint32*>(&x) * qint8(2);
+ });
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check that the checkpoint file is properly written
+ checkpoint::TensorSliceReader reader(filename,
+ checkpoint::OpenTableTensorSliceReader);
+ EXPECT_OK(reader.status());
+
+ // We expect to find all saved tensors
+ {
+ // The 1-d integer tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
+ TensorShape expected({10});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_INT32, type);
+
+ // We saved the full tensor so we should be able to read it all.
+ TensorSlice s = TensorSlice::ParseOrDie("-");
+ int data[10];
+ EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
+ }
+
+ {
+ // The 2-d float tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
+ TensorShape expected({2, 4});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_FLOAT, type);
+
+ // We saved the slice "-:0,2" so we should not be able to read the full
+ // tensor.
+ TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
+ TensorSlice saved_slice = TensorSlice::ParseOrDie("-:0,2");
+ float data[8];
+ EXPECT_FALSE(reader.CopySliceData("tensor_float", full_slice, data));
+ EXPECT_TRUE(reader.CopySliceData("tensor_float", saved_slice, data));
+ }
+
+ {
+ // The 2-d double tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
+ TensorShape expected({2, 4});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_DOUBLE, type);
+
+ // We saved the slice "0,1:2,2" so we should not be able to read the full
+ // tensor.
+ TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
+ TensorSlice saved_slice = TensorSlice::ParseOrDie("0,1:2,2");
+ double data[8];
+ EXPECT_FALSE(reader.CopySliceData("tensor_double", full_slice, data));
+ EXPECT_TRUE(reader.CopySliceData("tensor_double", saved_slice, data));
+ }
+
+ {
+ // The 2-d qint8 tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
+ TensorShape expected({3, 2});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_QINT8, type);
+
+ // We saved the full slice.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ qint8 data[6];
+ EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
+ }
+
+ {
+ // The 2-d qint32 tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
+ TensorShape expected({2, 3});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_QINT32, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:2,1");
+ TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
+ TensorSlice saved_slice = TensorSlice::ParseOrDie("1,1:2,1");
+ qint32 data[6];
+ EXPECT_FALSE(reader.CopySliceData("tensor_qint32", full_slice, data));
+ EXPECT_TRUE(reader.CopySliceData("tensor_qint32", saved_slice, data));
+ }
+}
+
+class SaveOpSlices2Test : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput({DT_INT32, DT_INT32, DT_FLOAT}))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(SaveOpSlices2Test, TwoSlices) {
+ const string filename = io::JoinPath(testing::TmpDir(), "three_slices");
+ // We will save 2 slices of the tensor named "four_by_sixteen" which is 4x16,
+ // and one slice of the "small" tensor.
+ const string tensornames[] = {"four_by_sixteen", "four_by_sixteen", "small"};
+ const string tensorshapes[] = {
+ // Slice specifications for the 2 slices of "four_by_sixteen"
+ "4 16 0,2:-", // 1st slice covers indices 0 and 1 in the first dim.
+ "4 16 2,2:-", // 2nd slice covers indices 2 and 3 in the first dim.
+ "" // We save the full "small" tensors.
+ };
+
+ MakeOp();
+ // Add a file name
+ AddInput<string>(TensorShape({}),
+ [&filename](int x) -> string { return filename; });
+
+ // Add the tensor names
+ AddInput<string>(TensorShape({3}),
+ [&tensornames](int x) -> string { return tensornames[x]; });
+
+ // Add the tensor shapes and slices
+ AddInput<string>(TensorShape({3}), [&tensorshapes](int x) -> string {
+ return tensorshapes[x];
+ });
+
+ // Add an integer tensor for slice 0,2:- of a 4x16 tensor: It is 2x16.
+ AddInput<int32>(TensorShape({2, 16}), [](int x) -> int32 { return x + 1; });
+
+ // Add an integer tensor for slice 2,2:- of a 4x16 tensor: It is 2x16.
+ AddInput<int32>(TensorShape({2, 16}),
+ [](int x) -> int32 { return 10 * (x + 1); });
+
+ // Add a float tensor for "small"
+ AddInput<float>(TensorShape({2, 4}),
+ [](int x) -> float { return static_cast<float>(x) / 10; });
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check that the checkpoint file is properly written
+ checkpoint::TensorSliceReader reader(filename,
+ checkpoint::OpenTableTensorSliceReader);
+ EXPECT_OK(reader.status());
+
+ {
+ // Reload the two slices of "four_by_sixteen" into that tensor.
+ Tensor reloaded(DT_INT32, {4, 16});
+
+ // We expect to find all slices
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("four_by_sixteen", &shape, &type));
+ EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
+ EXPECT_EQ(type, reloaded.dtype());
+
+ // Reload the whole tensor.
+ EXPECT_TRUE(reader.CopySliceData("four_by_sixteen",
+ TensorSlice(reloaded.dims()),
+ reloaded.flat<int>().data()));
+
+ {
+ auto slice = reloaded.Slice(0, 2).flat<int>();
+ for (int i = 0; i < slice.size(); ++i) {
+ EXPECT_EQ(i + 1, slice(i));
+ }
+ }
+ {
+ auto slice = reloaded.Slice(2, 4).flat<int>();
+ for (int i = 0; i < slice.size(); ++i) {
+ EXPECT_EQ(10 * (i + 1), slice(i));
+ }
+ }
+ }
+
+ {
+ // Reload the small float tensor.
+ Tensor reloaded(DT_FLOAT, {2, 4});
+
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("small", &shape, &type));
+ EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
+ EXPECT_EQ(DT_FLOAT, reloaded.dtype());
+
+ EXPECT_TRUE(reader.CopySliceData("small", TensorSlice(reloaded.dims()),
+ reloaded.flat<float>().data()));
+
+ for (int64 i = 0; i < reloaded.NumElements(); ++i) {
+ EXPECT_EQ(static_cast<float>(i) / 10, reloaded.flat<float>().data()[i]);
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow