# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause from collections import namedtuple import functools import os import random import socket import struct from struct import Struct import yaml import ipaddress import uuid from .nlspec import SpecFamily # # Generic Netlink code which should really be in some library, but I can't quickly find one. # class Netlink: # Netlink socket SOL_NETLINK = 270 NETLINK_ADD_MEMBERSHIP = 1 NETLINK_CAP_ACK = 10 NETLINK_EXT_ACK = 11 NETLINK_GET_STRICT_CHK = 12 # Netlink message NLMSG_ERROR = 2 NLMSG_DONE = 3 NLM_F_REQUEST = 1 NLM_F_ACK = 4 NLM_F_ROOT = 0x100 NLM_F_MATCH = 0x200 NLM_F_REPLACE = 0x100 NLM_F_EXCL = 0x200 NLM_F_CREATE = 0x400 NLM_F_APPEND = 0x800 NLM_F_CAPPED = 0x100 NLM_F_ACK_TLVS = 0x200 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH NLA_F_NESTED = 0x8000 NLA_F_NET_BYTEORDER = 0x4000 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER # Genetlink defines NETLINK_GENERIC = 16 GENL_ID_CTRL = 0x10 # nlctrl CTRL_CMD_GETFAMILY = 3 CTRL_ATTR_FAMILY_ID = 1 CTRL_ATTR_FAMILY_NAME = 2 CTRL_ATTR_MAXATTR = 5 CTRL_ATTR_MCAST_GROUPS = 7 CTRL_ATTR_MCAST_GRP_NAME = 1 CTRL_ATTR_MCAST_GRP_ID = 2 # Extack types NLMSGERR_ATTR_MSG = 1 NLMSGERR_ATTR_OFFS = 2 NLMSGERR_ATTR_COOKIE = 3 NLMSGERR_ATTR_POLICY = 4 NLMSGERR_ATTR_MISS_TYPE = 5 NLMSGERR_ATTR_MISS_NEST = 6 class NlError(Exception): def __init__(self, nl_msg): self.nl_msg = nl_msg def __str__(self): return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" class NlAttr: ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) type_formats = { 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("h"), Struct("I"), Struct("i"), Struct("Q"), Struct("q"), Struct(">= 1 i += 1 else: value = enum.entries_by_val[raw].name return value def _decode_binary(self, attr, attr_spec): if attr_spec.struct_name: members = self.consts[attr_spec.struct_name] decoded = attr.as_struct(members) for m in members: if m.enum: decoded[m.name] = self._decode_enum(decoded[m.name], m) elif attr_spec.sub_type: decoded = attr.as_c_array(attr_spec.sub_type) else: decoded = attr.as_bin() if attr_spec.display_hint: decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint) return decoded def _decode_array_nest(self, attr, attr_spec): decoded = [] offset = 0 while offset < len(attr.raw): item = NlAttr(attr.raw, offset) offset += item.full_len subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) decoded.append({ item.type: subattrs }) return decoded def _decode(self, attrs, space): attr_space = self.attr_sets[space] rsp = dict() for attr in attrs: try: attr_spec = attr_space.attrs_by_val[attr.type] except KeyError: raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") if attr_spec["type"] == 'nest': subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) decoded = subdict elif attr_spec["type"] == 'string': decoded = attr.as_strz() elif attr_spec["type"] == 'binary': decoded = self._decode_binary(attr, attr_spec) elif attr_spec["type"] == 'flag': decoded = True elif attr_spec["type"] in NlAttr.type_formats: decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) elif attr_spec["type"] == 'array-nest': decoded = self._decode_array_nest(attr, attr_spec) else: raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') if 'enum' in attr_spec: decoded = self._decode_enum(decoded, attr_spec) if not attr_spec.is_multi: rsp[attr_spec['name']] = decoded elif attr_spec.name in rsp: rsp[attr_spec.name].append(decoded) else: rsp[attr_spec.name] = [decoded] return rsp def _decode_extack_path(self, attrs, attr_set, offset, target): for attr in attrs: try: attr_spec = attr_set.attrs_by_val[attr.type] except KeyError: raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") if offset > target: break if offset == target: return '.' + attr_spec.name if offset + attr.full_len <= target: offset += attr.full_len continue if attr_spec['type'] != 'nest': raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") offset += 4 subpath = self._decode_extack_path(NlAttrs(attr.raw), self.attr_sets[attr_spec['nested-attributes']], offset, target) if subpath is None: return None return '.' + attr_spec.name + subpath return None def _decode_extack(self, request, op, extack): if 'bad-attr-offs' not in extack: return msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set)) offset = 20 + self._fixed_header_size(op) path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, extack['bad-attr-offs']) if path: del extack['bad-attr-offs'] extack['bad-attr'] = path def _fixed_header_size(self, op): if op.fixed_header: fixed_header_members = self.consts[op.fixed_header].members size = 0 for m in fixed_header_members: format = NlAttr.get_format(m.type, m.byte_order) size += format.size return size else: return 0 def _decode_fixed_header(self, msg, name): fixed_header_members = self.consts[name].members fixed_header_attrs = dict() offset = 0 for m in fixed_header_members: format = NlAttr.get_format(m.type, m.byte_order) [ value ] = format.unpack_from(msg.raw, offset) offset += format.size if m.enum: value = self._decode_enum(value, m) fixed_header_attrs[m.name] = value return fixed_header_attrs def handle_ntf(self, decoded): msg = dict() if self.include_raw: msg['raw'] = decoded op = self.rsp_by_value[decoded.cmd()] attrs = self._decode(decoded.raw_attrs, op.attr_set.name) if op.fixed_header: attrs.update(self._decode_fixed_header(decoded, op.fixed_header)) msg['name'] = op['name'] msg['msg'] = attrs self.async_msg_queue.append(msg) def check_ntf(self): while True: try: reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) except BlockingIOError: return nms = NlMsgs(reply) for nl_msg in nms: if nl_msg.error: print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) print(nl_msg) continue if nl_msg.done: print("Netlink done while checking for ntf!?") continue decoded = self.nlproto.decode(self, nl_msg) if decoded.cmd() not in self.async_msg_ids: print("Unexpected msg id done while checking for ntf", decoded) continue self.handle_ntf(decoded) def operation_do_attributes(self, name): """ For a given operation name, find and return a supported set of attributes (as a dict). """ op = self.find_operation(name) if not op: return None return op['do']['request']['attributes'].copy() def _op(self, method, vals, flags, dump=False): op = self.ops[method] nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK for flag in flags or []: nl_flags |= flag if dump: nl_flags |= Netlink.NLM_F_DUMP req_seq = random.randint(1024, 65535) msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) fixed_header_members = [] if op.fixed_header: fixed_header_members = self.consts[op.fixed_header].members for m in fixed_header_members: value = vals.pop(m.name) if m.name in vals else 0 format = NlAttr.get_format(m.type, m.byte_order) msg += format.pack(value) for name, value in vals.items(): msg += self._add_attr(op.attr_set.name, name, value) msg = _genl_msg_finalize(msg) self.sock.send(msg, 0) done = False rsp = [] while not done: reply = self.sock.recv(128 * 1024) nms = NlMsgs(reply, attr_space=op.attr_set) for nl_msg in nms: if nl_msg.extack: self._decode_extack(msg, op, nl_msg.extack) if nl_msg.error: raise NlError(nl_msg) if nl_msg.done: if nl_msg.extack: print("Netlink warning:") print(nl_msg) done = True break decoded = self.nlproto.decode(self, nl_msg) # Check if this is a reply to our request if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value: if decoded.cmd() in self.async_msg_ids: self.handle_ntf(decoded) continue else: print('Unexpected message: ' + repr(decoded)) continue rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) if op.fixed_header: rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header)) rsp.append(rsp_msg) if not rsp: return None if not dump and len(rsp) == 1: return rsp[0] return rsp def do(self, method, vals, flags): return self._op(method, vals, flags) def dump(self, method, vals): return self._op(method, vals, [], dump=True)