aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/save_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/save_op.cc')
-rw-r--r--tensorflow/core/kernels/save_op.cc81
1 files changed, 81 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/save_op.cc b/tensorflow/core/kernels/save_op.cc
new file mode 100644
index 0000000000..71a15c643e
--- /dev/null
+++ b/tensorflow/core/kernels/save_op.cc
@@ -0,0 +1,81 @@
+// See docs in ../ops/io_ops.cc
+#include "tensorflow/core/kernels/io.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+namespace tensorflow {
+
+class SaveOp : public OpKernel {
+ public:
+ explicit SaveOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ SaveTensors(context, &checkpoint::CreateTableTensorSliceBuilder, false);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
+
+class SaveSlicesOp : public OpKernel {
+ public:
+ explicit SaveSlicesOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ SaveTensors(context, &checkpoint::CreateTableTensorSliceBuilder, true);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("SaveSlices").Device(DEVICE_CPU), SaveSlicesOp);
+
+class ShardedFilenameOp : public OpKernel {
+ public:
+ explicit ShardedFilenameOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ static const char* input_names[3] = {"basename", "shard", "num_shards"};
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
+ errors::InvalidArgument(
+ input_names[i], " must be a scalar, got shape ",
+ ctx->input(i).shape().ShortDebugString()));
+ }
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
+ out->scalar<string>()() = strings::Printf(
+ "%s-%05d-of-%05d", ctx->input(0).scalar<string>()().c_str(),
+ ctx->input(1).scalar<int32>()(), ctx->input(2).scalar<int32>()());
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ShardedFilename").Device(DEVICE_CPU),
+ ShardedFilenameOp);
+
+class ShardedFilespecOp : public OpKernel {
+ public:
+ explicit ShardedFilespecOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ static const char* input_names[2] = {"basename", "num_shards"};
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
+ errors::InvalidArgument(
+ input_names[i], " must be a scalar, got shape ",
+ ctx->input(i).shape().ShortDebugString()));
+ }
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
+ out->scalar<string>()() = strings::Printf(
+ "%s-\?\?\?\?\?-of-%05d", ctx->input(0).scalar<string>()().c_str(),
+ ctx->input(1).scalar<int32>()());
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("ShardedFilespec").Device(DEVICE_CPU),
+ ShardedFilespecOp);
+
+} // namespace tensorflow