aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/io.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/io.cc')
-rw-r--r--tensorflow/core/kernels/io.cc270
1 files changed, 270 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/io.cc b/tensorflow/core/kernels/io.cc
new file mode 100644
index 0000000000..9d6921aa8e
--- /dev/null
+++ b/tensorflow/core/kernels/io.cc
@@ -0,0 +1,270 @@
+// See docs in ../ops/io_ops.cc
+#include <unordered_map>
+
+#include "tensorflow/core/kernels/io.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+namespace tensorflow {
+
+namespace {
+bool ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
+ TensorSlice* slice, TensorShape* shape_slice,
+ string* error) {
+ CHECK(!shape_and_slice.empty());
+ // Syntax: dim0 dim1 dim2 ... <slice string>
+ // Where slice string is defined in core/framework/tensor_slice.h
+ std::vector<string> splits = str_util::Split(shape_and_slice, ' ');
+
+ // Must have at least 2 strings.
+ if (splits.size() < 2) {
+ *error = strings::StrCat(
+ "Need least two elements in shape_and_slice specification: ",
+ shape_and_slice);
+ return false;
+ }
+ int num_dims = splits.size() - 1;
+ shape->Clear();
+ for (int i = 0; i < num_dims; ++i) {
+ int dim;
+ if (!str_util::NumericParse32(splits[i], &dim)) {
+ *error = strings::StrCat("Non numerical dimension in shape_and_slice: ",
+ shape_and_slice);
+ return false;
+ }
+ shape->AddDim(dim);
+ }
+ // The last split is the slice specification.
+ slice->Clear();
+ auto status = slice->Parse(splits.back(), slice);
+ if (!status.ok()) {
+ *error = status.error_message();
+ return false;
+ }
+ // The specified slice must be compatible with the specified shape.
+ status = slice->SliceTensorShape(*shape, shape_slice);
+ if (!status.ok()) {
+ *error = status.error_message();
+ return false;
+ }
+ return true;
+}
+} // namespace
+
+void SaveTensors(
+ OpKernelContext* context,
+ checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
+ bool save_slices) {
+ const Tensor& filename_t = context->input(0);
+ {
+ const int64 size = filename_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ errors::InvalidArgument(
+ "Input 0 (filename) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+
+ const Tensor& tensor_names_t = context->input(1);
+ const int64 N = tensor_names_t.NumElements();
+ const string* tensor_shapes_and_slices_ptr = nullptr;
+ if (save_slices) {
+ const Tensor& tensor_shapes_and_slices_t = context->input(2);
+ OP_REQUIRES(
+ context, tensor_shapes_and_slices_t.NumElements() == N,
+ errors::InvalidArgument("Expected ", N,
+ " elements for the tensor "
+ "shapes and slices but got ",
+ tensor_shapes_and_slices_t.NumElements()));
+ tensor_shapes_and_slices_ptr =
+ tensor_shapes_and_slices_t.flat<string>().data();
+ }
+ // Path, names, and slices if save_slices is true.
+ const int kFixedInputs = save_slices ? 3 : 2;
+ OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
+ errors::InvalidArgument("Expected totally ", N + kFixedInputs,
+ " inputs as input #1 (which is a string "
+ "tensor of saved names) contains ",
+ N, " names, but received ",
+ context->num_inputs(), " inputs"));
+
+ VLOG(1) << "About to save tensors to file " << filename_t.flat<string>()(0)
+ << "...";
+ checkpoint::TensorSliceWriter writer(filename_t.flat<string>()(0),
+ builder_func);
+
+ Status s;
+ auto tensor_names_flat = tensor_names_t.flat<string>();
+
+ string error;
+ for (int64 i = 0; i < N; ++i) {
+ const string& name = tensor_names_flat(i);
+ const Tensor& input = context->input(i + kFixedInputs);
+ TensorShape shape(input.shape());
+ TensorSlice slice(input.dims());
+ if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
+ const string& shape_spec = tensor_shapes_and_slices_ptr[i];
+ TensorShape slice_shape;
+ OP_REQUIRES(context, ParseShapeAndSlice(shape_spec, &shape, &slice,
+ &slice_shape, &error),
+ errors::InvalidArgument(error));
+ OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
+ errors::InvalidArgument("Slice in shape_and_slice "
+ "specification does not match the "
+ "shape of the tensor to save: ",
+ shape_spec, ", tensor: ",
+ input.shape().DebugString()));
+ }
+
+#define WRITER_ADD(dt) \
+ case dt: \
+ s = writer.Add(name, shape, slice, \
+ input.flat<EnumToDataType<dt>::Type>().data()); \
+ break
+
+ switch (input.dtype()) {
+ WRITER_ADD(DT_FLOAT);
+ WRITER_ADD(DT_DOUBLE);
+ WRITER_ADD(DT_INT32);
+ WRITER_ADD(DT_UINT8);
+ WRITER_ADD(DT_INT16);
+ WRITER_ADD(DT_INT8);
+ WRITER_ADD(DT_INT64);
+ WRITER_ADD(DT_QUINT8);
+ WRITER_ADD(DT_QINT8);
+ WRITER_ADD(DT_QINT32);
+ default:
+ context->SetStatus(errors::Unimplemented("Saving data type ",
+ DataTypeString(input.dtype()),
+ " not yet supported"));
+ return;
+ }
+#undef WRITER_ADD
+ if (!s.ok()) {
+ context->SetStatus(s);
+ return;
+ }
+ }
+
+ s = writer.Finish();
+ if (!s.ok()) {
+ context->SetStatus(s);
+ }
+}
+
+void RestoreTensor(OpKernelContext* context,
+ checkpoint::TensorSliceReader::OpenTableFunction open_func,
+ int preferred_shard, bool restore_slice) {
+ const Tensor& file_pattern_t = context->input(0);
+ {
+ const int64 size = file_pattern_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ errors::InvalidArgument(
+ "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+ const string& file_pattern = file_pattern_t.flat<string>()(0);
+
+ const Tensor& tensor_name_t = context->input(1);
+ {
+ const int64 size = tensor_name_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ errors::InvalidArgument(
+ "Input 1 (tensor_name) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+ const string& tensor_name = tensor_name_t.flat<string>()(0);
+
+ const string* tensor_shape_and_slice_ptr = nullptr;
+ if (restore_slice) {
+ const Tensor& tensor_shape_and_slice_t = context->input(2);
+ OP_REQUIRES(
+ context, tensor_shape_and_slice_t.NumElements() == 1,
+ errors::InvalidArgument("Expected 1 element for the tensor "
+ "shape and slice but got ",
+ tensor_shape_and_slice_t.NumElements()));
+ tensor_shape_and_slice_ptr = tensor_shape_and_slice_t.flat<string>().data();
+ }
+
+ // If we cannot find a cached reader we will allocate our own.
+ std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
+
+ const checkpoint::TensorSliceReader* reader =
+ context->slice_reader_cache()->GetReader(file_pattern, open_func,
+ preferred_shard);
+ if (!reader) {
+ allocated_reader.reset(new checkpoint::TensorSliceReader(
+ file_pattern, open_func, preferred_shard));
+ reader = allocated_reader.get();
+ }
+ OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
+
+ // Get the shape and type from the save file.
+ DataType type;
+ TensorShape saved_shape;
+ OP_REQUIRES(
+ context, reader->HasTensor(tensor_name, &saved_shape, &type),
+ errors::NotFound("Tensor name \"", tensor_name,
+ "\" not found in checkpoint files ", file_pattern));
+ OP_REQUIRES(
+ context, type == context->expected_output_dtype(0),
+ errors::InvalidArgument("Expected to restore a tensor of type ",
+ DataTypeString(context->expected_output_dtype(0)),
+ ", got a tensor of type ", DataTypeString(type),
+ " instead: tensor_name = ", tensor_name));
+
+ // Shape of the output and slice to load.
+ TensorShape output_shape(saved_shape);
+ TensorSlice slice_to_load(saved_shape.dims());
+ if (restore_slice && !tensor_shape_and_slice_ptr[0].empty()) {
+ const string& shape_spec = tensor_shape_and_slice_ptr[0];
+ TensorShape parsed_shape;
+ string error;
+ OP_REQUIRES(context,
+ ParseShapeAndSlice(shape_spec, &parsed_shape, &slice_to_load,
+ &output_shape, &error),
+ errors::InvalidArgument(error));
+ OP_REQUIRES(
+ context, parsed_shape.IsSameSize(saved_shape),
+ errors::InvalidArgument(
+ "Shape in shape_and_slice spec does not match the shape in the "
+ "save file: ",
+ parsed_shape.DebugString(), ", save file shape: ",
+ saved_shape.DebugString()));
+ }
+
+ Tensor* t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &t));
+#define READER_COPY(dt) \
+ case dt: \
+ reader->CopySliceData(tensor_name, slice_to_load, \
+ t->flat<EnumToDataType<dt>::Type>().data()); \
+ break
+
+ switch (type) {
+ READER_COPY(DT_FLOAT);
+ READER_COPY(DT_DOUBLE);
+ READER_COPY(DT_INT32);
+ READER_COPY(DT_UINT8);
+ READER_COPY(DT_INT16);
+ READER_COPY(DT_INT8);
+ READER_COPY(DT_INT64);
+ default:
+ context->SetStatus(errors::Unimplemented(
+ "Restoring data type ", DataTypeString(type), " not yet supported"));
+ }
+}
+
+} // namespace tensorflow