summaryrefslogtreecommitdiff
path: root/net/ipv4/af_inet.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/af_inet.c')
-rw-r--r--net/ipv4/af_inet.c41
1 files changed, 34 insertions, 7 deletions
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 33b7dffa7732..c5376c725503 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -86,6 +86,7 @@
#include <linux/poll.h>
#include <linux/netfilter_ipv4.h>
#include <linux/random.h>
+#include <linux/slab.h>
#include <asm/uaccess.h>
#include <asm/system.h>
@@ -153,7 +154,7 @@ void inet_sock_destruct(struct sock *sk)
WARN_ON(sk->sk_forward_alloc);
kfree(inet->opt);
- dst_release(sk->sk_dst_cache);
+ dst_release(rcu_dereference_check(sk->sk_dst_cache, 1));
sk_refcnt_debug_dec(sk);
}
EXPORT_SYMBOL(inet_sock_destruct);
@@ -418,6 +419,8 @@ int inet_release(struct socket *sock)
if (sk) {
long timeout;
+ inet_rps_reset_flow(sk);
+
/* Applications forget to leave groups before exiting */
ip_mc_drop_socket(sk);
@@ -530,6 +533,8 @@ int inet_dgram_connect(struct socket *sock, struct sockaddr * uaddr,
{
struct sock *sk = sock->sk;
+ if (addr_len < sizeof(uaddr->sa_family))
+ return -EINVAL;
if (uaddr->sa_family == AF_UNSPEC)
return sk->sk_prot->disconnect(sk, flags);
@@ -573,6 +578,9 @@ int inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
int err;
long timeo;
+ if (addr_len < sizeof(uaddr->sa_family))
+ return -EINVAL;
+
lock_sock(sk);
if (uaddr->sa_family == AF_UNSPEC) {
@@ -714,6 +722,8 @@ int inet_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg,
{
struct sock *sk = sock->sk;
+ inet_rps_record_flow(sk);
+
/* We may need to bind the socket. */
if (!inet_sk(sk)->inet_num && inet_autobind(sk))
return -EAGAIN;
@@ -722,12 +732,13 @@ int inet_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg,
}
EXPORT_SYMBOL(inet_sendmsg);
-
static ssize_t inet_sendpage(struct socket *sock, struct page *page, int offset,
size_t size, int flags)
{
struct sock *sk = sock->sk;
+ inet_rps_record_flow(sk);
+
/* We may need to bind the socket. */
if (!inet_sk(sk)->inet_num && inet_autobind(sk))
return -EAGAIN;
@@ -737,6 +748,22 @@ static ssize_t inet_sendpage(struct socket *sock, struct page *page, int offset,
return sock_no_sendpage(sock, page, offset, size, flags);
}
+int inet_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg,
+ size_t size, int flags)
+{
+ struct sock *sk = sock->sk;
+ int addr_len = 0;
+ int err;
+
+ inet_rps_record_flow(sk);
+
+ err = sk->sk_prot->recvmsg(iocb, sk, msg, size, flags & MSG_DONTWAIT,
+ flags & ~MSG_DONTWAIT, &addr_len);
+ if (err >= 0)
+ msg->msg_namelen = addr_len;
+ return err;
+}
+EXPORT_SYMBOL(inet_recvmsg);
int inet_shutdown(struct socket *sock, int how)
{
@@ -866,7 +893,7 @@ const struct proto_ops inet_stream_ops = {
.setsockopt = sock_common_setsockopt,
.getsockopt = sock_common_getsockopt,
.sendmsg = tcp_sendmsg,
- .recvmsg = sock_common_recvmsg,
+ .recvmsg = inet_recvmsg,
.mmap = sock_no_mmap,
.sendpage = tcp_sendpage,
.splice_read = tcp_splice_read,
@@ -893,7 +920,7 @@ const struct proto_ops inet_dgram_ops = {
.setsockopt = sock_common_setsockopt,
.getsockopt = sock_common_getsockopt,
.sendmsg = inet_sendmsg,
- .recvmsg = sock_common_recvmsg,
+ .recvmsg = inet_recvmsg,
.mmap = sock_no_mmap,
.sendpage = inet_sendpage,
#ifdef CONFIG_COMPAT
@@ -923,7 +950,7 @@ static const struct proto_ops inet_sockraw_ops = {
.setsockopt = sock_common_setsockopt,
.getsockopt = sock_common_getsockopt,
.sendmsg = inet_sendmsg,
- .recvmsg = sock_common_recvmsg,
+ .recvmsg = inet_recvmsg,
.mmap = sock_no_mmap,
.sendpage = inet_sendpage,
#ifdef CONFIG_COMPAT
@@ -1401,10 +1428,10 @@ EXPORT_SYMBOL_GPL(snmp_fold_field);
int snmp_mib_init(void __percpu *ptr[2], size_t mibsize)
{
BUG_ON(ptr == NULL);
- ptr[0] = __alloc_percpu(mibsize, __alignof__(unsigned long long));
+ ptr[0] = __alloc_percpu(mibsize, __alignof__(unsigned long));
if (!ptr[0])
goto err0;
- ptr[1] = __alloc_percpu(mibsize, __alignof__(unsigned long long));
+ ptr[1] = __alloc_percpu(mibsize, __alignof__(unsigned long));
if (!ptr[1])
goto err1;
return 0;