summaryrefslogtreecommitdiff
path: root/net/ipv4/igmp.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/igmp.c')
-rw-r--r--net/ipv4/igmp.c96
1 files changed, 55 insertions, 41 deletions
diff --git a/net/ipv4/igmp.c b/net/ipv4/igmp.c
index 1f3183168a90..5088f90835ae 100644
--- a/net/ipv4/igmp.c
+++ b/net/ipv4/igmp.c
@@ -1615,9 +1615,10 @@ int ip_mc_join_group(struct sock *sk , struct ip_mreqn *imr)
{
int err;
u32 addr = imr->imr_multiaddr.s_addr;
- struct ip_mc_socklist *iml, *i;
+ struct ip_mc_socklist *iml=NULL, *i;
struct in_device *in_dev;
struct inet_sock *inet = inet_sk(sk);
+ int ifindex;
int count = 0;
if (!MULTICAST(addr))
@@ -1633,37 +1634,30 @@ int ip_mc_join_group(struct sock *sk , struct ip_mreqn *imr)
goto done;
}
- iml = (struct ip_mc_socklist *)sock_kmalloc(sk, sizeof(*iml), GFP_KERNEL);
-
err = -EADDRINUSE;
+ ifindex = imr->imr_ifindex;
for (i = inet->mc_list; i; i = i->next) {
- if (memcmp(&i->multi, imr, sizeof(*imr)) == 0) {
- /* New style additions are reference counted */
- if (imr->imr_address.s_addr == 0) {
- i->count++;
- err = 0;
- }
+ if (i->multi.imr_multiaddr.s_addr == addr &&
+ i->multi.imr_ifindex == ifindex)
goto done;
- }
count++;
}
err = -ENOBUFS;
- if (iml == NULL || count >= sysctl_igmp_max_memberships)
+ if (count >= sysctl_igmp_max_memberships)
+ goto done;
+ iml = (struct ip_mc_socklist *)sock_kmalloc(sk,sizeof(*iml),GFP_KERNEL);
+ if (iml == NULL)
goto done;
+
memcpy(&iml->multi, imr, sizeof(*imr));
iml->next = inet->mc_list;
- iml->count = 1;
iml->sflist = NULL;
iml->sfmode = MCAST_EXCLUDE;
inet->mc_list = iml;
ip_mc_inc_group(in_dev, addr);
- iml = NULL;
err = 0;
-
done:
rtnl_shunlock();
- if (iml)
- sock_kfree_s(sk, iml, sizeof(*iml));
return err;
}
@@ -1693,30 +1687,25 @@ int ip_mc_leave_group(struct sock *sk, struct ip_mreqn *imr)
{
struct inet_sock *inet = inet_sk(sk);
struct ip_mc_socklist *iml, **imlp;
+ struct in_device *in_dev;
+ u32 group = imr->imr_multiaddr.s_addr;
+ u32 ifindex;
rtnl_lock();
+ in_dev = ip_mc_find_dev(imr);
+ if (!in_dev) {
+ rtnl_unlock();
+ return -ENODEV;
+ }
+ ifindex = imr->imr_ifindex;
for (imlp = &inet->mc_list; (iml = *imlp) != NULL; imlp = &iml->next) {
- if (iml->multi.imr_multiaddr.s_addr==imr->imr_multiaddr.s_addr &&
- iml->multi.imr_address.s_addr==imr->imr_address.s_addr &&
- (!imr->imr_ifindex || iml->multi.imr_ifindex==imr->imr_ifindex)) {
- struct in_device *in_dev;
-
- in_dev = inetdev_by_index(iml->multi.imr_ifindex);
- if (in_dev)
- (void) ip_mc_leave_src(sk, iml, in_dev);
- if (--iml->count) {
- rtnl_unlock();
- if (in_dev)
- in_dev_put(in_dev);
- return 0;
- }
+ if (iml->multi.imr_multiaddr.s_addr == group &&
+ iml->multi.imr_ifindex == ifindex) {
+ (void) ip_mc_leave_src(sk, iml, in_dev);
*imlp = iml->next;
- if (in_dev) {
- ip_mc_dec_group(in_dev, imr->imr_multiaddr.s_addr);
- in_dev_put(in_dev);
- }
+ ip_mc_dec_group(in_dev, group);
rtnl_unlock();
sock_kfree_s(sk, iml, sizeof(*iml));
return 0;
@@ -1736,6 +1725,7 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
struct in_device *in_dev = NULL;
struct inet_sock *inet = inet_sk(sk);
struct ip_sf_socklist *psl;
+ int leavegroup = 0;
int i, j, rv;
if (!MULTICAST(addr))
@@ -1755,15 +1745,20 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
err = -EADDRNOTAVAIL;
for (pmc=inet->mc_list; pmc; pmc=pmc->next) {
- if (memcmp(&pmc->multi, mreqs, 2*sizeof(__u32)) == 0)
+ if (pmc->multi.imr_multiaddr.s_addr == imr.imr_multiaddr.s_addr
+ && pmc->multi.imr_ifindex == imr.imr_ifindex)
break;
}
- if (!pmc) /* must have a prior join */
+ if (!pmc) { /* must have a prior join */
+ err = -EINVAL;
goto done;
+ }
/* if a source filter was set, must be the same mode as before */
if (pmc->sflist) {
- if (pmc->sfmode != omode)
+ if (pmc->sfmode != omode) {
+ err = -EINVAL;
goto done;
+ }
} else if (pmc->sfmode != omode) {
/* allow mode switches for empty-set filters */
ip_mc_add_src(in_dev, &mreqs->imr_multiaddr, omode, 0, NULL, 0);
@@ -1775,7 +1770,7 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
psl = pmc->sflist;
if (!add) {
if (!psl)
- goto done;
+ goto done; /* err = -EADDRNOTAVAIL */
rv = !0;
for (i=0; i<psl->sl_count; i++) {
rv = memcmp(&psl->sl_addr[i], &mreqs->imr_sourceaddr,
@@ -1784,7 +1779,13 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
break;
}
if (rv) /* source not found */
+ goto done; /* err = -EADDRNOTAVAIL */
+
+ /* special case - (INCLUDE, empty) == LEAVE_GROUP */
+ if (psl->sl_count == 1 && omode == MCAST_INCLUDE) {
+ leavegroup = 1;
goto done;
+ }
/* update the interface filter */
ip_mc_del_src(in_dev, &mreqs->imr_multiaddr, omode, 1,
@@ -1842,18 +1843,21 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
&mreqs->imr_sourceaddr, 1);
done:
rtnl_shunlock();
+ if (leavegroup)
+ return ip_mc_leave_group(sk, &imr);
return err;
}
int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex)
{
- int err;
+ int err = 0;
struct ip_mreqn imr;
u32 addr = msf->imsf_multiaddr;
struct ip_mc_socklist *pmc;
struct in_device *in_dev;
struct inet_sock *inet = inet_sk(sk);
struct ip_sf_socklist *newpsl, *psl;
+ int leavegroup = 0;
if (!MULTICAST(addr))
return -EINVAL;
@@ -1872,15 +1876,22 @@ int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex)
err = -ENODEV;
goto done;
}
- err = -EADDRNOTAVAIL;
+
+ /* special case - (INCLUDE, empty) == LEAVE_GROUP */
+ if (msf->imsf_fmode == MCAST_INCLUDE && msf->imsf_numsrc == 0) {
+ leavegroup = 1;
+ goto done;
+ }
for (pmc=inet->mc_list; pmc; pmc=pmc->next) {
if (pmc->multi.imr_multiaddr.s_addr == msf->imsf_multiaddr &&
pmc->multi.imr_ifindex == imr.imr_ifindex)
break;
}
- if (!pmc) /* must have a prior join */
+ if (!pmc) { /* must have a prior join */
+ err = -EINVAL;
goto done;
+ }
if (msf->imsf_numsrc) {
newpsl = (struct ip_sf_socklist *)sock_kmalloc(sk,
IP_SFLSIZE(msf->imsf_numsrc), GFP_KERNEL);
@@ -1909,8 +1920,11 @@ int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex)
0, NULL, 0);
pmc->sflist = newpsl;
pmc->sfmode = msf->imsf_fmode;
+ err = 0;
done:
rtnl_shunlock();
+ if (leavegroup)
+ err = ip_mc_leave_group(sk, &imr);
return err;
}