aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2017-01-17 19:51:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-17 20:07:35 -0800
commit1c7ef3db0840d6530e7dde693b9351b8205f7e84 (patch)
tree5ef7d9e688a8fcf7787bff2dd59dae5ef5cfccde
parent64ce45728c19c1058196dd64b19b2d10b7af42a3 (diff)
linear_operator_test_util.py: test_log_abs_det added. This tests
LinearOperator.log_abs_determinant. This test was missing previously. Change: 144787893
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
index 466fedd578..85cd7fcd9a 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
@@ -174,6 +174,29 @@ class LinearOperatorDerivedClassTest(test.TestCase):
feed_dict=feed_dict)
self.assertAC(op_det_v, mat_det_v)
+ def test_log_abs_det(self):
+ self._maybe_skip("log_abs_det")
+ for use_placeholder in False, True:
+ for shape in self._shapes_to_test:
+ for dtype in self._dtypes_to_test:
+ if dtype.is_complex:
+ self.skipTest(
+ "tf.matrix_determinant does not work with complex, so this "
+ "test is being skipped.")
+ with self.test_session(graph=ops.Graph()) as sess:
+ sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
+ operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ shape, dtype, use_placeholder=use_placeholder)
+ op_log_abs_det = operator.log_abs_determinant()
+ mat_log_abs_det = math_ops.log(
+ math_ops.abs(linalg_ops.matrix_determinant(mat)))
+ if not use_placeholder:
+ self.assertAllEqual(shape[:-2], op_log_abs_det.get_shape())
+ op_log_abs_det_v, mat_log_abs_det_v = sess.run(
+ [op_log_abs_det, mat_log_abs_det],
+ feed_dict=feed_dict)
+ self.assertAC(op_log_abs_det_v, mat_log_abs_det_v)
+
def test_apply(self):
self._maybe_skip("apply")
for use_placeholder in False, True:
@@ -291,7 +314,7 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
@property
def _tests_to_skip(self):
"""List of test names to skip."""
- return ["solve", "det"]
+ return ["solve", "det", "log_abs_det"]
@property
def _shapes_to_test(self):