aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Nick Felt <nickfelt@google.com>2018-07-24 09:49:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 09:53:22 -0700
commit568727eed199dba04e37f500265b50f96fed455e (patch)
tree999f31d1469b3b5f2dc12d5ca04061cfe6062faa /tensorflow/python/saved_model
parentf8bbd3ceb7e86b7595ba74a9a03cfc7c1be252a8 (diff)
Add v2 summary support to Estimator.train() and MonitoredSession hooks
This change makes Estimator.train() support v2 summaries (tf.contrib.summary.*) out-of-the-box, to match the support for v1 summaries. Estimator.train() will now handle the boilerplate necessary to initialize a file writer and enable summary writing every N steps, and will ensure that its own automatically exported summaries (for loss and global_step/sec) get written to the same underlying events file. As part of this change, tf.train.SummarySaverHook, tf.train.CheckpointSaverHook, tf.train.StepCounterHook, and tf.train.ProfilerHook have also been adapted to write summaries using the v2 summary system (via a compatibility layer), instead of using FileWriterCache. A couple additional smaller changes are: - the 'session' parameter to FileWriter() can now be a callable returning a tf.Session instance. - the introduction of tf.contrib.summary.record_summaries_if() which takes a boolean tensor for direct control of tf.contrib.summary.should_record_summaries(). - EstimatorSpec.train_op, besides a tf.Operation, is now allowed to be any Tensor-equivalent object rather than just a tf.Tensor. PiperOrigin-RevId: 205843986
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/builder_impl.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index e58be804c2..b67d0f2362 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -28,6 +28,7 @@ from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
@@ -178,10 +179,10 @@ class SavedModelBuilder(object):
stored as a collection with key TRAIN_OP_KEY, but not executed.
Raises:
- TypeError if Train op is not of type `Operation`.
+ TypeError if Train op is not of type `Operation` or a Tensor.
"""
if train_op is not None:
- if (not isinstance(train_op, ops.Tensor) and
+ if (not tensor_util.is_tensor(train_op) and
not isinstance(train_op, ops.Operation)):
raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)