#!/usr/bin/python3 # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. from socket import AF_INET, AF_INET6 import nftables import pyroute2 def send_command(nft, s): '''Sends a command to nftables. Sends the specified command string to the specified nftables instance. Waits for acknowledgement. Raises RuntimeError if the command fails. ''' print(s) failed, stdout, stderr = nft.cmd(s) if failed: raise RuntimeError('nftables command failed: %s' % stderr) def set_name(family): '''Generates the relevant nftables set name for the specified protocol.''' if family == AF_INET: return 'local_subnets4' elif family == AF_INET6: return 'local_subnets6' raise ValueError("don't have a set for family %d" % family) def flush(family): '''Creates an nftables command to flush the local subnet set.''' return 'flush set inet filter %s' % set_name(family) def add(msg): '''Creates an nftables command to add a subnet. Processes a netlink message to extract its subnet, and creates an nftables command to add that subnet to the appropriate set. ''' for k, v in msg['attrs']: if k == 'IFA_ADDRESS': subnet = v + '/' + str(msg['prefixlen']) break else: raise ValueError("message lacks 'IFA_ADDRESS' attribute") return 'add element inet filter %s { %s }' % (set_name( msg['family']), subnet) def add_current(netlink, family=None): '''Creates an nftables command to add all current subnets.''' r = [] for msg in netlink.get_addr(): if family and msg['family'] != family: continue r.append(add(msg)) return '; '.join(r) def maintain_sets(nft, netlink): # Start listening for netlink messages now so we don't miss any while # we're querying the current IP address set. netlink.bind() # Populate the sets with the initial set of IP addresses. send_command( nft, '; '.join([flush(AF_INET), flush(AF_INET6), add_current(netlink)])) # Watch for changes. while True: for msg in netlink.get(): if msg['event'] == 'RTM_NEWADDR': send_command(nft, add(msg)) elif msg['event'] == 'RTM_DELADDR': # We might still be holding another address on the same subnet, # so we can't just delete this subnet. Rebuild the subnet set # from scratch. fam = msg['family'] send_command(nft, flush(fam) + '; ' + add_current(netlink, fam)) if __name__ == '__main__': nft = nftables.Nftables() with pyroute2.IPRoute() as netlink: maintain_sets(nft, netlink)