aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-09-12 17:20:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 17:24:17 -0700
commit97d7281354af43ed5fd53ebf729cea76de84acdb (patch)
tree55d04925af27ceec6ae12307a18439e8ed7b68e3 /tensorflow/python/eager
parent20192a94258c870e617c8cf71d23a297383f05f2 (diff)
eager: Graceful failure on invalid inputs.
Tests added to pywrap_tfe_test.py would fail (segmentation fault / infinite loop) without corresponding fixes to pywrap_tfe.i and pywrap_tfe_src.cc Other statements that would fail ungracefully without this fix (and with eager execution enabled) include: tf.split(value=0, num_or_size_splits=-1) tf.dynamic_partition(data=0, partitions=0, num_partitions=-1) tf.split(value=0, num_or_size_splits=1.23, num=-1) tf.unstack(value=0, num=-1) PiperOrigin-RevId: 212731927
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc13
-rw-r--r--tensorflow/python/eager/pywrap_tfe_test.py25
3 files changed, 33 insertions, 6 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 85da1baaf0..c1bc27d443 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -345,6 +345,7 @@ py_test(
deps = [
":backprop",
":context",
+ ":core",
":test",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index c6a55949ab..1a8f3577b2 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -2563,13 +2563,18 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
int num_retvals = 0;
for (int i = 0; i < op_def->output_arg_size(); i++) {
const auto& output_arg = op_def->output_arg(i);
+ int delta = 1;
if (!output_arg.number_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.number_attr()];
+ delta = attr_list_sizes[output_arg.number_attr()];
} else if (!output_arg.type_list_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.type_list_attr()];
- } else {
- num_retvals++;
+ delta = attr_list_sizes[output_arg.type_list_attr()];
+ }
+ if (delta < 0) {
+ RaiseFallbackException(
+ "Attributes suggest that the size of an output list is less than 0");
+ return nullptr;
}
+ num_retvals += delta;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
index fd8ab695b8..669fa08488 100644
--- a/tensorflow/python/eager/pywrap_tfe_test.py
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import core
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -123,8 +124,8 @@ class Tests(test.TestCase):
def testFastpathExecute_MixedPrecisionVariableTapeWrite(self):
ctx = context.context()
with backprop.GradientTape(persistent=True) as tape:
- a_2_by_2 = constant_op.constant(
- [[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32)
+ a_2_by_2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]],
+ dtype=dtypes.float32)
a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16)
m1 = resource_variable_ops.ResourceVariable(a_2_by_2)
m2 = resource_variable_ops._MixedPrecisionVariable(
@@ -233,6 +234,26 @@ class Tests(test.TestCase):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
ctx_handle, None, [], a_2_by_2)
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastPathExecute_InvalidAttributes(self):
+ split_dim = constant_op.constant(0, dtype=dtypes.int32)
+ value = constant_op.constant([0, 1, 2, 3], dtype=dtypes.float32)
+ ctx = context.context()
+ ctx_handle = ctx._handle
+ with self.assertRaises(core._FallbackException):
+ pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
+ "Split", None, None, split_dim,
+ value, "num_split", -1)
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testInvalidNumOutputs(self):
+ with self.assertRaisesRegexp(
+ Exception,
+ "Value for attr 'num_split' of -1 must be at least minimum 1"):
+ array_ops.split(value=[1, 2, 3], num_or_size_splits=-1)
+
if __name__ == "__main__":
test.main()