diff options
Diffstat (limited to 'tensorflow/js/ops/ts_op_gen_test.cc')
-rw-r--r-- | tensorflow/js/ops/ts_op_gen_test.cc | 138 |
1 files changed, 86 insertions, 52 deletions
diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc index 9a85c021b0..03241689b5 100644 --- a/tensorflow/js/ops/ts_op_gen_test.cc +++ b/tensorflow/js/ops/ts_op_gen_test.cc @@ -36,7 +36,6 @@ void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) { << "'" << s << "' does not contain '" << expected << "'"; } -// TODO(kreeger): Add multiple outputs here? constexpr char kBaseOpDef[] = R"( op { name: "Foo" @@ -79,50 +78,15 @@ op { summary: "Summary for op Foo." description: "Description for op Foo." } -op { - name: "DeprecatedFoo" - input_arg { - name: "input" - description: "Description for input." - type: DT_FLOAT - } - output_arg { - name: "output" - description: "Description for output." - type: DT_FLOAT - } - deprecation { - explanation: "Deprecated." - } -} -op { - name: "MultiOutputFoo" - input_arg { - name: "input" - description: "Description for input." - type: DT_FLOAT - } - output_arg { - name: "output1" - description: "Description for output 1." - type: DT_FLOAT - } - output_arg { - name: "output2" - description: "Description for output 2." - type: DT_FLOAT - } - summary: "Summary for op MultiOutputFoo." - description: "Description for op MultiOutputFoo." -} )"; // Generate TypeScript code -// @param api_def_str TODO doc me. -void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) { +void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str, + string* ts_file_text) { Env* env = Env::Default(); OpList op_defs; - protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); + protobuf::TextFormat::ParseFromString( + op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs); ApiDefMap api_def_map(op_defs); if (!api_def_str.empty()) { @@ -138,11 +102,11 @@ void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) { TEST(TsOpGenTest, TestImports) { string ts_file_text; - GenerateTsOpFileText("", &ts_file_text); + GenerateTsOpFileText("", "", &ts_file_text); const string expected = R"( import * as tfc from '@tensorflow/tfjs-core'; -import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils'; +import {createTensorsTypeOpAttr, nodeBackend} from './op_utils'; )"; ExpectContainsStr(ts_file_text, expected); } @@ -160,12 +124,10 @@ op { )"; string ts_file_text; - GenerateTsOpFileText(api_def, &ts_file_text); + GenerateTsOpFileText("", api_def, &ts_file_text); const string expected = R"( export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { - return null; -} )"; ExpectContainsStr(ts_file_text, expected); } @@ -179,34 +141,106 @@ op { )"; string ts_file_text; - GenerateTsOpFileText(api_def, &ts_file_text); + GenerateTsOpFileText("", api_def, &ts_file_text); const string expected = R"( export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor { - return null; -} )"; ExpectDoesNotContainStr(ts_file_text, expected); } TEST(TsOpGenTest, SkipDeprecated) { + const string op_def = R"( +op { + name: "DeprecatedFoo" + input_arg { + name: "input" + type_attr: "T" + description: "Description for input." + } + output_arg { + name: "output" + description: "Description for output." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for input" + allowed_values { + list { + type: DT_FLOAT + } + } + } + deprecation { + explanation: "Deprecated." + } +} +)"; + string ts_file_text; - GenerateTsOpFileText("", &ts_file_text); + GenerateTsOpFileText(op_def, "", &ts_file_text); ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo"); } TEST(TsOpGenTest, MultiOutput) { + const string op_def = R"( +op { + name: "MultiOutputFoo" + input_arg { + name: "input" + description: "Description for input." + type_attr: "T" + } + output_arg { + name: "output1" + description: "Description for output 1." + type: DT_FLOAT + } + output_arg { + name: "output2" + description: "Description for output 2." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for input" + allowed_values { + list { + type: DT_FLOAT + } + } + } + summary: "Summary for op MultiOutputFoo." + description: "Description for op MultiOutputFoo." +} +)"; + string ts_file_text; - GenerateTsOpFileText("", &ts_file_text); + GenerateTsOpFileText(op_def, "", &ts_file_text); const string expected = R"( export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] { - return null; -} )"; ExpectContainsStr(ts_file_text, expected); } +TEST(TsOpGenTest, OpAttrs) { + string ts_file_text; + GenerateTsOpFileText("", "", &ts_file_text); + + const string expectedFooAttrs = R"( + const opAttrs = [ + createTensorsTypeOpAttr('T', images), + {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length} + ]; +)"; + + ExpectContainsStr(ts_file_text, expectedFooAttrs); +} + } // namespace } // namespace tensorflow |