diff options
author | 2018-08-09 20:11:37 -0700 | |
---|---|---|
committer | 2018-08-09 20:15:28 -0700 | |
commit | de082e57fb7609c261af72f583e0d5c236023376 (patch) | |
tree | 0ac48de4042bf52caec8307c8bebf191ef692c06 /tensorflow/core/ops/array_grad.cc | |
parent | 652be30e7140ca11b756a3ae0f9bd67f913af399 (diff) |
Add c++ gradient for SplitV op.
PiperOrigin-RevId: 208153311
Diffstat (limited to 'tensorflow/core/ops/array_grad.cc')
-rw-r--r-- | tensorflow/core/ops/array_grad.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index 38bd851da8..1f2e57e9a9 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -244,6 +244,27 @@ Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Split", SplitGrad); +Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"x: T", "size_splits: Tlen", "dim: int32", "dy: num_split*T"}, + // Ret val defs + {"dx: T", "d_size_splits: Tlen", "d_dim: int32"}, + // Attr defs + {"T: type", "Tlen: type", "num_split: int"}, + // Nodes + { + {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}}, + {{"d_size_splits"}, "ZerosLike", {"size_splits"}, {{"T", "$Tlen"}}}, + {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}}, + }); + // clang-format on + VLOG(1) << "SplitVGrad " << DebugString(*g); + return Status::OK(); +} +REGISTER_OP_GRADIENT("SplitV", SplitVGrad); + Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) { int N; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N)); |