diff options
author | Alexandre Passos <apassos@google.com> | 2018-08-14 15:04:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-14 15:07:53 -0700 |
commit | 0e4c277b3dc6e6004d47c400cb684e11b9b2e2f0 (patch) | |
tree | 23f3ce0cddca8a856f42b11459df2cd61dbd5c08 /tensorflow/core/graph | |
parent | a00fba38681fa28d6d20cd730fe362864b819a0d (diff) |
ZerosLike in symbolicgradient for resource variables
PiperOrigin-RevId: 208720651
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/gradients.cc | 41 |
1 files changed, 31 insertions, 10 deletions
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc index c1a8a63784..bec41712b1 100644 --- a/tensorflow/core/graph/gradients.cc +++ b/tensorflow/core/graph/gradients.cc @@ -65,16 +65,37 @@ struct NodeOutEq { static Node* AddZerosLike(Graph* g, NodeOut input) { DCHECK_LT(0, input.dtype()); DCHECK_LT(input.dtype(), DT_FLOAT_REF); - NodeDef ndef; - ndef.set_name(g->NewName(kNodeLabel)); - ndef.set_op("ZerosLike"); - ndef.add_input(input.name()); - AddNodeAttr("T", input.dtype(), &ndef); - Status s; - Node* ret = g->AddNode(ndef, &s); - TF_CHECK_OK(s); - g->AddEdge(input.node, input.index, ret, 0); - return ret; + if (input.dtype() == DT_RESOURCE) { + NodeDef read_def; + read_def.set_name(g->NewName("Read")); + read_def.set_op("ReadVariableOp"); + read_def.add_input(input.name()); + AddNodeAttr("dtype", DT_FLOAT, &read_def); + Status s; + Node* read = g->AddNode(read_def, &s); + TF_CHECK_OK(s); + g->AddEdge(input.node, input.index, read, 0); + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op("ZerosLike"); + ndef.add_input(read_def.name()); + AddNodeAttr("T", DT_FLOAT, &ndef); + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + g->AddEdge(read, 0, ret, 0); + return ret; + } else { + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op("ZerosLike"); + ndef.add_input(input.name()); + AddNodeAttr("T", input.dtype(), &ndef); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + g->AddEdge(input.node, input.index, ret, 0); + return ret; + } } static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) { |