aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-03-05 18:49:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-05 18:53:40 -0800
commit73999dc944b3516d485081fe060d6916c089e412 (patch)
treea77350a24ccc5f5e95d0b2469bce37588b5c2130 /tensorflow/python/layers
parentb5f943201afc06525818f45da28f82559fceced2 (diff)
Fixes a number of usability issues with model_to_estimator, in particular:
- make it possible to use a model that was compiled with a TF optimizer (do not require a Keras optimizer) - do not require input to be dict (input_fn supports plain arrays) - do not require `config` to be a RunConfig instance, can now be a dict (better UX) - make it possible to use a subclassed model (caveat: weights are not preserved, yet) - clear error message when model isn't compiled; improve various error messages PiperOrigin-RevId: 187959927
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 2ec9971b88..c6d16a3bc0 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -127,7 +127,7 @@ class Layer(checkpointable.CheckpointableBase):
# return tensors. When using graph execution, _losses is a list of ops.
self._losses = []
self._reuse = kwargs.get('_reuse')
- self._graph = ops.get_default_graph()
+ self._graph = None # Will be set at build time.
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
call_fn_args = estimator_util.fn_args(self.call)
self._compute_previous_mask = ('mask' in call_fn_args or
@@ -630,7 +630,8 @@ class Layer(checkpointable.CheckpointableBase):
# the same graph as where it was created.
if in_graph_mode:
try:
- ops._get_graph_from_inputs(input_list, graph=self.graph) # pylint: disable=protected-access
+ # Set layer's "graph" at build time
+ self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access
except ValueError as e:
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
if in_graph_mode or in_deferred_mode: