From 79ad4a423f3e9031eb841a164372cc7476cc112a Mon Sep 17 00:00:00 2001 From: Olivia Nordquist Date: Thu, 30 Nov 2017 19:46:05 -0800 Subject: enabling Tensor._set_shape() to work with the C API PiperOrigin-RevId: 177543170 --- tensorflow/python/client/tf_session.i | 43 ++++++++++++++++++++ tensorflow/python/client/tf_session_helper.cc | 19 +++++++++ tensorflow/python/client/tf_session_helper.h | 14 +++++++ tensorflow/python/framework/ops.py | 57 +++++++++++++++++++-------- 4 files changed, 117 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 5fa1a7e8fc..d471a39b69 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -532,6 +532,49 @@ def TF_Reset(target, containers=None, config=None): %unignore TF_GraphGetTensorShapeHelper; %ignore TF_GraphGetTensorShape; +// We use TF_GraphSetTensorShape_wrapper instead of +// TF_GraphSetTensorShape +%ignore TF_GraphSetTensorShape; +%unignore tensorflow; +%unignore TF_GraphSetTensorShape_wrapper; + +// $input is a Python list of ints to a vector for TF_GraphSetTensorShape_wrapper +%typemap(in) (const std::vector& dims) + (std::vector dims_local){ + if ($input != Py_None) { + if (!PyList_Check($input)) { + SWIG_exception_fail(SWIG_TypeError, tensorflow::strings::Printf( + "$symname: expected list but got %s ", Py_TYPE($input)->tp_name).c_str()); + } + size_t size = PyList_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* item = PyList_GetItem($input, i); + dims_local.push_back(PyInt_AsLong(item)); + } + $1 = &dims_local; + } else { + $1 = nullptr; + } +} + +// We use TF_GraphGetTensorShape_wrapper instead of +// TF_GraphGetTensorShape +%ignore TF_GraphGetTensorShape; +%unignore tensorflow; +%unignore TF_GraphGetTensorShape_wrapper; + +// Build a Python list of ints and return it. +%typemap(out) std::vector tensorflow::TF_GraphGetTensorShape_wrapper { + $result = PyList_New($1.size()); + if (!$result) { + SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); + } + + for (size_t i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); + } +} + %include "tensorflow/python/client/tf_session_helper.h" %unignoreall diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index ad982e5dd8..e4bf09a0ca 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -407,4 +407,23 @@ TF_Function* TF_GraphToFunction_wrapper( opts, description, out_status); } +void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, + const std::vector& dims, + bool unknown_shape, TF_Status* status) { + if (unknown_shape) { + TF_GraphSetTensorShape(graph, output, nullptr, -1, status); + return; + } + TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status); +} + +std::vector TF_GraphGetTensorShape_wrapper(TF_Graph* graph, + TF_Output output, + int num_dims, + TF_Status* status) { + std::vector dims(num_dims); + TF_GraphGetTensorShape(graph, output, dims.data(), num_dims, status); + return dims; +} + } // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 6ed08d3a58..bb7171db31 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -168,6 +168,20 @@ TF_Function* TF_GraphToFunction_wrapper( const std::vector& inputs, const std::vector& outputs, const NameVector& output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* out_status); + +// Set the shape of output. If unknown is true, `num_dims` must be set to +// -1 and `dims` is set to nullptr. +void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, + const std::vector& dims, + bool unknown_shape, TF_Status* status); + +// Return the shape of output. `num_dims` should be the output of +// TF_GraphGetTensorNumDims. If `num_dims = -1`, this should not be called. +std::vector TF_GraphGetTensorShape_wrapper(TF_Graph* graph, + TF_Output output, + int num_dims, + TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 5f945ac133..13e6426447 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -374,6 +374,19 @@ class Tensor(_TensorLike): A `TensorShape` representing the shape of this tensor. """ + if _USE_C_API: + graph = self._op._graph._c_graph # pylint: disable=protected-access + with errors.raise_exception_on_not_ok_status() as status: + num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(), + status) + if num_dims == -1: + dim_list = None + else: + with errors.raise_exception_on_not_ok_status() as status: + dim_list = c_api.TF_GraphGetTensorShape_wrapper( + graph, self._as_tf_output(), num_dims, status) + dim_list = [None if i == -1 else i for i in dim_list] + return tensor_shape.TensorShape(dim_list) return self._shape def __iter__(self): @@ -393,8 +406,8 @@ class Tensor(_TensorLike): yield self[i] def _shape_as_list(self): - if self._shape.ndims is not None: - return [dim.value for dim in self._shape.dims] + if self.shape.ndims is not None: + return [dim.value for dim in self.shape.dims] else: return None @@ -410,7 +423,7 @@ class Tensor(_TensorLike): Returns: Integer rank or None """ - return self._shape.ndims + return self.shape.ndims def get_shape(self): """Alias of Tensor.shape.""" @@ -441,14 +454,35 @@ class Tensor(_TensorLike): ``` Args: - shape: A `TensorShape` representing the shape of this tensor. + shape: A `TensorShape` representing the shape of this tensor, a + `TensorShapeProto`, a list, a tuple, or None. Raises: ValueError: If `shape` is not compatible with the current shape of this tensor. """ - # TODO(skyewm): call C API - self._shape = self._shape.merge_with(shape) + if not _USE_C_API: + self._shape = self._shape.merge_with(shape) # pylint: disable=protected-access + return + if not isinstance(shape, tensor_shape.TensorShape): + shape = tensor_shape.TensorShape(shape) + dim_list = [] + if shape.dims is None: + unknown_shape = True + else: + unknown_shape = False + for dim in shape.dims: + if dim.value is None: + dim_list.append(-1) + else: + dim_list.append(dim.value) + with errors.raise_exception_on_not_ok_status() as status: + c_api.TF_GraphSetTensorShape_wrapper( + self._op._graph._c_graph, # pylint: disable=protected-access + self._as_tf_output(), + dim_list, + unknown_shape, + status) @property def value_index(self): @@ -4521,15 +4555,11 @@ def control_dependencies(control_inputs): See @{tf.Graph.control_dependencies} for more details. - When eager execution is enabled, any callable object in the `control_inputs` - list will be called. - Args: control_inputs: A list of `Operation` or `Tensor` objects which must be executed or computed before running the operations defined in the context. Can also be `None` to clear the control - dependencies. If eager execution is enabled, any callable object in the - `control_inputs` list will be called. + dependencies. Returns: A context manager that specifies control dependencies for all @@ -4538,11 +4568,6 @@ def control_dependencies(control_inputs): if context.in_graph_mode(): return get_default_graph().control_dependencies(control_inputs) else: - if control_inputs: - # Excute any pending callables. - for control in control_inputs: - if callable(control): - control() return _NullContextmanager() -- cgit v1.2.3