summaryrefslogtreecommitdiff
path: root/tools/testing/selftests/net/rds/test.py
blob: 4a7178d11193f1a51a9871e58df0ae1bd6db6e44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
#! /usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0

import argparse
import ctypes
import errno
import hashlib
import os
import select
import signal
import socket
import subprocess
import sys
import atexit
from pwd import getpwuid
from os import stat

# Allow utils module to be imported from different directory
this_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(this_dir, "../"))
from lib.py.utils import ip

libc = ctypes.cdll.LoadLibrary('libc.so.6')
setns = libc.setns

net0 = 'net0'
net1 = 'net1'

veth0 = 'veth0'
veth1 = 'veth1'

# Helper function for creating a socket inside a network namespace.
# We need this because otherwise RDS will detect that the two TCP
# sockets are on the same interface and use the loop transport instead
# of the TCP transport.
def netns_socket(netns, *args):
    u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)

    child = os.fork()
    if child == 0:
        # change network namespace
        with open(f'/var/run/netns/{netns}') as f:
            try:
                ret = setns(f.fileno(), 0)
            except IOError as e:
                print(e.errno)
                print(e)

        # create socket in target namespace
        s = socket.socket(*args)

        # send resulting socket to parent
        socket.send_fds(u0, [], [s.fileno()])

        sys.exit(0)

    # receive socket from child
    _, s, _, _ = socket.recv_fds(u1, 0, 1)
    os.waitpid(child, 0)
    u0.close()
    u1.close()
    return socket.fromfd(s[0], *args)

def signal_handler(sig, frame):
    print('Test timed out')
    sys.exit(1)

#Parse out command line arguments.  We take an optional
# timeout parameter and an optional log output folder
parser = argparse.ArgumentParser(description="init script args",
                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-d", "--logdir", action="store",
                    help="directory to store logs", default="/tmp")
parser.add_argument('--timeout', help="timeout to terminate hung test",
                    type=int, default=0)
parser.add_argument('-l', '--loss', help="Simulate tcp packet loss",
                    type=int, default=0)
parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption",
                    type=int, default=0)
parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication",
                    type=int, default=0)
args = parser.parse_args()
logdir=args.logdir
packet_loss=str(args.loss)+'%'
packet_corruption=str(args.corruption)+'%'
packet_duplicate=str(args.duplicate)+'%'

ip(f"netns add {net0}")
ip(f"netns add {net1}")
ip(f"link add type veth")

addrs = [
    # we technically don't need different port numbers, but this will
    # help identify traffic in the network analyzer
    ('10.0.0.1', 10000),
    ('10.0.0.2', 20000),
]

# move interfaces to separate namespaces so they can no longer be
# bound directly; this prevents rds from switching over from the tcp
# transport to the loop transport.
ip(f"link set {veth0} netns {net0} up")
ip(f"link set {veth1} netns {net1} up")



# add addresses
ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}")
ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}")

# add routes
ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}")
ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}")

# sanity check that our two interfaces/addresses are correctly set up
# and communicating by doing a single ping
ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}")

# Start a packet capture on each network
for net in [net0, net1]:
    tcpdump_pid = os.fork()
    if tcpdump_pid == 0:
        pcap = logdir+'/'+net+'.pcap'
        subprocess.check_call(['touch', pcap])
        user = getpwuid(stat(pcap).st_uid).pw_name
        ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}")
        sys.exit(0)

# simulate packet loss, duplication and corruption
for net, iface in [(net0, veth0), (net1, veth1)]:
    ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem  \
         corrupt {packet_corruption} loss {packet_loss} duplicate  \
         {packet_duplicate}")

# add a timeout
if args.timeout > 0:
    signal.alarm(args.timeout)
    signal.signal(signal.SIGALRM, signal_handler)

sockets = [
    netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET),
    netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET),
]

for s, addr in zip(sockets, addrs):
    s.bind(addr)
    s.setblocking(0)

fileno_to_socket = {
    s.fileno(): s for s in sockets
}

addr_to_socket = {
    addr: s for addr, s in zip(addrs, sockets)
}

socket_to_addr = {
    s: addr for addr, s in zip(addrs, sockets)
}

send_hashes = {}
recv_hashes = {}

ep = select.epoll()

for s in sockets:
    ep.register(s, select.EPOLLRDNORM)

n = 50000
nr_send = 0
nr_recv = 0

while nr_send < n:
    # Send as much as we can without blocking
    print("sending...", nr_send, nr_recv)
    while nr_send < n:
        send_data = hashlib.sha256(
            f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8')

        # pseudo-random send/receive pattern
        sender = sockets[nr_send % 2]
        receiver = sockets[1 - (nr_send % 3) % 2]

        try:
            sender.sendto(send_data, socket_to_addr[receiver])
            send_hashes.setdefault((sender.fileno(), receiver.fileno()),
                    hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8'))
            nr_send = nr_send + 1
        except BlockingIOError as e:
            break
        except OSError as e:
            if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]:
                break
            raise

    # Receive as much as we can without blocking
    print("receiving...", nr_send, nr_recv)
    while nr_recv < nr_send:
        for fileno, eventmask in ep.poll():
            receiver = fileno_to_socket[fileno]

            if eventmask & select.EPOLLRDNORM:
                while True:
                    try:
                        recv_data, address = receiver.recvfrom(1024)
                        sender = addr_to_socket[address]
                        recv_hashes.setdefault((sender.fileno(),
                            receiver.fileno()), hashlib.sha256()).update(
                                    f'<{recv_data}>'.encode('utf-8'))
                        nr_recv = nr_recv + 1
                    except BlockingIOError as e:
                        break

    # exercise net/rds/tcp.c:rds_tcp_sysctl_reset()
    for net in [net0, net1]:
        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000")
        ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000")

print("done", nr_send, nr_recv)

# the Python socket module doesn't know these
RDS_INFO_FIRST = 10000
RDS_INFO_LAST = 10017

nr_success = 0
nr_error = 0

for s in sockets:
    for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1):
        # Sigh, the Python socket module doesn't allow us to pass
        # buffer lengths greater than 1024 for some reason. RDS
        # wants multiple pages.
        try:
            s.getsockopt(socket.SOL_RDS, optname, 1024)
            nr_success = nr_success + 1
        except OSError as e:
            nr_error = nr_error + 1
            if e.errno == errno.ENOSPC:
                # ignore
                pass

print(f"getsockopt(): {nr_success}/{nr_error}")

print("Stopping network packet captures")
subprocess.check_call(['killall', '-q', 'tcpdump'])

# We're done sending and receiving stuff, now let's check if what
# we received is what we sent.
for (sender, receiver), send_hash in send_hashes.items():
    recv_hash = recv_hashes.get((sender, receiver))

    if recv_hash is None:
        print("FAIL: No data received")
        sys.exit(1)

    if send_hash.hexdigest() != recv_hash.hexdigest():
        print("FAIL: Send/recv mismatch")
        print("hash expected:", send_hash.hexdigest())
        print("hash received:", recv_hash.hexdigest())
        sys.exit(1)

    print(f"{sender}/{receiver}: ok")

print("Success")
sys.exit(0)