aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/dataset.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/dataset.h')
-rw-r--r--tensorflow/core/framework/dataset.h120
1 files changed, 61 insertions, 59 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 66e836f9a6..e0c26d9286 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -40,6 +40,8 @@ limitations under the License.
namespace tensorflow {
+class DatasetBase;
+
// Interface for reading values from a key-value store.
// Used for restoring iterator state.
class IteratorStateReader {
@@ -66,7 +68,6 @@ class IteratorStateWriter {
// Forward declarations to avoid introducing a dependency on headers in
// "tensorflow/core/graph/...".
class GraphDefBuilder;
-class GraphDatasetBase;
class Node;
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
@@ -120,7 +121,7 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- Status AddDataset(const GraphDatasetBase* dataset,
+ Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
}
@@ -133,7 +134,7 @@ class GraphDefBuilderWrapper {
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
// non-null if the method returns with an OK status.
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
- Status AddDataset(const GraphDatasetBase* dataset,
+ Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
Node** output) {
@@ -145,7 +146,7 @@ class GraphDefBuilderWrapper {
}
Status AddDataset(
- const GraphDatasetBase* dataset,
+ const DatasetBase* dataset,
const std::vector<std::pair<size_t, Node*>>& inputs,
const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
@@ -276,6 +277,19 @@ class IteratorContext {
explicit IteratorContext(Params params) : params_(std::move(params)) {}
+ explicit IteratorContext(OpKernelContext* ctx) {
+ params_.env = ctx->env();
+ params_.runner = *(ctx->runner());
+ params_.lib = ctx->function_library();
+ // NOTE: must use reinterpret_cast because function.h forward-declares
+ // Device.
+ DeviceBase* device =
+ reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
+ params_.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ }
+
Env* env() const { return params_.env; }
std::function<void(std::function<void()>)>* runner() {
@@ -355,6 +369,11 @@ class IteratorBase {
virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) = 0;
+ Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ return GetNext(&ctx, out_tensors, end_of_sequence);
+ }
+
// Returns a vector of DataType values, representing the respective
// element types of each tuple component in the outputs of this
// iterator.
@@ -406,10 +425,40 @@ class IteratorBase {
}
};
+// Represents runtime information needed to construct a dataset.
+class DatasetContext {
+ public:
+ struct Params {
+ string name;
+ };
+
+ explicit DatasetContext(Params params) : params_(std::move(params)) {}
+
+ explicit DatasetContext(OpKernelContext* ctx) {
+ params_.name = ctx->op_kernel().type_string();
+ }
+
+ const string& name() const { return params_.name; }
+
+ private:
+ Params params_;
+};
+
// Represents a (potentially infinite) range of outputs, where each
// output is a tuple of tensors.
class DatasetBase : public core::RefCounted {
public:
+ // Key for storing the Dataset graph in the serialized format.
+ TF_EXPORT static const char kDatasetGraphKey[];
+
+ // Key for storing the output node of the Dataset graph in the serialized
+ // format.
+ TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
+
+ explicit DatasetBase(DatasetContext&& ctx) : name_(ctx.name()) {}
+
+ const string& name() const { return name_; }
+
// Returns a new iterator for iterating over the range of elements in
// this dataset.
//
@@ -426,6 +475,11 @@ class DatasetBase : public core::RefCounted {
return (*iterator)->Initialize(ctx);
}
+ Status MakeIterator(IteratorContext&& ctx, const string& prefix,
+ std::unique_ptr<IteratorBase>* iterator) const {
+ return MakeIterator(&ctx, prefix, iterator);
+ }
+
// Returns a vector of DataType values, representing the respective
// element types of each tuple component in the outputs of this
// dataset.
@@ -441,16 +495,9 @@ class DatasetBase : public core::RefCounted {
// Serializes the dataset and writes it to the `writer`.
virtual Status Save(SerializationContext* ctx,
- IteratorStateWriter* writer) const {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
- }
+ IteratorStateWriter* writer) const;
protected:
- // TODO(srbs): Ideally all graph related logic should reside in
- // GraphDatasetBase. However, that would require Datasets defined in all ops
- // to derive from GraphDatasetBase. Once that is done we can move
- // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase.
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
public:
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
@@ -463,54 +510,15 @@ class DatasetBase : public core::RefCounted {
// TODO(jsimsa): Consolidate overloading into a single method.
virtual Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
- Node** node) const {
- return AsGraphDefInternal(b, node);
- }
-
- virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
- Node** node) const {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
- }
+ Node** node) const = 0;
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
friend class DatasetToGraphOp; // For access to graph related members.
-};
-
-// Base-class for datasets that are built by ops.
-class GraphDatasetBase : public DatasetBase {
- public:
- GraphDatasetBase(OpKernelContext* ctx)
- : op_name_(ctx->op_kernel().type_string()) {}
-
- const string op_name() const { return op_name_; }
-
- Status Save(SerializationContext* ctx,
- IteratorStateWriter* writer) const override {
- string serialized_graph_def;
- string output_node;
- TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
- return Status::OK();
- }
-
- // Key for storing the Dataset graph in the serialized format.
- TF_EXPORT static const char kDatasetGraphKey[];
-
- // Key for storing the output node of the Dataset graph in the serialized
- // format.
- TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
private:
- Status Serialize(SerializationContext* ctx, string* serialized_graph_def,
- string* output_node) const;
-
- const string op_name_;
+ const string name_;
};
// Represents an iterator that is associated with a particular dataset.
@@ -718,12 +726,6 @@ class BackgroundWorker {
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
};
-namespace dataset {
-
-IteratorContext MakeIteratorContext(OpKernelContext* ctx);
-
-} // namespace dataset
-
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_