diff options
Diffstat (limited to 'tensorflow/core/framework/dataset.h')
-rw-r--r-- | tensorflow/core/framework/dataset.h | 120 |
1 files changed, 61 insertions, 59 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 66e836f9a6..e0c26d9286 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -40,6 +40,8 @@ limitations under the License. namespace tensorflow { +class DatasetBase; + // Interface for reading values from a key-value store. // Used for restoring iterator state. class IteratorStateReader { @@ -66,7 +68,6 @@ class IteratorStateWriter { // Forward declarations to avoid introducing a dependency on headers in // "tensorflow/core/graph/...". class GraphDefBuilder; -class GraphDatasetBase; class Node; // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. @@ -120,7 +121,7 @@ class GraphDefBuilderWrapper { return Status::OK(); } - Status AddDataset(const GraphDatasetBase* dataset, + Status AddDataset(const DatasetBase* dataset, const std::vector<Node*>& inputs, Node** output) { return AddDataset(dataset, inputs, {}, output); } @@ -133,7 +134,7 @@ class GraphDefBuilderWrapper { // `*output` contains a pointer to the output `Node`. It is guaranteed to be // non-null if the method returns with an OK status. // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - Status AddDataset(const GraphDatasetBase* dataset, + Status AddDataset(const DatasetBase* dataset, const std::vector<Node*>& inputs, const std::vector<std::pair<StringPiece, AttrValue>>& attrs, Node** output) { @@ -145,7 +146,7 @@ class GraphDefBuilderWrapper { } Status AddDataset( - const GraphDatasetBase* dataset, + const DatasetBase* dataset, const std::vector<std::pair<size_t, Node*>>& inputs, const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, const std::vector<std::pair<StringPiece, AttrValue>>& attrs, @@ -276,6 +277,19 @@ class IteratorContext { explicit IteratorContext(Params params) : params_(std::move(params)) {} + explicit IteratorContext(OpKernelContext* ctx) { + params_.env = ctx->env(); + params_.runner = *(ctx->runner()); + params_.lib = ctx->function_library(); + // NOTE: must use reinterpret_cast because function.h forward-declares + // Device. + DeviceBase* device = + reinterpret_cast<DeviceBase*>(ctx->function_library()->device()); + params_.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + } + Env* env() const { return params_.env; } std::function<void(std::function<void()>)>* runner() { @@ -355,6 +369,11 @@ class IteratorBase { virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) = 0; + Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors, + bool* end_of_sequence) { + return GetNext(&ctx, out_tensors, end_of_sequence); + } + // Returns a vector of DataType values, representing the respective // element types of each tuple component in the outputs of this // iterator. @@ -406,10 +425,40 @@ class IteratorBase { } }; +// Represents runtime information needed to construct a dataset. +class DatasetContext { + public: + struct Params { + string name; + }; + + explicit DatasetContext(Params params) : params_(std::move(params)) {} + + explicit DatasetContext(OpKernelContext* ctx) { + params_.name = ctx->op_kernel().type_string(); + } + + const string& name() const { return params_.name; } + + private: + Params params_; +}; + // Represents a (potentially infinite) range of outputs, where each // output is a tuple of tensors. class DatasetBase : public core::RefCounted { public: + // Key for storing the Dataset graph in the serialized format. + TF_EXPORT static const char kDatasetGraphKey[]; + + // Key for storing the output node of the Dataset graph in the serialized + // format. + TF_EXPORT static const char kDatasetGraphOutputNodeKey[]; + + explicit DatasetBase(DatasetContext&& ctx) : name_(ctx.name()) {} + + const string& name() const { return name_; } + // Returns a new iterator for iterating over the range of elements in // this dataset. // @@ -426,6 +475,11 @@ class DatasetBase : public core::RefCounted { return (*iterator)->Initialize(ctx); } + Status MakeIterator(IteratorContext&& ctx, const string& prefix, + std::unique_ptr<IteratorBase>* iterator) const { + return MakeIterator(&ctx, prefix, iterator); + } + // Returns a vector of DataType values, representing the respective // element types of each tuple component in the outputs of this // dataset. @@ -441,16 +495,9 @@ class DatasetBase : public core::RefCounted { // Serializes the dataset and writes it to the `writer`. virtual Status Save(SerializationContext* ctx, - IteratorStateWriter* writer) const { - return errors::Unimplemented("%s does not support serialization", - DebugString()); - } + IteratorStateWriter* writer) const; protected: - // TODO(srbs): Ideally all graph related logic should reside in - // GraphDatasetBase. However, that would require Datasets defined in all ops - // to derive from GraphDatasetBase. Once that is done we can move - // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase. class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { public: DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} @@ -463,54 +510,15 @@ class DatasetBase : public core::RefCounted { // TODO(jsimsa): Consolidate overloading into a single method. virtual Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, - Node** node) const { - return AsGraphDefInternal(b, node); - } - - virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, - Node** node) const { - return errors::Unimplemented("%s does not support serialization", - DebugString()); - } + Node** node) const = 0; virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const = 0; friend class DatasetToGraphOp; // For access to graph related members. -}; - -// Base-class for datasets that are built by ops. -class GraphDatasetBase : public DatasetBase { - public: - GraphDatasetBase(OpKernelContext* ctx) - : op_name_(ctx->op_kernel().type_string()) {} - - const string op_name() const { return op_name_; } - - Status Save(SerializationContext* ctx, - IteratorStateWriter* writer) const override { - string serialized_graph_def; - string output_node; - TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); - return Status::OK(); - } - - // Key for storing the Dataset graph in the serialized format. - TF_EXPORT static const char kDatasetGraphKey[]; - - // Key for storing the output node of the Dataset graph in the serialized - // format. - TF_EXPORT static const char kDatasetGraphOutputNodeKey[]; private: - Status Serialize(SerializationContext* ctx, string* serialized_graph_def, - string* output_node) const; - - const string op_name_; + const string name_; }; // Represents an iterator that is associated with a particular dataset. @@ -718,12 +726,6 @@ class BackgroundWorker { std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_); }; -namespace dataset { - -IteratorContext MakeIteratorContext(OpKernelContext* ctx); - -} // namespace dataset - } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ |