diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-08 18:24:10 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-08 19:32:47 -0700 |
commit | 4809767db21dc877e807149ac571d6c75d8590ad (patch) | |
tree | 21a78a26a79cf93038930f146467ccf8d015d071 | |
parent | e76d1d1163aeeaab22cd6451bede297ceb7ed467 (diff) |
Switch cudnn_rnn to use the C++ shape inference functions.
Change: 132630173
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 49 |
1 files changed, 4 insertions, 45 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 20bb37be03..dcaa394681 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -18,10 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import load_library from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import resource_loader @@ -318,47 +318,6 @@ def _cudnn_rnn_backward(op, *grad): direction=op.get_attr("direction")) -@ops.RegisterShape("CudnnRNNParamsSize") -def _cudnn_rnn_params_size_shape(_): - params_size_shape = tensor_shape.TensorShape([]) - return [params_size_shape] - - -@ops.RegisterShape("CudnnRNN") -def _cudnn_rnn_forward_shape(op): - """Shape function for the CudnnRNN forward operation. - - Args: - op: the forward op. - Returns: - A list of shapes for the forward operation. - """ - input_shape = op.inputs[0].get_shape() - input_h_shape = op.inputs[1].get_shape() - seq_length = input_shape[0] - batch_size = input_shape[1] - num_units = input_h_shape[2] - direction = op.get_attr("direction") - rnn_mode = op.get_attr("rnn_mode") - dir_count = tensor_shape.as_dimension( - 2) if direction == "bidirectional" else tensor_shape.as_dimension(1) - output_shape = [seq_length, batch_size, dir_count * num_units] - output_h_shape = input_h_shape - output_c_shape = output_h_shape if rnn_mode == "lstm" else [] - return [output_shape, output_h_shape, output_c_shape, None] - - -@ops.RegisterShape("CudnnRNNBackprop") -def _cudnn_rnn_backward_shape(op): - """Shape function for the CudnnRNN backward operation. - - Args: - op: the backward operation. - Returns: - A list shapes for the backward operation. - """ - input_shape = op.inputs[0].get_shape() - input_h_shape = op.inputs[1].get_shape() - input_c_shape = op.inputs[2].get_shape() - params_shape = op.inputs[3].get_shape() - return [input_shape, input_h_shape, input_c_shape, params_shape] +ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("CudnnRNN")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("CudnnRNNBackprop")(common_shapes.call_cpp_shape_fn) |