aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/io.h
blob: 7e548f1ad044e5e5dbff1a81bb595b371a8f1c52 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#ifndef TENSORFLOW_KERNELS_IO_H_
#define TENSORFLOW_KERNELS_IO_H_

#include "tensorflow/core/util/tensor_slice_reader.h"
#include "tensorflow/core/util/tensor_slice_writer.h"

namespace tensorflow {

class OpKernelContext;

// Save input tensors in *context to a writer built from builder_func().
// context must have the following inputs:
//  0: a single element string tensor that contains the file name.
//  1: names for the remaining tensors
// If save_slices is true:
//  2: shape and slice specifications.
//  rest: tensors to save
void SaveTensors(
    OpKernelContext* context,
    checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
    bool save_slices);

// Reads a tensor from the reader built from open_func() and produces it as
// context->output(0).  "preferred_shard" is the same the TensorSliceReader
// preferred_shard parameter.
//
// context must have the following inputs:
//  0: a single element string tensor that contains the file name.
//  1: a single element string tensor that names the output to be restored.
// If restore_slice is true:
//  2: shape and slice specification of the tensor to restore.
void RestoreTensor(OpKernelContext* context,
                   checkpoint::TensorSliceReader::OpenTableFunction open_func,
                   int preferred_shard, bool restore_slice);

}  // namespace tensorflow

#endif  // TENSORFLOW_KERNELS_IO_H_