diff options
Diffstat (limited to 'tools/net/ynl/lib/ynl.py')
| -rw-r--r-- | tools/net/ynl/lib/ynl.py | 311 | 
1 files changed, 216 insertions, 95 deletions
diff --git a/tools/net/ynl/lib/ynl.py b/tools/net/ynl/lib/ynl.py index 1e10512b2117..5fa7957f6e0f 100644 --- a/tools/net/ynl/lib/ynl.py +++ b/tools/net/ynl/lib/ynl.py @@ -7,6 +7,7 @@ import random  import socket  import struct  from struct import Struct +import sys  import yaml  import ipaddress  import uuid @@ -84,6 +85,10 @@ class NlError(Exception):      return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}" +class ConfigError(Exception): +    pass + +  class NlAttr:      ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])      type_formats = { @@ -113,20 +118,6 @@ class NlAttr:                  else format.little          return format.native -    @classmethod -    def formatted_string(cls, raw, display_hint): -        if display_hint == 'mac': -            formatted = ':'.join('%02x' % b for b in raw) -        elif display_hint == 'hex': -            formatted = bytes.hex(raw, ' ') -        elif display_hint in [ 'ipv4', 'ipv6' ]: -            formatted = format(ipaddress.ip_address(raw)) -        elif display_hint == 'uuid': -            formatted = str(uuid.UUID(bytes=raw)) -        else: -            formatted = raw -        return formatted -      def as_scalar(self, attr_type, byte_order=None):          format = self.get_format(attr_type, byte_order)          return format.unpack(self.raw)[0] @@ -148,23 +139,6 @@ class NlAttr:          format = self.get_format(type)          return [ x[0] for x in format.iter_unpack(self.raw) ] -    def as_struct(self, members): -        value = dict() -        offset = 0 -        for m in members: -            # TODO: handle non-scalar members -            if m.type == 'binary': -                decoded = self.raw[offset : offset + m['len']] -                offset += m['len'] -            elif m.type in NlAttr.type_formats: -                format = self.get_format(m.type, m.byte_order) -                [ decoded ] = format.unpack_from(self.raw, offset) -                offset += format.size -            if m.display_hint: -                decoded = self.formatted_string(decoded, m.display_hint) -            value[m.name] = decoded -        return value -      def __repr__(self):          return f"[type:{self.type} len:{self._len}] {self.raw}" @@ -244,11 +218,11 @@ class NlMsg:          return self.nl_type      def __repr__(self): -        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n" +        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"          if self.error: -            msg += '\terror: ' + str(self.error) +            msg += '\n\terror: ' + str(self.error)          if self.extack: -            msg += '\textack: ' + repr(self.extack) +            msg += '\n\textack: ' + repr(self.extack)          return msg @@ -370,7 +344,7 @@ class NetlinkProtocol:          fixed_header_size = 0          if ynl:              op = ynl.rsp_by_value[msg.cmd()] -            fixed_header_size = ynl._fixed_header_size(op.fixed_header) +            fixed_header_size = ynl._struct_size(op.fixed_header)          msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)          return msg @@ -379,6 +353,9 @@ class NetlinkProtocol:              raise Exception(f'Multicast group "{mcast_name}" not present in the spec')          return mcast_groups[mcast_name].value +    def msghdr_size(self): +        return 16 +  class GenlProtocol(NetlinkProtocol):      def __init__(self, family_name): @@ -404,6 +381,28 @@ class GenlProtocol(NetlinkProtocol):              raise Exception(f'Multicast group "{mcast_name}" not present in the family')          return self.genl_family['mcast'][mcast_name] +    def msghdr_size(self): +        return super().msghdr_size() + 4 + + +class SpaceAttrs: +    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) + +    def __init__(self, attr_space, attrs, outer = None): +        outer_scopes = outer.scopes if outer else [] +        inner_scope = self.SpecValuesPair(attr_space, attrs) +        self.scopes = [inner_scope] + outer_scopes + +    def lookup(self, name): +        for scope in self.scopes: +            if name in scope.spec: +                if name in scope.values: +                    return scope.values[name] +                spec_name = scope.spec.yaml['name'] +                raise Exception( +                    f"No value for '{name}' in attribute space '{spec_name}'") +        raise Exception(f"Attribute '{name}' not defined in any attribute-set") +  #  # YNL implementation details. @@ -411,7 +410,8 @@ class GenlProtocol(NetlinkProtocol):  class YnlFamily(SpecFamily): -    def __init__(self, def_path, schema=None, process_unknown=False): +    def __init__(self, def_path, schema=None, process_unknown=False, +                 recv_size=0):          super().__init__(def_path, schema)          self.include_raw = False @@ -426,6 +426,17 @@ class YnlFamily(SpecFamily):          except KeyError:              raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") +        self._recv_dbg = False +        # Note that netlink will use conservative (min) message size for +        # the first dump recv() on the socket, our setting will only matter +        # from the second recv() on. +        self._recv_size = recv_size if recv_size else 131072 +        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) +        # for a message, so smaller receive sizes will lead to truncation. +        # Note that the min size for other families may be larger than 4k! +        if self._recv_size < 4000: +            raise ConfigError() +          self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)          self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)          self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) @@ -449,18 +460,61 @@ class YnlFamily(SpecFamily):          self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,                               mcast_id) -    def _add_attr(self, space, name, value): +    def set_recv_dbg(self, enabled): +        self._recv_dbg = enabled + +    def _recv_dbg_print(self, reply, nl_msgs): +        if not self._recv_dbg: +            return +        print("Recv: read", len(reply), "bytes,", +              len(nl_msgs.msgs), "messages", file=sys.stderr) +        for nl_msg in nl_msgs: +            print("  ", nl_msg, file=sys.stderr) + +    def _encode_enum(self, attr_spec, value): +        enum = self.consts[attr_spec['enum']] +        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): +            scalar = 0 +            if isinstance(value, str): +                value = [value] +            for single_value in value: +                scalar += enum.entries[single_value].user_value(as_flags = True) +            return scalar +        else: +            return enum.entries[value].user_value() + +    def _get_scalar(self, attr_spec, value): +        try: +            return int(value) +        except (ValueError, TypeError) as e: +            if 'enum' not in attr_spec: +                raise e +        return self._encode_enum(attr_spec, value) + +    def _add_attr(self, space, name, value, search_attrs):          try:              attr = self.attr_sets[space][name]          except KeyError:              raise Exception(f"Space '{space}' has no attribute '{name}'")          nl_type = attr.value + +        if attr.is_multi and isinstance(value, list): +            attr_payload = b'' +            for subvalue in value: +                attr_payload += self._add_attr(space, name, subvalue, search_attrs) +            return attr_payload +          if attr["type"] == 'nest':              nl_type |= Netlink.NLA_F_NESTED              attr_payload = b'' +            sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)              for subname, subvalue in value.items(): -                attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) +                attr_payload += self._add_attr(attr['nested-attributes'], +                                               subname, subvalue, sub_attrs)          elif attr["type"] == 'flag': +            if not value: +                # If value is absent or false then skip attribute creation. +                return b''              attr_payload = b''          elif attr["type"] == 'string':              attr_payload = str(value).encode('ascii') + b'\x00' @@ -469,18 +523,36 @@ class YnlFamily(SpecFamily):                  attr_payload = value              elif isinstance(value, str):                  attr_payload = bytes.fromhex(value) +            elif isinstance(value, dict) and attr.struct_name: +                attr_payload = self._encode_struct(attr.struct_name, value)              else:                  raise Exception(f'Unknown type for binary attribute, value: {value}') -        elif attr.is_auto_scalar: -            scalar = int(value) -            real_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') -            format = NlAttr.get_format(real_type, attr.byte_order) -            attr_payload = format.pack(int(value)) -        elif attr['type'] in NlAttr.type_formats: -            format = NlAttr.get_format(attr['type'], attr.byte_order) -            attr_payload = format.pack(int(value)) +        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: +            scalar = self._get_scalar(attr, value) +            if attr.is_auto_scalar: +                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') +            else: +                attr_type = attr["type"] +            format = NlAttr.get_format(attr_type, attr.byte_order) +            attr_payload = format.pack(scalar)          elif attr['type'] in "bitfield32": -            attr_payload = struct.pack("II", int(value["value"]), int(value["selector"])) +            scalar_value = self._get_scalar(attr, value["value"]) +            scalar_selector = self._get_scalar(attr, value["selector"]) +            attr_payload = struct.pack("II", scalar_value, scalar_selector) +        elif attr['type'] == 'sub-message': +            msg_format = self._resolve_selector(attr, search_attrs) +            attr_payload = b'' +            if msg_format.fixed_header: +                attr_payload += self._encode_struct(msg_format.fixed_header, value) +            if msg_format.attr_set: +                if msg_format.attr_set in self.attr_sets: +                    nl_type |= Netlink.NLA_F_NESTED +                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) +                    for subname, subvalue in value.items(): +                        attr_payload += self._add_attr(msg_format.attr_set, +                                                       subname, subvalue, sub_attrs) +                else: +                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")          else:              raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') @@ -503,17 +575,13 @@ class YnlFamily(SpecFamily):      def _decode_binary(self, attr, attr_spec):          if attr_spec.struct_name: -            members = self.consts[attr_spec.struct_name] -            decoded = attr.as_struct(members) -            for m in members: -                if m.enum: -                    decoded[m.name] = self._decode_enum(decoded[m.name], m) +            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)          elif attr_spec.sub_type:              decoded = attr.as_c_array(attr_spec.sub_type)          else:              decoded = attr.as_bin()              if attr_spec.display_hint: -                decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint) +                decoded = self._formatted_string(decoded, attr_spec.display_hint)          return decoded      def _decode_array_nest(self, attr, attr_spec): @@ -527,6 +595,16 @@ class YnlFamily(SpecFamily):              decoded.append({ item.type: subattrs })          return decoded +    def _decode_nest_type_value(self, attr, attr_spec): +        decoded = {} +        value = attr +        for name in attr_spec['type-value']: +            value = NlAttr(value.raw, 0) +            decoded[name] = value.type +        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) +        decoded.update(subattrs) +        return decoded +      def _decode_unknown(self, attr):          if attr.is_nest:              return self._decode(NlAttrs(attr.raw), None) @@ -548,29 +626,27 @@ class YnlFamily(SpecFamily):          else:              rsp[name] = [decoded] -    def _resolve_selector(self, attr_spec, vals): +    def _resolve_selector(self, attr_spec, search_attrs):          sub_msg = attr_spec.sub_message          if sub_msg not in self.sub_msgs:              raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")          sub_msg_spec = self.sub_msgs[sub_msg]          selector = attr_spec.selector -        if selector not in vals: -            raise Exception(f"There is no value for {selector} to resolve '{attr_spec.name}'") -        value = vals[selector] +        value = search_attrs.lookup(selector)          if value not in sub_msg_spec.formats:              raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")          spec = sub_msg_spec.formats[value]          return spec -    def _decode_sub_msg(self, attr, attr_spec, rsp): -        msg_format = self._resolve_selector(attr_spec, rsp) +    def _decode_sub_msg(self, attr, attr_spec, search_attrs): +        msg_format = self._resolve_selector(attr_spec, search_attrs)          decoded = {}          offset = 0          if msg_format.fixed_header: -            decoded.update(self._decode_fixed_header(attr, msg_format.fixed_header)); -            offset = self._fixed_header_size(msg_format.fixed_header) +            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); +            offset = self._struct_size(msg_format.fixed_header)          if msg_format.attr_set:              if msg_format.attr_set in self.attr_sets:                  subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) @@ -579,10 +655,12 @@ class YnlFamily(SpecFamily):                  raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")          return decoded -    def _decode(self, attrs, space): +    def _decode(self, attrs, space, outer_attrs = None): +        rsp = dict()          if space:              attr_space = self.attr_sets[space] -        rsp = dict() +            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) +          for attr in attrs:              try:                  attr_spec = attr_space.attrs_by_val[attr.type] @@ -594,7 +672,7 @@ class YnlFamily(SpecFamily):                  continue              if attr_spec["type"] == 'nest': -                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) +                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)                  decoded = subdict              elif attr_spec["type"] == 'string':                  decoded = attr.as_strz() @@ -617,7 +695,9 @@ class YnlFamily(SpecFamily):                      selector = self._decode_enum(selector, attr_spec)                  decoded = {"value": value, "selector": selector}              elif attr_spec["type"] == 'sub-message': -                decoded = self._decode_sub_msg(attr, attr_spec, rsp) +                decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) +            elif attr_spec["type"] == 'nest-type-value': +                decoded = self._decode_nest_type_value(attr, attr_spec)              else:                  if not self.process_unknown:                      raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') @@ -658,20 +738,23 @@ class YnlFamily(SpecFamily):              return          msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set)) -        offset = 20 + self._fixed_header_size(op.fixed_header) +        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)          path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,                                          extack['bad-attr-offs'])          if path:              del extack['bad-attr-offs']              extack['bad-attr'] = path -    def _fixed_header_size(self, name): +    def _struct_size(self, name):          if name: -            fixed_header_members = self.consts[name].members +            members = self.consts[name].members              size = 0 -            for m in fixed_header_members: +            for m in members:                  if m.type in ['pad', 'binary']: -                    size += m.len +                    if m.struct: +                        size += self._struct_size(m.struct) +                    else: +                        size += m.len                  else:                      format = NlAttr.get_format(m.type, m.byte_order)                      size += format.size @@ -679,26 +762,71 @@ class YnlFamily(SpecFamily):          else:              return 0 -    def _decode_fixed_header(self, msg, name): -        fixed_header_members = self.consts[name].members -        fixed_header_attrs = dict() +    def _decode_struct(self, data, name): +        members = self.consts[name].members +        attrs = dict()          offset = 0 -        for m in fixed_header_members: +        for m in members:              value = None              if m.type == 'pad':                  offset += m.len              elif m.type == 'binary': -                value = msg.raw[offset : offset + m.len] -                offset += m.len +                if m.struct: +                    len = self._struct_size(m.struct) +                    value = self._decode_struct(data[offset : offset + len], +                                                m.struct) +                    offset += len +                else: +                    value = data[offset : offset + m.len] +                    offset += m.len              else:                  format = NlAttr.get_format(m.type, m.byte_order) -                [ value ] = format.unpack_from(msg.raw, offset) +                [ value ] = format.unpack_from(data, offset)                  offset += format.size              if value is not None:                  if m.enum:                      value = self._decode_enum(value, m) -                fixed_header_attrs[m.name] = value -        return fixed_header_attrs +                elif m.display_hint: +                    value = self._formatted_string(value, m.display_hint) +                attrs[m.name] = value +        return attrs + +    def _encode_struct(self, name, vals): +        members = self.consts[name].members +        attr_payload = b'' +        for m in members: +            value = vals.pop(m.name) if m.name in vals else None +            if m.type == 'pad': +                attr_payload += bytearray(m.len) +            elif m.type == 'binary': +                if m.struct: +                    if value is None: +                        value = dict() +                    attr_payload += self._encode_struct(m.struct, value) +                else: +                    if value is None: +                        attr_payload += bytearray(m.len) +                    else: +                        attr_payload += bytes.fromhex(value) +            else: +                if value is None: +                    value = 0 +                format = NlAttr.get_format(m.type, m.byte_order) +                attr_payload += format.pack(value) +        return attr_payload + +    def _formatted_string(self, raw, display_hint): +        if display_hint == 'mac': +            formatted = ':'.join('%02x' % b for b in raw) +        elif display_hint == 'hex': +            formatted = bytes.hex(raw, ' ') +        elif display_hint in [ 'ipv4', 'ipv6' ]: +            formatted = format(ipaddress.ip_address(raw)) +        elif display_hint == 'uuid': +            formatted = str(uuid.UUID(bytes=raw)) +        else: +            formatted = raw +        return formatted      def handle_ntf(self, decoded):          msg = dict() @@ -707,7 +835,7 @@ class YnlFamily(SpecFamily):          op = self.rsp_by_value[decoded.cmd()]          attrs = self._decode(decoded.raw_attrs, op.attr_set.name)          if op.fixed_header: -            attrs.update(self._decode_fixed_header(decoded, op.fixed_header)) +            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))          msg['name'] = op['name']          msg['msg'] = attrs @@ -716,11 +844,12 @@ class YnlFamily(SpecFamily):      def check_ntf(self):          while True:              try: -                reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT) +                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)              except BlockingIOError:                  return              nms = NlMsgs(reply) +            self._recv_dbg_print(reply, nms)              for nl_msg in nms:                  if nl_msg.error:                      print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) @@ -759,20 +888,11 @@ class YnlFamily(SpecFamily):          req_seq = random.randint(1024, 65535)          msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) -        fixed_header_members = []          if op.fixed_header: -            fixed_header_members = self.consts[op.fixed_header].members -            for m in fixed_header_members: -                value = vals.pop(m.name) if m.name in vals else 0 -                if m.type == 'pad': -                    msg += bytearray(m.len) -                elif m.type == 'binary': -                    msg += bytes.fromhex(value) -                else: -                    format = NlAttr.get_format(m.type, m.byte_order) -                    msg += format.pack(value) +            msg += self._encode_struct(op.fixed_header, vals) +        search_attrs = SpaceAttrs(op.attr_set, vals)          for name, value in vals.items(): -            msg += self._add_attr(op.attr_set.name, name, value) +            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)          msg = _genl_msg_finalize(msg)          self.sock.send(msg, 0) @@ -780,8 +900,9 @@ class YnlFamily(SpecFamily):          done = False          rsp = []          while not done: -            reply = self.sock.recv(128 * 1024) +            reply = self.sock.recv(self._recv_size)              nms = NlMsgs(reply, attr_space=op.attr_set) +            self._recv_dbg_print(reply, nms)              for nl_msg in nms:                  if nl_msg.extack:                      self._decode_extack(msg, op, nl_msg.extack) @@ -808,7 +929,7 @@ class YnlFamily(SpecFamily):                  rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)                  if op.fixed_header: -                    rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header)) +                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))                  rsp.append(rsp_msg)          if not rsp:  | 
