libtailscale: make tailscale_listener pollable

Use a socketpair(2) and sendmsg/recvmsg to pass a connection fd
from Go to C. This lets people write non-blocking C by polling on a
tailscale_listener for when they should tailscale_accept.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
David Crawshaw
2023-03-12 14:09:31 -07:00
committed by David Crawshaw
parent 42597d5fb7
commit b0e2f4a4e4
8 changed files with 170 additions and 110 deletions
+1 -1
View File
@@ -46,7 +46,7 @@ int main(void) {
}
close(conn);
}
tailscale_listener_close(ln);
close(ln);
tailscale_close(ts);
return 0;
-4
View File
@@ -44,10 +44,6 @@ PYBIND11_MODULE(_tailscale, m) {
Listen on a given protocol and port
)pbdoc");
m.def("close_listener", &TsnetListenerClose, R"pbdoc(
Create a new tsnet server
)pbdoc");
m.def("accept", [](int ld) { int connOut; int rv = TsnetAccept(ld, &connOut); return std::make_tuple(connOut, rv);}, R"pbdoc(
Accept a given listener and connection
)pbdoc");
+3 -3
View File
@@ -36,8 +36,8 @@ class Tailscale
attach_function :TsnetSetLogFD, [:int, :int], :int
attach_function :TsnetDial, [:int, :string, :string, :pointer], :int, blocking: true
attach_function :TsnetListen, [:int, :string, :string, :pointer], :int
attach_function :TsnetListenerClose, [:int], :int
attach_function :TsnetAccept, [:int, :pointer], :int, blocking: true
attach_function :close, [:int], :int
attach_function :tailscale_accept, [:int, :pointer], :int, blocking: true
attach_function :TsnetErrmsg, [:int, :pointer, :size_t], :int
attach_function :TsnetLoopback, [:int, :pointer, :size_t, :pointer, :pointer], :int
end
@@ -90,7 +90,7 @@ class Tailscale
# Close the listener.
def close
@ts.assert_open
Error.check @ts, Libtailscale::TsnetListenerClose(@listener)
Error.check @ts, Libtailscale::close(@listener)
end
end
+24 -7
View File
@@ -2,6 +2,9 @@
// SPDX-License-Identifier: BSD-3-Clause
#include "tailscale.h"
#include <sys/socket.h>
#include <stdio.h>
#include <unistd.h>
// Functions exported by Go.
extern int TsnetNewServer();
@@ -17,8 +20,6 @@ extern int TsnetSetControlURL(int sd, char* str);
extern int TsnetSetEphemeral(int sd, int ephemeral);
extern int TsnetSetLogFD(int sd, int fd);
extern int TsnetListen(int sd, char* net, char* addr, int* listenerOut);
extern int TsnetListenerClose(int ld);
extern int TsnetAccept(int ld, int* connOut);
extern int TsnetLoopback(int sd, char* addrOut, size_t addrLen, char* proxyOut, char* localOut);
tailscale tailscale_new() {
@@ -45,12 +46,28 @@ int tailscale_listen(tailscale sd, const char* network, const char* addr, tailsc
return TsnetListen(sd, (char*)network, (char*)addr, (int*)listener_out);
}
int tailscale_listener_close(tailscale_listener ld) {
return TsnetListenerClose(ld);
}
int tailscale_accept(tailscale_listener ld, tailscale_conn* conn_out) {
return TsnetAccept(ld, (int*)conn_out);
struct msghdr msg = {0};
char mbuf[256];
struct iovec io = { .iov_base = mbuf, .iov_len = sizeof(mbuf) };
msg.msg_iov = &io;
msg.msg_iovlen = 1;
char cbuf[256];
msg.msg_control = cbuf;
msg.msg_controllen = sizeof(cbuf);
if (recvmsg(ld, &msg, 0) == -1) {
return -1;
}
struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
unsigned char* data = CMSG_DATA(cmsg);
int fd = *(int*)data;
*conn_out = fd;
return 0;
}
int tailscale_set_dir(tailscale sd, const char* dir) {
+94 -66
View File
@@ -14,7 +14,6 @@ import (
"net"
"os"
"sync"
"sync/atomic"
"syscall"
"unsafe"
@@ -50,14 +49,14 @@ func getServer(sd C.int) (*server, error) {
// listeners tracks all the tsnet_listener objects allocated via tsnet_listen.
var listeners struct {
mu sync.Mutex
next C.int
m map[C.int]*listener
mu sync.Mutex
m map[C.int]*listener
}
type listener struct {
s *server
ln net.Listener
fd int // go side fd of socketpair sent to C
}
// conns tracks all the pipe(2)s allocated via tsnet_dial.
@@ -180,47 +179,86 @@ func TsnetListen(sd C.int, network, addr *C.char, listenerOut *C.int) C.int {
return s.recErr(err)
}
listeners.mu.Lock()
if listeners.next == 0 {
// Arbitrary magic number that will hopefully help someone
// debug some type confusion one day.
listeners.next = 37<<16 + 1
}
if listeners.m == nil {
listeners.m = map[C.int]*listener{}
}
ld := listeners.next
listeners.next++
listeners.m[ld] = &listener{s: s, ln: ln}
listeners.mu.Unlock()
*listenerOut = ld
return 0
}
//export TsnetListenerClose
func TsnetListenerClose(ld C.int) C.int {
listeners.mu.Lock()
defer listeners.mu.Unlock()
l := listeners.m[ld]
if l == nil {
return C.EBADF
}
err := l.ln.Close()
delete(listeners.m, ld)
if err != nil {
return l.s.recErr(err)
}
return 0
}
func newConn(s *server, netConn net.Conn, connOut *C.int) C.int {
// The tailscale_listener we return to C is one side of a socketpair(2).
// We do this so we can proactively call ln.Accept in a goroutine and
// feed an fd for the connection through the listener. This lets C use
// epoll on the tailscale_listener to know if it should call
// tailscale_accept, which avoids a blocking call on the far side.
fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
if err != nil {
return s.recErr(err)
}
sp := fds[1]
fdC := C.int(fds[0])
listeners.mu.Lock()
if listeners.m == nil {
listeners.m = map[C.int]*listener{}
}
listeners.m[fdC] = &listener{s: s, ln: ln, fd: sp}
listeners.mu.Unlock()
cleanup := func() {
// If fdC is closed on the C side, then we end up calling
// into cleanup twice. Be careful to avoid syscall.Close
// twice as the FD may have been reallocated.
listeners.mu.Lock()
if tsLn, ok := listeners.m[fdC]; ok && tsLn.ln == ln {
delete(listeners.m, fdC)
syscall.Close(sp)
}
listeners.mu.Unlock()
ln.Close()
}
go func() {
// fdC is never written to, so trying to read from sp blocks
// until fdC is closed. We use this as a signal that C is
// done with the listener, and we can tear it down.
//
// TODO: would using os.NewFile avoid a locked up thread?
var buf [256]byte
syscall.Read(sp, buf[:])
cleanup()
}()
go func() {
defer cleanup()
for {
netConn, err := ln.Accept()
if err != nil {
return
}
var connFd C.int
if err := newConn(s, netConn, &connFd); err != nil {
if s.s.Logf != nil {
s.s.Logf("libtailscale.accept: newConn: %v", err)
}
netConn.Close()
continue
}
rights := syscall.UnixRights(int(connFd))
err = syscall.Sendmsg(sp, nil, rights, nil, 0)
if err != nil {
// We handle sp being closed in the read goroutine above.
if s.s.Logf != nil {
s.s.Logf("libtailscale.accept: sendmsg failed: %v", err)
}
netConn.Close()
// fallthrough to close connFd, then continue Accept()ing
}
syscall.Close(int(connFd)) // now owned by recvmsg
}
}()
*listenerOut = fdC
return 0
}
func newConn(s *server, netConn net.Conn, connOut *C.int) error {
fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
if err != nil {
return err
}
r := os.NewFile(uintptr(fds[1]), "socketpair-r")
c := &conn{s: s.s, c: netConn, r: r}
fdC := C.int(fds[0])
@@ -232,17 +270,21 @@ func newConn(s *server, netConn net.Conn, connOut *C.int) C.int {
conns.m[fdC] = c
conns.mu.Unlock()
var doneOnce atomic.Bool
connCleanup := func() {
if !doneOnce.Swap(true) {
var inCleanup bool
conns.mu.Lock()
if tsConn, ok := conns.m[fdC]; ok && tsConn.c == netConn {
delete(conns.m, fdC)
inCleanup = true
}
conns.mu.Unlock()
if !inCleanup {
return
}
r.Close()
netConn.Close()
conns.mu.Lock()
delete(conns.m, fdC)
conns.mu.Unlock()
}
go func() {
defer connCleanup()
@@ -264,24 +306,7 @@ func newConn(s *server, netConn net.Conn, connOut *C.int) C.int {
}()
*connOut = fdC
return 0
}
//export TsnetAccept
func TsnetAccept(ld C.int, connOut *C.int) C.int {
listeners.mu.Lock()
l := listeners.m[ld]
listeners.mu.Unlock()
if l == nil {
return C.EBADF
}
netConn, err := l.ln.Accept()
if err != nil {
return l.s.recErr(err)
}
return newConn(l.s, netConn, connOut)
return nil
}
//export TsnetDial
@@ -294,7 +319,10 @@ func TsnetDial(sd C.int, network, addr *C.char, connOut *C.int) C.int {
if err != nil {
return s.recErr(err)
}
return newConn(s, netConn, connOut)
if newConn(s, netConn, connOut); err != nil {
return s.recErr(err)
}
return 0
}
//export TsnetSetDir
+6 -10
View File
@@ -87,8 +87,12 @@ extern int tailscale_dial(tailscale sd, const char* network, const char* addr, t
// A tailscale_listener is a socket on the tailnet listening for connections.
//
// It is much like allocating a system socket(2) and calling listen(2).
// Because it is not a system socket, operate on it using the functions
// tailscale_accept and tailscale_listener_close.
// Accept connections with tailscale_accept and close the listener with close.
//
// Under the hood, a tailscale_listener is one half of a socketpair itself,
// used to move the connection fd from Go to C. This means you can use epoll
// or its equivalent on a tailscale_listener to know if there is a connection
// read to accept.
typedef int tailscale_listener;
// tailscale_listen listens for a connection on the tailnet.
@@ -104,14 +108,6 @@ typedef int tailscale_listener;
// Returns zero on success or -1 on error, call tailscale_errmsg for details.
extern int tailscale_listen(tailscale sd, const char* network, const char* addr, tailscale_listener* listener_out);
// tailscale_listener_close closes the listener.
//
// Returns:
// 0 - success
// EBADF - listener is not a valid tailscale_listener
// -1 - call tailscale_errmsg for details
extern int tailscale_listener_close(tailscale_listener listener);
// tailscale_accept accepts a connection on a tailscale_listener.
//
// It is the spiritual equivalent to accept(2).
+39 -16
View File
@@ -2,6 +2,7 @@ package main
import (
"testing"
"time"
"github.com/tailscale/libtailscale/tsnetctest"
)
@@ -11,27 +12,49 @@ func TestConn(t *testing.T) {
// RunTestConn cleans up after itself, so there shouldn't be
// anything left in the global maps.
conns.mu.Lock()
rem := len(conns.m)
conns.mu.Unlock()
if rem > 0 {
t.Fatalf("want no remaining tsnet_conn objects, got %d", rem)
}
listeners.mu.Lock()
rem = len(listeners.m)
listeners.mu.Unlock()
if rem > 0 {
t.Fatalf("want no remaining tsnet_listener objects, got %d", rem)
}
servers.mu.Lock()
rem = len(servers.m)
rem := len(servers.m)
servers.mu.Unlock()
if rem > 0 {
t.Fatalf("want no remaining tsnet objects, got %d", rem)
}
var remConns, remLns int
for i := 0; i < 50; i++ {
conns.mu.Lock()
remConns = len(conns.m)
conns.mu.Unlock()
listeners.mu.Lock()
remLns = len(listeners.m)
listeners.mu.Unlock()
if remConns == 0 && remLns == 0 {
break
}
// We are waiting for cleanup goroutines to finish.
//
// libtailscale closes one side of a socketpair and
// then Go responds to the other side being unreadable
// by closing the connections and listeners.
//
// This is inherently asynchronous.
// Without ditching the standard close(2) and having our
// own close functions.
//
// So we spin for a while
time.Sleep(100 * time.Millisecond)
}
if remConns > 0 {
t.Errorf("want no remaining tsnet_conn objects, got %d", remConns)
}
if remLns > 0 {
t.Errorf("want no remaining tsnet_listener objects, got %d", remLns)
}
}
+3 -3
View File
@@ -109,11 +109,11 @@ int test_conn() {
snprintf(err, errlen, "failed to close r: %d (%s)", errno, strerror(errno));
return 1;
}
if ((ret = tailscale_listener_close(ln)) != 0) {
if ((ret = close(ln)) != 0) {
return set_err(s1, 'a');
}
if ((ret = tailscale_listener_close(ln)) != EBADF) {
snprintf(err, errlen, "double tailscale_listener_close = %d (%s), want EBADF", errno, strerror(errno));
if ((ret = close(ln)) == 0 || errno != EBADF) {
snprintf(err, errlen, "double tailscale_listener close = %d (errno %d: %s), want EBADF", ret, errno, strerror(errno));
return 1;
}