aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-08-14 15:04:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 15:07:53 -0700
commit0e4c277b3dc6e6004d47c400cb684e11b9b2e2f0 (patch)
tree23f3ce0cddca8a856f42b11459df2cd61dbd5c08 /tensorflow/core/graph
parenta00fba38681fa28d6d20cd730fe362864b819a0d (diff)
ZerosLike in symbolicgradient for resource variables
PiperOrigin-RevId: 208720651
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/gradients.cc41
1 files changed, 31 insertions, 10 deletions
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc
index c1a8a63784..bec41712b1 100644
--- a/tensorflow/core/graph/gradients.cc
+++ b/tensorflow/core/graph/gradients.cc
@@ -65,16 +65,37 @@ struct NodeOutEq {
static Node* AddZerosLike(Graph* g, NodeOut input) {
DCHECK_LT(0, input.dtype());
DCHECK_LT(input.dtype(), DT_FLOAT_REF);
- NodeDef ndef;
- ndef.set_name(g->NewName(kNodeLabel));
- ndef.set_op("ZerosLike");
- ndef.add_input(input.name());
- AddNodeAttr("T", input.dtype(), &ndef);
- Status s;
- Node* ret = g->AddNode(ndef, &s);
- TF_CHECK_OK(s);
- g->AddEdge(input.node, input.index, ret, 0);
- return ret;
+ if (input.dtype() == DT_RESOURCE) {
+ NodeDef read_def;
+ read_def.set_name(g->NewName("Read"));
+ read_def.set_op("ReadVariableOp");
+ read_def.add_input(input.name());
+ AddNodeAttr("dtype", DT_FLOAT, &read_def);
+ Status s;
+ Node* read = g->AddNode(read_def, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, read, 0);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("ZerosLike");
+ ndef.add_input(read_def.name());
+ AddNodeAttr("T", DT_FLOAT, &ndef);
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(read, 0, ret, 0);
+ return ret;
+ } else {
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("ZerosLike");
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+ }
}
static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {