diff options
author | Nick Felt <nickfelt@google.com> | 2018-07-24 09:49:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-24 09:53:22 -0700 |
commit | 568727eed199dba04e37f500265b50f96fed455e (patch) | |
tree | 999f31d1469b3b5f2dc12d5ca04061cfe6062faa /tensorflow/python/saved_model | |
parent | f8bbd3ceb7e86b7595ba74a9a03cfc7c1be252a8 (diff) |
Add v2 summary support to Estimator.train() and MonitoredSession hooks
This change makes Estimator.train() support v2 summaries (tf.contrib.summary.*) out-of-the-box, to match the support for v1 summaries. Estimator.train() will now handle the boilerplate necessary to initialize a file writer and enable summary writing every N steps, and will ensure that its own automatically exported summaries (for loss and global_step/sec) get written to the same underlying events file.
As part of this change, tf.train.SummarySaverHook, tf.train.CheckpointSaverHook, tf.train.StepCounterHook, and tf.train.ProfilerHook have also been adapted to write summaries using the v2 summary system (via a compatibility layer), instead of using FileWriterCache.
A couple additional smaller changes are:
- the 'session' parameter to FileWriter() can now be a callable returning a tf.Session instance.
- the introduction of tf.contrib.summary.record_summaries_if() which takes a boolean tensor for direct control of tf.contrib.summary.should_record_summaries().
- EstimatorSpec.train_op, besides a tf.Operation, is now allowed to be any Tensor-equivalent object rather than just a tf.Tensor.
PiperOrigin-RevId: 205843986
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/builder_impl.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index e58be804c2..b67d0f2362 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -28,6 +28,7 @@ from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging @@ -178,10 +179,10 @@ class SavedModelBuilder(object): stored as a collection with key TRAIN_OP_KEY, but not executed. Raises: - TypeError if Train op is not of type `Operation`. + TypeError if Train op is not of type `Operation` or a Tensor. """ if train_op is not None: - if (not isinstance(train_op, ops.Tensor) and + if (not tensor_util.is_tensor(train_op) and not isinstance(train_op, ops.Operation)): raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op) ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) |