aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/experimental/ops/writers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/experimental/ops/writers.py')
-rw-r--r--tensorflow/python/data/experimental/ops/writers.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/ops/writers.py b/tensorflow/python/data/experimental/ops/writers.py
new file mode 100644
index 0000000000..994447cb4d
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/writers.py
@@ -0,0 +1,60 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for tf.data writers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.TFRecordWriter")
+class TFRecordWriter(object):
+ """Writes data to a TFRecord file."""
+
+ def __init__(self, filename, compression_type=None):
+ self._filename = ops.convert_to_tensor(
+ filename, dtypes.string, name="filename")
+ self._compression_type = convert.optional_param_to_tensor(
+ "compression_type",
+ compression_type,
+ argument_default="",
+ argument_dtype=dtypes.string)
+
+ def write(self, dataset):
+ """Returns a `tf.Operation` to write a dataset to a file.
+
+ Args:
+ dataset: a `tf.data.Dataset` whose elements are to be written to a file
+
+ Returns:
+ A `tf.Operation` that, when run, writes contents of `dataset` to a file.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+ if (dataset.output_types != dtypes.string or
+ dataset.output_shapes != tensor_shape.scalar()):
+ raise TypeError(
+ "`dataset` must produce scalar `DT_STRING` tensors whereas it "
+ "produces shape {0} and types {1}".format(dataset.output_shapes,
+ dataset.output_types))
+ return gen_dataset_ops.dataset_to_tf_record(
+ dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access