aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries/examples
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-02-09 10:13:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-09 10:16:58 -0800
commit81b698d0dad88b0bb33af720247a3eba1d248e9d (patch)
tree3b2647c2558fcc02be393a2c57faea2031ffb572 /tensorflow/contrib/timeseries/examples
parent5e6e691ba320fba71aec0a426ff6e5608ddcad1d (diff)
TFTS: Better handling of exogenous features
Adds (dummy) exogenous features to the LSTM model-building example, and adds some small methods needed to support that (fetching the shape of embedded exogenous features). Also makes it more automatic to export a SavedModel with exogenous features (placeholder shapes will be inferred from the given FeatureColumns), which makes the LSTM example friendlier. PiperOrigin-RevId: 185157085
Diffstat (limited to 'tensorflow/contrib/timeseries/examples')
-rw-r--r--tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv200
-rw-r--r--tensorflow/contrib/timeseries/examples/lstm.py56
2 files changed, 142 insertions, 114 deletions
diff --git a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv
index 02a60d1cf6..b49a0662c2 100644
--- a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv
+++ b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv
@@ -1,100 +1,100 @@
-0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867
-1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303
-2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864
-3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426
-4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223
-5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842
-6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606
-7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347
-8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951
-9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228
-10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897
-11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634
-12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594
-13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394
-14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609
-15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449
-16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251
-17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382
-18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767
-19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713
-20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251
-21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811
-22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681
-23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735
-24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436
-25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899
-26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814
-27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727
-28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582
-29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555
-30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696
-31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548
-32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627
-33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104
-34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156
-35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459
-36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576
-37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584
-38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577
-39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467
-40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566
-41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909
-42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021
-43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831
-44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905
-45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271
-46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094
-47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554
-48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769
-49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606
-50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629
-51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199
-52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961
-53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122
-54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454
-55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301
-56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182
-57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365
-58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011
-59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449
-60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229
-61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259
-62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272
-63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989
-64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496
-65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376
-66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206
-67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502
-68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219
-69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125
-70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514
-71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166
-72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832
-73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913
-74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188
-75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388
-76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136
-77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766
-78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959
-79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083
-80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483
-81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656
-82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107
-83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991
-84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527
-85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649
-86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788
-87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289
-88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298
-89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873
-90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669
-91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462
-92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232
-93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225
-94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288
-95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086
-96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161
-97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227
-98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937
-99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724
+0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0.
+1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0.
+2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0.
+3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0.
+4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0.
+5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0.
+6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0.
+7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0.
+8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0.
+9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0.
+10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0.
+11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0.
+12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0.
+13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0.
+14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0.
+15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0.
+16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0.
+17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0.
+18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0.
+19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0.
+20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0.
+21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0.
+22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0.
+23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0.
+24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0.
+25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0.
+26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0.
+27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0.
+28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0.
+29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0.
+30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0.
+31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0.
+32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0.
+33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0.
+34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0.
+35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0.
+36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0.
+37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0.
+38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0.
+39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0.
+40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0.
+41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0.
+42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0.
+43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0.
+44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0.
+45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0.
+46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0.
+47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0.
+48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0.
+49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0.
+50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0.
+51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0.
+52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0.
+53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0.
+54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0.
+55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0.
+56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0.
+57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0.
+58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0.
+59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0.
+60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0.
+61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0.
+62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0.
+63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0.
+64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0.
+65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0.
+66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0.
+67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0.
+68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0.
+69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0.
+70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0.
+71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0.
+72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0.
+73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0.
+74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0.
+75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0.
+76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0.
+77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0.
+78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0.
+79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0.
+80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0.
+81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0.
+82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0.
+83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0.
+84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0.
+85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0.
+86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0.
+87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0.
+88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0.
+89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0.
+90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0.
+91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0.
+92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0.
+93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0.
+94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0.
+95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0.
+96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0.
+97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0.
+98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0.
+99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0.
diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py
index 630f4fc057..f37cafcc50 100644
--- a/tensorflow/contrib/timeseries/examples/lstm.py
+++ b/tensorflow/contrib/timeseries/examples/lstm.py
@@ -48,7 +48,8 @@ _DATA_FILE = path.join(_MODULE_PATH, "data/multivariate_periods.csv")
class _LSTMModel(ts_model.SequentialTimeSeriesModel):
"""A time series model-building example using an RNNCell."""
- def __init__(self, num_units, num_features, dtype=tf.float32):
+ def __init__(self, num_units, num_features, exogenous_feature_columns=None,
+ dtype=tf.float32):
"""Initialize/configure the model object.
Note that we do not start graph building here. Rather, this object is a
@@ -58,6 +59,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
num_units: The number of units in the model's LSTMCell.
num_features: The dimensionality of the time series (features per
timestep).
+ exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
+ objects representing features which are inputs to the model but are
+ not predicted by it. These must then be present for training,
+ evaluation, and prediction.
dtype: The floating point data type to use.
"""
super(_LSTMModel, self).__init__(
@@ -65,6 +70,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
train_output_names=["mean"],
predict_output_names=["mean"],
num_features=num_features,
+ exogenous_feature_columns=exogenous_feature_columns,
dtype=dtype)
self._num_units = num_units
# Filled in by initialize_graph()
@@ -104,6 +110,8 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
tf.zeros([], dtype=tf.int64),
# The previous observation or prediction.
tf.zeros([self.num_features], dtype=self.dtype),
+ # The most recently seen exogenous features.
+ tf.zeros(self._get_exogenous_embedding_shape(), dtype=self.dtype),
# The state of the RNNCell (batch dimension removed since this parent
# class will broadcast).
[tf.squeeze(state_element, axis=0)
@@ -131,7 +139,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
loss (note that we could also return other measures of goodness of fit,
although only "loss" will be optimized).
"""
- state_from_time, prediction, lstm_state = state
+ state_from_time, prediction, exogenous, lstm_state = state
with tf.control_dependencies(
[tf.assert_equal(current_times, state_from_time)]):
# Subtract the mean and divide by the variance of the series. Slightly
@@ -143,16 +151,22 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
(prediction - transformed_values) ** 2, axis=-1)
# Keep track of the new observation in model state. It won't be run
# through the LSTM until the next _imputation_step.
- new_state_tuple = (current_times, transformed_values, lstm_state)
+ new_state_tuple = (current_times, transformed_values,
+ exogenous, lstm_state)
return (new_state_tuple, predictions)
def _prediction_step(self, current_times, state):
"""Advance the RNN state using a previous observation or prediction."""
- _, previous_observation_or_prediction, lstm_state = state
+ _, previous_observation_or_prediction, exogenous, lstm_state = state
+ # Update LSTM state based on the most recent exogenous and endogenous
+ # features.
+ inputs = tf.concat([previous_observation_or_prediction, exogenous],
+ axis=-1)
lstm_output, new_lstm_state = self._lstm_cell_run(
- inputs=previous_observation_or_prediction, state=lstm_state)
+ inputs=inputs, state=lstm_state)
next_prediction = self._predict_from_lstm_output(lstm_output)
- new_state_tuple = (current_times, next_prediction, new_lstm_state)
+ new_state_tuple = (current_times, next_prediction,
+ exogenous, new_lstm_state)
return new_state_tuple, {"mean": self._scale_back_data(next_prediction)}
def _imputation_step(self, current_times, state):
@@ -164,9 +178,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
def _exogenous_input_step(
self, current_times, current_exogenous_regressors, state):
- """Update model state based on exogenous regressors."""
- raise NotImplementedError(
- "Exogenous inputs are not implemented for this example.")
+ """Save exogenous regressors in model state for use in _prediction_step."""
+ state_from_time, prediction, _, lstm_state = state
+ return (state_from_time, prediction,
+ current_exogenous_regressors, lstm_state)
def train_and_predict(
@@ -174,24 +189,37 @@ def train_and_predict(
export_directory=None):
"""Train and predict using a custom time series model."""
# Construct an Estimator from our LSTM model.
+ exogenous_feature_columns = [
+ # Exogenous features are not part of the loss, but can inform
+ # predictions. In this example the features have no extra information, but
+ # are included as an API example.
+ tf.contrib.layers.real_valued_column(
+ "2d_exogenous_feature", dimension=2)]
estimator = ts_estimators.TimeSeriesRegressor(
- model=_LSTMModel(num_features=5, num_units=128),
+ model=_LSTMModel(num_features=5, num_units=128,
+ exogenous_feature_columns=exogenous_feature_columns),
optimizer=tf.train.AdamOptimizer(0.001), config=estimator_config,
# Set state to be saved across windows.
state_manager=state_management.ChainingStateManager())
reader = tf.contrib.timeseries.CSVReader(
csv_file_name,
column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,)
- + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5))
+ + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5
+ + ("2d_exogenous_feature",) * 2))
train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
reader, batch_size=4, window_size=32)
estimator.train(input_fn=train_input_fn, steps=training_steps)
evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)
# Predict starting after the evaluation
+ predict_exogenous_features = {
+ "2d_exogenous_feature": numpy.concatenate(
+ [numpy.ones([1, 100, 1]), numpy.zeros([1, 100, 1])],
+ axis=-1)}
(predictions,) = tuple(estimator.predict(
input_fn=tf.contrib.timeseries.predict_continuation_input_fn(
- evaluation, steps=100)))
+ evaluation, steps=100,
+ exogenous_features=predict_exogenous_features)))
times = evaluation["times"][0]
observed = evaluation["observed"][0, :, :]
predicted_mean = numpy.squeeze(numpy.concatenate(
@@ -204,7 +232,6 @@ def train_and_predict(
input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
export_location = estimator.export_savedmodel(
export_directory, input_receiver_fn)
-
# Predict using the SavedModel
with tf.Graph().as_default():
with tf.Session() as session:
@@ -213,7 +240,8 @@ def train_and_predict(
saved_model_output = (
tf.contrib.timeseries.saved_model_utils.predict_continuation(
continue_from=evaluation, signatures=signatures,
- session=session, steps=100))
+ session=session, steps=100,
+ exogenous_features=predict_exogenous_features))
# The exported model gives the same results as the Estimator.predict()
# call above.
numpy.testing.assert_allclose(