diff options
author | 2018-08-31 20:19:50 -0700 | |
---|---|---|
committer | 2018-08-31 20:24:08 -0700 | |
commit | c4ac04af7bb1ceb2e83cd6e7a3e1cac5f3d5c256 (patch) | |
tree | fb197ef6ea0498635fa4a91d04a7ea7006d97ed0 /tensorflow/core/framework | |
parent | 68452bed16ce2cfee8f962e5abef92d27fb0796f (diff) |
[tf.data] Avoiding serialization of (potentially large) tensors during optimization.
PiperOrigin-RevId: 211179990
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/dataset.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/framework/dataset.h | 29 |
2 files changed, 32 insertions, 4 deletions
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index b0b27ce94f..9ffd8e1ee0 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -179,6 +179,13 @@ Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx, return Status::OK(); } +void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val, + Node** output) { + *output = ops::SourceOp( + "Placeholder", + b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape())); +} + void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val, Node** output) { *output = ops::SourceOp( diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index e06ca68bca..04865a1d4f 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -110,10 +110,11 @@ class GraphDefBuilderWrapper { return Status::OK(); } - // Adds a Const node with Tensor value to the Graph. + // Adds a `Const` node for the given tensor value to the graph. + // // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + // non-null if the method returns with an OK status. The returned `Node` + // pointer is owned by the backing graph of `GraphDefBuilder`. Status AddTensor(const Tensor& val, Node** output) { AddTensorInternal(val, output); if (*output == nullptr) { @@ -122,6 +123,20 @@ class GraphDefBuilderWrapper { return Status::OK(); } + // Adds a `Placeholder` node for the given tensor value to the graph. + // + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. The returned `Node` + // pointer is owned by the backing graph of `GraphDefBuilder`. + Status AddPlaceholder(const Tensor& val, Node** output) { + AddPlaceholderInternal(val, output); + if (*output == nullptr) { + return errors::Internal( + "AddPlaceholder: Failed to build Placeholder op."); + } + return Status::OK(); + } + Status AddDataset(const DatasetBase* dataset, const std::vector<Node*>& inputs, Node** output) { return AddDataset(dataset, inputs, {}, output); @@ -168,6 +183,7 @@ class GraphDefBuilderWrapper { } private: + void AddPlaceholderInternal(const Tensor& val, Node** output); void AddTensorInternal(const Tensor& val, Node** output); Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def, @@ -334,7 +350,8 @@ class SerializationContext { public: struct Params { bool allow_stateful_functions = false; - const FunctionLibraryDefinition* flib_def; // Not owned. + const FunctionLibraryDefinition* flib_def = nullptr; // Not owned. + std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned. }; explicit SerializationContext(Params params) : params_(std::move(params)) {} @@ -343,6 +360,10 @@ class SerializationContext { const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; } + std::vector<std::pair<string, Tensor>>* input_list() { + return params_.input_list; + } + private: Params params_; |