diff options
Diffstat (limited to 'tensorflow/core/graph/gradients.cc')
-rw-r--r-- | tensorflow/core/graph/gradients.cc | 41 |
1 files changed, 31 insertions, 10 deletions
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc index c1a8a63784..bec41712b1 100644 --- a/tensorflow/core/graph/gradients.cc +++ b/tensorflow/core/graph/gradients.cc @@ -65,16 +65,37 @@ struct NodeOutEq { static Node* AddZerosLike(Graph* g, NodeOut input) { DCHECK_LT(0, input.dtype()); DCHECK_LT(input.dtype(), DT_FLOAT_REF); - NodeDef ndef; - ndef.set_name(g->NewName(kNodeLabel)); - ndef.set_op("ZerosLike"); - ndef.add_input(input.name()); - AddNodeAttr("T", input.dtype(), &ndef); - Status s; - Node* ret = g->AddNode(ndef, &s); - TF_CHECK_OK(s); - g->AddEdge(input.node, input.index, ret, 0); - return ret; + if (input.dtype() == DT_RESOURCE) { + NodeDef read_def; + read_def.set_name(g->NewName("Read")); + read_def.set_op("ReadVariableOp"); + read_def.add_input(input.name()); + AddNodeAttr("dtype", DT_FLOAT, &read_def); + Status s; + Node* read = g->AddNode(read_def, &s); + TF_CHECK_OK(s); + g->AddEdge(input.node, input.index, read, 0); + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op("ZerosLike"); + ndef.add_input(read_def.name()); + AddNodeAttr("T", DT_FLOAT, &ndef); + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + g->AddEdge(read, 0, ret, 0); + return ret; + } else { + NodeDef ndef; + ndef.set_name(g->NewName(kNodeLabel)); + ndef.set_op("ZerosLike"); + ndef.add_input(input.name()); + AddNodeAttr("T", input.dtype(), &ndef); + Status s; + Node* ret = g->AddNode(ndef, &s); + TF_CHECK_OK(s); + g->AddEdge(input.node, input.index, ret, 0); + return ret; + } } static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) { |