1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
// See docs in ../ops/io_ops.cc.
#include "tensorflow/core/kernels/io.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_slice_reader.h"
namespace tensorflow {
class RestoreOp : public OpKernel {
public:
explicit RestoreOp(OpKernelConstruction* context) : OpKernel(context) {
int preferred_shard;
OP_REQUIRES_OK(context,
context->GetAttr("preferred_shard", &preferred_shard));
if (preferred_shard == -1) {
preferred_shard_ = checkpoint::TensorSliceReader::kLoadAllShards;
} else {
OP_REQUIRES(context, preferred_shard >= 0,
errors::InvalidArgument("Attribute 'preferred_shard' must be "
"greater or equal to -1"));
preferred_shard_ = preferred_shard;
}
}
void Compute(OpKernelContext* context) override {
RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
preferred_shard_, false);
}
private:
int preferred_shard_;
};
REGISTER_KERNEL_BUILDER(Name("Restore").Device(DEVICE_CPU), RestoreOp);
class RestoreSliceOp : public OpKernel {
public:
explicit RestoreSliceOp(OpKernelConstruction* context) : OpKernel(context) {
int preferred_shard;
OP_REQUIRES_OK(context,
context->GetAttr("preferred_shard", &preferred_shard));
if (preferred_shard == -1) {
preferred_shard_ = checkpoint::TensorSliceReader::kLoadAllShards;
} else {
OP_REQUIRES(context, preferred_shard >= 0,
errors::InvalidArgument("Attribute 'preferred_shard' must be "
"greater or equal to -1"));
preferred_shard_ = preferred_shard;
}
}
void Compute(OpKernelContext* context) override {
RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
preferred_shard_, true);
}
private:
int preferred_shard_;
};
REGISTER_KERNEL_BUILDER(Name("RestoreSlice").Device(DEVICE_CPU),
RestoreSliceOp);
} // namespace tensorflow
|