diff options
author | Alexandre Passos <apassos@google.com> | 2018-10-08 13:50:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 13:58:40 -0700 |
commit | eec9ca8f0baccd249a49046fe31b460903e44850 (patch) | |
tree | b6397af544af7c05abca4bea08bd6354f90bedf1 /tensorflow/core/graph | |
parent | 494bbdfced3fd8596721d12e73676c4967f452e4 (diff) |
Partial support tfe.defun in tf.gradients.
Doesn't attempt to deal with cases where we might have already generated
the functiondef for the parent function as in that case we cannot easily
modify the forward pass.
PiperOrigin-RevId: 216243224
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/graph.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.h | 5 | ||||
-rw-r--r-- | tensorflow/core/graph/node_builder.cc | 8 |
3 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 7a4a0096fa..6f068546d2 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -142,6 +142,19 @@ void Node::Clear() { assigned_device_name_index_ = 0; } +void Node::UpdateProperties() { + DataTypeVector inputs; + DataTypeVector outputs; + Status status = + InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs); + if (!status.ok()) { + LOG(ERROR) << "Failed at updating node: " << status; + return; + } + props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def, + inputs, outputs); +} + const string& Node::name() const { return props_->node_def.name(); } const string& Node::type_string() const { return props_->node_def.op(); } const NodeDef& Node::def() const { return props_->node_def; } diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 2944951f82..228b1331d9 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -171,6 +171,7 @@ class Node { template <typename T> void AddAttr(const string& name, const T& val) { SetAttrValue(val, AddAttrHelper(name)); + UpdateProperties(); } void ClearAttr(const string& name); @@ -211,6 +212,10 @@ class Node { // e.g. in AddAttr. void MaybeCopyOnWrite(); + // Called after an attr has changed. Decides whether we need to update some + // property of the node (stored in props_). + void UpdateProperties(); + AttrValue* AddAttrHelper(const string& name); // A set of mutually exclusive classes for different kinds of nodes, diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index d92874909f..68a20fcc5f 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -140,10 +140,10 @@ void NodeBuilder::AddIndexError(const Node* node, int i) { strings::StrCat("Attempt to add nullptr Node to node with type ", def_builder_.op_def().name())); } else { - errors_.emplace_back( - strings::StrCat("Attempt to add output ", i, " of ", node->name(), - " not in range [0, ", node->num_outputs(), - ") to node with type ", def_builder_.op_def().name())); + errors_.emplace_back(strings::StrCat( + "Attempt to add output ", i, " of ", node->name(), " not in range [0, ", + node->num_outputs(), ") to node with type ", + def_builder_.op_def().name(), ". Node: ", node->DebugString())); } } |