Commit c130ece5 authored by Alberts S's avatar Alberts S
Browse files

Use Route class in Router.py for easier management

parent 501371ce
import asyncio
import ipaddress
import json
from typing import List, Optional
import asyncssh
......@@ -7,6 +9,32 @@ from CapybaraNetty import CapybaraNetty
from Vty import Vty
class Route:
def __init__(self, prefix: ipaddress.ip_network, gateway: ipaddress.ip_address, interface: str):
self.prefix = ipaddress.ip_network(prefix)
self.gateway = ipaddress.ip_address(gateway)
self.interface = interface
def as_dict(self) -> dict:
return {
"prefix": str(self.prefix),
"gateway": str(self.gateway),
"interface": self.interface,
}
def __str__(self):
return f"{str(self.prefix):<16s} via {str(self.gateway):<16s} dev {self.interface:<20s}"
def __hash__(self):
return hash(json.dumps(self.as_dict(), sort_keys=True, default=str))
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
else:
return self.as_dict() == other.as_dict()
class Router(CapybaraNetty):
managed_routes = None
......@@ -60,7 +88,7 @@ class Router(CapybaraNetty):
return self
async def get_routes(self):
async def get_routes(self) -> list:
self.routes = await self.vty.get_routes()
return self.routes
......@@ -93,17 +121,18 @@ class Router(CapybaraNetty):
self.default_gateway = default_route["nexthops"][0]["ip"]
return self.default_gateway
async def add_routes(self, routes):
async def add_routes(self, routes: List[Route]):
route: Route
for route in routes:
self.add_managed_route_info(route)
await self.vty.add_routes(routes)
await self.vty.add_routes([route.as_dict() for route in routes])
if self.debug_assert_routes:
await self.assert_route_state()
async def del_routes(self, routes):
async def del_routes(self, routes: List[Route]):
for route in routes:
self.del_managed_route_info(route)
stdout, stderr = await self.vty.del_routes(routes)
stdout, stderr = await self.vty.del_routes([route.as_dict() for route in routes])
if self.debug_assert_routes:
await self.assert_route_state(must_not_exist=routes)
......@@ -125,16 +154,11 @@ class Router(CapybaraNetty):
def del_managed_interfaces_info(self):
self.managed_interfaces = {}
def add_managed_route_info(self, route):
self.managed_routes.update({route["prefix"]: route})
def add_managed_route_info(self, route: Route):
self.managed_routes.update({route.prefix: route})
# Adds additional attribute ext_info to managed_route
# Useful to store route specific information
def add_managed_route_info_supplement(self, route_prefix, supplemented_data):
self.managed_routes |= {route_prefix: {"ext_info": supplemented_data}}
def del_managed_route_info(self, route):
del self.managed_routes[route["prefix"]]
def del_managed_route_info(self, route: Route):
del self.managed_routes[route.prefix]
def del_managed_routes_info(self):
self.managed_routes = {}
......@@ -175,24 +199,25 @@ class Router(CapybaraNetty):
# By default checks if our managed routes - exist on the router
# if must_not_exist is passed - will check for the on-existence of passed routes
async def assert_route_state(self, must_not_exist=None):
async def assert_route_state(self, must_not_exist: Optional[List[Route]] = None):
if must_not_exist is None:
must_not_exist = []
await self.get_routes()
for k, v in self.managed_routes.items():
for prefix in self.managed_routes.keys():
try:
self.__logger.debug(f"{self.log_name()} Asserting existence of {v['prefix']}")
assert any(route["prefix"] == v["prefix"] for route in self.routes)
self.__logger.debug(f"{self.log_name()} Asserting existence of {prefix}")
assert any(route["prefix"] == str(prefix) for route in self.routes)
except AssertionError:
self.__logger.error(f"{self.log_name()} Route {v['prefix']} is not setup")
self.__logger.error(f"{self.log_name()} Route {prefix} is not setup")
raise
for v in must_not_exist:
for managed_route in must_not_exist:
try:
self.__logger.debug(f"{self.log_name()} Asserting nonexistence of {v['prefix']}")
assert not any(route["prefix"] == v["prefix"] for route in self.routes)
self.__logger.debug(f"{self.log_name()} Asserting nonexistence of {managed_route.prefix}")
assert not any(route["prefix"] == str(managed_route.prefix) for route in self.routes)
except AssertionError:
self.__logger.error(f"{self.log_name()} Route {v['prefix']} is setup, but shouldn't be")
self.__logger.error(f"{self.log_name()} Route {managed_route.prefix} is setup, but shouldn't be")
raise
async def ssh_exec(self, command, **kwargs):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment