aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-08 18:24:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-08 19:32:47 -0700
commit4809767db21dc877e807149ac571d6c75d8590ad (patch)
tree21a78a26a79cf93038930f146467ccf8d015d071
parente76d1d1163aeeaab22cd6451bede297ceb7ed467 (diff)
Switch cudnn_rnn to use the C++ shape inference functions.
Change: 132630173
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py49
1 files changed, 4 insertions, 45 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 20bb37be03..dcaa394681 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -18,10 +18,10 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
+from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import resource_loader
@@ -318,47 +318,6 @@ def _cudnn_rnn_backward(op, *grad):
direction=op.get_attr("direction"))
-@ops.RegisterShape("CudnnRNNParamsSize")
-def _cudnn_rnn_params_size_shape(_):
- params_size_shape = tensor_shape.TensorShape([])
- return [params_size_shape]
-
-
-@ops.RegisterShape("CudnnRNN")
-def _cudnn_rnn_forward_shape(op):
- """Shape function for the CudnnRNN forward operation.
-
- Args:
- op: the forward op.
- Returns:
- A list of shapes for the forward operation.
- """
- input_shape = op.inputs[0].get_shape()
- input_h_shape = op.inputs[1].get_shape()
- seq_length = input_shape[0]
- batch_size = input_shape[1]
- num_units = input_h_shape[2]
- direction = op.get_attr("direction")
- rnn_mode = op.get_attr("rnn_mode")
- dir_count = tensor_shape.as_dimension(
- 2) if direction == "bidirectional" else tensor_shape.as_dimension(1)
- output_shape = [seq_length, batch_size, dir_count * num_units]
- output_h_shape = input_h_shape
- output_c_shape = output_h_shape if rnn_mode == "lstm" else []
- return [output_shape, output_h_shape, output_c_shape, None]
-
-
-@ops.RegisterShape("CudnnRNNBackprop")
-def _cudnn_rnn_backward_shape(op):
- """Shape function for the CudnnRNN backward operation.
-
- Args:
- op: the backward operation.
- Returns:
- A list shapes for the backward operation.
- """
- input_shape = op.inputs[0].get_shape()
- input_h_shape = op.inputs[1].get_shape()
- input_c_shape = op.inputs[2].get_shape()
- params_shape = op.inputs[3].get_shape()
- return [input_shape, input_h_shape, input_c_shape, params_shape]
+ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("CudnnRNN")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("CudnnRNNBackprop")(common_shapes.call_cpp_shape_fn)