aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Olivia Nordquist <nolivia@google.com>2017-11-30 19:46:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 19:49:29 -0800
commit79ad4a423f3e9031eb841a164372cc7476cc112a (patch)
tree2210b3cd5493746e105a1edb58529f87b79ceb48
parent87e2f20c8b4f2ece313584c7c3c5588ee6ae5ece (diff)
enabling Tensor._set_shape() to work with the C API
PiperOrigin-RevId: 177543170
-rw-r--r--tensorflow/python/client/tf_session.i43
-rw-r--r--tensorflow/python/client/tf_session_helper.cc19
-rw-r--r--tensorflow/python/client/tf_session_helper.h14
-rw-r--r--tensorflow/python/framework/ops.py57
4 files changed, 117 insertions, 16 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 5fa1a7e8fc..d471a39b69 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -532,6 +532,49 @@ def TF_Reset(target, containers=None, config=None):
%unignore TF_GraphGetTensorShapeHelper;
%ignore TF_GraphGetTensorShape;
+// We use TF_GraphSetTensorShape_wrapper instead of
+// TF_GraphSetTensorShape
+%ignore TF_GraphSetTensorShape;
+%unignore tensorflow;
+%unignore TF_GraphSetTensorShape_wrapper;
+
+// $input is a Python list of ints to a vector<int> for TF_GraphSetTensorShape_wrapper
+%typemap(in) (const std::vector<int64_t>& dims)
+ (std::vector<int64_t> dims_local){
+ if ($input != Py_None) {
+ if (!PyList_Check($input)) {
+ SWIG_exception_fail(SWIG_TypeError, tensorflow::strings::Printf(
+ "$symname: expected list but got %s ", Py_TYPE($input)->tp_name).c_str());
+ }
+ size_t size = PyList_Size($input);
+ for (int i = 0; i < size; ++i) {
+ PyObject* item = PyList_GetItem($input, i);
+ dims_local.push_back(PyInt_AsLong(item));
+ }
+ $1 = &dims_local;
+ } else {
+ $1 = nullptr;
+ }
+}
+
+// We use TF_GraphGetTensorShape_wrapper instead of
+// TF_GraphGetTensorShape
+%ignore TF_GraphGetTensorShape;
+%unignore tensorflow;
+%unignore TF_GraphGetTensorShape_wrapper;
+
+// Build a Python list of ints and return it.
+%typemap(out) std::vector<int64_t> tensorflow::TF_GraphGetTensorShape_wrapper {
+ $result = PyList_New($1.size());
+ if (!$result) {
+ SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
+ }
+
+ for (size_t i = 0; i < $1.size(); ++i) {
+ PyList_SET_ITEM($result, i, PyInt_FromLong($1[i]));
+ }
+}
+
%include "tensorflow/python/client/tf_session_helper.h"
%unignoreall
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index ad982e5dd8..e4bf09a0ca 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -407,4 +407,23 @@ TF_Function* TF_GraphToFunction_wrapper(
opts, description, out_status);
}
+void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
+ const std::vector<int64_t>& dims,
+ bool unknown_shape, TF_Status* status) {
+ if (unknown_shape) {
+ TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
+ return;
+ }
+ TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
+}
+
+std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
+ TF_Output output,
+ int num_dims,
+ TF_Status* status) {
+ std::vector<int64_t> dims(num_dims);
+ TF_GraphGetTensorShape(graph, output, dims.data(), num_dims, status);
+ return dims;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 6ed08d3a58..bb7171db31 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -168,6 +168,20 @@ TF_Function* TF_GraphToFunction_wrapper(
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
const NameVector& output_names, const TF_FunctionOptions* opts,
const char* description, TF_Status* out_status);
+
+// Set the shape of output. If unknown is true, `num_dims` must be set to
+// -1 and `dims` is set to nullptr.
+void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
+ const std::vector<int64_t>& dims,
+ bool unknown_shape, TF_Status* status);
+
+// Return the shape of output. `num_dims` should be the output of
+// TF_GraphGetTensorNumDims. If `num_dims = -1`, this should not be called.
+std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
+ TF_Output output,
+ int num_dims,
+ TF_Status* status);
+
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 5f945ac133..13e6426447 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -374,6 +374,19 @@ class Tensor(_TensorLike):
A `TensorShape` representing the shape of this tensor.
"""
+ if _USE_C_API:
+ graph = self._op._graph._c_graph # pylint: disable=protected-access
+ with errors.raise_exception_on_not_ok_status() as status:
+ num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(),
+ status)
+ if num_dims == -1:
+ dim_list = None
+ else:
+ with errors.raise_exception_on_not_ok_status() as status:
+ dim_list = c_api.TF_GraphGetTensorShape_wrapper(
+ graph, self._as_tf_output(), num_dims, status)
+ dim_list = [None if i == -1 else i for i in dim_list]
+ return tensor_shape.TensorShape(dim_list)
return self._shape
def __iter__(self):
@@ -393,8 +406,8 @@ class Tensor(_TensorLike):
yield self[i]
def _shape_as_list(self):
- if self._shape.ndims is not None:
- return [dim.value for dim in self._shape.dims]
+ if self.shape.ndims is not None:
+ return [dim.value for dim in self.shape.dims]
else:
return None
@@ -410,7 +423,7 @@ class Tensor(_TensorLike):
Returns:
Integer rank or None
"""
- return self._shape.ndims
+ return self.shape.ndims
def get_shape(self):
"""Alias of Tensor.shape."""
@@ -441,14 +454,35 @@ class Tensor(_TensorLike):
```
Args:
- shape: A `TensorShape` representing the shape of this tensor.
+ shape: A `TensorShape` representing the shape of this tensor, a
+ `TensorShapeProto`, a list, a tuple, or None.
Raises:
ValueError: If `shape` is not compatible with the current shape of
this tensor.
"""
- # TODO(skyewm): call C API
- self._shape = self._shape.merge_with(shape)
+ if not _USE_C_API:
+ self._shape = self._shape.merge_with(shape) # pylint: disable=protected-access
+ return
+ if not isinstance(shape, tensor_shape.TensorShape):
+ shape = tensor_shape.TensorShape(shape)
+ dim_list = []
+ if shape.dims is None:
+ unknown_shape = True
+ else:
+ unknown_shape = False
+ for dim in shape.dims:
+ if dim.value is None:
+ dim_list.append(-1)
+ else:
+ dim_list.append(dim.value)
+ with errors.raise_exception_on_not_ok_status() as status:
+ c_api.TF_GraphSetTensorShape_wrapper(
+ self._op._graph._c_graph, # pylint: disable=protected-access
+ self._as_tf_output(),
+ dim_list,
+ unknown_shape,
+ status)
@property
def value_index(self):
@@ -4521,15 +4555,11 @@ def control_dependencies(control_inputs):
See @{tf.Graph.control_dependencies}
for more details.
- When eager execution is enabled, any callable object in the `control_inputs`
- list will be called.
-
Args:
control_inputs: A list of `Operation` or `Tensor` objects which
must be executed or computed before running the operations
defined in the context. Can also be `None` to clear the control
- dependencies. If eager execution is enabled, any callable object in the
- `control_inputs` list will be called.
+ dependencies.
Returns:
A context manager that specifies control dependencies for all
@@ -4538,11 +4568,6 @@ def control_dependencies(control_inputs):
if context.in_graph_mode():
return get_default_graph().control_dependencies(control_inputs)
else:
- if control_inputs:
- # Excute any pending callables.
- for control in control_inputs:
- if callable(control):
- control()
return _NullContextmanager()