aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-06 10:26:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 10:29:12 -0700
commit665e44f612c72d39717b0a5163dca82a07e1c174 (patch)
treebbe699c7a92f84d2e25b424fa0af230487db7675 /tensorflow/contrib/cudnn_rnn
parent53868bfd9705da3fc15b59ab02db39b652686b13 (diff)
Object-based checkpointing support for unidirectional cuDNN LSTM cells
Once checked in, this will be the only way I know of to save canonical weights when executing eagerly. Eager's name-based saving support will only do the opaque parameter buffer. I'm not going to try converting everything in one go, but it's a start at least. And everything else should raise a NotImplementedError rather than silently not saving correctly. Single-layer cuDNN cells can be swapped for un-wrapped cuDNN compatible cells or single cells wrapped in MultiRNNCells. Multi-layer cells need MultiRNNCell wrapping. PiperOrigin-RevId: 191905703
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD1
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py151
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py20
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py75
4 files changed, 237 insertions, 10 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index 8b5d13f725..d68015ae15 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -25,6 +25,7 @@ tf_custom_op_py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/contrib/eager/python:checkpointable_utils",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 9897c31a98..9cc6ca09ad 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import argparse
import collections
+import functools
import itertools
import os
import sys
@@ -28,13 +29,14 @@ import numpy as np
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
+from tensorflow.contrib.eager.python import checkpointable_utils
from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
-from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_nn_ops
@@ -265,7 +267,7 @@ def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None):
return outputs, (output_state_fw, output_state_bw)
-class CudnnRNNTestBasic(TensorFlowTestCase):
+class CudnnRNNTestBasic(test_util.TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@@ -467,7 +469,7 @@ class CudnnRNNTestBasic(TensorFlowTestCase):
# TODO(jamesqin): Transform to parameterized test after it is included in the
# TF open source codebase.
-class CudnnRNNTestSaveRestore(TensorFlowTestCase):
+class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase):
def _CompareWeights(self, lhs, rhs):
self.assertEqual(len(lhs), len(rhs))
@@ -701,9 +703,146 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase):
self._TestSaveRestoreHelper(CUDNN_RNN_RELU)
+class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
+
+ def _VerifyCheckpoint(
+ self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn,
+ num_layers, input_size, expected_variable_values, num_applications=3):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with ops.device("gpu:0"):
+ cudnn_layer = cudnn_cell_fn()
+ cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer)
+ status = cudnn_checkpoint.restore(checkpoint_path)
+ inputs = 3. * array_ops.ones([num_applications, num_layers, input_size],
+ dtype=dtypes.float32)
+ cudnn_output, _ = cudnn_layer(inputs)
+ status.assert_consumed().run_restore_ops()
+ second_save_path = cudnn_checkpoint.save(checkpoint_prefix)
+ restore_layer = compatible_cell_fn()
+ restore_layer_checkpoint = checkpointable_utils.Checkpoint(
+ cell=restore_layer)
+ status = restore_layer_checkpoint.restore(second_save_path)
+ current_state = restore_layer.zero_state(1, dtypes.float32)
+ for _ in range(num_applications):
+ restore_layer_output, current_state = restore_layer(
+ inputs=3. * array_ops.ones([1, input_size]),
+ state=current_state)
+ status.assert_consumed().run_restore_ops()
+ self.assertTrue(restore_layer.variables)
+ for variable, expected_value in zip(
+ restore_layer.variables, expected_variable_values):
+ self.assertAllClose(expected_value, self.evaluate(variable))
+ self.assertAllClose(self.evaluate(restore_layer_output),
+ self.evaluate(cudnn_output)[-1, -1:, ...])
+
+ def _CheckpointableSingleCellUnidirectionalTestTemplate(
+ self, single_cell_fn, cudnn_cell_fn):
+ # Single-layer cuDNN cells with object-based checkpointing should be
+ # checkpoint compatible with either single CudnnCompatible cells or
+ # MultiRnnCells with one cell.
+ input_size = 3
+ save_cell_layer = single_cell_fn()
+ save_cell_layer(
+ inputs=array_ops.ones([1, input_size]),
+ state=save_cell_layer.zero_state(1, dtypes.float32))
+ self.assertTrue(save_cell_layer.variables)
+ expected_values = []
+ np.random.seed(10)
+ for variable in save_cell_layer.variables:
+ value = np.random.normal(size=variable.shape)
+ expected_values.append(value)
+ self.evaluate(variable.assign(value))
+ save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ first_save_path = save_checkpoint.save(checkpoint_prefix)
+ self._VerifyCheckpoint(
+ checkpoint_path=first_save_path,
+ compatible_cell_fn=
+ lambda: rnn_cell_impl.MultiRNNCell([single_cell_fn()]),
+ cudnn_cell_fn=cudnn_cell_fn,
+ num_layers=1,
+ expected_variable_values=expected_values,
+ input_size=input_size)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ @test_util.run_in_graph_and_eager_modes()
+ def testLSTMCheckpointableSingleLayer(self):
+ num_units = 2
+ direction = CUDNN_RNN_UNIDIRECTION
+ self._CheckpointableSingleCellUnidirectionalTestTemplate(
+ single_cell_fn=functools.partial(
+ cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units),
+ cudnn_cell_fn=functools.partial(
+ cudnn_rnn.CudnnLSTM, num_layers=1, num_units=num_units,
+ direction=direction, name="awesome_lstm"))
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ @test_util.run_in_graph_and_eager_modes()
+ def testGRUCheckpointableSingleLayer(self):
+ num_units = 2
+ direction = CUDNN_RNN_UNIDIRECTION
+ with self.assertRaises(NotImplementedError):
+ # TODO(allenl): Implement object-based saving for GRUs and other cells.
+ self._CheckpointableSingleCellUnidirectionalTestTemplate(
+ single_cell_fn=functools.partial(
+ cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units),
+ cudnn_cell_fn=functools.partial(
+ cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units,
+ direction=direction, name="awesome_gru"))
+
+ def _CheckpointableMultiLayerTestTemplate(
+ self, single_cell_fn, cudnn_cell_fn, num_layers):
+
+ def _MultiCellFn():
+ return rnn_cell_impl.MultiRNNCell(
+ [single_cell_fn() for _ in range(num_layers)])
+ input_size = 3
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(graph=save_graph):
+ save_layer = _MultiCellFn()
+ save_layer(inputs=array_ops.ones([1, input_size]),
+ state=save_layer.zero_state(1, dtypes.float32))
+ self.assertTrue(save_layer.variables)
+ expected_values = []
+ np.random.seed(10)
+ for variable in save_layer.variables:
+ value = np.random.normal(size=variable.shape)
+ expected_values.append(value)
+ self.evaluate(variable.assign(value))
+ save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ first_save_path = save_checkpoint.save(checkpoint_prefix)
+ self._VerifyCheckpoint(
+ checkpoint_path=first_save_path,
+ compatible_cell_fn=_MultiCellFn, cudnn_cell_fn=cudnn_cell_fn,
+ num_layers=num_layers,
+ expected_variable_values=expected_values,
+ input_size=input_size)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ @test_util.run_in_graph_and_eager_modes()
+ def testCudnnCompatibleLSTMCheckpointablMultiLayer(self):
+ num_units = 2
+ num_layers = 3
+ direction = CUDNN_RNN_UNIDIRECTION
+ self._CheckpointableMultiLayerTestTemplate(
+ single_cell_fn=functools.partial(
+ cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units),
+ cudnn_cell_fn=functools.partial(
+ cudnn_rnn.CudnnLSTM, num_layers=num_layers, num_units=num_units,
+ direction=direction, name="awesome_lstm"),
+ num_layers=num_layers)
+
+
# TODO(jamesqin): Transform to parameterized test after it is included in the
# TF open source codebase.
-class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase):
+class CudnnRNNTestCompatibleRNNCells(test_util.TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@@ -884,7 +1023,7 @@ class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase):
rtol=2e-5)
-class CudnnRNNTestParamsSize(TensorFlowTestCase):
+class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase):
def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size,
dtype, direction):
@@ -931,7 +1070,7 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase):
dtype, direction)
-class CudnnRNNTestTraining(TensorFlowTestCase):
+class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1):
"""Compute the numeric gradient of y wrt to x.
diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
index 36fba917a8..00d9544602 100644
--- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
+++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
@@ -142,6 +142,9 @@ class _CudnnRNN(base_layer.Layer):
"""
# pylint:enable=line-too-long
+ # TODO(allenl): Document object-based saving and checkpoint compatibility once
+ # it's implemented for more cuDNN Layers.
+
# The following are constants defined by subclasses.
# Type of RNN cell.
_rnn_mode = None
@@ -363,6 +366,11 @@ class _CudnnRNN(base_layer.Layer):
self._create_saveable()
self.built = True
+ def _gather_saveables_for_checkpoint(self):
+ raise NotImplementedError(
+ "This cell does not yet support object-based saving. File a feature "
+ "request if this limitation bothers you.")
+
def call(self, inputs, initial_state=None, training=True):
"""Runs the forward step for the RNN model.
@@ -499,6 +507,8 @@ class _CudnnRNN(base_layer.Layer):
direction=self.direction,
scope=vs.get_variable_scope(),
name="%s_saveable" % self.trainable_variables[0].name.split(":")[0])
+ self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access
+ checkpointable=self, dtype=self._plain_dtype)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
@@ -521,6 +531,16 @@ class CudnnLSTM(_CudnnRNN):
return ([self.num_layers * self.num_dirs, batch_size, self.num_units],
[self.num_layers * self.num_dirs, batch_size, self.num_units])
+ @property
+ def _gather_saveables_for_checkpoint(self):
+ if self._direction == CUDNN_RNN_UNIDIRECTION:
+ # Skip one inheritance level to avoid NotImplementedError.
+ return super(_CudnnRNN, self)._gather_saveables_for_checkpoint
+ else:
+ raise NotImplementedError(
+ "Object-based saving does not currently support bidirectional LSTM "
+ "cells. File a feature request if this limitation bothers you.")
+
class _CudnnRNNNoInputC(_CudnnRNN):
"""Abstract simple CudnnRNN layer without input_c."""
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 622241a177..588a5e705d 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.eager.python import checkpointable_utils
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
@@ -31,6 +32,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.training import checkpointable as checkpointable_lib
from tensorflow.python.training import saver
CUDNN_RNN_UNIDIRECTION = "unidirectional"
@@ -262,13 +264,16 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
# instead of having the master pull all slices and then save them.
slice_spec = ""
params = weights + biases
- param_names = weight_names + bias_names
+ self._weight_names = weight_names
+ self._bias_names = bias_names
+ self._param_names = weight_names + bias_names
+ prefixed_param_names = weight_names + bias_names
if self._scope:
- param_names = ["%s/%s" % (self._scope, pn) for pn in param_names]
-
+ prefixed_param_names = [
+ "%s/%s" % (self._scope, pn) for pn in prefixed_param_names]
specs = [
saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name)
- for param, param_name in zip(params, param_names)
+ for param, param_name in zip(params, prefixed_param_names)
]
super(CudnnOpaqueParamsSaveable, self).__init__(
array_ops.identity(self._variables), specs, name)
@@ -281,6 +286,45 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
return state_ops.assign(
self._variables, opaque_params, validate_shape=False)
+ def _checkpointable_save(self, save_buffer):
+ weights, biases = self._OpaqueParamsToCanonical()
+ with ops.device("gpu:0"):
+ (weights, _), (biases, _) = self._TransformCanonical(
+ weights, biases)
+ for name, tensor in zip(self._param_names, weights + biases):
+ save_buffer[name] = array_ops.identity(tensor)
+
+ def _checkpointable_restore(self, restore_buffer):
+ tensors = [array_ops.identity(restore_buffer[name])
+ for name in self._param_names]
+ return self.restore(
+ restored_tensors=tensors,
+ restored_shapes=None # Unused
+ )
+
+ def _add_checkpointable_dependencies(self, checkpointable, dtype):
+ """Add canonical weight dependencies to `checkpointable`.
+
+ When saving or restoring, converts to or from the opaque buffer
+ format. Weights are saved and loaded in the configuration expected by
+ cuDNN-compatible cells.
+
+ Args:
+ checkpointable: An object inheriting from `CheckpointableBase` to add
+ dependencies too (typically the cuDNN `Layer`).
+ dtype: The dtype for the canonical parameter Tensors.
+ """
+ split_dependencies = checkpointable_utils.split_dependency(
+ component_names=self._param_names,
+ component_dtypes=(dtype,) * len(self._param_names),
+ fill_save_buffer_fn=self._checkpointable_save,
+ consume_restore_buffer_fn=self._checkpointable_restore)
+ self._checkpointable_track_params(checkpointable, split_dependencies)
+
+ def _checkpointable_track_params(self, checkpointable, params):
+ """Tracks parameters in a canonical configuration."""
+ return # NotImplementedError raised by the Layer.
+
def _TFCanonicalNamePrefix(self, layer, is_fwd=True):
if self._direction == CUDNN_RNN_UNIDIRECTION:
return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name)
@@ -570,6 +614,29 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable):
tf_biases.append(b)
tf_bias_names.append(prefix + "/bias")
+ def _checkpointable_track_params(self, checkpointable, params):
+ """Track parameters for compatibility with CudnnCompatibleLSTMCell."""
+ biases = []
+ weights = []
+ for name in self._weight_names:
+ weights.append(params[name])
+ for name in self._bias_names:
+ biases.append(params[name])
+ assert len(params) == len(weights) + len(biases)
+ if len(weights) == 1 and len(biases) == 1:
+ # For single-layer cells, allow substituting a cell with no MultiRNNCell
+ # wrapping.
+ kernel, = weights # pylint: disable=unbalanced-tuple-unpacking
+ bias, = biases # pylint: disable=unbalanced-tuple-unpacking
+ checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access
+ checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access
+ assert len(biases) == len(weights)
+ for cell_index, (bias, kernel) in enumerate(zip(biases, weights)):
+ cell = checkpointable_lib.Checkpointable()
+ checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access
+ cell.bias = bias
+ cell.kernel = kernel
+
class CudnnGRUSaveable(CudnnOpaqueParamsSaveable):
"""SaveableObject implementation handling Cudnn GRU opaque params."""