diff options
Diffstat (limited to 'tensorflow/core/kernels/restore_op.cc')
-rw-r--r-- | tensorflow/core/kernels/restore_op.cc | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/restore_op.cc b/tensorflow/core/kernels/restore_op.cc new file mode 100644 index 0000000000..b52c69449c --- /dev/null +++ b/tensorflow/core/kernels/restore_op.cc @@ -0,0 +1,65 @@ +// See docs in ../ops/io_ops.cc. +#include "tensorflow/core/kernels/io.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_slice_reader.h" + +namespace tensorflow { + +class RestoreOp : public OpKernel { + public: + explicit RestoreOp(OpKernelConstruction* context) : OpKernel(context) { + int preferred_shard; + OP_REQUIRES_OK(context, + context->GetAttr("preferred_shard", &preferred_shard)); + if (preferred_shard == -1) { + preferred_shard_ = checkpoint::TensorSliceReader::kLoadAllShards; + } else { + OP_REQUIRES(context, preferred_shard >= 0, + errors::InvalidArgument("Attribute 'preferred_shard' must be " + "greater or equal to -1")); + preferred_shard_ = preferred_shard; + } + } + void Compute(OpKernelContext* context) override { + RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, + preferred_shard_, false); + } + + private: + int preferred_shard_; +}; + +REGISTER_KERNEL_BUILDER(Name("Restore").Device(DEVICE_CPU), RestoreOp); + +class RestoreSliceOp : public OpKernel { + public: + explicit RestoreSliceOp(OpKernelConstruction* context) : OpKernel(context) { + int preferred_shard; + OP_REQUIRES_OK(context, + context->GetAttr("preferred_shard", &preferred_shard)); + if (preferred_shard == -1) { + preferred_shard_ = checkpoint::TensorSliceReader::kLoadAllShards; + } else { + OP_REQUIRES(context, preferred_shard >= 0, + errors::InvalidArgument("Attribute 'preferred_shard' must be " + "greater or equal to -1")); + preferred_shard_ = preferred_shard; + } + } + void Compute(OpKernelContext* context) override { + RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, + preferred_shard_, true); + } + + private: + int preferred_shard_; +}; + +REGISTER_KERNEL_BUILDER(Name("RestoreSlice").Device(DEVICE_CPU), + RestoreSliceOp); + +} // namespace tensorflow |