aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_slice_writer.cc
blob: 1b5629bc6b40c85ed1f59146a7ddf732ff8325fb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include "tensorflow/core/util/tensor_slice_writer.h"

#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/table_builder.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/env.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"

namespace tensorflow {

namespace checkpoint {

namespace {

class TableBuilder : public TensorSliceWriter::Builder {
 public:
  TableBuilder(const string& name, WritableFile* f)
      : name_(name),
        file_(f),
        builder_(new table::TableBuilder(table::Options(), f)) {}
  void Add(StringPiece key, StringPiece val) override {
    builder_->Add(key, val);
  }
  Status Finish(int64* file_size) override {
    *file_size = -1;
    Status s = builder_->Finish();
    if (s.ok()) {
      s = file_->Close();
      if (s.ok()) {
        *file_size = builder_->FileSize();
      }
    }
    if (!s.ok()) {
      s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ",
                           s.ToString());
    }
    builder_.reset();
    file_.reset();
    return s;
  }

 private:
  string name_;
  std::unique_ptr<WritableFile> file_;
  std::unique_ptr<table::TableBuilder> builder_;
};
}  // anonymous namespace

Status CreateTableTensorSliceBuilder(const string& name,
                                     TensorSliceWriter::Builder** builder) {
  *builder = nullptr;
  WritableFile* f;
  Status s = Env::Default()->NewWritableFile(name, &f);
  if (s.ok()) {
    *builder = new TableBuilder(name, f);
    return Status::OK();
  } else {
    return s;
  }
}

TensorSliceWriter::TensorSliceWriter(const string& filename,
                                     CreateBuilderFunction create_builder)
    : filename_(filename),
      create_builder_(create_builder),
      tmpname_(strings::StrCat(filename, ".tempstate", random::New64())),
      slices_(0) {}

Status TensorSliceWriter::Finish() {
  Builder* b;
  Status s = create_builder_(tmpname_, &b);
  if (!s.ok()) {
    delete b;
    return s;
  }
  std::unique_ptr<Builder> builder(b);

  // We save the saved tensor slice metadata as the first element.
  string meta;
  sts_.AppendToString(&meta);
  builder->Add(kSavedTensorSlicesKey, meta);

  // Go through all the data and add them
  for (const auto& x : data_) {
    builder->Add(x.first, x.second);
  }

  int64 file_size;
  s = builder->Finish(&file_size);
  // We need to rename the file to the proper name
  if (s.ok()) {
    s = Env::Default()->RenameFile(tmpname_, filename_);
    if (s.ok()) {
      VLOG(1) << "Written " << slices_ << " slices for "
              << sts_.meta().tensor_size() << " tensors (" << file_size
              << " bytes) to " << filename_;
    } else {
      LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_;
    }
  } else {
    Env::Default()->DeleteFile(tmpname_);
  }
  return s;
}

}  // namespace checkpoint

}  // namespace tensorflow