aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-08-31 20:19:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 20:24:08 -0700
commitc4ac04af7bb1ceb2e83cd6e7a3e1cac5f3d5c256 (patch)
treefb197ef6ea0498635fa4a91d04a7ea7006d97ed0 /tensorflow/core/framework
parent68452bed16ce2cfee8f962e5abef92d27fb0796f (diff)
[tf.data] Avoiding serialization of (potentially large) tensors during optimization.
PiperOrigin-RevId: 211179990
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/dataset.cc7
-rw-r--r--tensorflow/core/framework/dataset.h29
2 files changed, 32 insertions, 4 deletions
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index b0b27ce94f..9ffd8e1ee0 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -179,6 +179,13 @@ Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx,
return Status::OK();
}
+void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
+ Node** output) {
+ *output = ops::SourceOp(
+ "Placeholder",
+ b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
+}
+
void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
Node** output) {
*output = ops::SourceOp(
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index e06ca68bca..04865a1d4f 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -110,10 +110,11 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- // Adds a Const node with Tensor value to the Graph.
+ // Adds a `Const` node for the given tensor value to the graph.
+ //
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
- // non-null if the method returns with an OK status.
- // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
Status AddTensor(const Tensor& val, Node** output) {
AddTensorInternal(val, output);
if (*output == nullptr) {
@@ -122,6 +123,20 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
+ // Adds a `Placeholder` node for the given tensor value to the graph.
+ //
+ // `*output` contains a pointer to the output `Node`. It is guaranteed to be
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
+ Status AddPlaceholder(const Tensor& val, Node** output) {
+ AddPlaceholderInternal(val, output);
+ if (*output == nullptr) {
+ return errors::Internal(
+ "AddPlaceholder: Failed to build Placeholder op.");
+ }
+ return Status::OK();
+ }
+
Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
@@ -168,6 +183,7 @@ class GraphDefBuilderWrapper {
}
private:
+ void AddPlaceholderInternal(const Tensor& val, Node** output);
void AddTensorInternal(const Tensor& val, Node** output);
Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def,
@@ -334,7 +350,8 @@ class SerializationContext {
public:
struct Params {
bool allow_stateful_functions = false;
- const FunctionLibraryDefinition* flib_def; // Not owned.
+ const FunctionLibraryDefinition* flib_def = nullptr; // Not owned.
+ std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
};
explicit SerializationContext(Params params) : params_(std::move(params)) {}
@@ -343,6 +360,10 @@ class SerializationContext {
const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; }
+ std::vector<std::pair<string, Tensor>>* input_list() {
+ return params_.input_list;
+ }
+
private:
Params params_;