diff --git a/include/linux/mroute6.h b/include/linux/mroute6.h index e5e5b8282551..e1b9fb06e1ea 100644 --- a/include/linux/mroute6.h +++ b/include/linux/mroute6.h @@ -111,12 +111,12 @@ extern int ip6mr_get_route(struct net *net, struct sk_buff *skb, struct rtmsg *rtm, u32 portid); #ifdef CONFIG_IPV6_MROUTE -extern struct sock *mroute6_socket(struct net *net, struct sk_buff *skb); +bool mroute6_is_socket(struct net *net, struct sk_buff *skb); extern int ip6mr_sk_done(struct sock *sk); #else -static inline struct sock *mroute6_socket(struct net *net, struct sk_buff *skb) +static inline bool mroute6_is_socket(struct net *net, struct sk_buff *skb) { - return NULL; + return false; } static inline int ip6mr_sk_done(struct sock *sk) { diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c index 997c7f19ad62..a6eb0e699b15 100644 --- a/net/ipv6/ip6_output.c +++ b/net/ipv6/ip6_output.c @@ -71,7 +71,7 @@ static int ip6_finish_output2(struct net *net, struct sock *sk, struct sk_buff * struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb)); if (!(dev->flags & IFF_LOOPBACK) && sk_mc_loop(sk) && - ((mroute6_socket(net, skb) && + ((mroute6_is_socket(net, skb) && !(IP6CB(skb)->flags & IP6SKB_FORWARDED)) || ipv6_chk_mcast_addr(dev, &ipv6_hdr(skb)->daddr, &ipv6_hdr(skb)->saddr))) { diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c index e397990f6eb8..a0e297ddca6e 100644 --- a/net/ipv6/ip6mr.c +++ b/net/ipv6/ip6mr.c @@ -58,7 +58,7 @@ struct mr6_table { struct list_head list; possible_net_t net; u32 id; - struct sock *mroute6_sk; + struct sock __rcu *mroute6_sk; struct timer_list ipmr_expire_timer; struct list_head mfc6_unres_queue; struct list_head mfc6_cache_array[MFC6_LINES]; @@ -1121,6 +1121,7 @@ static void ip6mr_cache_resolve(struct net *net, struct mr6_table *mrt, static int ip6mr_cache_report(struct mr6_table *mrt, struct sk_buff *pkt, mifi_t mifi, int assert) { + struct sock *mroute6_sk; struct sk_buff *skb; struct mrt6msg *msg; int ret; @@ -1190,17 +1191,19 @@ static int ip6mr_cache_report(struct mr6_table *mrt, struct sk_buff *pkt, skb->ip_summed = CHECKSUM_UNNECESSARY; } - if (!mrt->mroute6_sk) { + rcu_read_lock(); + mroute6_sk = rcu_dereference(mrt->mroute6_sk); + if (!mroute6_sk) { + rcu_read_unlock(); kfree_skb(skb); return -EINVAL; } mrt6msg_netlink_event(mrt, skb); - /* - * Deliver to user space multicast routing algorithms - */ - ret = sock_queue_rcv_skb(mrt->mroute6_sk, skb); + /* Deliver to user space multicast routing algorithms */ + ret = sock_queue_rcv_skb(mroute6_sk, skb); + rcu_read_unlock(); if (ret < 0) { net_warn_ratelimited("mroute6: pending queue full, dropping entries\n"); kfree_skb(skb); @@ -1584,11 +1587,11 @@ static int ip6mr_sk_init(struct mr6_table *mrt, struct sock *sk) rtnl_lock(); write_lock_bh(&mrt_lock); - if (likely(mrt->mroute6_sk == NULL)) { - mrt->mroute6_sk = sk; - net->ipv6.devconf_all->mc_forwarding++; - } else { + if (rtnl_dereference(mrt->mroute6_sk)) { err = -EADDRINUSE; + } else { + rcu_assign_pointer(mrt->mroute6_sk, sk); + net->ipv6.devconf_all->mc_forwarding++; } write_unlock_bh(&mrt_lock); @@ -1614,9 +1617,9 @@ int ip6mr_sk_done(struct sock *sk) rtnl_lock(); ip6mr_for_each_table(mrt, net) { - if (sk == mrt->mroute6_sk) { + if (sk == rtnl_dereference(mrt->mroute6_sk)) { write_lock_bh(&mrt_lock); - mrt->mroute6_sk = NULL; + RCU_INIT_POINTER(mrt->mroute6_sk, NULL); net->ipv6.devconf_all->mc_forwarding--; write_unlock_bh(&mrt_lock); inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, @@ -1630,11 +1633,12 @@ int ip6mr_sk_done(struct sock *sk) } } rtnl_unlock(); + synchronize_rcu(); return err; } -struct sock *mroute6_socket(struct net *net, struct sk_buff *skb) +bool mroute6_is_socket(struct net *net, struct sk_buff *skb) { struct mr6_table *mrt; struct flowi6 fl6 = { @@ -1646,8 +1650,9 @@ struct sock *mroute6_socket(struct net *net, struct sk_buff *skb) if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0) return NULL; - return mrt->mroute6_sk; + return rcu_access_pointer(mrt->mroute6_sk); } +EXPORT_SYMBOL(mroute6_is_socket); /* * Socket options and virtual interface manipulation. The whole @@ -1674,7 +1679,8 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns return -ENOENT; if (optname != MRT6_INIT) { - if (sk != mrt->mroute6_sk && !ns_capable(net->user_ns, CAP_NET_ADMIN)) + if (sk != rcu_access_pointer(mrt->mroute6_sk) && + !ns_capable(net->user_ns, CAP_NET_ADMIN)) return -EACCES; } @@ -1696,7 +1702,8 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns if (vif.mif6c_mifi >= MAXMIFS) return -ENFILE; rtnl_lock(); - ret = mif6_add(net, mrt, &vif, sk == mrt->mroute6_sk); + ret = mif6_add(net, mrt, &vif, + sk == rtnl_dereference(mrt->mroute6_sk)); rtnl_unlock(); return ret; @@ -1731,7 +1738,9 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns ret = ip6mr_mfc_delete(mrt, &mfc, parent); else ret = ip6mr_mfc_add(net, mrt, &mfc, - sk == mrt->mroute6_sk, parent); + sk == + rtnl_dereference(mrt->mroute6_sk), + parent); rtnl_unlock(); return ret; @@ -1783,7 +1792,7 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns /* "pim6reg%u" should not exceed 16 bytes (IFNAMSIZ) */ if (v != RT_TABLE_DEFAULT && v >= 100000000) return -EINVAL; - if (sk == mrt->mroute6_sk) + if (sk == rcu_access_pointer(mrt->mroute6_sk)) return -EBUSY; rtnl_lock();