aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/text_literal_reader.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/text_literal_reader.cc')
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc155
1 files changed, 155 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
new file mode 100644
index 0000000000..7876272467
--- /dev/null
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/text_literal_reader.h"
+
+#include <limits>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
+ tensorflow::StringPiece path) {
+ CHECK(!path.ends_with(".gz"))
+ << "TextLiteralReader no longer supports reading .gz files";
+ std::unique_ptr<tensorflow::RandomAccessFile> file;
+ Status s =
+ tensorflow::Env::Default()->NewRandomAccessFile(path.ToString(), &file);
+ if (!s.ok()) {
+ return s;
+ }
+
+ TextLiteralReader reader(file.release());
+ return reader.ReadAllLines();
+}
+
+TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file)
+ : file_(file) {}
+
+namespace {
+// This is an optimized version of tensorflow::str_util::Split which uses
+// StringPiece for the delimited strings and uses an out parameter for the
+// result to avoid vector creation/destruction.
+void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim,
+ std::vector<tensorflow::StringPiece>* result) {
+ result->clear();
+
+ if (text.empty()) {
+ return;
+ }
+
+ // The following loop is a little strange: its bound is text.size() + 1
+ // instead of the more typical text.size().
+ // The final iteration of the loop (when i is equal to text.size()) handles
+ // the trailing token.
+ size_t token_start = 0;
+ for (size_t i = 0; i < text.size() + 1; i++) {
+ if (i == text.size() || text[i] == delim) {
+ tensorflow::StringPiece token(text.data() + token_start, i - token_start);
+ result->push_back(token);
+ token_start = i + 1;
+ }
+ }
+}
+} // namespace
+
+StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
+ tensorflow::io::RandomAccessInputStream stream(file_.get());
+ tensorflow::io::BufferedInputStream buf(&stream, 65536);
+ string shape_string;
+ Status s = buf.ReadLine(&shape_string);
+ if (!s.ok()) {
+ return s;
+ }
+
+ tensorflow::StringPiece sp(shape_string);
+ if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) {
+ string tmp = sp.ToString();
+ shape_string = tmp;
+ }
+ TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string));
+ if (shape.element_type() != F32) {
+ return Unimplemented(
+ "unsupported element type for text literal reading: %s",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+
+ auto result = MakeUnique<Literal>();
+ const float fill = std::numeric_limits<float>::quiet_NaN();
+ LiteralUtil::PopulateWithValue<float>(fill, AsInt64Slice(shape.dimensions()),
+ result.get());
+ std::vector<tensorflow::StringPiece> pieces;
+ std::vector<tensorflow::StringPiece> coordinates;
+ std::vector<int64> coordinate_values;
+ string line;
+ while (buf.ReadLine(&line).ok()) {
+ SplitByDelimToStringPieces(line, ':', &pieces);
+ tensorflow::StringPiece coordinates_string = pieces[0];
+ tensorflow::StringPiece value_string = pieces[1];
+ tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string);
+ tensorflow::str_util::RemoveWhitespaceContext(&value_string);
+ if (!coordinates_string.Consume("(")) {
+ return InvalidArgument(
+ "expected '(' at the beginning of coordinates: \"%s\"", line.c_str());
+ }
+ if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) {
+ return InvalidArgument("expected ')' at the end of coordinates: \"%s\"",
+ line.c_str());
+ }
+ float value;
+ if (!tensorflow::strings::safe_strtof(value_string.ToString().c_str(),
+ &value)) {
+ return InvalidArgument("could not parse value as float: \"%s\"",
+ value_string.ToString().c_str());
+ }
+ SplitByDelimToStringPieces(coordinates_string, ',', &coordinates);
+ coordinate_values.clear();
+ for (tensorflow::StringPiece piece : coordinates) {
+ int64 coordinate_value;
+ if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) {
+ return InvalidArgument(
+ "could not parse coordinate member as int64: \"%s\"",
+ piece.ToString().c_str());
+ }
+ coordinate_values.push_back(coordinate_value);
+ }
+ if (coordinate_values.size() != shape.dimensions_size()) {
+ return InvalidArgument(
+ "line did not have expected number of coordinates; want %d got %zu: "
+ "\"%s\"",
+ shape.dimensions_size(), coordinate_values.size(), line.c_str());
+ }
+ LiteralUtil::Set<float>(result.get(), coordinate_values, value);
+ }
+ return std::move(result);
+}
+
+} // namespace xla