aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2017-01-17 10:52:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-17 11:09:57 -0800
commit64ea20632bf346a9474b4e0420f1277e8054a002 (patch)
tree39d4e48174e0a5caf5ba4707d6e849e4e747efa9
parent3c44578744668d8524f78c33daef0ccd43f57b25 (diff)
Name change in LinearOperator: batch_shape_dynamic --> batch_shape_tensor.
Similarly for other "dynamic" Ops. Change: 144728885
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijector.py2
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py8
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py18
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py2
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator.py76
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_composition.py10
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_diag.py2
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_identity.py10
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_matrix.py2
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py10
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_tril.py2
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_util.py4
12 files changed, 74 insertions, 72 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py
index 7e92f49677..41a4f9d859 100644
--- a/tensorflow/contrib/distributions/python/ops/bijector.py
+++ b/tensorflow/contrib/distributions/python/ops/bijector.py
@@ -1977,7 +1977,7 @@ class AffineLinearOperator(Bijector):
if scale.tensor_rank is not None:
batch_ndims = scale.tensor_rank - 2
else:
- batch_ndims = scale.tensor_rank_dynamic() - 2
+ batch_ndims = scale.tensor_rank_tensor() - 2
graph_parents += [batch_ndims]
else:
batch_ndims = 0 # We won't need shape inference when scale is None.
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py
index 2f60554104..6309d36258 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py
@@ -200,16 +200,16 @@ class NonSquareLinearOperatorCompositionTest(
operator = linalg.LinearOperatorComposition(operators)
self.assertAllEqual((2, 3, 5), operator.shape)
- def test_dynamic_shapes_when_statically_available(self):
+ def test_shape_tensors_when_statically_available(self):
operators = [
linalg.LinearOperatorMatrix(rng.rand(2, 3, 4)),
linalg.LinearOperatorMatrix(rng.rand(2, 4, 5))
]
operator = linalg.LinearOperatorComposition(operators)
with self.test_session():
- self.assertAllEqual((2, 3, 5), operator.shape_dynamic().eval())
+ self.assertAllEqual((2, 3, 5), operator.shape_tensor().eval())
- def test_dynamic_shapes_when_only_dynamically_available(self):
+ def test_shape_tensors_when_only_dynamically_available(self):
mat_1 = rng.rand(1, 2, 3, 4)
mat_2 = rng.rand(1, 2, 4, 5)
mat_ph_1 = array_ops.placeholder(dtypes.float64)
@@ -223,7 +223,7 @@ class NonSquareLinearOperatorCompositionTest(
operator = linalg.LinearOperatorComposition(operators)
with self.test_session():
self.assertAllEqual(
- (1, 2, 3, 5), operator.shape_dynamic().eval(feed_dict=feed_dict))
+ (1, 2, 3, 5), operator.shape_tensor().eval(feed_dict=feed_dict))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py
index 8f77c5e6e3..c099194eed 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py
@@ -31,7 +31,7 @@ rng = np.random.RandomState(123)
class LinearOperatorShape(linalg.LinearOperator):
- """LinearOperator that implements the methods ._shape and _shape_dynamic."""
+ """LinearOperator that implements the methods ._shape and _shape_tensor."""
def __init__(self,
shape,
@@ -49,7 +49,7 @@ class LinearOperatorShape(linalg.LinearOperator):
def _shape(self):
return tensor_shape.TensorShape(self._stored_shape)
- def _shape_dynamic(self):
+ def _shape_tensor(self):
return constant_op.constant(self._stored_shape, dtype=dtypes.int32)
@@ -71,7 +71,7 @@ class LinearOperatorApplyOnly(linalg.LinearOperator):
def _shape(self):
return self._matrix.get_shape()
- def _shape_dynamic(self):
+ def _shape_tensor(self):
return array_ops.shape(self._matrix)
def _apply(self, x, adjoint=False):
@@ -96,11 +96,11 @@ class LinearOperatorTest(test.TestCase):
shape = (1, 2, 3, 4)
operator = LinearOperatorShape(shape)
- self.assertAllEqual(shape, operator.shape_dynamic().eval())
- self.assertAllEqual(4, operator.tensor_rank_dynamic().eval())
- self.assertAllEqual((1, 2), operator.batch_shape_dynamic().eval())
- self.assertAllEqual(4, operator.domain_dimension_dynamic().eval())
- self.assertAllEqual(3, operator.range_dimension_dynamic().eval())
+ self.assertAllEqual(shape, operator.shape_tensor().eval())
+ self.assertAllEqual(4, operator.tensor_rank_tensor().eval())
+ self.assertAllEqual((1, 2), operator.batch_shape_tensor().eval())
+ self.assertAllEqual(4, operator.domain_dimension_tensor().eval())
+ self.assertAllEqual(3, operator.range_dimension_tensor().eval())
def test_is_x_properties(self):
operator = LinearOperatorShape(
@@ -120,7 +120,7 @@ class LinearOperatorTest(test.TestCase):
self.assertAllEqual((2, 3, 4), operator_dense.get_shape())
self.assertAllClose(matrix, operator_dense.eval())
- def test_generic_to_dense_method_non_square_matrix_dynamic(self):
+ def test_generic_to_dense_method_non_square_matrix_tensor(self):
matrix = rng.randn(2, 3, 4)
matrix_ph = array_ops.placeholder(dtypes.float64)
operator = LinearOperatorApplyOnly(matrix_ph)
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
index 4eac01092f..bf6f8f8302 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
@@ -96,7 +96,7 @@ class DomainDimensionStubOperator(object):
def __init__(self, domain_dimension):
self._domain_dimension = ops.convert_to_tensor(domain_dimension)
- def domain_dimension_dynamic(self):
+ def domain_dimension_tensor(self):
return self._domain_dimension
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator.py b/tensorflow/contrib/linalg/python/ops/linear_operator.py
index e229820edc..2467603605 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator.py
@@ -180,13 +180,15 @@ class LinearOperator(object):
self._is_positive_definite = is_positive_definite
self._name = name or type(self).__name__
- # We will cache some values to avoid repeatedly adding shape
- # manipulation ops to the graph. Cleaner.
- self._cached_shape_dynamic = None
- self._cached_batch_shape_dynamic = None
- self._cached_domain_dimension_dynamic = None
- self._cached_range_dimension_dynamic = None
- self._cached_tensor_rank_dynamic = None
+ # We will cache some tensors to avoid repeatedly adding shape
+ # manipulation ops to the graph.
+ # Naming convention:
+ # self._cached_X_tensor is the cached version of self._X_tensor.
+ self._cached_shape_tensor = None
+ self._cached_batch_shape_tensor = None
+ self._cached_domain_dimension_tensor = None
+ self._cached_range_dimension_tensor = None
+ self._cached_tensor_rank_tensor = None
@contextlib.contextmanager
def _name_scope(self, name=None, values=None):
@@ -240,10 +242,10 @@ class LinearOperator(object):
"""
return self._shape()
- def _shape_dynamic(self):
- raise NotImplementedError("_shape_dynamic is not implemented.")
+ def _shape_tensor(self):
+ raise NotImplementedError("_shape_tensor is not implemented.")
- def shape_dynamic(self, name="shape_dynamic"):
+ def shape_tensor(self, name="shape_tensor"):
"""Shape of this `LinearOperator`, determined at runtime.
If this operator acts like the batch matrix `A` with
@@ -258,14 +260,14 @@ class LinearOperator(object):
"""
with self._name_scope(name):
# Be clean by avoiding adding shape Ops to the graph too many times.
- if self._cached_shape_dynamic is None:
+ if self._cached_shape_tensor is None:
# Prefer to use statically defined shape if available.
if self.shape.is_fully_defined():
- self._cached_shape_dynamic = linear_operator_util.shape_tensor(
+ self._cached_shape_tensor = linear_operator_util.shape_tensor(
self.shape.as_list())
else:
- self._cached_shape_dynamic = self._shape_dynamic()
- return self._cached_shape_dynamic
+ self._cached_shape_tensor = self._shape_tensor()
+ return self._cached_shape_tensor
@property
def batch_shape(self):
@@ -281,7 +283,7 @@ class LinearOperator(object):
# Derived classes get this "for free" once .shape is implemented.
return self.shape[:-2]
- def batch_shape_dynamic(self, name="batch_shape_dynamic"):
+ def batch_shape_tensor(self, name="batch_shape_tensor"):
"""Shape of batch dimensions of this operator, determined at runtime.
If this operator acts like the batch matrix `A` with
@@ -296,14 +298,14 @@ class LinearOperator(object):
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- if self._cached_batch_shape_dynamic is None:
+ if self._cached_batch_shape_tensor is None:
# Prefer to use statically defined shape if available.
if self.batch_shape.is_fully_defined():
- self._cached_batch_shape_dynamic = linear_operator_util.shape_tensor(
+ self._cached_batch_shape_tensor = linear_operator_util.shape_tensor(
self.batch_shape.as_list(), name="batch_shape")
else:
- self._cached_batch_shape_dynamic = self.shape_dynamic()[:-2]
- return self._cached_batch_shape_dynamic
+ self._cached_batch_shape_tensor = self.shape_tensor()[:-2]
+ return self._cached_batch_shape_tensor
@property
def tensor_rank(self, name="tensor_rank"):
@@ -322,7 +324,7 @@ class LinearOperator(object):
with self._name_scope(name):
return self.shape.ndims
- def tensor_rank_dynamic(self, name="tensor_rank_dynamic"):
+ def tensor_rank_tensor(self, name="tensor_rank_tensor"):
"""Rank (in the sense of tensors) of matrix corresponding to this operator.
If this operator acts like the batch matrix `A` with
@@ -336,15 +338,15 @@ class LinearOperator(object):
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- if self._cached_tensor_rank_dynamic is None:
+ if self._cached_tensor_rank_tensor is None:
# Prefer to use statically defined shape if available.
if self.tensor_rank is not None:
- self._cached_tensor_rank_dynamic = ops.convert_to_tensor(
+ self._cached_tensor_rank_tensor = ops.convert_to_tensor(
self.tensor_rank)
else:
- self._cached_tensor_rank_dynamic = array_ops.size(
- self.shape_dynamic())
- return self._cached_tensor_rank_dynamic
+ self._cached_tensor_rank_tensor = array_ops.size(
+ self.shape_tensor())
+ return self._cached_tensor_rank_tensor
@property
def domain_dimension(self):
@@ -359,7 +361,7 @@ class LinearOperator(object):
# Derived classes get this "for free" once .shape is implemented.
return self.shape[-1]
- def domain_dimension_dynamic(self, name="domain_dimension_dynamic"):
+ def domain_dimension_tensor(self, name="domain_dimension_tensor"):
"""Dimension (in the sense of vector spaces) of the domain of this operator.
Determined at runtime.
@@ -375,14 +377,14 @@ class LinearOperator(object):
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- if self._cached_domain_dimension_dynamic is None:
+ if self._cached_domain_dimension_tensor is None:
# Prefer to use statically defined shape if available.
if self.domain_dimension.value is not None:
- self._cached_domain_dimension_dynamic = ops.convert_to_tensor(
+ self._cached_domain_dimension_tensor = ops.convert_to_tensor(
self.domain_dimension.value)
else:
- self._cached_domain_dimension_dynamic = self.shape_dynamic()[-1]
- return self._cached_domain_dimension_dynamic
+ self._cached_domain_dimension_tensor = self.shape_tensor()[-1]
+ return self._cached_domain_dimension_tensor
@property
def range_dimension(self):
@@ -397,7 +399,7 @@ class LinearOperator(object):
# Derived classes get this "for free" once .shape is implemented.
return self.shape[-2]
- def range_dimension_dynamic(self, name="range_dimension_dynamic"):
+ def range_dimension_tensor(self, name="range_dimension_tensor"):
"""Dimension (in the sense of vector spaces) of the range of this operator.
Determined at runtime.
@@ -413,14 +415,14 @@ class LinearOperator(object):
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- if self._cached_range_dimension_dynamic is None:
+ if self._cached_range_dimension_tensor is None:
# Prefer to use statically defined shape if available.
if self.range_dimension.value is not None:
- self._cached_range_dimension_dynamic = ops.convert_to_tensor(
+ self._cached_range_dimension_tensor = ops.convert_to_tensor(
self.range_dimension.value)
else:
- self._cached_range_dimension_dynamic = self.shape_dynamic()[-2]
- return self._cached_range_dimension_dynamic
+ self._cached_range_dimension_tensor = self.shape_tensor()[-2]
+ return self._cached_range_dimension_tensor
def _assert_non_singular(self):
raise NotImplementedError("assert_non_singular is not implemented.")
@@ -574,12 +576,12 @@ class LinearOperator(object):
if self.batch_shape.is_fully_defined():
batch_shape = self.batch_shape
else:
- batch_shape = self.batch_shape_dynamic()
+ batch_shape = self.batch_shape_tensor()
if self.domain_dimension.value is not None:
n = self.domain_dimension.value
else:
- n = self.domain_dimension_dynamic()
+ n = self.domain_dimension_tensor()
eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
return self.apply(eye)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py
index 3e118ebbd4..81e7735841 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py
@@ -202,7 +202,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
return batch_shape.concatenate(matrix_shape)
- def _shape_dynamic(self):
+ def _shape_tensor(self):
# Avoid messy broadcasting if possible.
if self.shape.is_fully_defined():
return ops.convert_to_tensor(
@@ -212,14 +212,14 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
# the graph. Things will fail at runtime naturally if shapes are
# incompatible.
matrix_shape = array_ops.stack([
- self.operators[0].range_dimension_dynamic(),
- self.operators[-1].domain_dimension_dynamic()
+ self.operators[0].range_dimension_tensor(),
+ self.operators[-1].domain_dimension_tensor()
])
# Dummy Tensor of zeros. Will never be materialized.
- zeros = array_ops.zeros(shape=self.operators[0].batch_shape_dynamic())
+ zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
for operator in self.operators[1:]:
- zeros += array_ops.zeros(shape=operator.batch_shape_dynamic())
+ zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
batch_shape = array_ops.shape(zeros)
return array_ops.concat((batch_shape, matrix_shape), 0)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
index d59e8be767..4700e65518 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
@@ -166,7 +166,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
d_shape = self._diag.get_shape()
return d_shape.concatenate(d_shape[-1:])
- def _shape_dynamic(self):
+ def _shape_tensor(self):
d_shape = array_ops.shape(self._diag)
k = d_shape[-1]
return array_ops.concat((d_shape, [k]), 0)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py
index 3304698ec6..6559f8b116 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py
@@ -261,7 +261,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
return batch_shape.concatenate(matrix_shape)
- def _shape_dynamic(self):
+ def _shape_tensor(self):
matrix_shape = array_ops.stack(
(self._num_rows, self._num_rows), axis=0)
if self._batch_shape_arg is None:
@@ -307,7 +307,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
# Dynamic broadcast:
# Always add to an array of zeros, rather than using a "cond", since a
# cond would require copying data from GPU --> CPU.
- special_shape = array_ops.concat((self.batch_shape_dynamic(), [1, 1]), 0)
+ special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0)
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
return x + zeros
@@ -320,10 +320,10 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
return self._possibly_broadcast_batch_shape(x)
def _determinant(self):
- return array_ops.ones(shape=self.batch_shape_dynamic(), dtype=self.dtype)
+ return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype)
def _log_abs_determinant(self):
- return array_ops.zeros(shape=self.batch_shape_dynamic(), dtype=self.dtype)
+ return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
def _solve(self, rhs, adjoint=False):
return self._apply(rhs)
@@ -566,7 +566,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
batch_shape = self.multiplier.get_shape()
return batch_shape.concatenate(matrix_shape)
- def _shape_dynamic(self):
+ def _shape_tensor(self):
matrix_shape = array_ops.stack(
(self._num_rows, self._num_rows), axis=0)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_matrix.py b/tensorflow/contrib/linalg/python/ops/linear_operator_matrix.py
index 7ca18450d1..3b5dc7c481 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_matrix.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_matrix.py
@@ -157,7 +157,7 @@ class LinearOperatorMatrix(linear_operator.LinearOperator):
def _shape(self):
return self._matrix.get_shape()
- def _shape_dynamic(self):
+ def _shape_tensor(self):
return array_ops.shape(self._matrix)
def _apply(self, x, adjoint=False):
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
index 5de9bb5d77..466fedd578 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
@@ -262,8 +262,8 @@ class SquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
n = operator.domain_dimension.value
x_shape = batch_shape + [n, r]
else:
- batch_shape = operator.batch_shape_dynamic()
- n = operator.domain_dimension_dynamic()
+ batch_shape = operator.batch_shape_tensor()
+ n = operator.domain_dimension_tensor()
x_shape = array_ops.concat((batch_shape, [n, r]), 0)
return random_normal(x_shape, dtype=operator.dtype)
@@ -316,11 +316,11 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
n = operator.domain_dimension.value
x_shape = batch_shape + [n, r]
else:
- batch_shape = operator.batch_shape_dynamic()
+ batch_shape = operator.batch_shape_tensor()
if adjoint:
- n = operator.range_dimension_dynamic()
+ n = operator.range_dimension_tensor()
else:
- n = operator.domain_dimension_dynamic()
+ n = operator.domain_dimension_tensor()
x_shape = array_ops.concat((batch_shape, [n, r]), 0)
return random_normal(x_shape, dtype=operator.dtype)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py
index 7c5b9b6b54..2b1fb4c04c 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py
@@ -157,7 +157,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
def _shape(self):
return self._tril.get_shape()
- def _shape_dynamic(self):
+ def _shape_tensor(self):
return array_ops.shape(self._tril)
def _assert_non_singular(self):
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
index 44092f0c06..6e56fac2e3 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
@@ -83,10 +83,10 @@ def assert_compatible_matrix_dimensions(operator, x):
Returns:
`Assert` `Op`.
"""
- # Static checks are done in the base class. Only dynamic asserts here.
+ # Static checks are done in the base class. Only tensor asserts here.
assert_same_dd = check_ops.assert_equal(
array_ops.shape(x)[-2],
- operator.domain_dimension_dynamic(),
+ operator.domain_dimension_tensor(),
message=(
"Incompatible matrix dimensions. "
"shape[-2] of argument to be the same as this operator"))