aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_grad.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-09 20:11:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 20:15:28 -0700
commitde082e57fb7609c261af72f583e0d5c236023376 (patch)
tree0ac48de4042bf52caec8307c8bebf191ef692c06 /tensorflow/core/ops/array_grad.cc
parent652be30e7140ca11b756a3ae0f9bd67f913af399 (diff)
Add c++ gradient for SplitV op.
PiperOrigin-RevId: 208153311
Diffstat (limited to 'tensorflow/core/ops/array_grad.cc')
-rw-r--r--tensorflow/core/ops/array_grad.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 38bd851da8..1f2e57e9a9 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -244,6 +244,27 @@ Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Split", SplitGrad);
+Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"x: T", "size_splits: Tlen", "dim: int32", "dy: num_split*T"},
+ // Ret val defs
+ {"dx: T", "d_size_splits: Tlen", "d_dim: int32"},
+ // Attr defs
+ {"T: type", "Tlen: type", "num_split: int"},
+ // Nodes
+ {
+ {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}},
+ {{"d_size_splits"}, "ZerosLike", {"size_splits"}, {{"T", "$Tlen"}}},
+ {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
+ });
+ // clang-format on
+ VLOG(1) << "SplitVGrad " << DebugString(*g);
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("SplitV", SplitVGrad);
+
Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) {
int N;
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N));