diff options
author | 2018-03-05 18:49:53 -0800 | |
---|---|---|
committer | 2018-03-05 18:53:40 -0800 | |
commit | 73999dc944b3516d485081fe060d6916c089e412 (patch) | |
tree | a77350a24ccc5f5e95d0b2469bce37588b5c2130 /tensorflow/python/layers | |
parent | b5f943201afc06525818f45da28f82559fceced2 (diff) |
Fixes a number of usability issues with model_to_estimator, in particular:
- make it possible to use a model that was compiled with a TF optimizer (do not require a Keras optimizer)
- do not require input to be dict (input_fn supports plain arrays)
- do not require `config` to be a RunConfig instance, can now be a dict (better UX)
- make it possible to use a subclassed model (caveat: weights are not preserved, yet)
- clear error message when model isn't compiled; improve various error messages
PiperOrigin-RevId: 187959927
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/base.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 2ec9971b88..c6d16a3bc0 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -127,7 +127,7 @@ class Layer(checkpointable.CheckpointableBase): # return tensors. When using graph execution, _losses is a list of ops. self._losses = [] self._reuse = kwargs.get('_reuse') - self._graph = ops.get_default_graph() + self._graph = None # Will be set at build time. self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name call_fn_args = estimator_util.fn_args(self.call) self._compute_previous_mask = ('mask' in call_fn_args or @@ -630,7 +630,8 @@ class Layer(checkpointable.CheckpointableBase): # the same graph as where it was created. if in_graph_mode: try: - ops._get_graph_from_inputs(input_list, graph=self.graph) # pylint: disable=protected-access + # Set layer's "graph" at build time + self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access except ValueError as e: raise ValueError('Input graph and Layer graph are not the same: %s' % e) if in_graph_mode or in_deferred_mode: |