From 104593a44b2996b084572eb373359428616fdb29 Mon Sep 17 00:00:00 2001 From: Cosmin Apreuetsei Date: Thu, 27 Feb 2014 03:49:45 +0200 Subject: [PATCH] upgrade to 3.0-rc1 + win32 build --- csrc/socket/WHAT | 6 +- csrc/socket/build-mingw32.sh | 10 +- csrc/socket/src/auxiliar.c | 17 +- csrc/socket/src/auxiliar.h | 5 +- csrc/socket/src/buffer.c | 40 +- csrc/socket/src/buffer.h | 6 +- csrc/socket/src/except.c | 8 +- csrc/socket/src/except.h | 2 - csrc/socket/src/inet.c | 391 ++++++++++++++--- csrc/socket/src/inet.h | 30 +- csrc/socket/src/io.c | 2 - csrc/socket/src/io.h | 4 +- csrc/socket/src/luasocket.c | 22 +- csrc/socket/src/luasocket.h | 7 +- csrc/socket/src/mime.c | 728 +++++++++++++++++++++++++++++++ csrc/socket/src/mime.h | 29 ++ csrc/socket/src/options.c | 247 ++++++++++- csrc/socket/src/options.h | 49 ++- csrc/socket/src/select.c | 68 ++- csrc/socket/src/select.h | 2 - csrc/socket/src/serial.c | 188 ++++++++ csrc/socket/src/socket.h | 6 +- csrc/socket/src/tcp.c | 266 ++++++++--- csrc/socket/src/tcp.h | 3 +- csrc/socket/src/timeout.c | 26 +- csrc/socket/src/timeout.h | 2 - csrc/socket/src/udp.c | 235 +++++++--- csrc/socket/src/udp.h | 3 +- csrc/socket/src/unix.c | 346 +++++++++++++++ csrc/socket/src/unix.h | 30 ++ csrc/socket/src/usocket.c | 113 ++++- csrc/socket/src/usocket.h | 23 +- csrc/socket/src/wsocket.c | 79 +++- csrc/socket/src/wsocket.h | 18 +- csrc/socket/test/README | 12 - csrc/socket/test/testclnt.lua | 713 ------------------------------ csrc/socket/test/testsrvr.lua | 15 - csrc/socket/test/testsupport.lua | 37 -- ltn12.lua | 305 +++++++++++++ mime.lua | 90 ++++ socket.exclude | 2 + socket.lua | 76 ++-- socket.md | 12 +- socket/ftp.lua | 285 ++++++++++++ socket/headers.lua | 104 +++++ socket/http.lua | 356 +++++++++++++++ socket/mbox.lua | 92 ++++ socket/smtp.lua | 256 +++++++++++ socket/tftp.lua | 154 +++++++ socket/tp.lua | 126 ++++++ socket/url.lua | 170 ++++---- 51 files changed, 4548 insertions(+), 1268 deletions(-) create mode 100644 csrc/socket/src/mime.c create mode 100644 csrc/socket/src/mime.h create mode 100644 csrc/socket/src/serial.c create mode 100644 csrc/socket/src/unix.c create mode 100644 csrc/socket/src/unix.h delete mode 100644 csrc/socket/test/README delete mode 100644 csrc/socket/test/testclnt.lua delete mode 100644 csrc/socket/test/testsrvr.lua delete mode 100644 csrc/socket/test/testsupport.lua create mode 100644 ltn12.lua create mode 100644 mime.lua create mode 100644 socket/ftp.lua create mode 100644 socket/headers.lua create mode 100644 socket/http.lua create mode 100644 socket/mbox.lua create mode 100644 socket/smtp.lua create mode 100644 socket/tftp.lua create mode 100644 socket/tp.lua diff --git a/csrc/socket/WHAT b/csrc/socket/WHAT index 31a5ae1..834f751 100644 --- a/csrc/socket/WHAT +++ b/csrc/socket/WHAT @@ -1,5 +1 @@ -LuaSocket 2.0.2 from http://w3.impa.br/~diego/software/luasocket/ (MIT license) - -download link: http://files.luaforge.net/releases/luasocket/luasocket/luasocket-2.0.2/luasocket-2.0.2.tar.gz - -socket.lua moved to root dir +LuaSocket 3.0-rc1 from https://github.com/diegonehab/luasocket (MIT license) diff --git a/csrc/socket/build-mingw32.sh b/csrc/socket/build-mingw32.sh index d679bc7..1c484ab 100644 --- a/csrc/socket/build-mingw32.sh +++ b/csrc/socket/build-mingw32.sh @@ -1,2 +1,8 @@ -files="$(ls -1 src/*.c | grep -v usocket)" -gcc -O2 -s -static-libgcc $files -shared -o ../../bin/mingw32/clib/socket_core.dll -I. -I../lua -L../../bin/mingw32 -llua51 -lwsock32 +files="$(ls -1 src/*.c | grep -v "usocket\|unix\|serial\|mime")" +mkdir -p ../../bin/mingw32/clib/socket +gcc -O2 -s -static-libgcc $files -shared -o ../../bin/mingw32/clib/socket/core.dll \ + -I. -I../lua -L../../bin/mingw32 -llua51 -lws2_32 -DWINVER=0x0501 -DLUASOCKET_INET_PTON + +mkdir -p ../../bin/mingw32/clib/mime +gcc -O2 -s -static-libgcc src/mime.c -shared -o ../../bin/mingw32/clib/mime/core.dll \ + -I. -I../lua -L../../bin/mingw32 -llua51 diff --git a/csrc/socket/src/auxiliar.c b/csrc/socket/src/auxiliar.c index 9514970..de625e9 100644 --- a/csrc/socket/src/auxiliar.c +++ b/csrc/socket/src/auxiliar.c @@ -1,8 +1,6 @@ /*=========================================================================*\ * Auxiliar routines for class hierarchy manipulation * LuaSocket toolkit -* -* RCS ID: $Id: auxiliar.c,v 1.14 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include #include @@ -24,7 +22,7 @@ int auxiliar_open(lua_State *L) { * Creates a new class with given methods * Methods whose names start with __ are passed directly to the metatable. \*-------------------------------------------------------------------------*/ -void auxiliar_newclass(lua_State *L, const char *classname, luaL_reg *func) { +void auxiliar_newclass(lua_State *L, const char *classname, luaL_Reg *func) { luaL_newmetatable(L, classname); /* mt */ /* create __index table to place methods */ lua_pushstring(L, "__index"); /* mt,"__index" */ @@ -81,7 +79,7 @@ void auxiliar_add2group(lua_State *L, const char *classname, const char *groupna \*-------------------------------------------------------------------------*/ int auxiliar_checkboolean(lua_State *L, int objidx) { if (!lua_isboolean(L, objidx)) - luaL_typerror(L, objidx, lua_typename(L, LUA_TBOOLEAN)); + auxiliar_typeerror(L, objidx, lua_typename(L, LUA_TBOOLEAN)); return lua_toboolean(L, objidx); } @@ -147,3 +145,14 @@ void *auxiliar_getgroupudata(lua_State *L, const char *groupname, int objidx) { void *auxiliar_getclassudata(lua_State *L, const char *classname, int objidx) { return luaL_checkudata(L, objidx, classname); } + +/*-------------------------------------------------------------------------*\ +* Throws error when argument does not have correct type. +* Used to be part of lauxlib in Lua 5.1, was dropped from 5.2. +\*-------------------------------------------------------------------------*/ +int auxiliar_typeerror (lua_State *L, int narg, const char *tname) { + const char *msg = lua_pushfstring(L, "%s expected, got %s", tname, + luaL_typename(L, narg)); + return luaL_argerror(L, narg, msg); +} + diff --git a/csrc/socket/src/auxiliar.h b/csrc/socket/src/auxiliar.h index 18b8495..ea99013 100644 --- a/csrc/socket/src/auxiliar.h +++ b/csrc/socket/src/auxiliar.h @@ -27,15 +27,13 @@ * * The mapping from class name to the corresponding metatable and the * reverse mapping are done using lauxlib. -* -* RCS ID: $Id: auxiliar.h,v 1.9 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include "lua.h" #include "lauxlib.h" int auxiliar_open(lua_State *L); -void auxiliar_newclass(lua_State *L, const char *classname, luaL_reg *func); +void auxiliar_newclass(lua_State *L, const char *classname, luaL_Reg *func); void auxiliar_add2group(lua_State *L, const char *classname, const char *group); void auxiliar_setclass(lua_State *L, const char *classname, int objidx); void *auxiliar_checkclass(lua_State *L, const char *classname, int objidx); @@ -44,5 +42,6 @@ void *auxiliar_getclassudata(lua_State *L, const char *groupname, int objidx); void *auxiliar_getgroupudata(lua_State *L, const char *groupname, int objidx); int auxiliar_checkboolean(lua_State *L, int objidx); int auxiliar_tostring(lua_State *L); +int auxiliar_typeerror(lua_State *L, int narg, const char *tname); #endif /* AUXILIAR_H */ diff --git a/csrc/socket/src/buffer.c b/csrc/socket/src/buffer.c index 73f4ffa..423d804 100644 --- a/csrc/socket/src/buffer.c +++ b/csrc/socket/src/buffer.c @@ -1,8 +1,6 @@ /*=========================================================================*\ * Input/Output interface for Lua programs * LuaSocket toolkit -* -* RCS ID: $Id: buffer.c,v 1.28 2007/06/11 23:44:54 diego Exp $ \*=========================================================================*/ #include "lua.h" #include "lauxlib.h" @@ -42,7 +40,7 @@ int buffer_open(lua_State *L) { * Initializes C structure \*-------------------------------------------------------------------------*/ void buffer_init(p_buffer buf, p_io io, p_timeout tm) { - buf->first = buf->last = 0; + buf->first = buf->last = 0; buf->io = io; buf->tm = tm; buf->received = buf->sent = 0; @@ -53,8 +51,8 @@ void buffer_init(p_buffer buf, p_io io, p_timeout tm) { * object:getstats() interface \*-------------------------------------------------------------------------*/ int buffer_meth_getstats(lua_State *L, p_buffer buf) { - lua_pushnumber(L, buf->received); - lua_pushnumber(L, buf->sent); + lua_pushnumber(L, (lua_Number) buf->received); + lua_pushnumber(L, (lua_Number) buf->sent); lua_pushnumber(L, timeout_gettime() - buf->birthday); return 3; } @@ -63,8 +61,8 @@ int buffer_meth_getstats(lua_State *L, p_buffer buf) { * object:setstats() interface \*-------------------------------------------------------------------------*/ int buffer_meth_setstats(lua_State *L, p_buffer buf) { - buf->received = (long) luaL_optnumber(L, 2, buf->received); - buf->sent = (long) luaL_optnumber(L, 3, buf->sent); + buf->received = (long) luaL_optnumber(L, 2, (lua_Number) buf->received); + buf->sent = (long) luaL_optnumber(L, 3, (lua_Number) buf->sent); if (lua_isnumber(L, 4)) buf->birthday = timeout_gettime() - lua_tonumber(L, 4); lua_pushnumber(L, 1); return 1; @@ -80,7 +78,7 @@ int buffer_meth_send(lua_State *L, p_buffer buf) { const char *data = luaL_checklstring(L, 2, &size); long start = (long) luaL_optnumber(L, 3, 1); long end = (long) luaL_optnumber(L, 4, -1); - p_timeout tm = timeout_markstart(buf->tm); + timeout_markstart(buf->tm); if (start < 0) start = (long) (size+start+1); if (end < 0) end = (long) (size+end+1); if (start < 1) start = (long) 1; @@ -90,15 +88,15 @@ int buffer_meth_send(lua_State *L, p_buffer buf) { if (err != IO_DONE) { lua_pushnil(L); lua_pushstring(L, buf->io->error(buf->io->ctx, err)); - lua_pushnumber(L, sent+start-1); + lua_pushnumber(L, (lua_Number) (sent+start-1)); } else { - lua_pushnumber(L, sent+start-1); + lua_pushnumber(L, (lua_Number) (sent+start-1)); lua_pushnil(L); lua_pushnil(L); } #ifdef LUASOCKET_DEBUG /* push time elapsed during operation as the last return value */ - lua_pushnumber(L, timeout_gettime() - timeout_getstart(tm)); + lua_pushnumber(L, timeout_gettime() - timeout_getstart(buf->tm)); #endif return lua_gettop(L) - top; } @@ -111,7 +109,7 @@ int buffer_meth_receive(lua_State *L, p_buffer buf) { luaL_Buffer b; size_t size; const char *part = luaL_optlstring(L, 3, "", &size); - p_timeout tm = timeout_markstart(buf->tm); + timeout_markstart(buf->tm); /* initialize buffer with optional extra prefix * (useful for concatenating previous partial results) */ luaL_buffinit(L, &b); @@ -122,9 +120,15 @@ int buffer_meth_receive(lua_State *L, p_buffer buf) { if (p[0] == '*' && p[1] == 'l') err = recvline(buf, &b); else if (p[0] == '*' && p[1] == 'a') err = recvall(buf, &b); else luaL_argcheck(L, 0, 2, "invalid receive pattern"); - /* get a fixed number of bytes (minus what was already partially - * received) */ - } else err = recvraw(buf, (size_t) lua_tonumber(L, 2)-size, &b); + /* get a fixed number of bytes (minus what was already partially + * received) */ + } else { + double n = lua_tonumber(L, 2); + size_t wanted = (size_t) n; + luaL_argcheck(L, n >= 0, 2, "invalid receive pattern"); + if (size == 0 || wanted > size) + err = recvraw(buf, wanted-size, &b); + } /* check if there was an error */ if (err != IO_DONE) { /* we can't push anyting in the stack before pushing the @@ -141,7 +145,7 @@ int buffer_meth_receive(lua_State *L, p_buffer buf) { } #ifdef LUASOCKET_DEBUG /* push time elapsed during operation as the last return value */ - lua_pushnumber(L, timeout_gettime() - timeout_getstart(tm)); + lua_pushnumber(L, timeout_gettime() - timeout_getstart(buf->tm)); #endif return lua_gettop(L) - top; } @@ -166,7 +170,7 @@ static int sendraw(p_buffer buf, const char *data, size_t count, size_t *sent) { size_t total = 0; int err = IO_DONE; while (total < count && err == IO_DONE) { - size_t done; + size_t done = 0; size_t step = (count-total <= STEPSIZE)? count-total: STEPSIZE; err = io->send(io->ctx, data+total, step, &done, tm); total += done; @@ -225,7 +229,7 @@ static int recvline(p_buffer buf, luaL_Buffer *b) { pos = 0; while (pos < count && data[pos] != '\n') { /* we ignore all \r's */ - if (data[pos] != '\r') luaL_putchar(b, data[pos]); + if (data[pos] != '\r') luaL_addchar(b, data[pos]); pos++; } if (pos < count) { /* found '\n' */ diff --git a/csrc/socket/src/buffer.h b/csrc/socket/src/buffer.h index baf93ca..1281bb3 100644 --- a/csrc/socket/src/buffer.h +++ b/csrc/socket/src/buffer.h @@ -14,8 +14,6 @@ * * The module is built on top of the I/O abstraction defined in io.h and the * timeout management is done with the timeout.h interface. -* -* RCS ID: $Id: buffer.h,v 1.12 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include "lua.h" @@ -31,8 +29,8 @@ typedef struct t_buffer_ { size_t sent, received; /* bytes sent, and bytes received */ p_io io; /* IO driver used for this buffer */ p_timeout tm; /* timeout management for this buffer */ - size_t first, last; /* index of first and last bytes of stored data */ - char data[BUF_SIZE]; /* storage space for buffer data */ + size_t first, last; /* index of first and last bytes of stored data */ + char data[BUF_SIZE]; /* storage space for buffer data */ } t_buffer; typedef t_buffer *p_buffer; diff --git a/csrc/socket/src/except.c b/csrc/socket/src/except.c index 5faa5be..002e701 100644 --- a/csrc/socket/src/except.c +++ b/csrc/socket/src/except.c @@ -1,8 +1,6 @@ /*=========================================================================*\ * Simple exception support * LuaSocket toolkit -* -* RCS ID: $Id: except.c,v 1.8 2005/09/29 06:11:41 diego Exp $ \*=========================================================================*/ #include @@ -21,7 +19,7 @@ static int finalize(lua_State *L); static int do_nothing(lua_State *L); /* except functions */ -static luaL_reg func[] = { +static luaL_Reg func[] = { {"newtry", global_newtry}, {"protect", global_protect}, {NULL, NULL} @@ -94,6 +92,10 @@ static int global_protect(lua_State *L) { * Init module \*-------------------------------------------------------------------------*/ int except_open(lua_State *L) { +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + luaL_setfuncs(L, func, 0); +#else luaL_openlib(L, NULL, func, 0); +#endif return 0; } diff --git a/csrc/socket/src/except.h b/csrc/socket/src/except.h index 81efb29..1e7a245 100644 --- a/csrc/socket/src/except.h +++ b/csrc/socket/src/except.h @@ -24,8 +24,6 @@ * * With these two function, it's easy to write functions that throw * exceptions on error, but that don't interrupt the user script. -* -* RCS ID: $Id: except.h,v 1.2 2005/09/29 06:11:41 diego Exp $ \*=========================================================================*/ #include "lua.h" diff --git a/csrc/socket/src/inet.c b/csrc/socket/src/inet.c index f2cddee..1a411f6 100644 --- a/csrc/socket/src/inet.c +++ b/csrc/socket/src/inet.c @@ -1,10 +1,9 @@ /*=========================================================================*\ * Internet domain functions * LuaSocket toolkit -* -* RCS ID: $Id: inet.c,v 1.28 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include +#include #include #include "lua.h" @@ -16,14 +15,18 @@ * Internal function prototypes. \*=========================================================================*/ static int inet_global_toip(lua_State *L); +static int inet_global_getaddrinfo(lua_State *L); static int inet_global_tohostname(lua_State *L); +static int inet_global_getnameinfo(lua_State *L); static void inet_pushresolved(lua_State *L, struct hostent *hp); static int inet_global_gethostname(lua_State *L); /* DNS functions */ -static luaL_reg func[] = { - { "toip", inet_global_toip }, - { "tohostname", inet_global_tohostname }, +static luaL_Reg func[] = { + { "toip", inet_global_toip}, + { "getaddrinfo", inet_global_getaddrinfo}, + { "tohostname", inet_global_tohostname}, + { "getnameinfo", inet_global_getnameinfo}, { "gethostname", inet_global_gethostname}, { NULL, NULL} }; @@ -38,7 +41,11 @@ int inet_open(lua_State *L) { lua_pushstring(L, "dns"); lua_newtable(L); +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + luaL_setfuncs(L, func, 0); +#else luaL_openlib(L, NULL, func, 0); +#endif lua_settable(L, -3); return 0; } @@ -54,7 +61,7 @@ static int inet_gethost(const char *address, struct hostent **hp) { struct in_addr addr; if (inet_aton(address, &addr)) return socket_gethostbyaddr((char *) &addr, sizeof(addr), hp); - else + else return socket_gethostbyname(address, hp); } @@ -64,7 +71,7 @@ static int inet_gethost(const char *address, struct hostent **hp) { \*-------------------------------------------------------------------------*/ static int inet_global_tohostname(lua_State *L) { const char *address = luaL_checkstring(L, 1); - struct hostent *hp = NULL; + struct hostent *hp = NULL; int err = inet_gethost(address, &hp); if (err != IO_DONE) { lua_pushnil(L); @@ -76,6 +83,50 @@ static int inet_global_tohostname(lua_State *L) { return 2; } +static int inet_global_getnameinfo(lua_State *L) { + char hbuf[NI_MAXHOST]; + char sbuf[NI_MAXSERV]; + int i, ret; + struct addrinfo hints; + struct addrinfo *resolved, *iter; + const char *host = luaL_optstring(L, 1, NULL); + const char *serv = luaL_optstring(L, 2, NULL); + + if (!(host || serv)) + luaL_error(L, "host and serv cannot be both nil"); + + memset(&hints, 0, sizeof(hints)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_family = PF_UNSPEC; + + ret = getaddrinfo(host, serv, &hints, &resolved); + if (ret != 0) { + lua_pushnil(L); + lua_pushstring(L, socket_gaistrerror(ret)); + return 2; + } + + lua_newtable(L); + for (i = 1, iter = resolved; iter; i++, iter = iter->ai_next) { + getnameinfo(iter->ai_addr, (socklen_t) iter->ai_addrlen, + hbuf, host? (socklen_t) sizeof(hbuf): 0, + sbuf, serv? (socklen_t) sizeof(sbuf): 0, 0); + if (host) { + lua_pushnumber(L, i); + lua_pushstring(L, hbuf); + lua_settable(L, -3); + } + } + freeaddrinfo(resolved); + + if (serv) { + lua_pushstring(L, sbuf); + return 2; + } else { + return 1; + } +} + /*-------------------------------------------------------------------------*\ * Returns all information provided by the resolver given a host name * or ip address @@ -83,7 +134,7 @@ static int inet_global_tohostname(lua_State *L) { static int inet_global_toip(lua_State *L) { const char *address = luaL_checkstring(L, 1); - struct hostent *hp = NULL; + struct hostent *hp = NULL; int err = inet_gethost(address, &hp); if (err != IO_DONE) { lua_pushnil(L); @@ -95,6 +146,70 @@ static int inet_global_toip(lua_State *L) return 2; } +int inet_optfamily(lua_State* L, int narg, const char* def) +{ + static const char* optname[] = { "unspec", "inet", "inet6", NULL }; + static int optvalue[] = { PF_UNSPEC, PF_INET, PF_INET6, 0 }; + + return optvalue[luaL_checkoption(L, narg, def, optname)]; +} + +int inet_optsocktype(lua_State* L, int narg, const char* def) +{ + static const char* optname[] = { "stream", "dgram", NULL }; + static int optvalue[] = { SOCK_STREAM, SOCK_DGRAM, 0 }; + + return optvalue[luaL_checkoption(L, narg, def, optname)]; +} + +static int inet_global_getaddrinfo(lua_State *L) +{ + const char *hostname = luaL_checkstring(L, 1); + struct addrinfo *iterator = NULL, *resolved = NULL; + struct addrinfo hints; + int i = 1, ret = 0; + memset(&hints, 0, sizeof(hints)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_family = PF_UNSPEC; + ret = getaddrinfo(hostname, NULL, &hints, &resolved); + if (ret != 0) { + lua_pushnil(L); + lua_pushstring(L, socket_gaistrerror(ret)); + return 2; + } + lua_newtable(L); + for (iterator = resolved; iterator; iterator = iterator->ai_next) { + char hbuf[NI_MAXHOST]; + ret = getnameinfo(iterator->ai_addr, (socklen_t) iterator->ai_addrlen, + hbuf, (socklen_t) sizeof(hbuf), NULL, 0, NI_NUMERICHOST); + if (ret){ + lua_pushnil(L); + lua_pushstring(L, socket_gaistrerror(ret)); + return 2; + } + lua_pushnumber(L, i); + lua_newtable(L); + switch (iterator->ai_family) { + case AF_INET: + lua_pushliteral(L, "family"); + lua_pushliteral(L, "inet"); + lua_settable(L, -3); + break; + case AF_INET6: + lua_pushliteral(L, "family"); + lua_pushliteral(L, "inet6"); + lua_settable(L, -3); + break; + } + lua_pushliteral(L, "addr"); + lua_pushstring(L, hbuf); + lua_settable(L, -3); + lua_settable(L, -3); + i++; + } + freeaddrinfo(resolved); + return 1; +} /*-------------------------------------------------------------------------*\ * Gets the host name @@ -105,7 +220,7 @@ static int inet_global_gethostname(lua_State *L) name[256] = '\0'; if (gethostname(name, 256) < 0) { lua_pushnil(L); - lua_pushstring(L, "gethostname failed"); + lua_pushstring(L, socket_strerror(errno)); return 2; } else { lua_pushstring(L, name); @@ -113,43 +228,76 @@ static int inet_global_gethostname(lua_State *L) } } - - /*=========================================================================*\ * Lua methods \*=========================================================================*/ /*-------------------------------------------------------------------------*\ * Retrieves socket peer name \*-------------------------------------------------------------------------*/ -int inet_meth_getpeername(lua_State *L, p_socket ps) +int inet_meth_getpeername(lua_State *L, p_socket ps, int family) { - struct sockaddr_in peer; + int err; + struct sockaddr_storage peer; socklen_t peer_len = sizeof(peer); + char name[INET6_ADDRSTRLEN]; + char port[6]; /* 65535 = 5 bytes + 0 to terminate it */ if (getpeername(*ps, (SA *) &peer, &peer_len) < 0) { lua_pushnil(L); - lua_pushstring(L, "getpeername failed"); - } else { - lua_pushstring(L, inet_ntoa(peer.sin_addr)); - lua_pushnumber(L, ntohs(peer.sin_port)); + lua_pushstring(L, socket_strerror(errno)); + return 2; } - return 2; + err = getnameinfo((struct sockaddr *) &peer, peer_len, + name, INET6_ADDRSTRLEN, + port, sizeof(port), NI_NUMERICHOST | NI_NUMERICSERV); + if (err) { + lua_pushnil(L); + lua_pushstring(L, gai_strerror(err)); + return 2; + } + lua_pushstring(L, name); + lua_pushinteger(L, (int) strtol(port, (char **) NULL, 10)); + if (family == PF_INET) { + lua_pushliteral(L, "inet"); + } else if (family == PF_INET6) { + lua_pushliteral(L, "inet6"); + } else { + lua_pushliteral(L, "uknown family"); + } + return 3; } /*-------------------------------------------------------------------------*\ * Retrieves socket local name \*-------------------------------------------------------------------------*/ -int inet_meth_getsockname(lua_State *L, p_socket ps) +int inet_meth_getsockname(lua_State *L, p_socket ps, int family) { - struct sockaddr_in local; - socklen_t local_len = sizeof(local); - if (getsockname(*ps, (SA *) &local, &local_len) < 0) { + int err; + struct sockaddr_storage peer; + socklen_t peer_len = sizeof(peer); + char name[INET6_ADDRSTRLEN]; + char port[6]; /* 65535 = 5 bytes + 0 to terminate it */ + if (getsockname(*ps, (SA *) &peer, &peer_len) < 0) { lua_pushnil(L); - lua_pushstring(L, "getsockname failed"); - } else { - lua_pushstring(L, inet_ntoa(local.sin_addr)); - lua_pushnumber(L, ntohs(local.sin_port)); + lua_pushstring(L, socket_strerror(errno)); + return 2; } - return 2; + err=getnameinfo((struct sockaddr *)&peer, peer_len, + name, INET6_ADDRSTRLEN, port, 6, NI_NUMERICHOST | NI_NUMERICSERV); + if (err) { + lua_pushnil(L); + lua_pushstring(L, gai_strerror(err)); + return 2; + } + lua_pushstring(L, name); + lua_pushstring(L, port); + if (family == PF_INET) { + lua_pushliteral(L, "inet"); + } else if (family == PF_INET6) { + lua_pushliteral(L, "inet6"); + } else { + lua_pushliteral(L, "uknown family"); + } + return 3; } /*=========================================================================*\ @@ -198,65 +346,150 @@ static void inet_pushresolved(lua_State *L, struct hostent *hp) /*-------------------------------------------------------------------------*\ * Tries to create a new inet socket \*-------------------------------------------------------------------------*/ -const char *inet_trycreate(p_socket ps, int type) { - return socket_strerror(socket_create(ps, AF_INET, type, 0)); +const char *inet_trycreate(p_socket ps, int family, int type) { + return socket_strerror(socket_create(ps, family, type, 0)); +} + +/*-------------------------------------------------------------------------*\ +* "Disconnects" a DGRAM socket +\*-------------------------------------------------------------------------*/ +const char *inet_trydisconnect(p_socket ps, int family, p_timeout tm) +{ + switch (family) { + case PF_INET: { + struct sockaddr_in sin; + memset((char *) &sin, 0, sizeof(sin)); + sin.sin_family = AF_UNSPEC; + sin.sin_addr.s_addr = INADDR_ANY; + return socket_strerror(socket_connect(ps, (SA *) &sin, + sizeof(sin), tm)); + } + case PF_INET6: { + struct sockaddr_in6 sin6; + struct in6_addr addrany = IN6ADDR_ANY_INIT; + memset((char *) &sin6, 0, sizeof(sin6)); + sin6.sin6_family = AF_UNSPEC; + sin6.sin6_addr = addrany; + return socket_strerror(socket_connect(ps, (SA *) &sin6, + sizeof(sin6), tm)); + } + } + return NULL; } /*-------------------------------------------------------------------------*\ * Tries to connect to remote address (address, port) \*-------------------------------------------------------------------------*/ -const char *inet_tryconnect(p_socket ps, const char *address, - unsigned short port, p_timeout tm) +const char *inet_tryconnect(p_socket ps, int *family, const char *address, + const char *serv, p_timeout tm, struct addrinfo *connecthints) { - struct sockaddr_in remote; - int err; - memset(&remote, 0, sizeof(remote)); - remote.sin_family = AF_INET; - remote.sin_port = htons(port); - if (strcmp(address, "*")) { - if (!inet_aton(address, &remote.sin_addr)) { - struct hostent *hp = NULL; - struct in_addr **addr; - err = socket_gethostbyname(address, &hp); - if (err != IO_DONE) return socket_hoststrerror(err); - addr = (struct in_addr **) hp->h_addr_list; - memcpy(&remote.sin_addr, *addr, sizeof(struct in_addr)); + struct addrinfo *iterator = NULL, *resolved = NULL; + const char *err = NULL; + /* try resolving */ + err = socket_gaistrerror(getaddrinfo(address, serv, + connecthints, &resolved)); + if (err != NULL) { + if (resolved) freeaddrinfo(resolved); + return err; + } + for (iterator = resolved; iterator; iterator = iterator->ai_next) { + timeout_markstart(tm); + /* create new socket if necessary. if there was no + * bind, we need to create one for every new family + * that shows up while iterating. if there was a + * bind, all families will be the same and we will + * not enter this branch. */ + if (*family != iterator->ai_family) { + socket_destroy(ps); + err = socket_strerror(socket_create(ps, iterator->ai_family, + iterator->ai_socktype, iterator->ai_protocol)); + if (err != NULL) { + freeaddrinfo(resolved); + return err; + } + *family = iterator->ai_family; + /* all sockets initially non-blocking */ + socket_setnonblocking(ps); } - } else remote.sin_family = AF_UNSPEC; - err = socket_connect(ps, (SA *) &remote, sizeof(remote), tm); - return socket_strerror(err); + /* try connecting to remote address */ + err = socket_strerror(socket_connect(ps, (SA *) iterator->ai_addr, + (socklen_t) iterator->ai_addrlen, tm)); + /* if success, break out of loop */ + if (err == NULL) break; + } + freeaddrinfo(resolved); + /* here, if err is set, we failed */ + return err; +} + +/*-------------------------------------------------------------------------*\ +* Tries to accept a socket +\*-------------------------------------------------------------------------*/ +const char *inet_tryaccept(p_socket server, int family, p_socket client, + p_timeout tm) +{ + socklen_t len; + t_sockaddr_storage addr; + if (family == PF_INET6) { + len = sizeof(struct sockaddr_in6); + } else { + len = sizeof(struct sockaddr_in); + } + return socket_strerror(socket_accept(server, client, (SA *) &addr, + &len, tm)); } /*-------------------------------------------------------------------------*\ * Tries to bind socket to (address, port) \*-------------------------------------------------------------------------*/ -const char *inet_trybind(p_socket ps, const char *address, unsigned short port) +const char *inet_trybind(p_socket ps, const char *address, const char *serv, + struct addrinfo *bindhints) { - struct sockaddr_in local; - int err; - memset(&local, 0, sizeof(local)); - /* address is either wildcard or a valid ip address */ - local.sin_addr.s_addr = htonl(INADDR_ANY); - local.sin_port = htons(port); - local.sin_family = AF_INET; - if (strcmp(address, "*") && !inet_aton(address, &local.sin_addr)) { - struct hostent *hp = NULL; - struct in_addr **addr; - err = socket_gethostbyname(address, &hp); - if (err != IO_DONE) return socket_hoststrerror(err); - addr = (struct in_addr **) hp->h_addr_list; - memcpy(&local.sin_addr, *addr, sizeof(struct in_addr)); + struct addrinfo *iterator = NULL, *resolved = NULL; + const char *err = NULL; + t_socket sock = *ps; + /* translate luasocket special values to C */ + if (strcmp(address, "*") == 0) address = NULL; + if (!serv) serv = "0"; + /* try resolving */ + err = socket_gaistrerror(getaddrinfo(address, serv, bindhints, &resolved)); + if (err) { + if (resolved) freeaddrinfo(resolved); + return err; } - err = socket_bind(ps, (SA *) &local, sizeof(local)); - if (err != IO_DONE) socket_destroy(ps); - return socket_strerror(err); + /* iterate over resolved addresses until one is good */ + for (iterator = resolved; iterator; iterator = iterator->ai_next) { + if(sock == SOCKET_INVALID) { + err = socket_strerror(socket_create(&sock, iterator->ai_family, + iterator->ai_socktype, iterator->ai_protocol)); + if(err) + continue; + } + /* try binding to local address */ + err = socket_strerror(socket_bind(&sock, + (SA *) iterator->ai_addr, + (socklen_t) iterator->ai_addrlen)); + + /* keep trying unless bind succeeded */ + if (err) { + if(sock != *ps) + socket_destroy(&sock); + } else { + /* remember what we connected to, particularly the family */ + *bindhints = *iterator; + break; + } + } + /* cleanup and return error */ + freeaddrinfo(resolved); + *ps = sock; + return err; } /*-------------------------------------------------------------------------*\ -* Some systems do not provide this so that we provide our own. It's not -* marvelously fast, but it works just fine. +* Some systems do not provide these so that we provide our own. \*-------------------------------------------------------------------------*/ -#ifdef INET_ATON +#ifdef LUASOCKET_INET_ATON int inet_aton(const char *cp, struct in_addr *inp) { unsigned int a = 0, b = 0, c = 0, d = 0; @@ -278,4 +511,26 @@ int inet_aton(const char *cp, struct in_addr *inp) } #endif +#ifdef LUASOCKET_INET_PTON +int inet_pton(int af, const char *src, void *dst) +{ + struct addrinfo hints, *res; + int ret = 1; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = af; + hints.ai_flags = AI_NUMERICHOST; + if (getaddrinfo(src, NULL, &hints, &res) != 0) return -1; + if (af == AF_INET) { + struct sockaddr_in *in = (struct sockaddr_in *) res->ai_addr; + memcpy(dst, &in->sin_addr, sizeof(in->sin_addr)); + } else if (af == AF_INET6) { + struct sockaddr_in6 *in = (struct sockaddr_in6 *) res->ai_addr; + memcpy(dst, &in->sin6_addr, sizeof(in->sin6_addr)); + } else { + ret = -1; + } + freeaddrinfo(res); + return ret; +} +#endif diff --git a/csrc/socket/src/inet.h b/csrc/socket/src/inet.h index 7662266..1f1a96a 100644 --- a/csrc/socket/src/inet.h +++ b/csrc/socket/src/inet.h @@ -13,30 +13,38 @@ * getpeername and getsockname functions as seen by Lua programs. * * The Lua functions toip and tohostname are also implemented here. -* -* RCS ID: $Id: inet.h,v 1.16 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include "lua.h" #include "socket.h" #include "timeout.h" #ifdef _WIN32 -#define INET_ATON +#define LUASOCKET_INET_ATON #endif int inet_open(lua_State *L); -const char *inet_trycreate(p_socket ps, int type); -const char *inet_tryconnect(p_socket ps, const char *address, - unsigned short port, p_timeout tm); -const char *inet_trybind(p_socket ps, const char *address, - unsigned short port); +const char *inet_trycreate(p_socket ps, int family, int type); +const char *inet_tryconnect(p_socket ps, int *family, const char *address, + const char *serv, p_timeout tm, struct addrinfo *connecthints); +const char *inet_trybind(p_socket ps, const char *address, const char *serv, + struct addrinfo *bindhints); +const char *inet_trydisconnect(p_socket ps, int family, p_timeout tm); +const char *inet_tryaccept(p_socket server, int family, p_socket client, p_timeout tm); -int inet_meth_getpeername(lua_State *L, p_socket ps); -int inet_meth_getsockname(lua_State *L, p_socket ps); +int inet_meth_getpeername(lua_State *L, p_socket ps, int family); +int inet_meth_getsockname(lua_State *L, p_socket ps, int family); -#ifdef INET_ATON +int inet_optfamily(lua_State* L, int narg, const char* def); +int inet_optsocktype(lua_State* L, int narg, const char* def); + +#ifdef LUASOCKET_INET_ATON int inet_aton(const char *cp, struct in_addr *inp); #endif +#ifdef LUASOCKET_INET_PTON +const char *inet_ntop(int af, const void *src, char *dst, socklen_t cnt); +int inet_pton(int af, const char *src, void *dst); +#endif + #endif /* INET_H */ diff --git a/csrc/socket/src/io.c b/csrc/socket/src/io.c index 06dc50e..35f46f7 100644 --- a/csrc/socket/src/io.c +++ b/csrc/socket/src/io.c @@ -1,8 +1,6 @@ /*=========================================================================*\ * Input/Output abstraction * LuaSocket toolkit -* -* RCS ID: $Id: io.c,v 1.6 2005/09/29 06:11:41 diego Exp $ \*=========================================================================*/ #include "io.h" diff --git a/csrc/socket/src/io.h b/csrc/socket/src/io.h index cce3aaf..76a3e58 100644 --- a/csrc/socket/src/io.h +++ b/csrc/socket/src/io.h @@ -11,8 +11,6 @@ * * The module socket.h implements this interface, and thus the module tcp.h * is very simple. -* -* RCS ID: $Id: io.h,v 1.11 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include #include "lua.h" @@ -24,7 +22,7 @@ enum { IO_DONE = 0, /* operation completed successfully */ IO_TIMEOUT = -1, /* operation timed out */ IO_CLOSED = -2, /* the connection has been closed */ - IO_UNKNOWN = -3 + IO_UNKNOWN = -3 }; /* interface to error message function */ diff --git a/csrc/socket/src/luasocket.c b/csrc/socket/src/luasocket.c index 11ffee9..e6ee747 100644 --- a/csrc/socket/src/luasocket.c +++ b/csrc/socket/src/luasocket.c @@ -10,8 +10,6 @@ * involved in setting up both client and server connections. The provided * IO routines, however, follow the Lua style, being very similar to the * standard Lua read and write functions. -* -* RCS ID: $Id: luasocket.c,v 1.53 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ /*=========================================================================*\ @@ -20,9 +18,6 @@ #include "lua.h" #include "lauxlib.h" -#if !defined(LUA_VERSION_NUM) || (LUA_VERSION_NUM < 501) -#include "compat-5.1.h" -#endif /*=========================================================================*\ * LuaSocket includes @@ -47,7 +42,7 @@ static int base_open(lua_State *L); /*-------------------------------------------------------------------------*\ * Modules and functions \*-------------------------------------------------------------------------*/ -static const luaL_reg mod[] = { +static const luaL_Reg mod[] = { {"auxiliar", auxiliar_open}, {"except", except_open}, {"timeout", timeout_open}, @@ -59,7 +54,7 @@ static const luaL_reg mod[] = { {NULL, NULL} }; -static luaL_reg func[] = { +static luaL_Reg func[] = { {"skip", global_skip}, {"__unload", global_unload}, {NULL, NULL} @@ -83,13 +78,26 @@ static int global_unload(lua_State *L) { return 0; } +#if LUA_VERSION_NUM > 501 +int luaL_typerror (lua_State *L, int narg, const char *tname) { + const char *msg = lua_pushfstring(L, "%s expected, got %s", + tname, luaL_typename(L, narg)); + return luaL_argerror(L, narg, msg); +} +#endif + /*-------------------------------------------------------------------------*\ * Setup basic stuff. \*-------------------------------------------------------------------------*/ static int base_open(lua_State *L) { if (socket_open()) { /* export functions (and leave namespace table on top of stack) */ +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + lua_newtable(L); + luaL_setfuncs(L, func, 0); +#else luaL_openlib(L, "socket", func, 0); +#endif #ifdef LUASOCKET_DEBUG lua_pushstring(L, "_DEBUG"); lua_pushboolean(L, 1); diff --git a/csrc/socket/src/luasocket.h b/csrc/socket/src/luasocket.h index 67270ab..f75d21f 100644 --- a/csrc/socket/src/luasocket.h +++ b/csrc/socket/src/luasocket.h @@ -5,17 +5,14 @@ * Networking support for the Lua language * Diego Nehab * 9/11/1999 -* -* RCS ID: $Id: luasocket.h,v 1.25 2007/06/11 23:44:54 diego Exp $ \*=========================================================================*/ #include "lua.h" /*-------------------------------------------------------------------------*\ * Current socket library version \*-------------------------------------------------------------------------*/ -#define LUASOCKET_VERSION "LuaSocket 2.0.2" -#define LUASOCKET_COPYRIGHT "Copyright (C) 2004-2007 Diego Nehab" -#define LUASOCKET_AUTHORS "Diego Nehab" +#define LUASOCKET_VERSION "LuaSocket 3.0-rc1" +#define LUASOCKET_COPYRIGHT "Copyright (C) 1999-2013 Diego Nehab" /*-------------------------------------------------------------------------*\ * This macro prefixes all exported API functions diff --git a/csrc/socket/src/mime.c b/csrc/socket/src/mime.c new file mode 100644 index 0000000..dd37dcf --- /dev/null +++ b/csrc/socket/src/mime.c @@ -0,0 +1,728 @@ +/*=========================================================================*\ +* MIME support functions +* LuaSocket toolkit +\*=========================================================================*/ +#include + +#include "lua.h" +#include "lauxlib.h" + +#if !defined(LUA_VERSION_NUM) || (LUA_VERSION_NUM < 501) +#include "compat-5.1.h" +#endif + +#include "mime.h" + +/*=========================================================================*\ +* Don't want to trust escape character constants +\*=========================================================================*/ +typedef unsigned char UC; +static const char CRLF[] = "\r\n"; +static const char EQCRLF[] = "=\r\n"; + +/*=========================================================================*\ +* Internal function prototypes. +\*=========================================================================*/ +static int mime_global_wrp(lua_State *L); +static int mime_global_b64(lua_State *L); +static int mime_global_unb64(lua_State *L); +static int mime_global_qp(lua_State *L); +static int mime_global_unqp(lua_State *L); +static int mime_global_qpwrp(lua_State *L); +static int mime_global_eol(lua_State *L); +static int mime_global_dot(lua_State *L); + +static size_t dot(int c, size_t state, luaL_Buffer *buffer); +static void b64setup(UC *base); +static size_t b64encode(UC c, UC *input, size_t size, luaL_Buffer *buffer); +static size_t b64pad(const UC *input, size_t size, luaL_Buffer *buffer); +static size_t b64decode(UC c, UC *input, size_t size, luaL_Buffer *buffer); + +static void qpsetup(UC *class, UC *unbase); +static void qpquote(UC c, luaL_Buffer *buffer); +static size_t qpdecode(UC c, UC *input, size_t size, luaL_Buffer *buffer); +static size_t qpencode(UC c, UC *input, size_t size, + const char *marker, luaL_Buffer *buffer); +static size_t qppad(UC *input, size_t size, luaL_Buffer *buffer); + +/* code support functions */ +static luaL_Reg func[] = { + { "dot", mime_global_dot }, + { "b64", mime_global_b64 }, + { "eol", mime_global_eol }, + { "qp", mime_global_qp }, + { "qpwrp", mime_global_qpwrp }, + { "unb64", mime_global_unb64 }, + { "unqp", mime_global_unqp }, + { "wrp", mime_global_wrp }, + { NULL, NULL } +}; + +/*-------------------------------------------------------------------------*\ +* Quoted-printable globals +\*-------------------------------------------------------------------------*/ +static UC qpclass[256]; +static UC qpbase[] = "0123456789ABCDEF"; +static UC qpunbase[256]; +enum {QP_PLAIN, QP_QUOTED, QP_CR, QP_IF_LAST}; + +/*-------------------------------------------------------------------------*\ +* Base64 globals +\*-------------------------------------------------------------------------*/ +static const UC b64base[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +static UC b64unbase[256]; + +/*=========================================================================*\ +* Exported functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Initializes module +\*-------------------------------------------------------------------------*/ +MIME_API int luaopen_mime_core(lua_State *L) +{ +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + lua_newtable(L); + luaL_setfuncs(L, func, 0); +#else + luaL_openlib(L, "mime", func, 0); +#endif + /* make version string available to scripts */ + lua_pushstring(L, "_VERSION"); + lua_pushstring(L, MIME_VERSION); + lua_rawset(L, -3); + /* initialize lookup tables */ + qpsetup(qpclass, qpunbase); + b64setup(b64unbase); + return 1; +} + +/*=========================================================================*\ +* Global Lua functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Incrementaly breaks a string into lines. The string can have CRLF breaks. +* A, n = wrp(l, B, length) +* A is a copy of B, broken into lines of at most 'length' bytes. +* 'l' is how many bytes are left for the first line of B. +* 'n' is the number of bytes left in the last line of A. +\*-------------------------------------------------------------------------*/ +static int mime_global_wrp(lua_State *L) +{ + size_t size = 0; + int left = (int) luaL_checknumber(L, 1); + const UC *input = (UC *) luaL_optlstring(L, 2, NULL, &size); + const UC *last = input + size; + int length = (int) luaL_optnumber(L, 3, 76); + luaL_Buffer buffer; + /* end of input black-hole */ + if (!input) { + /* if last line has not been terminated, add a line break */ + if (left < length) lua_pushstring(L, CRLF); + /* otherwise, we are done */ + else lua_pushnil(L); + lua_pushnumber(L, length); + return 2; + } + luaL_buffinit(L, &buffer); + while (input < last) { + switch (*input) { + case '\r': + break; + case '\n': + luaL_addstring(&buffer, CRLF); + left = length; + break; + default: + if (left <= 0) { + left = length; + luaL_addstring(&buffer, CRLF); + } + luaL_addchar(&buffer, *input); + left--; + break; + } + input++; + } + luaL_pushresult(&buffer); + lua_pushnumber(L, left); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Fill base64 decode map. +\*-------------------------------------------------------------------------*/ +static void b64setup(UC *unbase) +{ + int i; + for (i = 0; i <= 255; i++) unbase[i] = (UC) 255; + for (i = 0; i < 64; i++) unbase[b64base[i]] = (UC) i; + unbase['='] = 0; +} + +/*-------------------------------------------------------------------------*\ +* Acumulates bytes in input buffer until 3 bytes are available. +* Translate the 3 bytes into Base64 form and append to buffer. +* Returns new number of bytes in buffer. +\*-------------------------------------------------------------------------*/ +static size_t b64encode(UC c, UC *input, size_t size, + luaL_Buffer *buffer) +{ + input[size++] = c; + if (size == 3) { + UC code[4]; + unsigned long value = 0; + value += input[0]; value <<= 8; + value += input[1]; value <<= 8; + value += input[2]; + code[3] = b64base[value & 0x3f]; value >>= 6; + code[2] = b64base[value & 0x3f]; value >>= 6; + code[1] = b64base[value & 0x3f]; value >>= 6; + code[0] = b64base[value]; + luaL_addlstring(buffer, (char *) code, 4); + size = 0; + } + return size; +} + +/*-------------------------------------------------------------------------*\ +* Encodes the Base64 last 1 or 2 bytes and adds padding '=' +* Result, if any, is appended to buffer. +* Returns 0. +\*-------------------------------------------------------------------------*/ +static size_t b64pad(const UC *input, size_t size, + luaL_Buffer *buffer) +{ + unsigned long value = 0; + UC code[4] = {'=', '=', '=', '='}; + switch (size) { + case 1: + value = input[0] << 4; + code[1] = b64base[value & 0x3f]; value >>= 6; + code[0] = b64base[value]; + luaL_addlstring(buffer, (char *) code, 4); + break; + case 2: + value = input[0]; value <<= 8; + value |= input[1]; value <<= 2; + code[2] = b64base[value & 0x3f]; value >>= 6; + code[1] = b64base[value & 0x3f]; value >>= 6; + code[0] = b64base[value]; + luaL_addlstring(buffer, (char *) code, 4); + break; + default: + break; + } + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Acumulates bytes in input buffer until 4 bytes are available. +* Translate the 4 bytes from Base64 form and append to buffer. +* Returns new number of bytes in buffer. +\*-------------------------------------------------------------------------*/ +static size_t b64decode(UC c, UC *input, size_t size, + luaL_Buffer *buffer) +{ + /* ignore invalid characters */ + if (b64unbase[c] > 64) return size; + input[size++] = c; + /* decode atom */ + if (size == 4) { + UC decoded[3]; + int valid, value = 0; + value = b64unbase[input[0]]; value <<= 6; + value |= b64unbase[input[1]]; value <<= 6; + value |= b64unbase[input[2]]; value <<= 6; + value |= b64unbase[input[3]]; + decoded[2] = (UC) (value & 0xff); value >>= 8; + decoded[1] = (UC) (value & 0xff); value >>= 8; + decoded[0] = (UC) value; + /* take care of paddding */ + valid = (input[2] == '=') ? 1 : (input[3] == '=') ? 2 : 3; + luaL_addlstring(buffer, (char *) decoded, valid); + return 0; + /* need more data */ + } else return size; +} + +/*-------------------------------------------------------------------------*\ +* Incrementally applies the Base64 transfer content encoding to a string +* A, B = b64(C, D) +* A is the encoded version of the largest prefix of C .. D that is +* divisible by 3. B has the remaining bytes of C .. D, *without* encoding. +* The easiest thing would be to concatenate the two strings and +* encode the result, but we can't afford that or Lua would dupplicate +* every chunk we received. +\*-------------------------------------------------------------------------*/ +static int mime_global_b64(lua_State *L) +{ + UC atom[3]; + size_t isize = 0, asize = 0; + const UC *input = (UC *) luaL_optlstring(L, 1, NULL, &isize); + const UC *last = input + isize; + luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* make sure we don't confuse buffer stuff with arguments */ + lua_settop(L, 2); + /* process first part of the input */ + luaL_buffinit(L, &buffer); + while (input < last) + asize = b64encode(*input++, atom, asize, &buffer); + input = (UC *) luaL_optlstring(L, 2, NULL, &isize); + /* if second part is nil, we are done */ + if (!input) { + size_t osize = 0; + asize = b64pad(atom, asize, &buffer); + luaL_pushresult(&buffer); + /* if the output is empty and the input is nil, return nil */ + lua_tolstring(L, -1, &osize); + if (osize == 0) lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* otherwise process the second part */ + last = input + isize; + while (input < last) + asize = b64encode(*input++, atom, asize, &buffer); + luaL_pushresult(&buffer); + lua_pushlstring(L, (char *) atom, asize); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Incrementally removes the Base64 transfer content encoding from a string +* A, B = b64(C, D) +* A is the encoded version of the largest prefix of C .. D that is +* divisible by 4. B has the remaining bytes of C .. D, *without* encoding. +\*-------------------------------------------------------------------------*/ +static int mime_global_unb64(lua_State *L) +{ + UC atom[4]; + size_t isize = 0, asize = 0; + const UC *input = (UC *) luaL_optlstring(L, 1, NULL, &isize); + const UC *last = input + isize; + luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* make sure we don't confuse buffer stuff with arguments */ + lua_settop(L, 2); + /* process first part of the input */ + luaL_buffinit(L, &buffer); + while (input < last) + asize = b64decode(*input++, atom, asize, &buffer); + input = (UC *) luaL_optlstring(L, 2, NULL, &isize); + /* if second is nil, we are done */ + if (!input) { + size_t osize = 0; + luaL_pushresult(&buffer); + /* if the output is empty and the input is nil, return nil */ + lua_tolstring(L, -1, &osize); + if (osize == 0) lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* otherwise, process the rest of the input */ + last = input + isize; + while (input < last) + asize = b64decode(*input++, atom, asize, &buffer); + luaL_pushresult(&buffer); + lua_pushlstring(L, (char *) atom, asize); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Quoted-printable encoding scheme +* all (except CRLF in text) can be =XX +* CLRL in not text must be =XX=XX +* 33 through 60 inclusive can be plain +* 62 through 126 inclusive can be plain +* 9 and 32 can be plain, unless in the end of a line, where must be =XX +* encoded lines must be no longer than 76 not counting CRLF +* soft line-break are =CRLF +* To encode one byte, we need to see the next two. +* Worst case is when we see a space, and wonder if a CRLF is comming +\*-------------------------------------------------------------------------*/ +/*-------------------------------------------------------------------------*\ +* Split quoted-printable characters into classes +* Precompute reverse map for encoding +\*-------------------------------------------------------------------------*/ +static void qpsetup(UC *cl, UC *unbase) +{ + int i; + for (i = 0; i < 256; i++) cl[i] = QP_QUOTED; + for (i = 33; i <= 60; i++) cl[i] = QP_PLAIN; + for (i = 62; i <= 126; i++) cl[i] = QP_PLAIN; + cl['\t'] = QP_IF_LAST; + cl[' '] = QP_IF_LAST; + cl['\r'] = QP_CR; + for (i = 0; i < 256; i++) unbase[i] = 255; + unbase['0'] = 0; unbase['1'] = 1; unbase['2'] = 2; + unbase['3'] = 3; unbase['4'] = 4; unbase['5'] = 5; + unbase['6'] = 6; unbase['7'] = 7; unbase['8'] = 8; + unbase['9'] = 9; unbase['A'] = 10; unbase['a'] = 10; + unbase['B'] = 11; unbase['b'] = 11; unbase['C'] = 12; + unbase['c'] = 12; unbase['D'] = 13; unbase['d'] = 13; + unbase['E'] = 14; unbase['e'] = 14; unbase['F'] = 15; + unbase['f'] = 15; +} + +/*-------------------------------------------------------------------------*\ +* Output one character in form =XX +\*-------------------------------------------------------------------------*/ +static void qpquote(UC c, luaL_Buffer *buffer) +{ + luaL_addchar(buffer, '='); + luaL_addchar(buffer, qpbase[c >> 4]); + luaL_addchar(buffer, qpbase[c & 0x0F]); +} + +/*-------------------------------------------------------------------------*\ +* Accumulate characters until we are sure about how to deal with them. +* Once we are sure, output to the buffer, in the correct form. +\*-------------------------------------------------------------------------*/ +static size_t qpencode(UC c, UC *input, size_t size, + const char *marker, luaL_Buffer *buffer) +{ + input[size++] = c; + /* deal with all characters we can have */ + while (size > 0) { + switch (qpclass[input[0]]) { + /* might be the CR of a CRLF sequence */ + case QP_CR: + if (size < 2) return size; + if (input[1] == '\n') { + luaL_addstring(buffer, marker); + return 0; + } else qpquote(input[0], buffer); + break; + /* might be a space and that has to be quoted if last in line */ + case QP_IF_LAST: + if (size < 3) return size; + /* if it is the last, quote it and we are done */ + if (input[1] == '\r' && input[2] == '\n') { + qpquote(input[0], buffer); + luaL_addstring(buffer, marker); + return 0; + } else luaL_addchar(buffer, input[0]); + break; + /* might have to be quoted always */ + case QP_QUOTED: + qpquote(input[0], buffer); + break; + /* might never have to be quoted */ + default: + luaL_addchar(buffer, input[0]); + break; + } + input[0] = input[1]; input[1] = input[2]; + size--; + } + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Deal with the final characters +\*-------------------------------------------------------------------------*/ +static size_t qppad(UC *input, size_t size, luaL_Buffer *buffer) +{ + size_t i; + for (i = 0; i < size; i++) { + if (qpclass[input[i]] == QP_PLAIN) luaL_addchar(buffer, input[i]); + else qpquote(input[i], buffer); + } + if (size > 0) luaL_addstring(buffer, EQCRLF); + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Incrementally converts a string to quoted-printable +* A, B = qp(C, D, marker) +* Marker is the text to be used to replace CRLF sequences found in A. +* A is the encoded version of the largest prefix of C .. D that +* can be encoded without doubts. +* B has the remaining bytes of C .. D, *without* encoding. +\*-------------------------------------------------------------------------*/ +static int mime_global_qp(lua_State *L) +{ + + size_t asize = 0, isize = 0; + UC atom[3]; + const UC *input = (UC *) luaL_optlstring(L, 1, NULL, &isize); + const UC *last = input + isize; + const char *marker = luaL_optstring(L, 3, CRLF); + luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* make sure we don't confuse buffer stuff with arguments */ + lua_settop(L, 3); + /* process first part of input */ + luaL_buffinit(L, &buffer); + while (input < last) + asize = qpencode(*input++, atom, asize, marker, &buffer); + input = (UC *) luaL_optlstring(L, 2, NULL, &isize); + /* if second part is nil, we are done */ + if (!input) { + asize = qppad(atom, asize, &buffer); + luaL_pushresult(&buffer); + if (!(*lua_tostring(L, -1))) lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* otherwise process rest of input */ + last = input + isize; + while (input < last) + asize = qpencode(*input++, atom, asize, marker, &buffer); + luaL_pushresult(&buffer); + lua_pushlstring(L, (char *) atom, asize); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Accumulate characters until we are sure about how to deal with them. +* Once we are sure, output the to the buffer, in the correct form. +\*-------------------------------------------------------------------------*/ +static size_t qpdecode(UC c, UC *input, size_t size, luaL_Buffer *buffer) { + int d; + input[size++] = c; + /* deal with all characters we can deal */ + switch (input[0]) { + /* if we have an escape character */ + case '=': + if (size < 3) return size; + /* eliminate soft line break */ + if (input[1] == '\r' && input[2] == '\n') return 0; + /* decode quoted representation */ + c = qpunbase[input[1]]; d = qpunbase[input[2]]; + /* if it is an invalid, do not decode */ + if (c > 15 || d > 15) luaL_addlstring(buffer, (char *)input, 3); + else luaL_addchar(buffer, (char) ((c << 4) + d)); + return 0; + case '\r': + if (size < 2) return size; + if (input[1] == '\n') luaL_addlstring(buffer, (char *)input, 2); + return 0; + default: + if (input[0] == '\t' || (input[0] > 31 && input[0] < 127)) + luaL_addchar(buffer, input[0]); + return 0; + } +} + +/*-------------------------------------------------------------------------*\ +* Incrementally decodes a string in quoted-printable +* A, B = qp(C, D) +* A is the decoded version of the largest prefix of C .. D that +* can be decoded without doubts. +* B has the remaining bytes of C .. D, *without* decoding. +\*-------------------------------------------------------------------------*/ +static int mime_global_unqp(lua_State *L) +{ + size_t asize = 0, isize = 0; + UC atom[3]; + const UC *input = (UC *) luaL_optlstring(L, 1, NULL, &isize); + const UC *last = input + isize; + luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* make sure we don't confuse buffer stuff with arguments */ + lua_settop(L, 2); + /* process first part of input */ + luaL_buffinit(L, &buffer); + while (input < last) + asize = qpdecode(*input++, atom, asize, &buffer); + input = (UC *) luaL_optlstring(L, 2, NULL, &isize); + /* if second part is nil, we are done */ + if (!input) { + luaL_pushresult(&buffer); + if (!(*lua_tostring(L, -1))) lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* otherwise process rest of input */ + last = input + isize; + while (input < last) + asize = qpdecode(*input++, atom, asize, &buffer); + luaL_pushresult(&buffer); + lua_pushlstring(L, (char *) atom, asize); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Incrementally breaks a quoted-printed string into lines +* A, n = qpwrp(l, B, length) +* A is a copy of B, broken into lines of at most 'length' bytes. +* 'l' is how many bytes are left for the first line of B. +* 'n' is the number of bytes left in the last line of A. +* There are two complications: lines can't be broken in the middle +* of an encoded =XX, and there might be line breaks already +\*-------------------------------------------------------------------------*/ +static int mime_global_qpwrp(lua_State *L) +{ + size_t size = 0; + int left = (int) luaL_checknumber(L, 1); + const UC *input = (UC *) luaL_optlstring(L, 2, NULL, &size); + const UC *last = input + size; + int length = (int) luaL_optnumber(L, 3, 76); + luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + if (left < length) lua_pushstring(L, EQCRLF); + else lua_pushnil(L); + lua_pushnumber(L, length); + return 2; + } + /* process all input */ + luaL_buffinit(L, &buffer); + while (input < last) { + switch (*input) { + case '\r': + break; + case '\n': + left = length; + luaL_addstring(&buffer, CRLF); + break; + case '=': + if (left <= 3) { + left = length; + luaL_addstring(&buffer, EQCRLF); + } + luaL_addchar(&buffer, *input); + left--; + break; + default: + if (left <= 1) { + left = length; + luaL_addstring(&buffer, EQCRLF); + } + luaL_addchar(&buffer, *input); + left--; + break; + } + input++; + } + luaL_pushresult(&buffer); + lua_pushnumber(L, left); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Here is what we do: \n, and \r are considered candidates for line +* break. We issue *one* new line marker if any of them is seen alone, or +* followed by a different one. That is, \n\n and \r\r will issue two +* end of line markers each, but \r\n, \n\r etc will only issue *one* +* marker. This covers Mac OS, Mac OS X, VMS, Unix and DOS, as well as +* probably other more obscure conventions. +* +* c is the current character being processed +* last is the previous character +\*-------------------------------------------------------------------------*/ +#define eolcandidate(c) (c == '\r' || c == '\n') +static int eolprocess(int c, int last, const char *marker, + luaL_Buffer *buffer) +{ + if (eolcandidate(c)) { + if (eolcandidate(last)) { + if (c == last) luaL_addstring(buffer, marker); + return 0; + } else { + luaL_addstring(buffer, marker); + return c; + } + } else { + luaL_addchar(buffer, (char) c); + return 0; + } +} + +/*-------------------------------------------------------------------------*\ +* Converts a string to uniform EOL convention. +* A, n = eol(o, B, marker) +* A is the converted version of the largest prefix of B that can be +* converted unambiguously. 'o' is the context returned by the previous +* call. 'n' is the new context. +\*-------------------------------------------------------------------------*/ +static int mime_global_eol(lua_State *L) +{ + int ctx = luaL_checkint(L, 1); + size_t isize = 0; + const char *input = luaL_optlstring(L, 2, NULL, &isize); + const char *last = input + isize; + const char *marker = luaL_optstring(L, 3, CRLF); + luaL_Buffer buffer; + luaL_buffinit(L, &buffer); + /* end of input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnumber(L, 0); + return 2; + } + /* process all input */ + while (input < last) + ctx = eolprocess(*input++, ctx, marker, &buffer); + luaL_pushresult(&buffer); + lua_pushnumber(L, ctx); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Takes one byte and stuff it if needed. +\*-------------------------------------------------------------------------*/ +static size_t dot(int c, size_t state, luaL_Buffer *buffer) +{ + luaL_addchar(buffer, (char) c); + switch (c) { + case '\r': + return 1; + case '\n': + return (state == 1)? 2: 0; + case '.': + if (state == 2) + luaL_addchar(buffer, '.'); + default: + return 0; + } +} + +/*-------------------------------------------------------------------------*\ +* Incrementally applies smtp stuffing to a string +* A, n = dot(l, D) +\*-------------------------------------------------------------------------*/ +static int mime_global_dot(lua_State *L) +{ + size_t isize = 0, state = (size_t) luaL_checknumber(L, 1); + const char *input = luaL_optlstring(L, 2, NULL, &isize); + const char *last = input + isize; + luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnumber(L, 2); + return 2; + } + /* process all input */ + luaL_buffinit(L, &buffer); + while (input < last) + state = dot(*input++, state, &buffer); + luaL_pushresult(&buffer); + lua_pushnumber(L, (lua_Number) state); + return 2; +} + diff --git a/csrc/socket/src/mime.h b/csrc/socket/src/mime.h new file mode 100644 index 0000000..99968a5 --- /dev/null +++ b/csrc/socket/src/mime.h @@ -0,0 +1,29 @@ +#ifndef MIME_H +#define MIME_H +/*=========================================================================*\ +* Core MIME support +* LuaSocket toolkit +* +* This module provides functions to implement transfer content encodings +* and formatting conforming to RFC 2045. It is used by mime.lua, which +* provide a higher level interface to this functionality. +\*=========================================================================*/ +#include "lua.h" + +/*-------------------------------------------------------------------------*\ +* Current MIME library version +\*-------------------------------------------------------------------------*/ +#define MIME_VERSION "MIME 1.0.3" +#define MIME_COPYRIGHT "Copyright (C) 2004-2013 Diego Nehab" +#define MIME_AUTHORS "Diego Nehab" + +/*-------------------------------------------------------------------------*\ +* This macro prefixes all exported API functions +\*-------------------------------------------------------------------------*/ +#ifndef MIME_API +#define MIME_API extern +#endif + +MIME_API int luaopen_mime_core(lua_State *L); + +#endif /* MIME_H */ diff --git a/csrc/socket/src/options.c b/csrc/socket/src/options.c index 5da3c51..8ac2a14 100644 --- a/csrc/socket/src/options.c +++ b/csrc/socket/src/options.c @@ -1,8 +1,6 @@ /*=========================================================================*\ * Common option interface * LuaSocket toolkit -* -* RCS ID: $Id: options.c,v 1.6 2005/11/20 07:20:23 diego Exp $ \*=========================================================================*/ #include @@ -17,9 +15,15 @@ * Internal functions prototypes \*=========================================================================*/ static int opt_setmembership(lua_State *L, p_socket ps, int level, int name); +static int opt_ip6_setmembership(lua_State *L, p_socket ps, int level, int name); static int opt_setboolean(lua_State *L, p_socket ps, int level, int name); +static int opt_getboolean(lua_State *L, p_socket ps, int level, int name); +static int opt_setint(lua_State *L, p_socket ps, int level, int name); +static int opt_getint(lua_State *L, p_socket ps, int level, int name); static int opt_set(lua_State *L, p_socket ps, int level, int name, void *val, int len); +static int opt_get(lua_State *L, p_socket ps, int level, int name, + void *val, int* len); /*=========================================================================*\ * Exported functions @@ -40,42 +44,116 @@ int opt_meth_setoption(lua_State *L, p_opt opt, p_socket ps) return opt->func(L, ps); } +int opt_meth_getoption(lua_State *L, p_opt opt, p_socket ps) +{ + const char *name = luaL_checkstring(L, 2); /* obj, name, ... */ + while (opt->name && strcmp(name, opt->name)) + opt++; + if (!opt->func) { + char msg[45]; + sprintf(msg, "unsupported option `%.35s'", name); + luaL_argerror(L, 2, msg); + } + return opt->func(L, ps); +} + /* enables reuse of local address */ -int opt_reuseaddr(lua_State *L, p_socket ps) +int opt_set_reuseaddr(lua_State *L, p_socket ps) { return opt_setboolean(L, ps, SOL_SOCKET, SO_REUSEADDR); } +int opt_get_reuseaddr(lua_State *L, p_socket ps) +{ + return opt_getboolean(L, ps, SOL_SOCKET, SO_REUSEADDR); +} + +/* enables reuse of local port */ +int opt_set_reuseport(lua_State *L, p_socket ps) +{ + return opt_setboolean(L, ps, SOL_SOCKET, SO_REUSEPORT); +} + +int opt_get_reuseport(lua_State *L, p_socket ps) +{ + return opt_getboolean(L, ps, SOL_SOCKET, SO_REUSEPORT); +} + /* disables the Naggle algorithm */ -int opt_tcp_nodelay(lua_State *L, p_socket ps) +int opt_set_tcp_nodelay(lua_State *L, p_socket ps) { return opt_setboolean(L, ps, IPPROTO_TCP, TCP_NODELAY); } -int opt_keepalive(lua_State *L, p_socket ps) +int opt_get_tcp_nodelay(lua_State *L, p_socket ps) +{ + return opt_getboolean(L, ps, IPPROTO_TCP, TCP_NODELAY); +} + +int opt_set_keepalive(lua_State *L, p_socket ps) { return opt_setboolean(L, ps, SOL_SOCKET, SO_KEEPALIVE); } -int opt_dontroute(lua_State *L, p_socket ps) +int opt_get_keepalive(lua_State *L, p_socket ps) +{ + return opt_getboolean(L, ps, SOL_SOCKET, SO_KEEPALIVE); +} + +int opt_set_dontroute(lua_State *L, p_socket ps) { return opt_setboolean(L, ps, SOL_SOCKET, SO_DONTROUTE); } -int opt_broadcast(lua_State *L, p_socket ps) +int opt_set_broadcast(lua_State *L, p_socket ps) { return opt_setboolean(L, ps, SOL_SOCKET, SO_BROADCAST); } -int opt_ip_multicast_loop(lua_State *L, p_socket ps) +int opt_set_ip6_unicast_hops(lua_State *L, p_socket ps) +{ + return opt_setint(L, ps, IPPROTO_IPV6, IPV6_UNICAST_HOPS); +} + +int opt_get_ip6_unicast_hops(lua_State *L, p_socket ps) +{ + return opt_getint(L, ps, IPPROTO_IPV6, IPV6_UNICAST_HOPS); +} + +int opt_set_ip6_multicast_hops(lua_State *L, p_socket ps) +{ + return opt_setint(L, ps, IPPROTO_IPV6, IPV6_MULTICAST_HOPS); +} + +int opt_get_ip6_multicast_hops(lua_State *L, p_socket ps) +{ + return opt_getint(L, ps, IPPROTO_IPV6, IPV6_MULTICAST_HOPS); +} + +int opt_set_ip_multicast_loop(lua_State *L, p_socket ps) { return opt_setboolean(L, ps, IPPROTO_IP, IP_MULTICAST_LOOP); } -int opt_linger(lua_State *L, p_socket ps) +int opt_get_ip_multicast_loop(lua_State *L, p_socket ps) +{ + return opt_getboolean(L, ps, IPPROTO_IP, IP_MULTICAST_LOOP); +} + +int opt_set_ip6_multicast_loop(lua_State *L, p_socket ps) +{ + return opt_setboolean(L, ps, IPPROTO_IPV6, IPV6_MULTICAST_LOOP); +} + +int opt_get_ip6_multicast_loop(lua_State *L, p_socket ps) +{ + return opt_getboolean(L, ps, IPPROTO_IPV6, IPV6_MULTICAST_LOOP); +} + +int opt_set_linger(lua_State *L, p_socket ps) { struct linger li; /* obj, name, table */ - if (!lua_istable(L, 3)) luaL_typerror(L, 3, lua_typename(L, LUA_TTABLE)); + if (!lua_istable(L, 3)) auxiliar_typeerror(L,3,lua_typename(L, LUA_TTABLE)); lua_pushstring(L, "on"); lua_gettable(L, 3); if (!lua_isboolean(L, -1)) @@ -89,29 +167,87 @@ int opt_linger(lua_State *L, p_socket ps) return opt_set(L, ps, SOL_SOCKET, SO_LINGER, (char *) &li, sizeof(li)); } -int opt_ip_multicast_ttl(lua_State *L, p_socket ps) +int opt_get_linger(lua_State *L, p_socket ps) { - int val = (int) luaL_checknumber(L, 3); /* obj, name, int */ - return opt_set(L, ps, SOL_SOCKET, SO_LINGER, (char *) &val, sizeof(val)); + struct linger li; /* obj, name */ + int len = sizeof(li); + int err = opt_get(L, ps, SOL_SOCKET, SO_LINGER, (char *) &li, &len); + if (err) + return err; + lua_newtable(L); + lua_pushboolean(L, li.l_onoff); + lua_setfield(L, -2, "on"); + lua_pushinteger(L, li.l_linger); + lua_setfield(L, -2, "timeout"); + return 1; } -int opt_ip_add_membership(lua_State *L, p_socket ps) +int opt_set_ip_multicast_ttl(lua_State *L, p_socket ps) +{ + return opt_setint(L, ps, IPPROTO_IP, IP_MULTICAST_TTL); +} + +int opt_set_ip_multicast_if(lua_State *L, p_socket ps) +{ + const char *address = luaL_checkstring(L, 3); /* obj, name, ip */ + struct in_addr val; + val.s_addr = htonl(INADDR_ANY); + if (strcmp(address, "*") && !inet_aton(address, &val)) + luaL_argerror(L, 3, "ip expected"); + return opt_set(L, ps, IPPROTO_IP, IP_MULTICAST_IF, + (char *) &val, sizeof(val)); +} + +int opt_get_ip_multicast_if(lua_State *L, p_socket ps) +{ + struct in_addr val; + socklen_t len = sizeof(val); + if (getsockopt(*ps, IPPROTO_IP, IP_MULTICAST_IF, (char *) &val, &len) < 0) { + lua_pushnil(L); + lua_pushstring(L, "getsockopt failed"); + return 2; + } + lua_pushstring(L, inet_ntoa(val)); + return 1; +} + +int opt_set_ip_add_membership(lua_State *L, p_socket ps) { return opt_setmembership(L, ps, IPPROTO_IP, IP_ADD_MEMBERSHIP); } -int opt_ip_drop_membersip(lua_State *L, p_socket ps) +int opt_set_ip_drop_membersip(lua_State *L, p_socket ps) { return opt_setmembership(L, ps, IPPROTO_IP, IP_DROP_MEMBERSHIP); } +int opt_set_ip6_add_membership(lua_State *L, p_socket ps) +{ + return opt_ip6_setmembership(L, ps, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP); +} + +int opt_set_ip6_drop_membersip(lua_State *L, p_socket ps) +{ + return opt_ip6_setmembership(L, ps, IPPROTO_IPV6, IPV6_DROP_MEMBERSHIP); +} + +int opt_get_ip6_v6only(lua_State *L, p_socket ps) +{ + return opt_getboolean(L, ps, IPPROTO_IPV6, IPV6_V6ONLY); +} + +int opt_set_ip6_v6only(lua_State *L, p_socket ps) +{ + return opt_setboolean(L, ps, IPPROTO_IPV6, IPV6_V6ONLY); +} + /*=========================================================================*\ * Auxiliar functions \*=========================================================================*/ static int opt_setmembership(lua_State *L, p_socket ps, int level, int name) { struct ip_mreq val; /* obj, name, table */ - if (!lua_istable(L, 3)) luaL_typerror(L, 3, lua_typename(L, LUA_TTABLE)); + if (!lua_istable(L, 3)) auxiliar_typeerror(L,3,lua_typename(L, LUA_TTABLE)); lua_pushstring(L, "multiaddr"); lua_gettable(L, 3); if (!lua_isstring(L, -1)) @@ -129,6 +265,45 @@ static int opt_setmembership(lua_State *L, p_socket ps, int level, int name) return opt_set(L, ps, level, name, (char *) &val, sizeof(val)); } +static int opt_ip6_setmembership(lua_State *L, p_socket ps, int level, int name) +{ + struct ipv6_mreq val; /* obj, opt-name, table */ + memset(&val, 0, sizeof(val)); + if (!lua_istable(L, 3)) auxiliar_typeerror(L,3,lua_typename(L, LUA_TTABLE)); + lua_pushstring(L, "multiaddr"); + lua_gettable(L, 3); + if (!lua_isstring(L, -1)) + luaL_argerror(L, 3, "string 'multiaddr' field expected"); + if (!inet_pton(AF_INET6, lua_tostring(L, -1), &val.ipv6mr_multiaddr)) + luaL_argerror(L, 3, "invalid 'multiaddr' ip address"); + lua_pushstring(L, "interface"); + lua_gettable(L, 3); + /* By default we listen to interface on default route + * (sigh). However, interface= can override it. We should + * support either number, or name for it. Waiting for + * windows port of if_nametoindex */ + if (!lua_isnil(L, -1)) { + if (lua_isnumber(L, -1)) { + val.ipv6mr_interface = (unsigned int) lua_tonumber(L, -1); + } else + luaL_argerror(L, -1, "number 'interface' field expected"); + } + return opt_set(L, ps, level, name, (char *) &val, sizeof(val)); +} + +static +int opt_get(lua_State *L, p_socket ps, int level, int name, void *val, int* len) +{ + socklen_t socklen = *len; + if (getsockopt(*ps, level, name, (char *) val, &socklen) < 0) { + lua_pushnil(L); + lua_pushstring(L, "getsockopt failed"); + return 2; + } + *len = socklen; + return 0; +} + static int opt_set(lua_State *L, p_socket ps, int level, int name, void *val, int len) { @@ -141,9 +316,49 @@ int opt_set(lua_State *L, p_socket ps, int level, int name, void *val, int len) return 1; } +static int opt_getboolean(lua_State *L, p_socket ps, int level, int name) +{ + int val = 0; + int len = sizeof(val); + int err = opt_get(L, ps, level, name, (char *) &val, &len); + if (err) + return err; + lua_pushboolean(L, val); + return 1; +} + +int opt_get_error(lua_State *L, p_socket ps) +{ + int val = 0; + socklen_t len = sizeof(val); + if (getsockopt(*ps, SOL_SOCKET, SO_ERROR, (char *) &val, &len) < 0) { + lua_pushnil(L); + lua_pushstring(L, "getsockopt failed"); + return 2; + } + lua_pushstring(L, socket_strerror(val)); + return 1; +} + static int opt_setboolean(lua_State *L, p_socket ps, int level, int name) { int val = auxiliar_checkboolean(L, 3); /* obj, name, bool */ return opt_set(L, ps, level, name, (char *) &val, sizeof(val)); } +static int opt_getint(lua_State *L, p_socket ps, int level, int name) +{ + int val = 0; + int len = sizeof(val); + int err = opt_get(L, ps, level, name, (char *) &val, &len); + if (err) + return err; + lua_pushnumber(L, val); + return 1; +} + +static int opt_setint(lua_State *L, p_socket ps, int level, int name) +{ + int val = (int) lua_tonumber(L, 3); /* obj, name, int */ + return opt_set(L, ps, level, name, (char *) &val, sizeof(val)); +} diff --git a/csrc/socket/src/options.h b/csrc/socket/src/options.h index 4981cf2..5657a06 100644 --- a/csrc/socket/src/options.h +++ b/csrc/socket/src/options.h @@ -6,8 +6,6 @@ * * This module provides a common interface to socket options, used mainly by * modules UDP and TCP. -* -* RCS ID: $Id: options.h,v 1.4 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include "lua.h" @@ -20,20 +18,43 @@ typedef struct t_opt { } t_opt; typedef t_opt *p_opt; -/* supported options */ -int opt_dontroute(lua_State *L, p_socket ps); -int opt_broadcast(lua_State *L, p_socket ps); -int opt_reuseaddr(lua_State *L, p_socket ps); -int opt_tcp_nodelay(lua_State *L, p_socket ps); -int opt_keepalive(lua_State *L, p_socket ps); -int opt_linger(lua_State *L, p_socket ps); -int opt_reuseaddr(lua_State *L, p_socket ps); -int opt_ip_multicast_ttl(lua_State *L, p_socket ps); -int opt_ip_multicast_loop(lua_State *L, p_socket ps); -int opt_ip_add_membership(lua_State *L, p_socket ps); -int opt_ip_drop_membersip(lua_State *L, p_socket ps); +/* supported options for setoption */ +int opt_set_dontroute(lua_State *L, p_socket ps); +int opt_set_broadcast(lua_State *L, p_socket ps); +int opt_set_reuseaddr(lua_State *L, p_socket ps); +int opt_set_tcp_nodelay(lua_State *L, p_socket ps); +int opt_set_keepalive(lua_State *L, p_socket ps); +int opt_set_linger(lua_State *L, p_socket ps); +int opt_set_reuseaddr(lua_State *L, p_socket ps); +int opt_set_reuseport(lua_State *L, p_socket ps); +int opt_set_ip_multicast_if(lua_State *L, p_socket ps); +int opt_set_ip_multicast_ttl(lua_State *L, p_socket ps); +int opt_set_ip_multicast_loop(lua_State *L, p_socket ps); +int opt_set_ip_add_membership(lua_State *L, p_socket ps); +int opt_set_ip_drop_membersip(lua_State *L, p_socket ps); +int opt_set_ip6_unicast_hops(lua_State *L, p_socket ps); +int opt_set_ip6_multicast_hops(lua_State *L, p_socket ps); +int opt_set_ip6_multicast_loop(lua_State *L, p_socket ps); +int opt_set_ip6_add_membership(lua_State *L, p_socket ps); +int opt_set_ip6_drop_membersip(lua_State *L, p_socket ps); +int opt_set_ip6_v6only(lua_State *L, p_socket ps); + +/* supported options for getoption */ +int opt_get_reuseaddr(lua_State *L, p_socket ps); +int opt_get_tcp_nodelay(lua_State *L, p_socket ps); +int opt_get_keepalive(lua_State *L, p_socket ps); +int opt_get_linger(lua_State *L, p_socket ps); +int opt_get_reuseaddr(lua_State *L, p_socket ps); +int opt_get_ip_multicast_loop(lua_State *L, p_socket ps); +int opt_get_ip_multicast_if(lua_State *L, p_socket ps); +int opt_get_error(lua_State *L, p_socket ps); +int opt_get_ip6_multicast_loop(lua_State *L, p_socket ps); +int opt_get_ip6_multicast_hops(lua_State *L, p_socket ps); +int opt_get_ip6_unicast_hops(lua_State *L, p_socket ps); +int opt_get_ip6_v6only(lua_State *L, p_socket ps); /* invokes the appropriate option handler */ int opt_meth_setoption(lua_State *L, p_opt opt, p_socket ps); +int opt_meth_getoption(lua_State *L, p_opt opt, p_socket ps); #endif diff --git a/csrc/socket/src/select.c b/csrc/socket/src/select.c index d70f662..fafaa62 100644 --- a/csrc/socket/src/select.c +++ b/csrc/socket/src/select.c @@ -1,8 +1,6 @@ /*=========================================================================*\ * Select implementation * LuaSocket toolkit -* -* RCS ID: $Id: select.c,v 1.22 2005/11/20 07:20:23 diego Exp $ \*=========================================================================*/ #include @@ -18,8 +16,8 @@ \*=========================================================================*/ static t_socket getfd(lua_State *L); static int dirty(lua_State *L); -static t_socket collect_fd(lua_State *L, int tab, t_socket max_fd, - int itab, fd_set *set); +static void collect_fd(lua_State *L, int tab, int itab, + fd_set *set, t_socket *max_fd); static int check_dirty(lua_State *L, int tab, int dtab, fd_set *set); static void return_fd(lua_State *L, fd_set *set, t_socket max_fd, int itab, int tab, int start); @@ -27,7 +25,7 @@ static void make_assoc(lua_State *L, int tab); static int global_select(lua_State *L); /* functions in library namespace */ -static luaL_reg func[] = { +static luaL_Reg func[] = { {"select", global_select}, {NULL, NULL} }; @@ -39,7 +37,14 @@ static luaL_reg func[] = { * Initializes module \*-------------------------------------------------------------------------*/ int select_open(lua_State *L) { + lua_pushstring(L, "_SETSIZE"); + lua_pushnumber(L, FD_SETSIZE); + lua_rawset(L, -3); +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + luaL_setfuncs(L, func, 0); +#else luaL_openlib(L, NULL, func, 0); +#endif return 0; } @@ -51,7 +56,7 @@ int select_open(lua_State *L) { \*-------------------------------------------------------------------------*/ static int global_select(lua_State *L) { int rtab, wtab, itab, ret, ndirty; - t_socket max_fd; + t_socket max_fd = SOCKET_INVALID; fd_set rset, wset; t_timeout tm; double t = luaL_optnumber(L, 3, -1); @@ -60,12 +65,12 @@ static int global_select(lua_State *L) { lua_newtable(L); itab = lua_gettop(L); lua_newtable(L); rtab = lua_gettop(L); lua_newtable(L); wtab = lua_gettop(L); - max_fd = collect_fd(L, 1, SOCKET_INVALID, itab, &rset); + collect_fd(L, 1, itab, &rset, &max_fd); + collect_fd(L, 2, itab, &wset, &max_fd); ndirty = check_dirty(L, 1, rtab, &rset); t = ndirty > 0? 0.0: t; timeout_init(&tm, t, -1); timeout_markstart(&tm); - max_fd = collect_fd(L, 2, max_fd, itab, &wset); ret = socket_select(max_fd+1, &rset, &wset, NULL, &tm); if (ret > 0 || ndirty > 0) { return_fd(L, &rset, max_fd+1, itab, rtab, ndirty); @@ -77,7 +82,7 @@ static int global_select(lua_State *L) { lua_pushstring(L, "timeout"); return 3; } else { - lua_pushstring(L, "error"); + luaL_error(L, "select failed"); return 3; } } @@ -92,8 +97,10 @@ static t_socket getfd(lua_State *L) { if (!lua_isnil(L, -1)) { lua_pushvalue(L, -2); lua_call(L, 1, 1); - if (lua_isnumber(L, -1)) - fd = (t_socket) lua_tonumber(L, -1); + if (lua_isnumber(L, -1)) { + double numfd = lua_tonumber(L, -1); + fd = (numfd >= 0.0)? (t_socket) numfd: SOCKET_INVALID; + } } lua_pop(L, 1); return fd; @@ -112,12 +119,14 @@ static int dirty(lua_State *L) { return is; } -static t_socket collect_fd(lua_State *L, int tab, t_socket max_fd, - int itab, fd_set *set) { - int i = 1; - if (lua_isnil(L, tab)) - return max_fd; - while (1) { +static void collect_fd(lua_State *L, int tab, int itab, + fd_set *set, t_socket *max_fd) { + int i = 1, n = 0; + /* nil is the same as an empty table */ + if (lua_isnil(L, tab)) return; + /* otherwise we need it to be a table */ + luaL_checktype(L, tab, LUA_TTABLE); + for ( ;; ) { t_socket fd; lua_pushnumber(L, i); lua_gettable(L, tab); @@ -125,26 +134,37 @@ static t_socket collect_fd(lua_State *L, int tab, t_socket max_fd, lua_pop(L, 1); break; } + /* getfd figures out if this is a socket */ fd = getfd(L); if (fd != SOCKET_INVALID) { + /* make sure we don't overflow the fd_set */ +#ifdef _WIN32 + if (n >= FD_SETSIZE) + luaL_argerror(L, tab, "too many sockets"); +#else + if (fd >= FD_SETSIZE) + luaL_argerror(L, tab, "descriptor too large for set size"); +#endif FD_SET(fd, set); - if (max_fd == SOCKET_INVALID || max_fd < fd) - max_fd = fd; - lua_pushnumber(L, fd); + n++; + /* keep track of the largest descriptor so far */ + if (*max_fd == SOCKET_INVALID || *max_fd < fd) + *max_fd = fd; + /* make sure we can map back from descriptor to the object */ + lua_pushnumber(L, (lua_Number) fd); lua_pushvalue(L, -2); lua_settable(L, itab); } lua_pop(L, 1); i = i + 1; } - return max_fd; } static int check_dirty(lua_State *L, int tab, int dtab, fd_set *set) { int ndirty = 0, i = 1; if (lua_isnil(L, tab)) return 0; - while (1) { + for ( ;; ) { t_socket fd; lua_pushnumber(L, i); lua_gettable(L, tab); @@ -171,7 +191,7 @@ static void return_fd(lua_State *L, fd_set *set, t_socket max_fd, for (fd = 0; fd < max_fd; fd++) { if (FD_ISSET(fd, set)) { lua_pushnumber(L, ++start); - lua_pushnumber(L, fd); + lua_pushnumber(L, (lua_Number) fd); lua_gettable(L, itab); lua_settable(L, tab); } @@ -181,7 +201,7 @@ static void return_fd(lua_State *L, fd_set *set, t_socket max_fd, static void make_assoc(lua_State *L, int tab) { int i = 1, atab; lua_newtable(L); atab = lua_gettop(L); - while (1) { + for ( ;; ) { lua_pushnumber(L, i); lua_gettable(L, tab); if (!lua_isnil(L, -1)) { diff --git a/csrc/socket/src/select.h b/csrc/socket/src/select.h index aa3db4a..8750200 100644 --- a/csrc/socket/src/select.h +++ b/csrc/socket/src/select.h @@ -8,8 +8,6 @@ * method getfd() which returns the descriptor to be passed to the * underlying select function. Another method, dirty(), should return * true if there is data ready for reading (required for buffered input). -* -* RCS ID: $Id: select.h,v 1.7 2004/06/16 01:02:07 diego Exp $ \*=========================================================================*/ int select_open(lua_State *L); diff --git a/csrc/socket/src/serial.c b/csrc/socket/src/serial.c new file mode 100644 index 0000000..583d4e5 --- /dev/null +++ b/csrc/socket/src/serial.c @@ -0,0 +1,188 @@ +/*=========================================================================*\ +* Serial stream +* LuaSocket toolkit +\*=========================================================================*/ +#include + +#include "lua.h" +#include "lauxlib.h" + +#include "auxiliar.h" +#include "socket.h" +#include "options.h" +#include "unix.h" +#include + +/* +Reuses userdata definition from unix.h, since it is useful for all +stream-like objects. + +If we stored the serial path for use in error messages or userdata +printing, we might need our own userdata definition. + +Group usage is semi-inherited from unix.c, but unnecessary since we +have only one object type. +*/ + +/*=========================================================================*\ +* Internal function prototypes +\*=========================================================================*/ +static int global_create(lua_State *L); +static int meth_send(lua_State *L); +static int meth_receive(lua_State *L); +static int meth_close(lua_State *L); +static int meth_settimeout(lua_State *L); +static int meth_getfd(lua_State *L); +static int meth_setfd(lua_State *L); +static int meth_dirty(lua_State *L); +static int meth_getstats(lua_State *L); +static int meth_setstats(lua_State *L); + +/* serial object methods */ +static luaL_Reg serial_methods[] = { + {"__gc", meth_close}, + {"__tostring", auxiliar_tostring}, + {"close", meth_close}, + {"dirty", meth_dirty}, + {"getfd", meth_getfd}, + {"getstats", meth_getstats}, + {"setstats", meth_setstats}, + {"receive", meth_receive}, + {"send", meth_send}, + {"setfd", meth_setfd}, + {"settimeout", meth_settimeout}, + {NULL, NULL} +}; + +/* our socket creation function */ +/* this is an ad-hoc module that returns a single function + * as such, do not include other functions in this array. */ +static luaL_Reg func[] = { + {"serial", global_create}, + {NULL, NULL} +}; + + +/*-------------------------------------------------------------------------*\ +* Initializes module +\*-------------------------------------------------------------------------*/ +LUASOCKET_API int luaopen_socket_serial(lua_State *L) { + /* create classes */ + auxiliar_newclass(L, "serial{client}", serial_methods); + /* create class groups */ + auxiliar_add2group(L, "serial{client}", "serial{any}"); +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + lua_pushcfunction(L, global_create); + (void) func; +#else + /* set function into socket namespace */ + luaL_openlib(L, "socket", func, 0); + lua_pushcfunction(L, global_create); +#endif + return 1; +} + +/*=========================================================================*\ +* Lua methods +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Just call buffered IO methods +\*-------------------------------------------------------------------------*/ +static int meth_send(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkclass(L, "serial{client}", 1); + return buffer_meth_send(L, &un->buf); +} + +static int meth_receive(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkclass(L, "serial{client}", 1); + return buffer_meth_receive(L, &un->buf); +} + +static int meth_getstats(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkclass(L, "serial{client}", 1); + return buffer_meth_getstats(L, &un->buf); +} + +static int meth_setstats(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkclass(L, "serial{client}", 1); + return buffer_meth_setstats(L, &un->buf); +} + +/*-------------------------------------------------------------------------*\ +* Select support methods +\*-------------------------------------------------------------------------*/ +static int meth_getfd(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkgroup(L, "serial{any}", 1); + lua_pushnumber(L, (int) un->sock); + return 1; +} + +/* this is very dangerous, but can be handy for those that are brave enough */ +static int meth_setfd(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkgroup(L, "serial{any}", 1); + un->sock = (t_socket) luaL_checknumber(L, 2); + return 0; +} + +static int meth_dirty(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkgroup(L, "serial{any}", 1); + lua_pushboolean(L, !buffer_isempty(&un->buf)); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Closes socket used by object +\*-------------------------------------------------------------------------*/ +static int meth_close(lua_State *L) +{ + p_unix un = (p_unix) auxiliar_checkgroup(L, "serial{any}", 1); + socket_destroy(&un->sock); + lua_pushnumber(L, 1); + return 1; +} + + +/*-------------------------------------------------------------------------*\ +* Just call tm methods +\*-------------------------------------------------------------------------*/ +static int meth_settimeout(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkgroup(L, "serial{any}", 1); + return timeout_meth_settimeout(L, &un->tm); +} + +/*=========================================================================*\ +* Library functions +\*=========================================================================*/ + + +/*-------------------------------------------------------------------------*\ +* Creates a serial object +\*-------------------------------------------------------------------------*/ +static int global_create(lua_State *L) { + const char* path = luaL_checkstring(L, 1); + + /* allocate unix object */ + p_unix un = (p_unix) lua_newuserdata(L, sizeof(t_unix)); + + /* open serial device */ + t_socket sock = open(path, O_NOCTTY|O_RDWR); + + /*printf("open %s on %d\n", path, sock);*/ + + if (sock < 0) { + lua_pushnil(L); + lua_pushstring(L, socket_strerror(errno)); + lua_pushnumber(L, errno); + return 3; + } + /* set its type as client object */ + auxiliar_setclass(L, "serial{client}", -1); + /* initialize remaining structure fields */ + socket_setnonblocking(&sock); + un->sock = sock; + io_init(&un->io, (p_send) socket_write, (p_recv) socket_read, + (p_error) socket_ioerror, &un->sock); + timeout_init(&un->tm, -1, -1); + buffer_init(&un->buf, &un->io, &un->tm); + return 1; +} diff --git a/csrc/socket/src/socket.h b/csrc/socket/src/socket.h index 656c7f5..63573de 100644 --- a/csrc/socket/src/socket.h +++ b/csrc/socket/src/socket.h @@ -8,8 +8,6 @@ * differences. Also, not all *nix platforms behave the same. This module * (and the associated usocket.h and wsocket.h) factor these differences and * creates a interface compatible with the io.h module. -* -* RCS ID: $Id: socket.h,v 1.20 2005/11/20 07:20:23 diego Exp $ \*=========================================================================*/ #include "io.h" @@ -61,6 +59,7 @@ int socket_accept(p_socket ps, p_socket pa, SA *addr, socklen_t *addr_len, p_timeout tm); const char *socket_hoststrerror(int err); +const char *socket_gaistrerror(int err); const char *socket_strerror(int err); /* these are perfect to use with the io abstraction module @@ -68,6 +67,9 @@ const char *socket_strerror(int err); int socket_send(p_socket ps, const char *data, size_t count, size_t *sent, p_timeout tm); int socket_recv(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm); +int socket_write(p_socket ps, const char *data, size_t count, + size_t *sent, p_timeout tm); +int socket_read(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm); const char *socket_ioerror(p_socket ps, int err); int socket_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp); diff --git a/csrc/socket/src/tcp.c b/csrc/socket/src/tcp.c index 6b8a79b..6594bda 100644 --- a/csrc/socket/src/tcp.c +++ b/csrc/socket/src/tcp.c @@ -1,10 +1,8 @@ /*=========================================================================*\ -* TCP object +* TCP object * LuaSocket toolkit -* -* RCS ID: $Id: tcp.c,v 1.41 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ -#include +#include #include "lua.h" #include "lauxlib.h" @@ -19,8 +17,11 @@ * Internal function prototypes \*=========================================================================*/ static int global_create(lua_State *L); +static int global_create6(lua_State *L); +static int global_connect(lua_State *L); static int meth_connect(lua_State *L); static int meth_listen(lua_State *L); +static int meth_getfamily(lua_State *L); static int meth_bind(lua_State *L); static int meth_send(lua_State *L); static int meth_getstats(lua_State *L); @@ -31,6 +32,7 @@ static int meth_shutdown(lua_State *L); static int meth_receive(lua_State *L); static int meth_accept(lua_State *L); static int meth_close(lua_State *L); +static int meth_getoption(lua_State *L); static int meth_setoption(lua_State *L); static int meth_settimeout(lua_State *L); static int meth_getfd(lua_State *L); @@ -38,7 +40,7 @@ static int meth_setfd(lua_State *L); static int meth_dirty(lua_State *L); /* tcp object methods */ -static luaL_reg tcp[] = { +static luaL_Reg tcp_methods[] = { {"__gc", meth_close}, {"__tostring", auxiliar_tostring}, {"accept", meth_accept}, @@ -46,7 +48,9 @@ static luaL_reg tcp[] = { {"close", meth_close}, {"connect", meth_connect}, {"dirty", meth_dirty}, + {"getfamily", meth_getfamily}, {"getfd", meth_getfd}, + {"getoption", meth_getoption}, {"getpeername", meth_getpeername}, {"getsockname", meth_getsockname}, {"getstats", meth_getstats}, @@ -64,17 +68,29 @@ static luaL_reg tcp[] = { }; /* socket option handlers */ -static t_opt opt[] = { - {"keepalive", opt_keepalive}, - {"reuseaddr", opt_reuseaddr}, - {"tcp-nodelay", opt_tcp_nodelay}, - {"linger", opt_linger}, +static t_opt optget[] = { + {"keepalive", opt_get_keepalive}, + {"reuseaddr", opt_get_reuseaddr}, + {"tcp-nodelay", opt_get_tcp_nodelay}, + {"linger", opt_get_linger}, + {"error", opt_get_error}, + {NULL, NULL} +}; + +static t_opt optset[] = { + {"keepalive", opt_set_keepalive}, + {"reuseaddr", opt_set_reuseaddr}, + {"tcp-nodelay", opt_set_tcp_nodelay}, + {"ipv6-v6only", opt_set_ip6_v6only}, + {"linger", opt_set_linger}, {NULL, NULL} }; /* functions in library namespace */ -static luaL_reg func[] = { +static luaL_Reg func[] = { {"tcp", global_create}, + {"tcp6", global_create6}, + {"connect", global_connect}, {NULL, NULL} }; @@ -84,15 +100,19 @@ static luaL_reg func[] = { int tcp_open(lua_State *L) { /* create classes */ - auxiliar_newclass(L, "tcp{master}", tcp); - auxiliar_newclass(L, "tcp{client}", tcp); - auxiliar_newclass(L, "tcp{server}", tcp); + auxiliar_newclass(L, "tcp{master}", tcp_methods); + auxiliar_newclass(L, "tcp{client}", tcp_methods); + auxiliar_newclass(L, "tcp{server}", tcp_methods); /* create class groups */ auxiliar_add2group(L, "tcp{master}", "tcp{any}"); auxiliar_add2group(L, "tcp{client}", "tcp{any}"); auxiliar_add2group(L, "tcp{server}", "tcp{any}"); /* define library functions */ - luaL_openlib(L, NULL, func, 0); +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + luaL_setfuncs(L, func, 0); +#else + luaL_openlib(L, NULL, func, 0); +#endif return 0; } @@ -125,10 +145,16 @@ static int meth_setstats(lua_State *L) { /*-------------------------------------------------------------------------*\ * Just call option handler \*-------------------------------------------------------------------------*/ +static int meth_getoption(lua_State *L) +{ + p_tcp tcp = (p_tcp) auxiliar_checkgroup(L, "tcp{any}", 1); + return opt_meth_getoption(L, optget, &tcp->sock); +} + static int meth_setoption(lua_State *L) { p_tcp tcp = (p_tcp) auxiliar_checkgroup(L, "tcp{any}", 1); - return opt_meth_setoption(L, opt, &tcp->sock); + return opt_meth_setoption(L, optset, &tcp->sock); } /*-------------------------------------------------------------------------*\ @@ -145,7 +171,7 @@ static int meth_getfd(lua_State *L) static int meth_setfd(lua_State *L) { p_tcp tcp = (p_tcp) auxiliar_checkgroup(L, "tcp{any}", 1); - tcp->sock = (t_socket) luaL_checknumber(L, 2); + tcp->sock = (t_socket) luaL_checknumber(L, 2); return 0; } @@ -157,43 +183,51 @@ static int meth_dirty(lua_State *L) } /*-------------------------------------------------------------------------*\ -* Waits for and returns a client object attempting connection to the -* server object +* Waits for and returns a client object attempting connection to the +* server object \*-------------------------------------------------------------------------*/ static int meth_accept(lua_State *L) { p_tcp server = (p_tcp) auxiliar_checkclass(L, "tcp{server}", 1); p_timeout tm = timeout_markstart(&server->tm); t_socket sock; - int err = socket_accept(&server->sock, &sock, NULL, NULL, tm); + const char *err = inet_tryaccept(&server->sock, server->family, &sock, tm); /* if successful, push client socket */ - if (err == IO_DONE) { + if (err == NULL) { p_tcp clnt = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); auxiliar_setclass(L, "tcp{client}", -1); /* initialize structure fields */ + memset(clnt, 0, sizeof(t_tcp)); socket_setnonblocking(&sock); clnt->sock = sock; - io_init(&clnt->io, (p_send) socket_send, (p_recv) socket_recv, + io_init(&clnt->io, (p_send) socket_send, (p_recv) socket_recv, (p_error) socket_ioerror, &clnt->sock); timeout_init(&clnt->tm, -1, -1); buffer_init(&clnt->buf, &clnt->io, &clnt->tm); + clnt->family = server->family; return 1; } else { - lua_pushnil(L); - lua_pushstring(L, socket_strerror(err)); + lua_pushnil(L); + lua_pushstring(L, err); return 2; } } /*-------------------------------------------------------------------------*\ -* Binds an object to an address +* Binds an object to an address \*-------------------------------------------------------------------------*/ static int meth_bind(lua_State *L) { p_tcp tcp = (p_tcp) auxiliar_checkclass(L, "tcp{master}", 1); const char *address = luaL_checkstring(L, 2); - unsigned short port = (unsigned short) luaL_checknumber(L, 3); - const char *err = inet_trybind(&tcp->sock, address, port); + const char *port = luaL_checkstring(L, 3); + const char *err; + struct addrinfo bindhints; + memset(&bindhints, 0, sizeof(bindhints)); + bindhints.ai_socktype = SOCK_STREAM; + bindhints.ai_family = tcp->family; + bindhints.ai_flags = AI_PASSIVE; + err = inet_trybind(&tcp->sock, address, port, &bindhints); if (err) { lua_pushnil(L); lua_pushstring(L, err); @@ -210,9 +244,16 @@ static int meth_connect(lua_State *L) { p_tcp tcp = (p_tcp) auxiliar_checkgroup(L, "tcp{any}", 1); const char *address = luaL_checkstring(L, 2); - unsigned short port = (unsigned short) luaL_checknumber(L, 3); - p_timeout tm = timeout_markstart(&tcp->tm); - const char *err = inet_tryconnect(&tcp->sock, address, port, tm); + const char *port = luaL_checkstring(L, 3); + struct addrinfo connecthints; + const char *err; + memset(&connecthints, 0, sizeof(connecthints)); + connecthints.ai_socktype = SOCK_STREAM; + /* make sure we try to connect only to the same family */ + connecthints.ai_family = tcp->family; + timeout_markstart(&tcp->tm); + err = inet_tryconnect(&tcp->sock, &tcp->family, address, port, + &tcp->tm, &connecthints); /* have to set the class even if it failed due to non-blocking connects */ auxiliar_setclass(L, "tcp{client}", 1); if (err) { @@ -220,13 +261,12 @@ static int meth_connect(lua_State *L) lua_pushstring(L, err); return 2; } - /* turn master object into a client object */ lua_pushnumber(L, 1); return 1; } /*-------------------------------------------------------------------------*\ -* Closes socket used by object +* Closes socket used by object \*-------------------------------------------------------------------------*/ static int meth_close(lua_State *L) { @@ -236,6 +276,21 @@ static int meth_close(lua_State *L) return 1; } +/*-------------------------------------------------------------------------*\ +* Returns family as string +\*-------------------------------------------------------------------------*/ +static int meth_getfamily(lua_State *L) +{ + p_tcp tcp = (p_tcp) auxiliar_checkgroup(L, "tcp{any}", 1); + if (tcp->family == PF_INET6) { + lua_pushliteral(L, "inet6"); + return 1; + } else { + lua_pushliteral(L, "inet4"); + return 1; + } +} + /*-------------------------------------------------------------------------*\ * Puts the sockt in listen mode \*-------------------------------------------------------------------------*/ @@ -260,27 +315,13 @@ static int meth_listen(lua_State *L) \*-------------------------------------------------------------------------*/ static int meth_shutdown(lua_State *L) { + /* SHUT_RD, SHUT_WR, SHUT_RDWR have the value 0, 1, 2, so we can use method index directly */ + static const char* methods[] = { "receive", "send", "both", NULL }; p_tcp tcp = (p_tcp) auxiliar_checkclass(L, "tcp{client}", 1); - const char *how = luaL_optstring(L, 2, "both"); - switch (how[0]) { - case 'b': - if (strcmp(how, "both")) goto error; - socket_shutdown(&tcp->sock, 2); - break; - case 's': - if (strcmp(how, "send")) goto error; - socket_shutdown(&tcp->sock, 1); - break; - case 'r': - if (strcmp(how, "receive")) goto error; - socket_shutdown(&tcp->sock, 0); - break; - } + int how = luaL_checkoption(L, 2, "both", methods); + socket_shutdown(&tcp->sock, how); lua_pushnumber(L, 1); return 1; -error: - luaL_argerror(L, 2, "invalid shutdown method"); - return 0; } /*-------------------------------------------------------------------------*\ @@ -289,13 +330,13 @@ error: static int meth_getpeername(lua_State *L) { p_tcp tcp = (p_tcp) auxiliar_checkgroup(L, "tcp{any}", 1); - return inet_meth_getpeername(L, &tcp->sock); + return inet_meth_getpeername(L, &tcp->sock, tcp->family); } static int meth_getsockname(lua_State *L) { p_tcp tcp = (p_tcp) auxiliar_checkgroup(L, "tcp{any}", 1); - return inet_meth_getsockname(L, &tcp->sock); + return inet_meth_getsockname(L, &tcp->sock, tcp->family); } /*-------------------------------------------------------------------------*\ @@ -311,25 +352,31 @@ static int meth_settimeout(lua_State *L) * Library functions \*=========================================================================*/ /*-------------------------------------------------------------------------*\ -* Creates a master tcp object +* Creates a master tcp object \*-------------------------------------------------------------------------*/ -static int global_create(lua_State *L) -{ +static int tcp_create(lua_State *L, int family) { t_socket sock; - const char *err = inet_trycreate(&sock, SOCK_STREAM); + const char *err = inet_trycreate(&sock, family, SOCK_STREAM); /* try to allocate a system socket */ - if (!err) { + if (!err) { /* allocate tcp object */ p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + memset(tcp, 0, sizeof(t_tcp)); /* set its type as master object */ auxiliar_setclass(L, "tcp{master}", -1); /* initialize remaining structure fields */ socket_setnonblocking(&sock); + if (family == PF_INET6) { + int yes = 1; + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + (void *)&yes, sizeof(yes)); + } tcp->sock = sock; - io_init(&tcp->io, (p_send) socket_send, (p_recv) socket_recv, + io_init(&tcp->io, (p_send) socket_send, (p_recv) socket_recv, (p_error) socket_ioerror, &tcp->sock); timeout_init(&tcp->tm, -1, -1); buffer_init(&tcp->buf, &tcp->io, &tcp->tm); + tcp->family = family; return 1; } else { lua_pushnil(L); @@ -337,3 +384,106 @@ static int global_create(lua_State *L) return 2; } } + +static int global_create(lua_State *L) { + return tcp_create(L, AF_INET); +} + +static int global_create6(lua_State *L) { + return tcp_create(L, AF_INET6); +} + +#if 0 +static const char *tryconnect6(const char *remoteaddr, const char *remoteserv, + struct addrinfo *connecthints, p_tcp tcp) { + struct addrinfo *iterator = NULL, *resolved = NULL; + const char *err = NULL; + /* try resolving */ + err = socket_gaistrerror(getaddrinfo(remoteaddr, remoteserv, + connecthints, &resolved)); + if (err != NULL) { + if (resolved) freeaddrinfo(resolved); + return err; + } + /* iterate over all returned addresses trying to connect */ + for (iterator = resolved; iterator; iterator = iterator->ai_next) { + p_timeout tm = timeout_markstart(&tcp->tm); + /* create new socket if necessary. if there was no + * bind, we need to create one for every new family + * that shows up while iterating. if there was a + * bind, all families will be the same and we will + * not enter this branch. */ + if (tcp->family != iterator->ai_family) { + socket_destroy(&tcp->sock); + err = socket_strerror(socket_create(&tcp->sock, + iterator->ai_family, iterator->ai_socktype, + iterator->ai_protocol)); + if (err != NULL) { + freeaddrinfo(resolved); + return err; + } + tcp->family = iterator->ai_family; + /* all sockets initially non-blocking */ + socket_setnonblocking(&tcp->sock); + } + /* finally try connecting to remote address */ + err = socket_strerror(socket_connect(&tcp->sock, + (SA *) iterator->ai_addr, + (socklen_t) iterator->ai_addrlen, tm)); + /* if success, break out of loop */ + if (err == NULL) break; + } + + freeaddrinfo(resolved); + /* here, if err is set, we failed */ + return err; +} +#endif + +static int global_connect(lua_State *L) { + const char *remoteaddr = luaL_checkstring(L, 1); + const char *remoteserv = luaL_checkstring(L, 2); + const char *localaddr = luaL_optstring(L, 3, NULL); + const char *localserv = luaL_optstring(L, 4, "0"); + int family = inet_optfamily(L, 5, "unspec"); + p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + struct addrinfo bindhints, connecthints; + const char *err = NULL; + /* initialize tcp structure */ + memset(tcp, 0, sizeof(t_tcp)); + io_init(&tcp->io, (p_send) socket_send, (p_recv) socket_recv, + (p_error) socket_ioerror, &tcp->sock); + timeout_init(&tcp->tm, -1, -1); + buffer_init(&tcp->buf, &tcp->io, &tcp->tm); + tcp->sock = SOCKET_INVALID; + tcp->family = PF_UNSPEC; + /* allow user to pick local address and port */ + memset(&bindhints, 0, sizeof(bindhints)); + bindhints.ai_socktype = SOCK_STREAM; + bindhints.ai_family = family; + bindhints.ai_flags = AI_PASSIVE; + if (localaddr) { + err = inet_trybind(&tcp->sock, localaddr, localserv, &bindhints); + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + tcp->family = bindhints.ai_family; + } + /* try to connect to remote address and port */ + memset(&connecthints, 0, sizeof(connecthints)); + connecthints.ai_socktype = SOCK_STREAM; + /* make sure we try to connect only to the same family */ + connecthints.ai_family = bindhints.ai_family; + err = inet_tryconnect(&tcp->sock, &tcp->family, remoteaddr, remoteserv, + &tcp->tm, &connecthints); + if (err) { + socket_destroy(&tcp->sock); + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + auxiliar_setclass(L, "tcp{client}", -1); + return 1; +} diff --git a/csrc/socket/src/tcp.h b/csrc/socket/src/tcp.h index 511357f..eded620 100644 --- a/csrc/socket/src/tcp.h +++ b/csrc/socket/src/tcp.h @@ -13,8 +13,6 @@ * objects are tcp objects bound to some local address. Client objects are * tcp objects either connected to some address or returned by the accept * method of a server object. -* -* RCS ID: $Id: tcp.h,v 1.7 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include "lua.h" @@ -27,6 +25,7 @@ typedef struct t_tcp_ { t_io io; t_buffer buf; t_timeout tm; + int family; } t_tcp; typedef t_tcp *p_tcp; diff --git a/csrc/socket/src/timeout.c b/csrc/socket/src/timeout.c index c1df102..bdd5e1c 100644 --- a/csrc/socket/src/timeout.c +++ b/csrc/socket/src/timeout.c @@ -1,10 +1,10 @@ /*=========================================================================*\ * Timeout management functions * LuaSocket toolkit -* -* RCS ID: $Id: timeout.c,v 1.30 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include +#include +#include #include "lua.h" #include "lauxlib.h" @@ -33,7 +33,7 @@ static int timeout_lua_gettime(lua_State *L); static int timeout_lua_sleep(lua_State *L); -static luaL_reg func[] = { +static luaL_Reg func[] = { { "gettime", timeout_lua_gettime }, { "sleep", timeout_lua_sleep }, { NULL, NULL } @@ -144,7 +144,11 @@ double timeout_gettime(void) { * Initializes module \*-------------------------------------------------------------------------*/ int timeout_open(lua_State *L) { +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + luaL_setfuncs(L, func, 0); +#else luaL_openlib(L, NULL, func, 0); +#endif return 0; } @@ -187,13 +191,23 @@ static int timeout_lua_gettime(lua_State *L) /*-------------------------------------------------------------------------*\ * Sleep for n seconds. \*-------------------------------------------------------------------------*/ +#ifdef _WIN32 int timeout_lua_sleep(lua_State *L) { double n = luaL_checknumber(L, 1); -#ifdef _WIN32 - Sleep((int)(n*1000)); + if (n < 0.0) n = 0.0; + if (n < DBL_MAX/1000.0) n *= 1000.0; + if (n > INT_MAX) n = INT_MAX; + Sleep((int)n); + return 0; +} #else +int timeout_lua_sleep(lua_State *L) +{ + double n = luaL_checknumber(L, 1); struct timespec t, r; + if (n < 0.0) n = 0.0; + if (n > INT_MAX) n = INT_MAX; t.tv_sec = (int) n; n -= t.tv_sec; t.tv_nsec = (int) (n * 1000000000); @@ -202,6 +216,6 @@ int timeout_lua_sleep(lua_State *L) t.tv_sec = r.tv_sec; t.tv_nsec = r.tv_nsec; } -#endif return 0; } +#endif diff --git a/csrc/socket/src/timeout.h b/csrc/socket/src/timeout.h index d2d8964..6715ca7 100644 --- a/csrc/socket/src/timeout.h +++ b/csrc/socket/src/timeout.h @@ -3,8 +3,6 @@ /*=========================================================================*\ * Timeout management functions * LuaSocket toolkit -* -* RCS ID: $Id: timeout.h,v 1.14 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include "lua.h" diff --git a/csrc/socket/src/udp.c b/csrc/socket/src/udp.c index fc25aa0..a9f2393 100644 --- a/csrc/socket/src/udp.c +++ b/csrc/socket/src/udp.c @@ -1,10 +1,9 @@ /*=========================================================================*\ -* UDP object +* UDP object * LuaSocket toolkit -* -* RCS ID: $Id: udp.c,v 1.29 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ -#include +#include +#include #include "lua.h" #include "lauxlib.h" @@ -18,36 +17,40 @@ /* min and max macros */ #ifndef MIN #define MIN(x, y) ((x) < (y) ? x : y) -#endif +#endif #ifndef MAX #define MAX(x, y) ((x) > (y) ? x : y) -#endif +#endif /*=========================================================================*\ * Internal function prototypes \*=========================================================================*/ static int global_create(lua_State *L); +static int global_create6(lua_State *L); static int meth_send(lua_State *L); static int meth_sendto(lua_State *L); static int meth_receive(lua_State *L); static int meth_receivefrom(lua_State *L); +static int meth_getfamily(lua_State *L); static int meth_getsockname(lua_State *L); static int meth_getpeername(lua_State *L); static int meth_setsockname(lua_State *L); static int meth_setpeername(lua_State *L); static int meth_close(lua_State *L); static int meth_setoption(lua_State *L); +static int meth_getoption(lua_State *L); static int meth_settimeout(lua_State *L); static int meth_getfd(lua_State *L); static int meth_setfd(lua_State *L); static int meth_dirty(lua_State *L); /* udp object methods */ -static luaL_reg udp[] = { +static luaL_Reg udp_methods[] = { {"__gc", meth_close}, {"__tostring", auxiliar_tostring}, {"close", meth_close}, {"dirty", meth_dirty}, + {"getfamily", meth_getfamily}, {"getfd", meth_getfd}, {"getpeername", meth_getpeername}, {"getsockname", meth_getsockname}, @@ -57,27 +60,49 @@ static luaL_reg udp[] = { {"sendto", meth_sendto}, {"setfd", meth_setfd}, {"setoption", meth_setoption}, + {"getoption", meth_getoption}, {"setpeername", meth_setpeername}, {"setsockname", meth_setsockname}, {"settimeout", meth_settimeout}, {NULL, NULL} }; -/* socket options */ -static t_opt opt[] = { - {"dontroute", opt_dontroute}, - {"broadcast", opt_broadcast}, - {"reuseaddr", opt_reuseaddr}, - {"ip-multicast-ttl", opt_ip_multicast_ttl}, - {"ip-multicast-loop", opt_ip_multicast_loop}, - {"ip-add-membership", opt_ip_add_membership}, - {"ip-drop-membership", opt_ip_drop_membersip}, - {NULL, NULL} +/* socket options for setoption */ +static t_opt optset[] = { + {"dontroute", opt_set_dontroute}, + {"broadcast", opt_set_broadcast}, + {"reuseaddr", opt_set_reuseaddr}, + {"reuseport", opt_set_reuseport}, + {"ip-multicast-if", opt_set_ip_multicast_if}, + {"ip-multicast-ttl", opt_set_ip_multicast_ttl}, + {"ip-multicast-loop", opt_set_ip_multicast_loop}, + {"ip-add-membership", opt_set_ip_add_membership}, + {"ip-drop-membership", opt_set_ip_drop_membersip}, + {"ipv6-unicast-hops", opt_set_ip6_unicast_hops}, + {"ipv6-multicast-hops", opt_set_ip6_unicast_hops}, + {"ipv6-multicast-loop", opt_set_ip6_multicast_loop}, + {"ipv6-add-membership", opt_set_ip6_add_membership}, + {"ipv6-drop-membership", opt_set_ip6_drop_membersip}, + {"ipv6-v6only", opt_set_ip6_v6only}, + {NULL, NULL} +}; + +/* socket options for getoption */ +static t_opt optget[] = { + {"ip-multicast-if", opt_get_ip_multicast_if}, + {"ip-multicast-loop", opt_get_ip_multicast_loop}, + {"error", opt_get_error}, + {"ipv6-unicast-hops", opt_get_ip6_unicast_hops}, + {"ipv6-multicast-hops", opt_get_ip6_unicast_hops}, + {"ipv6-multicast-loop", opt_get_ip6_multicast_loop}, + {"ipv6-v6only", opt_get_ip6_v6only}, + {NULL, NULL} }; /* functions in library namespace */ -static luaL_reg func[] = { +static luaL_Reg func[] = { {"udp", global_create}, + {"udp6", global_create6}, {NULL, NULL} }; @@ -87,15 +112,19 @@ static luaL_reg func[] = { int udp_open(lua_State *L) { /* create classes */ - auxiliar_newclass(L, "udp{connected}", udp); - auxiliar_newclass(L, "udp{unconnected}", udp); + auxiliar_newclass(L, "udp{connected}", udp_methods); + auxiliar_newclass(L, "udp{unconnected}", udp_methods); /* create class groups */ auxiliar_add2group(L, "udp{connected}", "udp{any}"); auxiliar_add2group(L, "udp{unconnected}", "udp{any}"); auxiliar_add2group(L, "udp{connected}", "select{able}"); auxiliar_add2group(L, "udp{unconnected}", "select{able}"); /* define library functions */ - luaL_openlib(L, NULL, func, 0); +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + luaL_setfuncs(L, func, 0); +#else + luaL_openlib(L, NULL, func, 0); +#endif return 0; } @@ -125,7 +154,7 @@ static int meth_send(lua_State *L) { lua_pushstring(L, udp_strerror(err)); return 2; } - lua_pushnumber(L, sent); + lua_pushnumber(L, (lua_Number) sent); return 1; } @@ -137,24 +166,31 @@ static int meth_sendto(lua_State *L) { size_t count, sent = 0; const char *data = luaL_checklstring(L, 2, &count); const char *ip = luaL_checkstring(L, 3); - unsigned short port = (unsigned short) luaL_checknumber(L, 4); + const char *port = luaL_checkstring(L, 4); p_timeout tm = &udp->tm; - struct sockaddr_in addr; int err; - memset(&addr, 0, sizeof(addr)); - if (!inet_aton(ip, &addr.sin_addr)) - luaL_argerror(L, 3, "invalid ip address"); - addr.sin_family = AF_INET; - addr.sin_port = htons(port); + struct addrinfo aihint; + struct addrinfo *ai; + memset(&aihint, 0, sizeof(aihint)); + aihint.ai_family = udp->family; + aihint.ai_socktype = SOCK_DGRAM; + aihint.ai_flags = AI_NUMERICHOST | AI_NUMERICSERV; + err = getaddrinfo(ip, port, &aihint, &ai); + if (err) { + lua_pushnil(L); + lua_pushstring(L, gai_strerror(err)); + return 2; + } timeout_markstart(tm); - err = socket_sendto(&udp->sock, data, count, &sent, - (SA *) &addr, sizeof(addr), tm); + err = socket_sendto(&udp->sock, data, count, &sent, ai->ai_addr, + (socklen_t) ai->ai_addrlen, tm); + freeaddrinfo(ai); if (err != IO_DONE) { lua_pushnil(L); lua_pushstring(L, udp_strerror(err)); return 2; } - lua_pushnumber(L, sent); + lua_pushnumber(L, (lua_Number) sent); return 1; } @@ -170,6 +206,9 @@ static int meth_receive(lua_State *L) { count = MIN(count, sizeof(buffer)); timeout_markstart(tm); err = socket_recv(&udp->sock, buffer, count, &got, tm); + /* Unlike TCP, recv() of zero is not closed, but a zero-length packet. */ + if (err == IO_CLOSED) + err = IO_DONE; if (err != IO_DONE) { lua_pushnil(L); lua_pushstring(L, udp_strerror(err)); @@ -182,28 +221,55 @@ static int meth_receive(lua_State *L) { /*-------------------------------------------------------------------------*\ * Receives data and sender from a UDP socket \*-------------------------------------------------------------------------*/ -static int meth_receivefrom(lua_State *L) { +static int meth_receivefrom(lua_State *L) +{ p_udp udp = (p_udp) auxiliar_checkclass(L, "udp{unconnected}", 1); - struct sockaddr_in addr; - socklen_t addr_len = sizeof(addr); char buffer[UDP_DATAGRAMSIZE]; size_t got, count = (size_t) luaL_optnumber(L, 2, sizeof(buffer)); int err; p_timeout tm = &udp->tm; + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + char addrstr[INET6_ADDRSTRLEN]; + char portstr[6]; timeout_markstart(tm); count = MIN(count, sizeof(buffer)); - err = socket_recvfrom(&udp->sock, buffer, count, &got, - (SA *) &addr, &addr_len, tm); - if (err == IO_DONE) { - lua_pushlstring(L, buffer, got); - lua_pushstring(L, inet_ntoa(addr.sin_addr)); - lua_pushnumber(L, ntohs(addr.sin_port)); - return 3; - } else { + err = socket_recvfrom(&udp->sock, buffer, count, &got, (SA *) &addr, + &addr_len, tm); + /* Unlike TCP, recv() of zero is not closed, but a zero-length packet. */ + if (err == IO_CLOSED) + err = IO_DONE; + if (err != IO_DONE) { lua_pushnil(L); lua_pushstring(L, udp_strerror(err)); return 2; } + err = getnameinfo((struct sockaddr *)&addr, addr_len, addrstr, + INET6_ADDRSTRLEN, portstr, 6, NI_NUMERICHOST | NI_NUMERICSERV); + if (err) { + lua_pushnil(L); + lua_pushstring(L, gai_strerror(err)); + return 2; + } + lua_pushlstring(L, buffer, got); + lua_pushstring(L, addrstr); + lua_pushinteger(L, (int) strtol(portstr, (char **) NULL, 10)); + return 3; +} + +/*-------------------------------------------------------------------------*\ +* Returns family as string +\*-------------------------------------------------------------------------*/ +static int meth_getfamily(lua_State *L) +{ + p_udp udp = (p_udp) auxiliar_checkgroup(L, "udp{any}", 1); + if (udp->family == PF_INET6) { + lua_pushliteral(L, "inet6"); + return 1; + } else { + lua_pushliteral(L, "inet4"); + return 1; + } } /*-------------------------------------------------------------------------*\ @@ -234,12 +300,12 @@ static int meth_dirty(lua_State *L) { \*-------------------------------------------------------------------------*/ static int meth_getpeername(lua_State *L) { p_udp udp = (p_udp) auxiliar_checkclass(L, "udp{connected}", 1); - return inet_meth_getpeername(L, &udp->sock); + return inet_meth_getpeername(L, &udp->sock, udp->family); } static int meth_getsockname(lua_State *L) { p_udp udp = (p_udp) auxiliar_checkgroup(L, "udp{any}", 1); - return inet_meth_getsockname(L, &udp->sock); + return inet_meth_getsockname(L, &udp->sock, udp->family); } /*-------------------------------------------------------------------------*\ @@ -247,7 +313,15 @@ static int meth_getsockname(lua_State *L) { \*-------------------------------------------------------------------------*/ static int meth_setoption(lua_State *L) { p_udp udp = (p_udp) auxiliar_checkgroup(L, "udp{any}", 1); - return opt_meth_setoption(L, opt, &udp->sock); + return opt_meth_setoption(L, optset, &udp->sock); +} + +/*-------------------------------------------------------------------------*\ +* Just call option handler +\*-------------------------------------------------------------------------*/ +static int meth_getoption(lua_State *L) { + p_udp udp = (p_udp) auxiliar_checkgroup(L, "udp{any}", 1); + return opt_meth_getoption(L, optget, &udp->sock); } /*-------------------------------------------------------------------------*\ @@ -264,26 +338,37 @@ static int meth_settimeout(lua_State *L) { static int meth_setpeername(lua_State *L) { p_udp udp = (p_udp) auxiliar_checkgroup(L, "udp{any}", 1); p_timeout tm = &udp->tm; - const char *address = luaL_checkstring(L, 2); + const char *address = luaL_checkstring(L, 2); int connecting = strcmp(address, "*"); - unsigned short port = connecting ? - (unsigned short) luaL_checknumber(L, 3) : - (unsigned short) luaL_optnumber(L, 3, 0); - const char *err = inet_tryconnect(&udp->sock, address, port, tm); - if (err) { - lua_pushnil(L); - lua_pushstring(L, err); - return 2; + const char *port = connecting? luaL_checkstring(L, 3): "0"; + struct addrinfo connecthints; + const char *err; + memset(&connecthints, 0, sizeof(connecthints)); + connecthints.ai_socktype = SOCK_DGRAM; + /* make sure we try to connect only to the same family */ + connecthints.ai_family = udp->family; + if (connecting) { + err = inet_tryconnect(&udp->sock, &udp->family, address, + port, tm, &connecthints); + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + auxiliar_setclass(L, "udp{connected}", 1); + } else { + /* we ignore possible errors because Mac OS X always + * returns EAFNOSUPPORT */ + inet_trydisconnect(&udp->sock, udp->family, tm); + auxiliar_setclass(L, "udp{unconnected}", 1); } /* change class to connected or unconnected depending on address */ - if (connecting) auxiliar_setclass(L, "udp{connected}", 1); - else auxiliar_setclass(L, "udp{unconnected}", 1); lua_pushnumber(L, 1); return 1; } /*-------------------------------------------------------------------------*\ -* Closes socket used by object +* Closes socket used by object \*-------------------------------------------------------------------------*/ static int meth_close(lua_State *L) { p_udp udp = (p_udp) auxiliar_checkgroup(L, "udp{any}", 1); @@ -298,8 +383,14 @@ static int meth_close(lua_State *L) { static int meth_setsockname(lua_State *L) { p_udp udp = (p_udp) auxiliar_checkclass(L, "udp{unconnected}", 1); const char *address = luaL_checkstring(L, 2); - unsigned short port = (unsigned short) luaL_checknumber(L, 3); - const char *err = inet_trybind(&udp->sock, address, port); + const char *port = luaL_checkstring(L, 3); + const char *err; + struct addrinfo bindhints; + memset(&bindhints, 0, sizeof(bindhints)); + bindhints.ai_socktype = SOCK_DGRAM; + bindhints.ai_family = udp->family; + bindhints.ai_flags = AI_PASSIVE; + err = inet_trybind(&udp->sock, address, port, &bindhints); if (err) { lua_pushnil(L); lua_pushstring(L, err); @@ -313,20 +404,26 @@ static int meth_setsockname(lua_State *L) { * Library functions \*=========================================================================*/ /*-------------------------------------------------------------------------*\ -* Creates a master udp object +* Creates a master udp object \*-------------------------------------------------------------------------*/ -static int global_create(lua_State *L) { +static int udp_create(lua_State *L, int family) { t_socket sock; - const char *err = inet_trycreate(&sock, SOCK_DGRAM); + const char *err = inet_trycreate(&sock, family, SOCK_DGRAM); /* try to allocate a system socket */ - if (!err) { - /* allocate tcp object */ + if (!err) { + /* allocate udp object */ p_udp udp = (p_udp) lua_newuserdata(L, sizeof(t_udp)); auxiliar_setclass(L, "udp{unconnected}", -1); /* initialize remaining structure fields */ socket_setnonblocking(&sock); + if (family == PF_INET6) { + int yes = 1; + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + (void *)&yes, sizeof(yes)); + } udp->sock = sock; timeout_init(&udp->tm, -1, -1); + udp->family = family; return 1; } else { lua_pushnil(L); @@ -334,3 +431,11 @@ static int global_create(lua_State *L) { return 2; } } + +static int global_create(lua_State *L) { + return udp_create(L, AF_INET); +} + +static int global_create6(lua_State *L) { + return udp_create(L, AF_INET6); +} diff --git a/csrc/socket/src/udp.h b/csrc/socket/src/udp.h index 2801712..2b831a5 100644 --- a/csrc/socket/src/udp.h +++ b/csrc/socket/src/udp.h @@ -11,8 +11,6 @@ * originally unconnected. They can be "connected" to a given address * with a call to the setpeername function. The same function can be used to * break the connection. -* -* RCS ID: $Id: udp.h,v 1.10 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ #include "lua.h" @@ -25,6 +23,7 @@ typedef struct t_udp_ { t_socket sock; t_timeout tm; + int family; } t_udp; typedef t_udp *p_udp; diff --git a/csrc/socket/src/unix.c b/csrc/socket/src/unix.c new file mode 100644 index 0000000..91aaaf8 --- /dev/null +++ b/csrc/socket/src/unix.c @@ -0,0 +1,346 @@ +/*=========================================================================*\ +* Unix domain socket +* LuaSocket toolkit +\*=========================================================================*/ +#include + +#include "lua.h" +#include "lauxlib.h" + +#include "auxiliar.h" +#include "socket.h" +#include "options.h" +#include "unix.h" +#include + +/*=========================================================================*\ +* Internal function prototypes +\*=========================================================================*/ +static int global_create(lua_State *L); +static int meth_connect(lua_State *L); +static int meth_listen(lua_State *L); +static int meth_bind(lua_State *L); +static int meth_send(lua_State *L); +static int meth_shutdown(lua_State *L); +static int meth_receive(lua_State *L); +static int meth_accept(lua_State *L); +static int meth_close(lua_State *L); +static int meth_setoption(lua_State *L); +static int meth_settimeout(lua_State *L); +static int meth_getfd(lua_State *L); +static int meth_setfd(lua_State *L); +static int meth_dirty(lua_State *L); +static int meth_getstats(lua_State *L); +static int meth_setstats(lua_State *L); + +static const char *unix_tryconnect(p_unix un, const char *path); +static const char *unix_trybind(p_unix un, const char *path); + +/* unix object methods */ +static luaL_Reg unix_methods[] = { + {"__gc", meth_close}, + {"__tostring", auxiliar_tostring}, + {"accept", meth_accept}, + {"bind", meth_bind}, + {"close", meth_close}, + {"connect", meth_connect}, + {"dirty", meth_dirty}, + {"getfd", meth_getfd}, + {"getstats", meth_getstats}, + {"setstats", meth_setstats}, + {"listen", meth_listen}, + {"receive", meth_receive}, + {"send", meth_send}, + {"setfd", meth_setfd}, + {"setoption", meth_setoption}, + {"setpeername", meth_connect}, + {"setsockname", meth_bind}, + {"settimeout", meth_settimeout}, + {"shutdown", meth_shutdown}, + {NULL, NULL} +}; + +/* socket option handlers */ +static t_opt optset[] = { + {"keepalive", opt_set_keepalive}, + {"reuseaddr", opt_set_reuseaddr}, + {"linger", opt_set_linger}, + {NULL, NULL} +}; + +/* our socket creation function */ +/* this is an ad-hoc module that returns a single function + * as such, do not include other functions in this array. */ +static luaL_Reg func[] = { + {"unix", global_create}, + {NULL, NULL} +}; + + +/*-------------------------------------------------------------------------*\ +* Initializes module +\*-------------------------------------------------------------------------*/ +int luaopen_socket_unix(lua_State *L) { + /* create classes */ + auxiliar_newclass(L, "unix{master}", unix_methods); + auxiliar_newclass(L, "unix{client}", unix_methods); + auxiliar_newclass(L, "unix{server}", unix_methods); + /* create class groups */ + auxiliar_add2group(L, "unix{master}", "unix{any}"); + auxiliar_add2group(L, "unix{client}", "unix{any}"); + auxiliar_add2group(L, "unix{server}", "unix{any}"); +#if LUA_VERSION_NUM > 501 && !defined(LUA_COMPAT_MODULE) + lua_pushcfunction(L, global_create); + (void) func; +#else + /* set function into socket namespace */ + luaL_openlib(L, "socket", func, 0); + lua_pushcfunction(L, global_create); +#endif + /* return the function instead of the 'socket' table */ + return 1; +} + +/*=========================================================================*\ +* Lua methods +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Just call buffered IO methods +\*-------------------------------------------------------------------------*/ +static int meth_send(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkclass(L, "unix{client}", 1); + return buffer_meth_send(L, &un->buf); +} + +static int meth_receive(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkclass(L, "unix{client}", 1); + return buffer_meth_receive(L, &un->buf); +} + +static int meth_getstats(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkclass(L, "unix{client}", 1); + return buffer_meth_getstats(L, &un->buf); +} + +static int meth_setstats(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkclass(L, "unix{client}", 1); + return buffer_meth_setstats(L, &un->buf); +} + +/*-------------------------------------------------------------------------*\ +* Just call option handler +\*-------------------------------------------------------------------------*/ +static int meth_setoption(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkgroup(L, "unix{any}", 1); + return opt_meth_setoption(L, optset, &un->sock); +} + +/*-------------------------------------------------------------------------*\ +* Select support methods +\*-------------------------------------------------------------------------*/ +static int meth_getfd(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkgroup(L, "unix{any}", 1); + lua_pushnumber(L, (int) un->sock); + return 1; +} + +/* this is very dangerous, but can be handy for those that are brave enough */ +static int meth_setfd(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkgroup(L, "unix{any}", 1); + un->sock = (t_socket) luaL_checknumber(L, 2); + return 0; +} + +static int meth_dirty(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkgroup(L, "unix{any}", 1); + lua_pushboolean(L, !buffer_isempty(&un->buf)); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Waits for and returns a client object attempting connection to the +* server object +\*-------------------------------------------------------------------------*/ +static int meth_accept(lua_State *L) { + p_unix server = (p_unix) auxiliar_checkclass(L, "unix{server}", 1); + p_timeout tm = timeout_markstart(&server->tm); + t_socket sock; + int err = socket_accept(&server->sock, &sock, NULL, NULL, tm); + /* if successful, push client socket */ + if (err == IO_DONE) { + p_unix clnt = (p_unix) lua_newuserdata(L, sizeof(t_unix)); + auxiliar_setclass(L, "unix{client}", -1); + /* initialize structure fields */ + socket_setnonblocking(&sock); + clnt->sock = sock; + io_init(&clnt->io, (p_send)socket_send, (p_recv)socket_recv, + (p_error) socket_ioerror, &clnt->sock); + timeout_init(&clnt->tm, -1, -1); + buffer_init(&clnt->buf, &clnt->io, &clnt->tm); + return 1; + } else { + lua_pushnil(L); + lua_pushstring(L, socket_strerror(err)); + return 2; + } +} + +/*-------------------------------------------------------------------------*\ +* Binds an object to an address +\*-------------------------------------------------------------------------*/ +static const char *unix_trybind(p_unix un, const char *path) { + struct sockaddr_un local; + size_t len = strlen(path); + int err; + if (len >= sizeof(local.sun_path)) return "path too long"; + memset(&local, 0, sizeof(local)); + strcpy(local.sun_path, path); + local.sun_family = AF_UNIX; +#ifdef UNIX_HAS_SUN_LEN + local.sun_len = sizeof(local.sun_family) + sizeof(local.sun_len) + + len + 1; + err = socket_bind(&un->sock, (SA *) &local, local.sun_len); + +#else + err = socket_bind(&un->sock, (SA *) &local, + sizeof(local.sun_family) + len); +#endif + if (err != IO_DONE) socket_destroy(&un->sock); + return socket_strerror(err); +} + +static int meth_bind(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkclass(L, "unix{master}", 1); + const char *path = luaL_checkstring(L, 2); + const char *err = unix_trybind(un, path); + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + lua_pushnumber(L, 1); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Turns a master unix object into a client object. +\*-------------------------------------------------------------------------*/ +static const char *unix_tryconnect(p_unix un, const char *path) +{ + struct sockaddr_un remote; + int err; + size_t len = strlen(path); + if (len >= sizeof(remote.sun_path)) return "path too long"; + memset(&remote, 0, sizeof(remote)); + strcpy(remote.sun_path, path); + remote.sun_family = AF_UNIX; + timeout_markstart(&un->tm); +#ifdef UNIX_HAS_SUN_LEN + remote.sun_len = sizeof(remote.sun_family) + sizeof(remote.sun_len) + + len + 1; + err = socket_connect(&un->sock, (SA *) &remote, remote.sun_len, &un->tm); +#else + err = socket_connect(&un->sock, (SA *) &remote, + sizeof(remote.sun_family) + len, &un->tm); +#endif + if (err != IO_DONE) socket_destroy(&un->sock); + return socket_strerror(err); +} + +static int meth_connect(lua_State *L) +{ + p_unix un = (p_unix) auxiliar_checkclass(L, "unix{master}", 1); + const char *path = luaL_checkstring(L, 2); + const char *err = unix_tryconnect(un, path); + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + /* turn master object into a client object */ + auxiliar_setclass(L, "unix{client}", 1); + lua_pushnumber(L, 1); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Closes socket used by object +\*-------------------------------------------------------------------------*/ +static int meth_close(lua_State *L) +{ + p_unix un = (p_unix) auxiliar_checkgroup(L, "unix{any}", 1); + socket_destroy(&un->sock); + lua_pushnumber(L, 1); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Puts the sockt in listen mode +\*-------------------------------------------------------------------------*/ +static int meth_listen(lua_State *L) +{ + p_unix un = (p_unix) auxiliar_checkclass(L, "unix{master}", 1); + int backlog = (int) luaL_optnumber(L, 2, 32); + int err = socket_listen(&un->sock, backlog); + if (err != IO_DONE) { + lua_pushnil(L); + lua_pushstring(L, socket_strerror(err)); + return 2; + } + /* turn master object into a server object */ + auxiliar_setclass(L, "unix{server}", 1); + lua_pushnumber(L, 1); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Shuts the connection down partially +\*-------------------------------------------------------------------------*/ +static int meth_shutdown(lua_State *L) +{ + /* SHUT_RD, SHUT_WR, SHUT_RDWR have the value 0, 1, 2, so we can use method index directly */ + static const char* methods[] = { "receive", "send", "both", NULL }; + p_unix tcp = (p_unix) auxiliar_checkclass(L, "unix{client}", 1); + int how = luaL_checkoption(L, 2, "both", methods); + socket_shutdown(&tcp->sock, how); + lua_pushnumber(L, 1); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Just call tm methods +\*-------------------------------------------------------------------------*/ +static int meth_settimeout(lua_State *L) { + p_unix un = (p_unix) auxiliar_checkgroup(L, "unix{any}", 1); + return timeout_meth_settimeout(L, &un->tm); +} + +/*=========================================================================*\ +* Library functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Creates a master unix object +\*-------------------------------------------------------------------------*/ +static int global_create(lua_State *L) { + t_socket sock; + int err = socket_create(&sock, AF_UNIX, SOCK_STREAM, 0); + /* try to allocate a system socket */ + if (err == IO_DONE) { + /* allocate unix object */ + p_unix un = (p_unix) lua_newuserdata(L, sizeof(t_unix)); + /* set its type as master object */ + auxiliar_setclass(L, "unix{master}", -1); + /* initialize remaining structure fields */ + socket_setnonblocking(&sock); + un->sock = sock; + io_init(&un->io, (p_send) socket_send, (p_recv) socket_recv, + (p_error) socket_ioerror, &un->sock); + timeout_init(&un->tm, -1, -1); + buffer_init(&un->buf, &un->io, &un->tm); + return 1; + } else { + lua_pushnil(L); + lua_pushstring(L, socket_strerror(err)); + return 2; + } +} diff --git a/csrc/socket/src/unix.h b/csrc/socket/src/unix.h new file mode 100644 index 0000000..8cc7a79 --- /dev/null +++ b/csrc/socket/src/unix.h @@ -0,0 +1,30 @@ +#ifndef UNIX_H +#define UNIX_H +/*=========================================================================*\ +* Unix domain object +* LuaSocket toolkit +* +* This module is just an example of how to extend LuaSocket with a new +* domain. +\*=========================================================================*/ +#include "lua.h" + +#include "buffer.h" +#include "timeout.h" +#include "socket.h" + +#ifndef UNIX_API +#define UNIX_API extern +#endif + +typedef struct t_unix_ { + t_socket sock; + t_io io; + t_buffer buf; + t_timeout tm; +} t_unix; +typedef t_unix *p_unix; + +UNIX_API int luaopen_socket_unix(lua_State *L); + +#endif /* UNIX_H */ diff --git a/csrc/socket/src/usocket.c b/csrc/socket/src/usocket.c index 70c6e1e..096ecd0 100644 --- a/csrc/socket/src/usocket.c +++ b/csrc/socket/src/usocket.c @@ -5,8 +5,6 @@ * The code is now interrupt-safe. * The penalty of calling select to avoid busy-wait is only paid when * the I/O call fail in the first place. -* -* RCS ID: $Id: usocket.c,v 1.38 2007/10/13 23:55:20 diego Exp $ \*=========================================================================*/ #include #include @@ -16,7 +14,7 @@ /*-------------------------------------------------------------------------*\ * Wait for readable/writable/connected socket with timeout \*-------------------------------------------------------------------------*/ -#ifdef SOCKET_POLL +#ifndef SOCKET_SELECT #include #define WAITFD_R POLLIN @@ -30,9 +28,9 @@ int socket_waitfd(p_socket ps, int sw, p_timeout tm) { pfd.revents = 0; if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */ do { - int t = (int)(timeout_getretry(tm)*1e3); - ret = poll(&pfd, 1, t >= 0? t: -1); - } while (ret == -1 && errno == EINTR); + int t = (int)(timeout_getretry(tm)*1e3); + ret = poll(&pfd, 1, t >= 0? t: -1); + } while (ret == -1 && errno == EINTR); if (ret == -1) return errno; if (ret == 0) return IO_TIMEOUT; if (sw == WAITFD_C && (pfd.revents & (POLLIN|POLLERR))) return IO_CLOSED; @@ -49,6 +47,7 @@ int socket_waitfd(p_socket ps, int sw, p_timeout tm) { fd_set rfds, wfds, *rp, *wp; struct timeval tv, *tp; double t; + if (*ps >= FD_SETSIZE) return EINVAL; if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */ do { /* must set bits within loop, because select may have modifed them */ @@ -182,11 +181,7 @@ int socket_connect(p_socket ps, SA *addr, socklen_t len, p_timeout tm) { * Accept with timeout \*-------------------------------------------------------------------------*/ int socket_accept(p_socket ps, p_socket pa, SA *addr, socklen_t *len, p_timeout tm) { - SA daddr; - socklen_t dlen = sizeof(daddr); if (*ps == SOCKET_INVALID) return IO_CLOSED; - if (!addr) addr = &daddr; - if (!len) len = &dlen; for ( ;; ) { int err; if ((*pa = accept(*ps, addr, len)) != SOCKET_INVALID) return IO_DONE; @@ -213,14 +208,13 @@ int socket_send(p_socket ps, const char *data, size_t count, for ( ;; ) { long put = (long) send(*ps, data, count, 0); /* if we sent anything, we are done */ - if (put > 0) { + if (put >= 0) { *sent = put; return IO_DONE; } err = errno; - /* send can't really return 0, but EPIPE means the connection was - closed */ - if (put == 0 || err == EPIPE) return IO_CLOSED; + /* EPIPE means the connection was closed */ + if (err == EPIPE) return IO_CLOSED; /* we call was interrupted, just try again */ if (err == EINTR) continue; /* if failed fatal reason, report error */ @@ -243,12 +237,12 @@ int socket_sendto(p_socket ps, const char *data, size_t count, size_t *sent, if (*ps == SOCKET_INVALID) return IO_CLOSED; for ( ;; ) { long put = (long) sendto(*ps, data, count, 0, addr, len); - if (put > 0) { + if (put >= 0) { *sent = put; return IO_DONE; } err = errno; - if (put == 0 || err == EPIPE) return IO_CLOSED; + if (err == EPIPE) return IO_CLOSED; if (err == EINTR) continue; if (err != EAGAIN) return err; if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; @@ -301,6 +295,66 @@ int socket_recvfrom(p_socket ps, char *data, size_t count, size_t *got, return IO_UNKNOWN; } + +/*-------------------------------------------------------------------------*\ +* Write with timeout +* +* socket_read and socket_write are cut-n-paste of socket_send and socket_recv, +* with send/recv replaced with write/read. We can't just use write/read +* in the socket version, because behaviour when size is zero is different. +\*-------------------------------------------------------------------------*/ +int socket_write(p_socket ps, const char *data, size_t count, + size_t *sent, p_timeout tm) +{ + int err; + *sent = 0; + /* avoid making system calls on closed sockets */ + if (*ps == SOCKET_INVALID) return IO_CLOSED; + /* loop until we send something or we give up on error */ + for ( ;; ) { + long put = (long) write(*ps, data, count); + /* if we sent anything, we are done */ + if (put >= 0) { + *sent = put; + return IO_DONE; + } + err = errno; + /* EPIPE means the connection was closed */ + if (err == EPIPE) return IO_CLOSED; + /* we call was interrupted, just try again */ + if (err == EINTR) continue; + /* if failed fatal reason, report error */ + if (err != EAGAIN) return err; + /* wait until we can send something or we timeout */ + if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; + } + /* can't reach here */ + return IO_UNKNOWN; +} + +/*-------------------------------------------------------------------------*\ +* Read with timeout +* See note for socket_write +\*-------------------------------------------------------------------------*/ +int socket_read(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) { + int err; + *got = 0; + if (*ps == SOCKET_INVALID) return IO_CLOSED; + for ( ;; ) { + long taken = (long) read(*ps, data, count); + if (taken > 0) { + *got = taken; + return IO_DONE; + } + err = errno; + if (taken == 0) return IO_CLOSED; + if (err == EINTR) continue; + if (err != EAGAIN) return err; + if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; + } + return IO_UNKNOWN; +} + /*-------------------------------------------------------------------------*\ * Put socket into blocking mode \*-------------------------------------------------------------------------*/ @@ -360,7 +414,7 @@ const char *socket_strerror(int err) { case ECONNABORTED: return "closed"; case ECONNRESET: return "closed"; case ETIMEDOUT: return "timeout"; - default: return strerror(errno); + default: return strerror(err); } } @@ -368,3 +422,28 @@ const char *socket_ioerror(p_socket ps, int err) { (void) ps; return socket_strerror(err); } + +const char *socket_gaistrerror(int err) { + if (err == 0) return NULL; + switch (err) { + case EAI_AGAIN: return "temporary failure in name resolution"; + case EAI_BADFLAGS: return "invalid value for ai_flags"; +#ifdef EAI_BADHINTS + case EAI_BADHINTS: return "invalid value for hints"; +#endif + case EAI_FAIL: return "non-recoverable failure in name resolution"; + case EAI_FAMILY: return "ai_family not supported"; + case EAI_MEMORY: return "memory allocation failure"; + case EAI_NONAME: + return "host or service not provided, or not known"; + case EAI_OVERFLOW: return "argument buffer overflow"; +#ifdef EAI_PROTOCOL + case EAI_PROTOCOL: return "resolved protocol is unknown"; +#endif + case EAI_SERVICE: return "service not supported for socket type"; + case EAI_SOCKTYPE: return "ai_socktype not supported"; + case EAI_SYSTEM: return strerror(errno); + default: return gai_strerror(err); + } +} + diff --git a/csrc/socket/src/usocket.h b/csrc/socket/src/usocket.h index f2a89aa..45f2f99 100644 --- a/csrc/socket/src/usocket.h +++ b/csrc/socket/src/usocket.h @@ -3,8 +3,6 @@ /*=========================================================================*\ * Socket compatibilization module for Unix * LuaSocket toolkit -* -* RCS ID: $Id: usocket.h,v 1.7 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ /*=========================================================================*\ @@ -31,9 +29,30 @@ #include /* TCP options (nagle algorithm disable) */ #include +#include + +#ifndef SO_REUSEPORT +#define SO_REUSEPORT SO_REUSEADDR +#endif + +/* Some platforms use IPV6_JOIN_GROUP instead if + * IPV6_ADD_MEMBERSHIP. The semantics are same, though. */ +#ifndef IPV6_ADD_MEMBERSHIP +#ifdef IPV6_JOIN_GROUP +#define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP +#endif /* IPV6_JOIN_GROUP */ +#endif /* !IPV6_ADD_MEMBERSHIP */ + +/* Same with IPV6_DROP_MEMBERSHIP / IPV6_LEAVE_GROUP. */ +#ifndef IPV6_DROP_MEMBERSHIP +#ifdef IPV6_LEAVE_GROUP +#define IPV6_DROP_MEMBERSHIP IPV6_LEAVE_GROUP +#endif /* IPV6_LEAVE_GROUP */ +#endif /* !IPV6_DROP_MEMBERSHIP */ typedef int t_socket; typedef t_socket *p_socket; +typedef struct sockaddr_storage t_sockaddr_storage; #define SOCKET_INVALID (-1) diff --git a/csrc/socket/src/wsocket.c b/csrc/socket/src/wsocket.c index 6022565..b4a4384 100644 --- a/csrc/socket/src/wsocket.c +++ b/csrc/socket/src/wsocket.c @@ -4,8 +4,6 @@ * * The penalty of calling select to avoid busy-wait is only paid when * the I/O call fail in the first place. -* -* RCS ID: $Id: wsocket.c,v 1.36 2007/06/11 23:44:54 diego Exp $ \*=========================================================================*/ #include @@ -54,7 +52,7 @@ int socket_waitfd(p_socket ps, int sw, p_timeout tm) { if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */ if (sw & WAITFD_R) { FD_ZERO(&rfds); - FD_SET(*ps, &rfds); + FD_SET(*ps, &rfds); rp = &rfds; } if (sw & WAITFD_W) { FD_ZERO(&wfds); FD_SET(*ps, &wfds); wp = &wfds; } @@ -171,11 +169,7 @@ int socket_listen(p_socket ps, int backlog) { \*-------------------------------------------------------------------------*/ int socket_accept(p_socket ps, p_socket pa, SA *addr, socklen_t *len, p_timeout tm) { - SA daddr; - socklen_t dlen = sizeof(daddr); if (*ps == SOCKET_INVALID) return IO_CLOSED; - if (!addr) addr = &daddr; - if (!len) len = &dlen; for ( ;; ) { int err; /* try to get client socket */ @@ -187,8 +181,6 @@ int socket_accept(p_socket ps, p_socket pa, SA *addr, socklen_t *len, /* call select to avoid busy wait */ if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; } - /* can't reach here */ - return IO_UNKNOWN; } /*-------------------------------------------------------------------------*\ @@ -207,7 +199,7 @@ int socket_send(p_socket ps, const char *data, size_t count, /* loop until we send something or we give up on error */ for ( ;; ) { /* try to send something */ - int put = send(*ps, data, (int) count, 0); + int put = send(*ps, data, (int) count, 0); /* if we sent something, we are done */ if (put > 0) { *sent = put; @@ -220,8 +212,6 @@ int socket_send(p_socket ps, const char *data, size_t count, /* avoid busy wait */ if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; } - /* can't reach here */ - return IO_UNKNOWN; } /*-------------------------------------------------------------------------*\ @@ -243,14 +233,15 @@ int socket_sendto(p_socket ps, const char *data, size_t count, size_t *sent, if (err != WSAEWOULDBLOCK) return err; if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; } - return IO_UNKNOWN; } /*-------------------------------------------------------------------------*\ * Receive with timeout \*-------------------------------------------------------------------------*/ -int socket_recv(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) { - int err; +int socket_recv(p_socket ps, char *data, size_t count, size_t *got, + p_timeout tm) +{ + int err, prev = IO_DONE; *got = 0; if (*ps == SOCKET_INVALID) return IO_CLOSED; for ( ;; ) { @@ -261,18 +252,25 @@ int socket_recv(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm } if (taken == 0) return IO_CLOSED; err = WSAGetLastError(); - if (err != WSAEWOULDBLOCK) return err; + /* On UDP, a connreset simply means the previous send failed. + * So we try again. + * On TCP, it means our socket is now useless, so the error passes. + * (We will loop again, exiting because the same error will happen) */ + if (err != WSAEWOULDBLOCK) { + if (err != WSAECONNRESET || prev == WSAECONNRESET) return err; + prev = err; + } if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; } - return IO_UNKNOWN; } /*-------------------------------------------------------------------------*\ * Recvfrom with timeout \*-------------------------------------------------------------------------*/ int socket_recvfrom(p_socket ps, char *data, size_t count, size_t *got, - SA *addr, socklen_t *len, p_timeout tm) { - int err; + SA *addr, socklen_t *len, p_timeout tm) +{ + int err, prev = IO_DONE; *got = 0; if (*ps == SOCKET_INVALID) return IO_CLOSED; for ( ;; ) { @@ -283,10 +281,16 @@ int socket_recvfrom(p_socket ps, char *data, size_t count, size_t *got, } if (taken == 0) return IO_CLOSED; err = WSAGetLastError(); - if (err != WSAEWOULDBLOCK) return err; + /* On UDP, a connreset simply means the previous send failed. + * So we try again. + * On TCP, it means our socket is now useless, so the error passes. + * (We will loop again, exiting because the same error will happen) */ + if (err != WSAEWOULDBLOCK) { + if (err != WSAECONNRESET || prev == WSAECONNRESET) return err; + prev = err; + } if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; } - return IO_UNKNOWN; } /*-------------------------------------------------------------------------*\ @@ -346,8 +350,8 @@ const char *socket_strerror(int err) { } const char *socket_ioerror(p_socket ps, int err) { - (void) ps; - return socket_strerror(err); + (void) ps; + return socket_strerror(err); } static const char *wstrerror(int err) { @@ -399,3 +403,32 @@ static const char *wstrerror(int err) { default: return "Unknown error"; } } + +const char *socket_gaistrerror(int err) { + if (err == 0) return NULL; + switch (err) { + case EAI_AGAIN: return "temporary failure in name resolution"; + case EAI_BADFLAGS: return "invalid value for ai_flags"; +#ifdef EAI_BADHINTS + case EAI_BADHINTS: return "invalid value for hints"; +#endif + case EAI_FAIL: return "non-recoverable failure in name resolution"; + case EAI_FAMILY: return "ai_family not supported"; + case EAI_MEMORY: return "memory allocation failure"; + case EAI_NONAME: + return "host or service not provided, or not known"; +#ifdef EAI_OVERFLOW + case EAI_OVERFLOW: return "argument buffer overflow"; +#endif +#ifdef EAI_PROTOCOL + case EAI_PROTOCOL: return "resolved protocol is unknown"; +#endif + case EAI_SERVICE: return "service not supported for socket type"; + case EAI_SOCKTYPE: return "ai_socktype not supported"; +#ifdef EAI_SYSTEM + case EAI_SYSTEM: return strerror(errno); +#endif + default: return gai_strerror(err); + } +} + diff --git a/csrc/socket/src/wsocket.h b/csrc/socket/src/wsocket.h index b536683..3986640 100644 --- a/csrc/socket/src/wsocket.h +++ b/csrc/socket/src/wsocket.h @@ -3,19 +3,31 @@ /*=========================================================================*\ * Socket compatibilization module for Win32 * LuaSocket toolkit -* -* RCS ID: $Id: wsocket.h,v 1.4 2005/10/07 04:40:59 diego Exp $ \*=========================================================================*/ /*=========================================================================*\ * WinSock include files \*=========================================================================*/ -#include +#include +#include typedef int socklen_t; +typedef SOCKADDR_STORAGE t_sockaddr_storage; typedef SOCKET t_socket; typedef t_socket *p_socket; +#ifndef IPV6_V6ONLY +#define IPV6_V6ONLY 27 +#endif + #define SOCKET_INVALID (INVALID_SOCKET) +#ifndef SO_REUSEPORT +#define SO_REUSEPORT SO_REUSEADDR +#endif + +#ifndef AI_NUMERICSERV +#define AI_NUMERICSERV (0) +#endif + #endif /* WSOCKET_H */ diff --git a/csrc/socket/test/README b/csrc/socket/test/README deleted file mode 100644 index 180fa27..0000000 --- a/csrc/socket/test/README +++ /dev/null @@ -1,12 +0,0 @@ -This provides the automated test scripts used to make sure the library -is working properly. - -The files provided are: - - testsrvr.lua -- test server - testclnt.lua -- test client - -To run these tests, just run lua on the server and then on the client. - -Good luck, -Diego. diff --git a/csrc/socket/test/testclnt.lua b/csrc/socket/test/testclnt.lua deleted file mode 100644 index a7ca1ba..0000000 --- a/csrc/socket/test/testclnt.lua +++ /dev/null @@ -1,713 +0,0 @@ -local socket = require"socket" - -host = host or "localhost" -port = port or "8383" - -function pass(...) - local s = string.format(unpack(arg)) - io.stderr:write(s, "\n") -end - -function fail(...) - local s = string.format(unpack(arg)) - io.stderr:write("ERROR: ", s, "!\n") -socket.sleep(3) - os.exit() -end - -function warn(...) - local s = string.format(unpack(arg)) - io.stderr:write("WARNING: ", s, "\n") -end - -function remote(...) - local s = string.format(unpack(arg)) - s = string.gsub(s, "\n", ";") - s = string.gsub(s, "%s+", " ") - s = string.gsub(s, "^%s*", "") - control:send(s .. "\n") - control:receive() -end - -function test(test) - io.stderr:write("----------------------------------------------\n", - "testing: ", test, "\n", - "----------------------------------------------\n") -end - -function check_timeout(tm, sl, elapsed, err, opp, mode, alldone) - if tm < sl then - if opp == "send" then - if not err then warn("must be buffered") - elseif err == "timeout" then pass("proper timeout") - else fail("unexpected error '%s'", err) end - else - if err ~= "timeout" then fail("should have timed out") - else pass("proper timeout") end - end - else - if mode == "total" then - if elapsed > tm then - if err ~= "timeout" then fail("should have timed out") - else pass("proper timeout") end - elseif elapsed < tm then - if err then fail(err) - else pass("ok") end - else - if alldone then - if err then fail("unexpected error '%s'", err) - else pass("ok") end - else - if err ~= "timeout" then fail(err) - else pass("proper timeoutk") end - end - end - else - if err then fail(err) - else pass("ok") end - end - end -end - -if not socket._DEBUG then - fail("Please define LUASOCKET_DEBUG and recompile LuaSocket") -end - -io.stderr:write("----------------------------------------------\n", -"LuaSocket Test Procedures\n", -"----------------------------------------------\n") - -start = socket.gettime() - -function reconnect() - io.stderr:write("attempting data connection... ") - if data then data:close() end - remote [[ - if data then data:close() data = nil end - data = server:accept() - data:setoption("tcp-nodelay", true) - ]] - data, err = socket.connect(host, port) - if not data then fail(err) - else pass("connected!") end - data:setoption("tcp-nodelay", true) -end - -pass("attempting control connection...") -control, err = socket.connect(host, port) -if err then fail(err) -else pass("connected!") end -control:setoption("tcp-nodelay", true) - ------------------------------------------------------------------------- -function test_methods(sock, methods) - for _, v in pairs(methods) do - if type(sock[v]) ~= "function" then - fail(sock.class .. " method '" .. v .. "' not registered") - end - end - pass(sock.class .. " methods are ok") -end - ------------------------------------------------------------------------- -function test_mixed(len) - reconnect() - local inter = math.ceil(len/4) - local p1 = "unix " .. string.rep("x", inter) .. "line\n" - local p2 = "dos " .. string.rep("y", inter) .. "line\r\n" - local p3 = "raw " .. string.rep("z", inter) .. "bytes" - local p4 = "end" .. string.rep("w", inter) .. "bytes" - local bp1, bp2, bp3, bp4 -remote (string.format("str = data:receive(%d)", - string.len(p1)+string.len(p2)+string.len(p3)+string.len(p4))) - sent, err = data:send(p1..p2..p3..p4) - if err then fail(err) end -remote "data:send(str); data:close()" - bp1, err = data:receive() - if err then fail(err) end - bp2, err = data:receive() - if err then fail(err) end - bp3, err = data:receive(string.len(p3)) - if err then fail(err) end - bp4, err = data:receive("*a") - if err then fail(err) end - if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then - pass("patterns match") - else fail("patterns don't match") end -end - ------------------------------------------------------------------------- -function test_asciiline(len) - reconnect() - local str, str10, back, err - str = string.rep("x", math.mod(len, 10)) - str10 = string.rep("aZb.c#dAe?", math.floor(len/10)) - str = str .. str10 -remote "str = data:receive()" - sent, err = data:send(str.."\n") - if err then fail(err) end -remote "data:send(str ..'\\n')" - back, err = data:receive() - if err then fail(err) end - if back == str then pass("lines match") - else fail("lines don't match") end -end - ------------------------------------------------------------------------- -function test_rawline(len) - reconnect() - local str, str10, back, err - str = string.rep(string.char(47), math.mod(len, 10)) - str10 = string.rep(string.char(120,21,77,4,5,0,7,36,44,100), - math.floor(len/10)) - str = str .. str10 -remote "str = data:receive()" - sent, err = data:send(str.."\n") - if err then fail(err) end -remote "data:send(str..'\\n')" - back, err = data:receive() - if err then fail(err) end - if back == str then pass("lines match") - else fail("lines don't match") end -end - ------------------------------------------------------------------------- -function test_raw(len) - reconnect() - local half = math.floor(len/2) - local s1, s2, back, err - s1 = string.rep("x", half) - s2 = string.rep("y", len-half) -remote (string.format("str = data:receive(%d)", len)) - sent, err = data:send(s1) - if err then fail(err) end - sent, err = data:send(s2) - if err then fail(err) end -remote "data:send(str)" - back, err = data:receive(len) - if err then fail(err) end - if back == s1..s2 then pass("blocks match") - else fail("blocks don't match") end -end - ------------------------------------------------------------------------- -function test_totaltimeoutreceive(len, tm, sl) - reconnect() - local str, err, partial - pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl) - remote (string.format ([[ - data:settimeout(%d) - str = string.rep('a', %d) - data:send(str) - print('server: sleeping for %ds') - socket.sleep(%d) - print('server: woke up') - data:send(str) - ]], 2*tm, len, sl, sl)) - data:settimeout(tm, "total") -local t = socket.gettime() - str, err, partial, elapsed = data:receive(2*len) - check_timeout(tm, sl, elapsed, err, "receive", "total", - string.len(str or partial) == 2*len) -end - ------------------------------------------------------------------------- -function test_totaltimeoutsend(len, tm, sl) - reconnect() - local str, err, total - pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl) - remote (string.format ([[ - data:settimeout(%d) - str = data:receive(%d) - print('server: sleeping for %ds') - socket.sleep(%d) - print('server: woke up') - str = data:receive(%d) - ]], 2*tm, len, sl, sl, len)) - data:settimeout(tm, "total") - str = string.rep("a", 2*len) - total, err, partial, elapsed = data:send(str) - check_timeout(tm, sl, elapsed, err, "send", "total", - total == 2*len) -end - ------------------------------------------------------------------------- -function test_blockingtimeoutreceive(len, tm, sl) - reconnect() - local str, err, partial - pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) - remote (string.format ([[ - data:settimeout(%d) - str = string.rep('a', %d) - data:send(str) - print('server: sleeping for %ds') - socket.sleep(%d) - print('server: woke up') - data:send(str) - ]], 2*tm, len, sl, sl)) - data:settimeout(tm) - str, err, partial, elapsed = data:receive(2*len) - check_timeout(tm, sl, elapsed, err, "receive", "blocking", - string.len(str or partial) == 2*len) -end - ------------------------------------------------------------------------- -function test_blockingtimeoutsend(len, tm, sl) - reconnect() - local str, err, total - pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) - remote (string.format ([[ - data:settimeout(%d) - str = data:receive(%d) - print('server: sleeping for %ds') - socket.sleep(%d) - print('server: woke up') - str = data:receive(%d) - ]], 2*tm, len, sl, sl, len)) - data:settimeout(tm) - str = string.rep("a", 2*len) - total, err, partial, elapsed = data:send(str) - check_timeout(tm, sl, elapsed, err, "send", "blocking", - total == 2*len) -end - ------------------------------------------------------------------------- -function empty_connect() - reconnect() - if data then data:close() data = nil end - remote [[ - if data then data:close() data = nil end - data = server:accept() - ]] - data, err = socket.connect("", port) - if not data then - pass("ok") - data = socket.connect(host, port) - else - pass("gethostbyname returns localhost on empty string...") - end -end - ------------------------------------------------------------------------- -function isclosed(c) - return c:getfd() == -1 or c:getfd() == (2^32-1) -end - -function active_close() - reconnect() - if isclosed(data) then fail("should not be closed") end - data:close() - if not isclosed(data) then fail("should be closed") end - data = nil - local udp = socket.udp() - if isclosed(udp) then fail("should not be closed") end - udp:close() - if not isclosed(udp) then fail("should be closed") end - pass("ok") -end - ------------------------------------------------------------------------- -function test_closed() - local back, partial, err - local str = 'little string' - reconnect() - pass("trying read detection") - remote (string.format ([[ - data:send('%s') - data:close() - data = nil - ]], str)) - -- try to get a line - back, err, partial = data:receive() - if not err then fail("should have gotten 'closed'.") - elseif err ~= "closed" then fail("got '"..err.."' instead of 'closed'.") - elseif str ~= partial then fail("didn't receive partial result.") - else pass("graceful 'closed' received") end - reconnect() - pass("trying write detection") - remote [[ - data:close() - data = nil - ]] - total, err, partial = data:send(string.rep("ugauga", 100000)) - if not err then - pass("failed: output buffer is at least %d bytes long!", total) - elseif err ~= "closed" then - fail("got '"..err.."' instead of 'closed'.") - else - pass("graceful 'closed' received after %d bytes were sent", partial) - end -end - ------------------------------------------------------------------------- -function test_selectbugs() - local r, s, e = socket.select(nil, nil, 0.1) - assert(type(r) == "table" and type(s) == "table" and - (e == "timeout" or e == "error")) - pass("both nil: ok") - local udp = socket.udp() - udp:close() - r, s, e = socket.select({ udp }, { udp }, 0.1) - assert(type(r) == "table" and type(s) == "table" and - (e == "timeout" or e == "error")) - pass("closed sockets: ok") - e = pcall(socket.select, "wrong", 1, 0.1) - assert(e == false) - e = pcall(socket.select, {}, 1, 0.1) - assert(e == false) - pass("invalid input: ok") -end - ------------------------------------------------------------------------- -function accept_timeout() - io.stderr:write("accept with timeout (if it hangs, it failed): ") - local s, e = socket.bind("*", 0, 0) - assert(s, e) - local t = socket.gettime() - s:settimeout(1) - local c, e = s:accept() - assert(not c, "should not accept") - assert(e == "timeout", string.format("wrong error message (%s)", e)) - t = socket.gettime() - t - assert(t < 2, string.format("took to long to give up (%gs)", t)) - s:close() - pass("good") -end - ------------------------------------------------------------------------- -function connect_timeout() - io.stderr:write("connect with timeout (if it hangs, it failed!): ") - local t = socket.gettime() - local c, e = socket.tcp() - assert(c, e) - c:settimeout(0.1) - local t = socket.gettime() - local r, e = c:connect("10.0.0.1", 81) -print(r, e) - assert(not r, "should not connect") - assert(socket.gettime() - t < 2, "took too long to give up.") - c:close() - print("ok") -end - ------------------------------------------------------------------------- -function accept_errors() - io.stderr:write("not listening: ") - local d, e = socket.bind("*", 0) - assert(d, e); - local c, e = socket.tcp(); - assert(c, e); - d:setfd(c:getfd()) - d:settimeout(2) - local r, e = d:accept() - assert(not r and e) - print("ok: ", e) - io.stderr:write("not supported: ") - local c, e = socket.udp() - assert(c, e); - d:setfd(c:getfd()) - local r, e = d:accept() - assert(not r and e) - print("ok: ", e) -end - ------------------------------------------------------------------------- -function connect_errors() - io.stderr:write("connection refused: ") - local c, e = socket.connect("localhost", 1); - assert(not c and e) - print("ok: ", e) - io.stderr:write("host not found: ") - local c, e = socket.connect("host.is.invalid", 1); - assert(not c and e, e) - print("ok: ", e) -end - ------------------------------------------------------------------------- -function rebind_test() - local c = socket.bind("localhost", 0) - local i, p = c:getsockname() - local s, e = socket.tcp() - assert(s, e) - s:setoption("reuseaddr", false) - r, e = s:bind("localhost", p) - assert(not r, "managed to rebind!") - assert(e) - print("ok: ", e) -end - ------------------------------------------------------------------------- -function getstats_test() - reconnect() - local t = 0 - for i = 1, 25 do - local c = math.random(1, 100) - remote (string.format ([[ - str = data:receive(%d) - data:send(str) - ]], c)) - data:send(string.rep("a", c)) - data:receive(c) - t = t + c - local r, s, a = data:getstats() - assert(r == t, "received count failed" .. tostring(r) - .. "/" .. tostring(t)) - assert(s == t, "sent count failed" .. tostring(s) - .. "/" .. tostring(t)) - end - print("ok") -end - - ------------------------------------------------------------------------- -function test_nonblocking(size) - reconnect() -print("Testing " .. 2*size .. " bytes") -remote(string.format([[ - data:send(string.rep("a", %d)) - socket.sleep(0.5) - data:send(string.rep("b", %d) .. "\n") -]], size, size)) - local err = "timeout" - local part = "" - local str - data:settimeout(0) - while 1 do - str, err, part = data:receive("*l", part) - if err ~= "timeout" then break end - end - assert(str == (string.rep("a", size) .. string.rep("b", size))) - reconnect() -remote(string.format([[ - str = data:receive(%d) - socket.sleep(0.5) - str = data:receive(2*%d, str) - data:send(str) -]], size, size)) - data:settimeout(0) - local start = 0 - while 1 do - ret, err, start = data:send(str, start+1) - if err ~= "timeout" then break end - end - data:send("\n") - data:settimeout(-1) - local back = data:receive(2*size) - assert(back == str, "'" .. back .. "' vs '" .. str .. "'") - print("ok") -end - ------------------------------------------------------------------------- -function test_readafterclose() - local back, partial, err - local str = 'little string' - reconnect() - pass("trying repeated '*a' pattern") - remote (string.format ([[ - data:send('%s') - data:close() - data = nil - ]], str)) - back, err, partial = data:receive("*a") - assert(back == str, "unexpected data read") - back, err, partial = data:receive("*a") - assert(back == nil and err == "closed", "should have returned 'closed'") - print("ok") - reconnect() - pass("trying active close before '*a'") - remote (string.format ([[ - data:close() - data = nil - ]])) - data:close() - back, err, partial = data:receive("*a") - assert(back == nil and err == "closed", "should have returned 'closed'") - print("ok") - reconnect() - pass("trying active close before '*l'") - remote (string.format ([[ - data:close() - data = nil - ]])) - data:close() - back, err, partial = data:receive() - assert(back == nil and err == "closed", "should have returned 'closed'") - print("ok") - reconnect() - pass("trying active close before raw 1") - remote (string.format ([[ - data:close() - data = nil - ]])) - data:close() - back, err, partial = data:receive(1) - assert(back == nil and err == "closed", "should have returned 'closed'") - print("ok") - reconnect() - pass("trying active close before raw 0") - remote (string.format ([[ - data:close() - data = nil - ]])) - data:close() - back, err, partial = data:receive(0) - assert(back == nil and err == "closed", "should have returned 'closed'") - print("ok") -end - -test("method registration") -test_methods(socket.tcp(), { - "accept", - "bind", - "close", - "connect", - "dirty", - "getfd", - "getpeername", - "getsockname", - "getstats", - "setstats", - "listen", - "receive", - "send", - "setfd", - "setoption", - "setpeername", - "setsockname", - "settimeout", - "shutdown", -}) - -test_methods(socket.udp(), { - "close", - "getpeername", - "dirty", - "getfd", - "getpeername", - "getsockname", - "receive", - "receivefrom", - "send", - "sendto", - "setfd", - "setoption", - "setpeername", - "setsockname", - "settimeout" -}) - -test("testing read after close") -test_readafterclose() - -test("select function") -test_selectbugs() - -test("connect function") -connect_timeout() -empty_connect() -connect_errors() - -test("rebinding: ") -rebind_test() - -test("active close: ") -active_close() - -test("closed connection detection: ") -test_closed() - -test("accept function: ") -accept_timeout() -accept_errors() - -test("getstats test") -getstats_test() - -test("character line") -test_asciiline(1) -test_asciiline(17) -test_asciiline(200) -test_asciiline(4091) -test_asciiline(80199) -test_asciiline(8000000) -test_asciiline(80199) -test_asciiline(4091) -test_asciiline(200) -test_asciiline(17) -test_asciiline(1) - -test("mixed patterns") -test_mixed(1) -test_mixed(17) -test_mixed(200) -test_mixed(4091) -test_mixed(801990) -test_mixed(4091) -test_mixed(200) -test_mixed(17) -test_mixed(1) - -test("binary line") -test_rawline(1) -test_rawline(17) -test_rawline(200) -test_rawline(4091) -test_rawline(80199) -test_rawline(8000000) -test_rawline(80199) -test_rawline(4091) -test_rawline(200) -test_rawline(17) -test_rawline(1) - -test("raw transfer") -test_raw(1) -test_raw(17) -test_raw(200) -test_raw(4091) -test_raw(80199) -test_raw(8000000) -test_raw(80199) -test_raw(4091) -test_raw(200) -test_raw(17) -test_raw(1) - -test("non-blocking transfer") -test_nonblocking(1) -test_nonblocking(17) -test_nonblocking(200) -test_nonblocking(4091) -test_nonblocking(80199) -test_nonblocking(800000) -test_nonblocking(80199) -test_nonblocking(4091) -test_nonblocking(200) -test_nonblocking(17) -test_nonblocking(1) - -test("total timeout on send") -test_totaltimeoutsend(800091, 1, 3) -test_totaltimeoutsend(800091, 2, 3) -test_totaltimeoutsend(800091, 5, 2) -test_totaltimeoutsend(800091, 3, 1) - -test("total timeout on receive") -test_totaltimeoutreceive(800091, 1, 3) -test_totaltimeoutreceive(800091, 2, 3) -test_totaltimeoutreceive(800091, 3, 2) -test_totaltimeoutreceive(800091, 3, 1) - -test("blocking timeout on send") -test_blockingtimeoutsend(800091, 1, 3) -test_blockingtimeoutsend(800091, 2, 3) -test_blockingtimeoutsend(800091, 3, 2) -test_blockingtimeoutsend(800091, 3, 1) - -test("blocking timeout on receive") -test_blockingtimeoutreceive(800091, 1, 3) -test_blockingtimeoutreceive(800091, 2, 3) -test_blockingtimeoutreceive(800091, 3, 2) -test_blockingtimeoutreceive(800091, 3, 1) - -test(string.format("done in %.2fs", socket.gettime() - start)) diff --git a/csrc/socket/test/testsrvr.lua b/csrc/socket/test/testsrvr.lua deleted file mode 100644 index f1972c2..0000000 --- a/csrc/socket/test/testsrvr.lua +++ /dev/null @@ -1,15 +0,0 @@ -socket = require("socket"); -host = host or "localhost"; -port = port or "8383"; -server = assert(socket.bind(host, port)); -ack = "\n"; -while 1 do - print("server: waiting for client connection..."); - control = assert(server:accept()); - while 1 do - command = assert(control:receive()); - assert(control:send(ack)); - print(command); - (loadstring(command))(); - end -end diff --git a/csrc/socket/test/testsupport.lua b/csrc/socket/test/testsupport.lua deleted file mode 100644 index acad8f5..0000000 --- a/csrc/socket/test/testsupport.lua +++ /dev/null @@ -1,37 +0,0 @@ -function readfile(name) - local f = io.open(name, "rb") - if not f then return nil end - local s = f:read("*a") - f:close() - return s -end - -function similar(s1, s2) - return string.lower(string.gsub(s1 or "", "%s", "")) == - string.lower(string.gsub(s2 or "", "%s", "")) -end - -function fail(msg) - msg = msg or "failed" - error(msg, 2) -end - -function compare(input, output) - local original = readfile(input) - local recovered = readfile(output) - if original ~= recovered then fail("comparison failed") - else print("ok") end -end - -local G = _G -local set = rawset -local warn = print - -local setglobal = function(table, key, value) - warn("changed " .. key) - set(table, key, value) -end - -setmetatable(G, { - __newindex = setglobal -}) diff --git a/ltn12.lua b/ltn12.lua new file mode 100644 index 0000000..1014de2 --- /dev/null +++ b/ltn12.lua @@ -0,0 +1,305 @@ +----------------------------------------------------------------------------- +-- LTN12 - Filters, sources, sinks and pumps. +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module +----------------------------------------------------------------------------- +local string = require("string") +local table = require("table") +local base = _G +local _M = {} +if module then -- heuristic for exporting a global package table + ltn12 = _M +end +local filter,source,sink,pump = {},{},{},{} + +_M.filter = filter +_M.source = source +_M.sink = sink +_M.pump = pump + +-- 2048 seems to be better in windows... +_M.BLOCKSIZE = 2048 +_M._VERSION = "LTN12 1.0.3" + +----------------------------------------------------------------------------- +-- Filter stuff +----------------------------------------------------------------------------- +-- returns a high level filter that cycles a low-level filter +function filter.cycle(low, ctx, extra) + base.assert(low) + return function(chunk) + local ret + ret, ctx = low(ctx, chunk, extra) + return ret + end +end + +-- chains a bunch of filters together +-- (thanks to Wim Couwenberg) +function filter.chain(...) + local arg = {...} + local n = select('#',...) + local top, index = 1, 1 + local retry = "" + return function(chunk) + retry = chunk and retry + while true do + if index == top then + chunk = arg[index](chunk) + if chunk == "" or top == n then return chunk + elseif chunk then index = index + 1 + else + top = top+1 + index = top + end + else + chunk = arg[index](chunk or "") + if chunk == "" then + index = index - 1 + chunk = retry + elseif chunk then + if index == n then return chunk + else index = index + 1 end + else base.error("filter returned inappropriate nil") end + end + end + end +end + +----------------------------------------------------------------------------- +-- Source stuff +----------------------------------------------------------------------------- +-- create an empty source +local function empty() + return nil +end + +function source.empty() + return empty +end + +-- returns a source that just outputs an error +function source.error(err) + return function() + return nil, err + end +end + +-- creates a file source +function source.file(handle, io_err) + if handle then + return function() + local chunk = handle:read(_M.BLOCKSIZE) + if not chunk then handle:close() end + return chunk + end + else return source.error(io_err or "unable to open file") end +end + +-- turns a fancy source into a simple source +function source.simplify(src) + base.assert(src) + return function() + local chunk, err_or_new = src() + src = err_or_new or src + if not chunk then return nil, err_or_new + else return chunk end + end +end + +-- creates string source +function source.string(s) + if s then + local i = 1 + return function() + local chunk = string.sub(s, i, i+_M.BLOCKSIZE-1) + i = i + _M.BLOCKSIZE + if chunk ~= "" then return chunk + else return nil end + end + else return source.empty() end +end + +-- creates rewindable source +function source.rewind(src) + base.assert(src) + local t = {} + return function(chunk) + if not chunk then + chunk = table.remove(t) + if not chunk then return src() + else return chunk end + else + table.insert(t, chunk) + end + end +end + +-- chains a source with one or several filter(s) +function source.chain(src, f, ...) + if ... then f=filter.chain(f, ...) end + base.assert(src and f) + local last_in, last_out = "", "" + local state = "feeding" + local err + return function() + if not last_out then + base.error('source is empty!', 2) + end + while true do + if state == "feeding" then + last_in, err = src() + if err then return nil, err end + last_out = f(last_in) + if not last_out then + if last_in then + base.error('filter returned inappropriate nil') + else + return nil + end + elseif last_out ~= "" then + state = "eating" + if last_in then last_in = "" end + return last_out + end + else + last_out = f(last_in) + if last_out == "" then + if last_in == "" then + state = "feeding" + else + base.error('filter returned ""') + end + elseif not last_out then + if last_in then + base.error('filter returned inappropriate nil') + else + return nil + end + else + return last_out + end + end + end + end +end + +-- creates a source that produces contents of several sources, one after the +-- other, as if they were concatenated +-- (thanks to Wim Couwenberg) +function source.cat(...) + local arg = {...} + local src = table.remove(arg, 1) + return function() + while src do + local chunk, err = src() + if chunk then return chunk end + if err then return nil, err end + src = table.remove(arg, 1) + end + end +end + +----------------------------------------------------------------------------- +-- Sink stuff +----------------------------------------------------------------------------- +-- creates a sink that stores into a table +function sink.table(t) + t = t or {} + local f = function(chunk, err) + if chunk then table.insert(t, chunk) end + return 1 + end + return f, t +end + +-- turns a fancy sink into a simple sink +function sink.simplify(snk) + base.assert(snk) + return function(chunk, err) + local ret, err_or_new = snk(chunk, err) + if not ret then return nil, err_or_new end + snk = err_or_new or snk + return 1 + end +end + +-- creates a file sink +function sink.file(handle, io_err) + if handle then + return function(chunk, err) + if not chunk then + handle:close() + return 1 + else return handle:write(chunk) end + end + else return sink.error(io_err or "unable to open file") end +end + +-- creates a sink that discards data +local function null() + return 1 +end + +function sink.null() + return null +end + +-- creates a sink that just returns an error +function sink.error(err) + return function() + return nil, err + end +end + +-- chains a sink with one or several filter(s) +function sink.chain(f, snk, ...) + if ... then + local args = { f, snk, ... } + snk = table.remove(args, #args) + f = filter.chain(unpack(args)) + end + base.assert(f and snk) + return function(chunk, err) + if chunk ~= "" then + local filtered = f(chunk) + local done = chunk and "" + while true do + local ret, snkerr = snk(filtered, err) + if not ret then return nil, snkerr end + if filtered == done then return 1 end + filtered = f(done) + end + else return 1 end + end +end + +----------------------------------------------------------------------------- +-- Pump stuff +----------------------------------------------------------------------------- +-- pumps one chunk from the source to the sink +function pump.step(src, snk) + local chunk, src_err = src() + local ret, snk_err = snk(chunk, src_err) + if chunk and ret then return 1 + else return nil, src_err or snk_err end +end + +-- pumps all data from a source to a sink, using a step function +function pump.all(src, snk, step) + base.assert(src and snk) + step = step or pump.step + while true do + local ret, err = step(src, snk) + if not ret then + if err then return nil, err + else return 1 end + end + end +end + +return _M diff --git a/mime.lua b/mime.lua new file mode 100644 index 0000000..642cd9c --- /dev/null +++ b/mime.lua @@ -0,0 +1,90 @@ +----------------------------------------------------------------------------- +-- MIME support for the Lua language. +-- Author: Diego Nehab +-- Conforming to RFCs 2045-2049 +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +----------------------------------------------------------------------------- +local base = _G +local ltn12 = require("ltn12") +local mime = require("mime.core") +local io = require("io") +local string = require("string") +local _M = mime + +-- encode, decode and wrap algorithm tables +local encodet, decodet, wrapt = {},{},{} + +_M.encodet = encodet +_M.decodet = decodet +_M.wrapt = wrapt + +-- creates a function that chooses a filter by name from a given table +local function choose(table) + return function(name, opt1, opt2) + if base.type(name) ~= "string" then + name, opt1, opt2 = "default", name, opt1 + end + local f = table[name or "nil"] + if not f then + base.error("unknown key (" .. base.tostring(name) .. ")", 3) + else return f(opt1, opt2) end + end +end + +-- define the encoding filters +encodet['base64'] = function() + return ltn12.filter.cycle(_M.b64, "") +end + +encodet['quoted-printable'] = function(mode) + return ltn12.filter.cycle(_M.qp, "", + (mode == "binary") and "=0D=0A" or "\r\n") +end + +-- define the decoding filters +decodet['base64'] = function() + return ltn12.filter.cycle(_M.unb64, "") +end + +decodet['quoted-printable'] = function() + return ltn12.filter.cycle(_M.unqp, "") +end + +local function format(chunk) + if chunk then + if chunk == "" then return "''" + else return string.len(chunk) end + else return "nil" end +end + +-- define the line-wrap filters +wrapt['text'] = function(length) + length = length or 76 + return ltn12.filter.cycle(_M.wrp, length, length) +end +wrapt['base64'] = wrapt['text'] +wrapt['default'] = wrapt['text'] + +wrapt['quoted-printable'] = function() + return ltn12.filter.cycle(_M.qpwrp, 76, 76) +end + +-- function that choose the encoding, decoding or wrap algorithm +_M.encode = choose(encodet) +_M.decode = choose(decodet) +_M.wrap = choose(wrapt) + +-- define the end-of-line normalization filter +function _M.normalize(marker) + return ltn12.filter.cycle(_M.eol, 0, marker) +end + +-- high level stuffing filter +function _M.stuff() + return ltn12.filter.cycle(_M.dot, 2) +end + +return _M \ No newline at end of file diff --git a/socket.exclude b/socket.exclude index ec0bda5..dc4e180 100644 --- a/socket.exclude +++ b/socket.exclude @@ -16,6 +16,8 @@ !/media/ !/socket* +!/mime.lua +!/ltn12.lua !/socket/ !/socket/** !/bin/mingw32/clib/socket* diff --git a/socket.lua b/socket.lua index 0b63c51..3913e6f 100644 --- a/socket.lua +++ b/socket.lua @@ -1,7 +1,6 @@ ----------------------------------------------------------------------------- -- LuaSocket helper module -- Author: Diego Nehab --- RCS ID: $Id: socket.lua,v 1.22 2005/11/22 08:33:29 diego Exp $ ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- @@ -10,38 +9,53 @@ local base = _G local string = require("string") local math = require("math") -local socket = require("socket_core") -module("socket") +local socket = require("socket.core") + +local _M = socket ----------------------------------------------------------------------------- -- Exported auxiliar functions ----------------------------------------------------------------------------- -function connect(address, port, laddress, lport) - local sock, err = socket.tcp() - if not sock then return nil, err end - if laddress then - local res, err = sock:bind(laddress, lport, -1) - if not res then return nil, err end +function _M.connect4(address, port, laddress, lport) + return socket.connect(address, port, laddress, lport, "inet") +end + +function _M.connect6(address, port, laddress, lport) + return socket.connect(address, port, laddress, lport, "inet6") +end + +function _M.bind(host, port, backlog) + if host == "*" then host = "0.0.0.0" end + local addrinfo, err = socket.dns.getaddrinfo(host); + if not addrinfo then return nil, err end + local sock, res + err = "no info on address" + for i, alt in base.ipairs(addrinfo) do + if alt.family == "inet" then + sock, err = socket.tcp() + else + sock, err = socket.tcp6() + end + if not sock then return nil, err end + sock:setoption("reuseaddr", true) + res, err = sock:bind(alt.addr, port) + if not res then + sock:close() + else + res, err = sock:listen(backlog) + if not res then + sock:close() + else + return sock + end + end end - local res, err = sock:connect(address, port) - if not res then return nil, err end - return sock + return nil, err end -function bind(host, port, backlog) - local sock, err = socket.tcp() - if not sock then return nil, err end - sock:setoption("reuseaddr", true) - local res, err = sock:bind(host, port) - if not res then return nil, err end - res, err = sock:listen(backlog) - if not res then return nil, err end - return sock -end +_M.try = _M.newtry() -try = newtry() - -function choose(table) +function _M.choose(table) return function(name, opt1, opt2) if base.type(name) ~= "string" then name, opt1, opt2 = "default", name, opt1 @@ -56,10 +70,11 @@ end -- Socket sources and sinks, conforming to LTN12 ----------------------------------------------------------------------------- -- create namespaces inside LuaSocket namespace -sourcet = {} -sinkt = {} +local sourcet, sinkt = {}, {} +_M.sourcet = sourcet +_M.sinkt = sinkt -BLOCKSIZE = 2048 +_M.BLOCKSIZE = 2048 sinkt["close-when-done"] = function(sock) return base.setmetatable({ @@ -89,7 +104,7 @@ end sinkt["default"] = sinkt["keep-open"] -sink = choose(sinkt) +_M.sink = _M.choose(sinkt) sourcet["by-length"] = function(sock, length) return base.setmetatable({ @@ -129,5 +144,6 @@ end sourcet["default"] = sourcet["until-closed"] -source = choose(sourcet) +_M.source = _M.choose(sourcet) +return _M diff --git a/socket.md b/socket.md index 9359d4d..7728da8 100644 --- a/socket.md +++ b/socket.md @@ -3,14 +3,14 @@ project: socket tagline: networking support --- -## `local socket = require'socket'` +**NOTE: This is just a distribution of luasocket. luasocket is developed [here][luasocket site].** -LuaSocket. +## `local socket = require'socket'` ## Documentation -It's [here][luasocket doc]. - - -[luasocket doc]: http://w3.impa.br/~diego/software/luasocket/reference.html +There's up-to-date luasocket documentation [here][luasocket doc]. Ignore the [official site]. +[luasocket doc]: https://rawgithub.com/diegonehab/luasocket/master/doc/index.html +[official site]: http://w3.impa.br/~diego/software/luasocket/ +[luasocket site]: https://github.com/diegonehab/luasocket diff --git a/socket/ftp.lua b/socket/ftp.lua new file mode 100644 index 0000000..ea1145b --- /dev/null +++ b/socket/ftp.lua @@ -0,0 +1,285 @@ +----------------------------------------------------------------------------- +-- FTP support for the Lua language +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +----------------------------------------------------------------------------- +local base = _G +local table = require("table") +local string = require("string") +local math = require("math") +local socket = require("socket") +local url = require("socket.url") +local tp = require("socket.tp") +local ltn12 = require("ltn12") +socket.ftp = {} +local _M = socket.ftp +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +-- timeout in seconds before the program gives up on a connection +_M.TIMEOUT = 60 +-- default port for ftp service +_M.PORT = 21 +-- this is the default anonymous password. used when no password is +-- provided in url. should be changed to your e-mail. +_M.USER = "ftp" +_M.PASSWORD = "anonymous@anonymous.org" + +----------------------------------------------------------------------------- +-- Low level FTP API +----------------------------------------------------------------------------- +local metat = { __index = {} } + +function _M.open(server, port, create) + local tp = socket.try(tp.connect(server, port or _M.PORT, _M.TIMEOUT, create)) + local f = base.setmetatable({ tp = tp }, metat) + -- make sure everything gets closed in an exception + f.try = socket.newtry(function() f:close() end) + return f +end + +function metat.__index:portconnect() + self.try(self.server:settimeout(_M.TIMEOUT)) + self.data = self.try(self.server:accept()) + self.try(self.data:settimeout(_M.TIMEOUT)) +end + +function metat.__index:pasvconnect() + self.data = self.try(socket.tcp()) + self.try(self.data:settimeout(_M.TIMEOUT)) + self.try(self.data:connect(self.pasvt.ip, self.pasvt.port)) +end + +function metat.__index:login(user, password) + self.try(self.tp:command("user", user or _M.USER)) + local code, reply = self.try(self.tp:check{"2..", 331}) + if code == 331 then + self.try(self.tp:command("pass", password or _M.PASSWORD)) + self.try(self.tp:check("2..")) + end + return 1 +end + +function metat.__index:pasv() + self.try(self.tp:command("pasv")) + local code, reply = self.try(self.tp:check("2..")) + local pattern = "(%d+)%D(%d+)%D(%d+)%D(%d+)%D(%d+)%D(%d+)" + local a, b, c, d, p1, p2 = socket.skip(2, string.find(reply, pattern)) + self.try(a and b and c and d and p1 and p2, reply) + self.pasvt = { + ip = string.format("%d.%d.%d.%d", a, b, c, d), + port = p1*256 + p2 + } + if self.server then + self.server:close() + self.server = nil + end + return self.pasvt.ip, self.pasvt.port +end + +function metat.__index:port(ip, port) + self.pasvt = nil + if not ip then + ip, port = self.try(self.tp:getcontrol():getsockname()) + self.server = self.try(socket.bind(ip, 0)) + ip, port = self.try(self.server:getsockname()) + self.try(self.server:settimeout(_M.TIMEOUT)) + end + local pl = math.mod(port, 256) + local ph = (port - pl)/256 + local arg = string.gsub(string.format("%s,%d,%d", ip, ph, pl), "%.", ",") + self.try(self.tp:command("port", arg)) + self.try(self.tp:check("2..")) + return 1 +end + +function metat.__index:send(sendt) + self.try(self.pasvt or self.server, "need port or pasv first") + -- if there is a pasvt table, we already sent a PASV command + -- we just get the data connection into self.data + if self.pasvt then self:pasvconnect() end + -- get the transfer argument and command + local argument = sendt.argument or + url.unescape(string.gsub(sendt.path or "", "^[/\\]", "")) + if argument == "" then argument = nil end + local command = sendt.command or "stor" + -- send the transfer command and check the reply + self.try(self.tp:command(command, argument)) + local code, reply = self.try(self.tp:check{"2..", "1.."}) + -- if there is not a a pasvt table, then there is a server + -- and we already sent a PORT command + if not self.pasvt then self:portconnect() end + -- get the sink, source and step for the transfer + local step = sendt.step or ltn12.pump.step + local readt = {self.tp.c} + local checkstep = function(src, snk) + -- check status in control connection while downloading + local readyt = socket.select(readt, nil, 0) + if readyt[tp] then code = self.try(self.tp:check("2..")) end + return step(src, snk) + end + local sink = socket.sink("close-when-done", self.data) + -- transfer all data and check error + self.try(ltn12.pump.all(sendt.source, sink, checkstep)) + if string.find(code, "1..") then self.try(self.tp:check("2..")) end + -- done with data connection + self.data:close() + -- find out how many bytes were sent + local sent = socket.skip(1, self.data:getstats()) + self.data = nil + return sent +end + +function metat.__index:receive(recvt) + self.try(self.pasvt or self.server, "need port or pasv first") + if self.pasvt then self:pasvconnect() end + local argument = recvt.argument or + url.unescape(string.gsub(recvt.path or "", "^[/\\]", "")) + if argument == "" then argument = nil end + local command = recvt.command or "retr" + self.try(self.tp:command(command, argument)) + local code,reply = self.try(self.tp:check{"1..", "2.."}) + if (code >= 200) and (code <= 299) then + recvt.sink(reply) + return 1 + end + if not self.pasvt then self:portconnect() end + local source = socket.source("until-closed", self.data) + local step = recvt.step or ltn12.pump.step + self.try(ltn12.pump.all(source, recvt.sink, step)) + if string.find(code, "1..") then self.try(self.tp:check("2..")) end + self.data:close() + self.data = nil + return 1 +end + +function metat.__index:cwd(dir) + self.try(self.tp:command("cwd", dir)) + self.try(self.tp:check(250)) + return 1 +end + +function metat.__index:type(type) + self.try(self.tp:command("type", type)) + self.try(self.tp:check(200)) + return 1 +end + +function metat.__index:greet() + local code = self.try(self.tp:check{"1..", "2.."}) + if string.find(code, "1..") then self.try(self.tp:check("2..")) end + return 1 +end + +function metat.__index:quit() + self.try(self.tp:command("quit")) + self.try(self.tp:check("2..")) + return 1 +end + +function metat.__index:close() + if self.data then self.data:close() end + if self.server then self.server:close() end + return self.tp:close() +end + +----------------------------------------------------------------------------- +-- High level FTP API +----------------------------------------------------------------------------- +local function override(t) + if t.url then + local u = url.parse(t.url) + for i,v in base.pairs(t) do + u[i] = v + end + return u + else return t end +end + +local function tput(putt) + putt = override(putt) + socket.try(putt.host, "missing hostname") + local f = _M.open(putt.host, putt.port, putt.create) + f:greet() + f:login(putt.user, putt.password) + if putt.type then f:type(putt.type) end + f:pasv() + local sent = f:send(putt) + f:quit() + f:close() + return sent +end + +local default = { + path = "/", + scheme = "ftp" +} + +local function parse(u) + local t = socket.try(url.parse(u, default)) + socket.try(t.scheme == "ftp", "wrong scheme '" .. t.scheme .. "'") + socket.try(t.host, "missing hostname") + local pat = "^type=(.)$" + if t.params then + t.type = socket.skip(2, string.find(t.params, pat)) + socket.try(t.type == "a" or t.type == "i", + "invalid type '" .. t.type .. "'") + end + return t +end + +local function sput(u, body) + local putt = parse(u) + putt.source = ltn12.source.string(body) + return tput(putt) +end + +_M.put = socket.protect(function(putt, body) + if base.type(putt) == "string" then return sput(putt, body) + else return tput(putt) end +end) + +local function tget(gett) + gett = override(gett) + socket.try(gett.host, "missing hostname") + local f = _M.open(gett.host, gett.port, gett.create) + f:greet() + f:login(gett.user, gett.password) + if gett.type then f:type(gett.type) end + f:pasv() + f:receive(gett) + f:quit() + return f:close() +end + +local function sget(u) + local gett = parse(u) + local t = {} + gett.sink = ltn12.sink.table(t) + tget(gett) + return table.concat(t) +end + +_M.command = socket.protect(function(cmdt) + cmdt = override(cmdt) + socket.try(cmdt.host, "missing hostname") + socket.try(cmdt.command, "missing command") + local f = open(cmdt.host, cmdt.port, cmdt.create) + f:greet() + f:login(cmdt.user, cmdt.password) + f.try(f.tp:command(cmdt.command, cmdt.argument)) + if cmdt.check then f.try(f.tp:check(cmdt.check)) end + f:quit() + return f:close() +end) + +_M.get = socket.protect(function(gett) + if base.type(gett) == "string" then return sget(gett) + else return tget(gett) end +end) + +return _M \ No newline at end of file diff --git a/socket/headers.lua b/socket/headers.lua new file mode 100644 index 0000000..1eb8223 --- /dev/null +++ b/socket/headers.lua @@ -0,0 +1,104 @@ +----------------------------------------------------------------------------- +-- Canonic header field capitalization +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- +local socket = require("socket") +socket.headers = {} +local _M = socket.headers + +_M.canonic = { + ["accept"] = "Accept", + ["accept-charset"] = "Accept-Charset", + ["accept-encoding"] = "Accept-Encoding", + ["accept-language"] = "Accept-Language", + ["accept-ranges"] = "Accept-Ranges", + ["action"] = "Action", + ["alternate-recipient"] = "Alternate-Recipient", + ["age"] = "Age", + ["allow"] = "Allow", + ["arrival-date"] = "Arrival-Date", + ["authorization"] = "Authorization", + ["bcc"] = "Bcc", + ["cache-control"] = "Cache-Control", + ["cc"] = "Cc", + ["comments"] = "Comments", + ["connection"] = "Connection", + ["content-description"] = "Content-Description", + ["content-disposition"] = "Content-Disposition", + ["content-encoding"] = "Content-Encoding", + ["content-id"] = "Content-ID", + ["content-language"] = "Content-Language", + ["content-length"] = "Content-Length", + ["content-location"] = "Content-Location", + ["content-md5"] = "Content-MD5", + ["content-range"] = "Content-Range", + ["content-transfer-encoding"] = "Content-Transfer-Encoding", + ["content-type"] = "Content-Type", + ["cookie"] = "Cookie", + ["date"] = "Date", + ["diagnostic-code"] = "Diagnostic-Code", + ["dsn-gateway"] = "DSN-Gateway", + ["etag"] = "ETag", + ["expect"] = "Expect", + ["expires"] = "Expires", + ["final-log-id"] = "Final-Log-ID", + ["final-recipient"] = "Final-Recipient", + ["from"] = "From", + ["host"] = "Host", + ["if-match"] = "If-Match", + ["if-modified-since"] = "If-Modified-Since", + ["if-none-match"] = "If-None-Match", + ["if-range"] = "If-Range", + ["if-unmodified-since"] = "If-Unmodified-Since", + ["in-reply-to"] = "In-Reply-To", + ["keywords"] = "Keywords", + ["last-attempt-date"] = "Last-Attempt-Date", + ["last-modified"] = "Last-Modified", + ["location"] = "Location", + ["max-forwards"] = "Max-Forwards", + ["message-id"] = "Message-ID", + ["mime-version"] = "MIME-Version", + ["original-envelope-id"] = "Original-Envelope-ID", + ["original-recipient"] = "Original-Recipient", + ["pragma"] = "Pragma", + ["proxy-authenticate"] = "Proxy-Authenticate", + ["proxy-authorization"] = "Proxy-Authorization", + ["range"] = "Range", + ["received"] = "Received", + ["received-from-mta"] = "Received-From-MTA", + ["references"] = "References", + ["referer"] = "Referer", + ["remote-mta"] = "Remote-MTA", + ["reply-to"] = "Reply-To", + ["reporting-mta"] = "Reporting-MTA", + ["resent-bcc"] = "Resent-Bcc", + ["resent-cc"] = "Resent-Cc", + ["resent-date"] = "Resent-Date", + ["resent-from"] = "Resent-From", + ["resent-message-id"] = "Resent-Message-ID", + ["resent-reply-to"] = "Resent-Reply-To", + ["resent-sender"] = "Resent-Sender", + ["resent-to"] = "Resent-To", + ["retry-after"] = "Retry-After", + ["return-path"] = "Return-Path", + ["sender"] = "Sender", + ["server"] = "Server", + ["smtp-remote-recipient"] = "SMTP-Remote-Recipient", + ["status"] = "Status", + ["subject"] = "Subject", + ["te"] = "TE", + ["to"] = "To", + ["trailer"] = "Trailer", + ["transfer-encoding"] = "Transfer-Encoding", + ["upgrade"] = "Upgrade", + ["user-agent"] = "User-Agent", + ["vary"] = "Vary", + ["via"] = "Via", + ["warning"] = "Warning", + ["will-retry-until"] = "Will-Retry-Until", + ["www-authenticate"] = "WWW-Authenticate", + ["x-mailer"] = "X-Mailer", +} + +return _M \ No newline at end of file diff --git a/socket/http.lua b/socket/http.lua new file mode 100644 index 0000000..1d0eb50 --- /dev/null +++ b/socket/http.lua @@ -0,0 +1,356 @@ +----------------------------------------------------------------------------- +-- HTTP/1.1 client support for the Lua language. +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +------------------------------------------------------------------------------- +local socket = require("socket") +local url = require("socket.url") +local ltn12 = require("ltn12") +local mime = require("mime") +local string = require("string") +local headers = require("socket.headers") +local base = _G +local table = require("table") +socket.http = {} +local _M = socket.http + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +-- connection timeout in seconds +_M.TIMEOUT = 60 +-- default port for document retrieval +_M.PORT = 80 +-- user agent field sent in request +_M.USERAGENT = socket._VERSION + +----------------------------------------------------------------------------- +-- Reads MIME headers from a connection, unfolding where needed +----------------------------------------------------------------------------- +local function receiveheaders(sock, headers) + local line, name, value, err + headers = headers or {} + -- get first line + line, err = sock:receive() + if err then return nil, err end + -- headers go until a blank line is found + while line ~= "" do + -- get field-name and value + name, value = socket.skip(2, string.find(line, "^(.-):%s*(.*)")) + if not (name and value) then return nil, "malformed reponse headers" end + name = string.lower(name) + -- get next line (value might be folded) + line, err = sock:receive() + if err then return nil, err end + -- unfold any folded values + while string.find(line, "^%s") do + value = value .. line + line = sock:receive() + if err then return nil, err end + end + -- save pair in table + if headers[name] then headers[name] = headers[name] .. ", " .. value + else headers[name] = value end + end + return headers +end + +----------------------------------------------------------------------------- +-- Extra sources and sinks +----------------------------------------------------------------------------- +socket.sourcet["http-chunked"] = function(sock, headers) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function() + -- get chunk size, skip extention + local line, err = sock:receive() + if err then return nil, err end + local size = base.tonumber(string.gsub(line, ";.*", ""), 16) + if not size then return nil, "invalid chunk size" end + -- was it the last chunk? + if size > 0 then + -- if not, get chunk and skip terminating CRLF + local chunk, err, part = sock:receive(size) + if chunk then sock:receive() end + return chunk, err + else + -- if it was, read trailers into headers table + headers, err = receiveheaders(sock, headers) + if not headers then return nil, err end + end + end + }) +end + +socket.sinkt["http-chunked"] = function(sock) + return base.setmetatable({ + getfd = function() return sock:getfd() end, + dirty = function() return sock:dirty() end + }, { + __call = function(self, chunk, err) + if not chunk then return sock:send("0\r\n\r\n") end + local size = string.format("%X\r\n", string.len(chunk)) + return sock:send(size .. chunk .. "\r\n") + end + }) +end + +----------------------------------------------------------------------------- +-- Low level HTTP API +----------------------------------------------------------------------------- +local metat = { __index = {} } + +function _M.open(host, port, create) + -- create socket with user connect function, or with default + local c = socket.try((create or socket.tcp)()) + local h = base.setmetatable({ c = c }, metat) + -- create finalized try + h.try = socket.newtry(function() h:close() end) + -- set timeout before connecting + h.try(c:settimeout(_M.TIMEOUT)) + h.try(c:connect(host, port or _M.PORT)) + -- here everything worked + return h +end + +function metat.__index:sendrequestline(method, uri) + local reqline = string.format("%s %s HTTP/1.1\r\n", method or "GET", uri) + return self.try(self.c:send(reqline)) +end + +function metat.__index:sendheaders(tosend) + local canonic = headers.canonic + local h = "\r\n" + for f, v in base.pairs(tosend) do + h = (canonic[f] or f) .. ": " .. v .. "\r\n" .. h + end + self.try(self.c:send(h)) + return 1 +end + +function metat.__index:sendbody(headers, source, step) + source = source or ltn12.source.empty() + step = step or ltn12.pump.step + -- if we don't know the size in advance, send chunked and hope for the best + local mode = "http-chunked" + if headers["content-length"] then mode = "keep-open" end + return self.try(ltn12.pump.all(source, socket.sink(mode, self.c), step)) +end + +function metat.__index:receivestatusline() + local status = self.try(self.c:receive(5)) + -- identify HTTP/0.9 responses, which do not contain a status line + -- this is just a heuristic, but is what the RFC recommends + if status ~= "HTTP/" then return nil, status end + -- otherwise proceed reading a status line + status = self.try(self.c:receive("*l", status)) + local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) + return self.try(base.tonumber(code), status) +end + +function metat.__index:receiveheaders() + return self.try(receiveheaders(self.c)) +end + +function metat.__index:receivebody(headers, sink, step) + sink = sink or ltn12.sink.null() + step = step or ltn12.pump.step + local length = base.tonumber(headers["content-length"]) + local t = headers["transfer-encoding"] -- shortcut + local mode = "default" -- connection close + if t and t ~= "identity" then mode = "http-chunked" + elseif base.tonumber(headers["content-length"]) then mode = "by-length" end + return self.try(ltn12.pump.all(socket.source(mode, self.c, length), + sink, step)) +end + +function metat.__index:receive09body(status, sink, step) + local source = ltn12.source.rewind(socket.source("until-closed", self.c)) + source(status) + return self.try(ltn12.pump.all(source, sink, step)) +end + +function metat.__index:close() + return self.c:close() +end + +----------------------------------------------------------------------------- +-- High level HTTP API +----------------------------------------------------------------------------- +local function adjusturi(reqt) + local u = reqt + -- if there is a proxy, we need the full url. otherwise, just a part. + if not reqt.proxy and not _M.PROXY then + u = { + path = socket.try(reqt.path, "invalid path 'nil'"), + params = reqt.params, + query = reqt.query, + fragment = reqt.fragment + } + end + return url.build(u) +end + +local function adjustproxy(reqt) + local proxy = reqt.proxy or _M.PROXY + if proxy then + proxy = url.parse(proxy) + return proxy.host, proxy.port or 3128 + else + return reqt.host, reqt.port + end +end + +local function adjustheaders(reqt) + -- default headers + local host = reqt.host + if reqt.port then host = host .. ":" .. reqt.port end + local lower = { + ["user-agent"] = _M.USERAGENT, + ["host"] = host, + ["connection"] = "close, TE", + ["te"] = "trailers" + } + -- if we have authentication information, pass it along + if reqt.user and reqt.password then + lower["authorization"] = + "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) + end + -- override with user headers + for i,v in base.pairs(reqt.headers or lower) do + lower[string.lower(i)] = v + end + return lower +end + +-- default url parts +local default = { + host = "", + port = _M.PORT, + path ="/", + scheme = "http" +} + +local function adjustrequest(reqt) + -- parse url if provided + local nreqt = reqt.url and url.parse(reqt.url, default) or {} + -- explicit components override url + for i,v in base.pairs(reqt) do nreqt[i] = v end + if nreqt.port == "" then nreqt.port = 80 end + socket.try(nreqt.host and nreqt.host ~= "", + "invalid host '" .. base.tostring(nreqt.host) .. "'") + -- compute uri if user hasn't overriden + nreqt.uri = reqt.uri or adjusturi(nreqt) + -- ajust host and port if there is a proxy + nreqt.host, nreqt.port = adjustproxy(nreqt) + -- adjust headers in request + nreqt.headers = adjustheaders(nreqt) + return nreqt +end + +local function shouldredirect(reqt, code, headers) + return headers.location and + string.gsub(headers.location, "%s", "") ~= "" and + (reqt.redirect ~= false) and + (code == 301 or code == 302 or code == 303 or code == 307) and + (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") + and (not reqt.nredirects or reqt.nredirects < 5) +end + +local function shouldreceivebody(reqt, code) + if reqt.method == "HEAD" then return nil end + if code == 204 or code == 304 then return nil end + if code >= 100 and code < 200 then return nil end + return 1 +end + +-- forward declarations +local trequest, tredirect + +--[[local]] function tredirect(reqt, location) + local result, code, headers, status = trequest { + -- the RFC says the redirect URL has to be absolute, but some + -- servers do not respect that + url = url.absolute(reqt.url, location), + source = reqt.source, + sink = reqt.sink, + headers = reqt.headers, + proxy = reqt.proxy, + nredirects = (reqt.nredirects or 0) + 1, + create = reqt.create + } + -- pass location header back as a hint we redirected + headers = headers or {} + headers.location = headers.location or location + return result, code, headers, status +end + +--[[local]] function trequest(reqt) + -- we loop until we get what we want, or + -- until we are sure there is no way to get it + local nreqt = adjustrequest(reqt) + local h = _M.open(nreqt.host, nreqt.port, nreqt.create) + -- send request line and headers + h:sendrequestline(nreqt.method, nreqt.uri) + h:sendheaders(nreqt.headers) + -- if there is a body, send it + if nreqt.source then + h:sendbody(nreqt.headers, nreqt.source, nreqt.step) + end + local code, status = h:receivestatusline() + -- if it is an HTTP/0.9 server, simply get the body and we are done + if not code then + h:receive09body(status, nreqt.sink, nreqt.step) + return 1, 200 + end + local headers + -- ignore any 100-continue messages + while code == 100 do + headers = h:receiveheaders() + code, status = h:receivestatusline() + end + headers = h:receiveheaders() + -- at this point we should have a honest reply from the server + -- we can't redirect if we already used the source, so we report the error + if shouldredirect(nreqt, code, headers) and not nreqt.source then + h:close() + return tredirect(reqt, headers.location) + end + -- here we are finally done + if shouldreceivebody(nreqt, code) then + h:receivebody(headers, nreqt.sink, nreqt.step) + end + h:close() + return 1, code, headers, status +end + +local function srequest(u, b) + local t = {} + local reqt = { + url = u, + sink = ltn12.sink.table(t) + } + if b then + reqt.source = ltn12.source.string(b) + reqt.headers = { + ["content-length"] = string.len(b), + ["content-type"] = "application/x-www-form-urlencoded" + } + reqt.method = "POST" + end + local code, headers, status = socket.skip(1, trequest(reqt)) + return table.concat(t), code, headers, status +end + +_M.request = socket.protect(function(reqt, body) + if base.type(reqt) == "string" then return srequest(reqt, body) + else return trequest(reqt) end +end) + +return _M \ No newline at end of file diff --git a/socket/mbox.lua b/socket/mbox.lua new file mode 100644 index 0000000..7724ae2 --- /dev/null +++ b/socket/mbox.lua @@ -0,0 +1,92 @@ +local _M = {} + +if module then + mbox = _M +end + +function _M.split_message(message_s) + local message = {} + message_s = string.gsub(message_s, "\r\n", "\n") + string.gsub(message_s, "^(.-\n)\n", function (h) message.headers = h end) + string.gsub(message_s, "^.-\n\n(.*)", function (b) message.body = b end) + if not message.body then + string.gsub(message_s, "^\n(.*)", function (b) message.body = b end) + end + if not message.headers and not message.body then + message.headers = message_s + end + return message.headers or "", message.body or "" +end + +function _M.split_headers(headers_s) + local headers = {} + headers_s = string.gsub(headers_s, "\r\n", "\n") + headers_s = string.gsub(headers_s, "\n[ ]+", " ") + string.gsub("\n" .. headers_s, "\n([^\n]+)", function (h) table.insert(headers, h) end) + return headers +end + +function _M.parse_header(header_s) + header_s = string.gsub(header_s, "\n[ ]+", " ") + header_s = string.gsub(header_s, "\n+", "") + local _, __, name, value = string.find(header_s, "([^%s:]-):%s*(.*)") + return name, value +end + +function _M.parse_headers(headers_s) + local headers_t = _M.split_headers(headers_s) + local headers = {} + for i = 1, #headers_t do + local name, value = _M.parse_header(headers_t[i]) + if name then + name = string.lower(name) + if headers[name] then + headers[name] = headers[name] .. ", " .. value + else headers[name] = value end + end + end + return headers +end + +function _M.parse_from(from) + local _, __, name, address = string.find(from, "^%s*(.-)%s*%<(.-)%>") + if not address then + _, __, address = string.find(from, "%s*(.+)%s*") + end + name = name or "" + address = address or "" + if name == "" then name = address end + name = string.gsub(name, '"', "") + return name, address +end + +function _M.split_mbox(mbox_s) + mbox = {} + mbox_s = string.gsub(mbox_s, "\r\n", "\n") .."\n\nFrom \n" + local nj, i, j = 1, 1, 1 + while 1 do + i, nj = string.find(mbox_s, "\n\nFrom .-\n", j) + if not i then break end + local message = string.sub(mbox_s, j, i-1) + table.insert(mbox, message) + j = nj+1 + end + return mbox +end + +function _M.parse(mbox_s) + local mbox = _M.split_mbox(mbox_s) + for i = 1, #mbox do + mbox[i] = _M.parse_message(mbox[i]) + end + return mbox +end + +function _M.parse_message(message_s) + local message = {} + message.headers, message.body = _M.split_message(message_s) + message.headers = _M.parse_headers(message.headers) + return message +end + +return _M diff --git a/socket/smtp.lua b/socket/smtp.lua new file mode 100644 index 0000000..b113d00 --- /dev/null +++ b/socket/smtp.lua @@ -0,0 +1,256 @@ +----------------------------------------------------------------------------- +-- SMTP client support for the Lua language. +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +----------------------------------------------------------------------------- +local base = _G +local coroutine = require("coroutine") +local string = require("string") +local math = require("math") +local os = require("os") +local socket = require("socket") +local tp = require("socket.tp") +local ltn12 = require("ltn12") +local headers = require("socket.headers") +local mime = require("mime") + +socket.smtp = {} +local _M = socket.smtp + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +-- timeout for connection +_M.TIMEOUT = 60 +-- default server used to send e-mails +_M.SERVER = "localhost" +-- default port +_M.PORT = 25 +-- domain used in HELO command and default sendmail +-- If we are under a CGI, try to get from environment +_M.DOMAIN = os.getenv("SERVER_NAME") or "localhost" +-- default time zone (means we don't know) +_M.ZONE = "-0000" + +--------------------------------------------------------------------------- +-- Low level SMTP API +----------------------------------------------------------------------------- +local metat = { __index = {} } + +function metat.__index:greet(domain) + self.try(self.tp:check("2..")) + self.try(self.tp:command("EHLO", domain or _M.DOMAIN)) + return socket.skip(1, self.try(self.tp:check("2.."))) +end + +function metat.__index:mail(from) + self.try(self.tp:command("MAIL", "FROM:" .. from)) + return self.try(self.tp:check("2..")) +end + +function metat.__index:rcpt(to) + self.try(self.tp:command("RCPT", "TO:" .. to)) + return self.try(self.tp:check("2..")) +end + +function metat.__index:data(src, step) + self.try(self.tp:command("DATA")) + self.try(self.tp:check("3..")) + self.try(self.tp:source(src, step)) + self.try(self.tp:send("\r\n.\r\n")) + return self.try(self.tp:check("2..")) +end + +function metat.__index:quit() + self.try(self.tp:command("QUIT")) + return self.try(self.tp:check("2..")) +end + +function metat.__index:close() + return self.tp:close() +end + +function metat.__index:login(user, password) + self.try(self.tp:command("AUTH", "LOGIN")) + self.try(self.tp:check("3..")) + self.try(self.tp:send(mime.b64(user) .. "\r\n")) + self.try(self.tp:check("3..")) + self.try(self.tp:send(mime.b64(password) .. "\r\n")) + return self.try(self.tp:check("2..")) +end + +function metat.__index:plain(user, password) + local auth = "PLAIN " .. mime.b64("\0" .. user .. "\0" .. password) + self.try(self.tp:command("AUTH", auth)) + return self.try(self.tp:check("2..")) +end + +function metat.__index:auth(user, password, ext) + if not user or not password then return 1 end + if string.find(ext, "AUTH[^\n]+LOGIN") then + return self:login(user, password) + elseif string.find(ext, "AUTH[^\n]+PLAIN") then + return self:plain(user, password) + else + self.try(nil, "authentication not supported") + end +end + +-- send message or throw an exception +function metat.__index:send(mailt) + self:mail(mailt.from) + if base.type(mailt.rcpt) == "table" then + for i,v in base.ipairs(mailt.rcpt) do + self:rcpt(v) + end + else + self:rcpt(mailt.rcpt) + end + self:data(ltn12.source.chain(mailt.source, mime.stuff()), mailt.step) +end + +function _M.open(server, port, create) + local tp = socket.try(tp.connect(server or _M.SERVER, port or _M.PORT, + _M.TIMEOUT, create)) + local s = base.setmetatable({tp = tp}, metat) + -- make sure tp is closed if we get an exception + s.try = socket.newtry(function() + s:close() + end) + return s +end + +-- convert headers to lowercase +local function lower_headers(headers) + local lower = {} + for i,v in base.pairs(headers or lower) do + lower[string.lower(i)] = v + end + return lower +end + +--------------------------------------------------------------------------- +-- Multipart message source +----------------------------------------------------------------------------- +-- returns a hopefully unique mime boundary +local seqno = 0 +local function newboundary() + seqno = seqno + 1 + return string.format('%s%05d==%05u', os.date('%d%m%Y%H%M%S'), + math.random(0, 99999), seqno) +end + +-- send_message forward declaration +local send_message + +-- yield the headers all at once, it's faster +local function send_headers(tosend) + local canonic = headers.canonic + local h = "\r\n" + for f,v in base.pairs(tosend) do + h = (canonic[f] or f) .. ': ' .. v .. "\r\n" .. h + end + coroutine.yield(h) +end + +-- yield multipart message body from a multipart message table +local function send_multipart(mesgt) + -- make sure we have our boundary and send headers + local bd = newboundary() + local headers = lower_headers(mesgt.headers or {}) + headers['content-type'] = headers['content-type'] or 'multipart/mixed' + headers['content-type'] = headers['content-type'] .. + '; boundary="' .. bd .. '"' + send_headers(headers) + -- send preamble + if mesgt.body.preamble then + coroutine.yield(mesgt.body.preamble) + coroutine.yield("\r\n") + end + -- send each part separated by a boundary + for i, m in base.ipairs(mesgt.body) do + coroutine.yield("\r\n--" .. bd .. "\r\n") + send_message(m) + end + -- send last boundary + coroutine.yield("\r\n--" .. bd .. "--\r\n\r\n") + -- send epilogue + if mesgt.body.epilogue then + coroutine.yield(mesgt.body.epilogue) + coroutine.yield("\r\n") + end +end + +-- yield message body from a source +local function send_source(mesgt) + -- make sure we have a content-type + local headers = lower_headers(mesgt.headers or {}) + headers['content-type'] = headers['content-type'] or + 'text/plain; charset="iso-8859-1"' + send_headers(headers) + -- send body from source + while true do + local chunk, err = mesgt.body() + if err then coroutine.yield(nil, err) + elseif chunk then coroutine.yield(chunk) + else break end + end +end + +-- yield message body from a string +local function send_string(mesgt) + -- make sure we have a content-type + local headers = lower_headers(mesgt.headers or {}) + headers['content-type'] = headers['content-type'] or + 'text/plain; charset="iso-8859-1"' + send_headers(headers) + -- send body from string + coroutine.yield(mesgt.body) +end + +-- message source +function send_message(mesgt) + if base.type(mesgt.body) == "table" then send_multipart(mesgt) + elseif base.type(mesgt.body) == "function" then send_source(mesgt) + else send_string(mesgt) end +end + +-- set defaul headers +local function adjust_headers(mesgt) + local lower = lower_headers(mesgt.headers) + lower["date"] = lower["date"] or + os.date("!%a, %d %b %Y %H:%M:%S ") .. (mesgt.zone or _M.ZONE) + lower["x-mailer"] = lower["x-mailer"] or socket._VERSION + -- this can't be overriden + lower["mime-version"] = "1.0" + return lower +end + +function _M.message(mesgt) + mesgt.headers = adjust_headers(mesgt) + -- create and return message source + local co = coroutine.create(function() send_message(mesgt) end) + return function() + local ret, a, b = coroutine.resume(co) + if ret then return a, b + else return nil, a end + end +end + +--------------------------------------------------------------------------- +-- High level SMTP API +----------------------------------------------------------------------------- +_M.send = socket.protect(function(mailt) + local s = _M.open(mailt.server, mailt.port, mailt.create) + local ext = s:greet(mailt.domain) + s:auth(mailt.user, mailt.password, ext) + s:send(mailt) + s:quit() + return s:close() +end) + +return _M \ No newline at end of file diff --git a/socket/tftp.lua b/socket/tftp.lua new file mode 100644 index 0000000..ed99cd1 --- /dev/null +++ b/socket/tftp.lua @@ -0,0 +1,154 @@ +----------------------------------------------------------------------------- +-- TFTP support for the Lua language +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Load required files +----------------------------------------------------------------------------- +local base = _G +local table = require("table") +local math = require("math") +local string = require("string") +local socket = require("socket") +local ltn12 = require("ltn12") +local url = require("socket.url") +module("socket.tftp") + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +local char = string.char +local byte = string.byte + +PORT = 69 +local OP_RRQ = 1 +local OP_WRQ = 2 +local OP_DATA = 3 +local OP_ACK = 4 +local OP_ERROR = 5 +local OP_INV = {"RRQ", "WRQ", "DATA", "ACK", "ERROR"} + +----------------------------------------------------------------------------- +-- Packet creation functions +----------------------------------------------------------------------------- +local function RRQ(source, mode) + return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0) +end + +local function WRQ(source, mode) + return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0) +end + +local function ACK(block) + local low, high + low = math.mod(block, 256) + high = (block - low)/256 + return char(0, OP_ACK, high, low) +end + +local function get_OP(dgram) + local op = byte(dgram, 1)*256 + byte(dgram, 2) + return op +end + +----------------------------------------------------------------------------- +-- Packet analysis functions +----------------------------------------------------------------------------- +local function split_DATA(dgram) + local block = byte(dgram, 3)*256 + byte(dgram, 4) + local data = string.sub(dgram, 5) + return block, data +end + +local function get_ERROR(dgram) + local code = byte(dgram, 3)*256 + byte(dgram, 4) + local msg + _,_, msg = string.find(dgram, "(.*)\000", 5) + return string.format("error code %d: %s", code, msg) +end + +----------------------------------------------------------------------------- +-- The real work +----------------------------------------------------------------------------- +local function tget(gett) + local retries, dgram, sent, datahost, dataport, code + local last = 0 + socket.try(gett.host, "missing host") + local con = socket.try(socket.udp()) + local try = socket.newtry(function() con:close() end) + -- convert from name to ip if needed + gett.host = try(socket.dns.toip(gett.host)) + con:settimeout(1) + -- first packet gives data host/port to be used for data transfers + local path = string.gsub(gett.path or "", "^/", "") + path = url.unescape(path) + retries = 0 + repeat + sent = try(con:sendto(RRQ(path, "octet"), gett.host, gett.port)) + dgram, datahost, dataport = con:receivefrom() + retries = retries + 1 + until dgram or datahost ~= "timeout" or retries > 5 + try(dgram, datahost) + -- associate socket with data host/port + try(con:setpeername(datahost, dataport)) + -- default sink + local sink = gett.sink or ltn12.sink.null() + -- process all data packets + while 1 do + -- decode packet + code = get_OP(dgram) + try(code ~= OP_ERROR, get_ERROR(dgram)) + try(code == OP_DATA, "unhandled opcode " .. code) + -- get data packet parts + local block, data = split_DATA(dgram) + -- if not repeated, write + if block == last+1 then + try(sink(data)) + last = block + end + -- last packet brings less than 512 bytes of data + if string.len(data) < 512 then + try(con:send(ACK(block))) + try(con:close()) + try(sink(nil)) + return 1 + end + -- get the next packet + retries = 0 + repeat + sent = try(con:send(ACK(last))) + dgram, err = con:receive() + retries = retries + 1 + until dgram or err ~= "timeout" or retries > 5 + try(dgram, err) + end +end + +local default = { + port = PORT, + path ="/", + scheme = "tftp" +} + +local function parse(u) + local t = socket.try(url.parse(u, default)) + socket.try(t.scheme == "tftp", "invalid scheme '" .. t.scheme .. "'") + socket.try(t.host, "invalid host") + return t +end + +local function sget(u) + local gett = parse(u) + local t = {} + gett.sink = ltn12.sink.table(t) + tget(gett) + return table.concat(t) +end + +get = socket.protect(function(gett) + if base.type(gett) == "string" then return sget(gett) + else return tget(gett) end +end) + diff --git a/socket/tp.lua b/socket/tp.lua new file mode 100644 index 0000000..cbeff56 --- /dev/null +++ b/socket/tp.lua @@ -0,0 +1,126 @@ +----------------------------------------------------------------------------- +-- Unified SMTP/FTP subsystem +-- LuaSocket toolkit. +-- Author: Diego Nehab +----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- +-- Declare module and import dependencies +----------------------------------------------------------------------------- +local base = _G +local string = require("string") +local socket = require("socket") +local ltn12 = require("ltn12") + +socket.tp = {} +local _M = socket.tp + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +_M.TIMEOUT = 60 + +----------------------------------------------------------------------------- +-- Implementation +----------------------------------------------------------------------------- +-- gets server reply (works for SMTP and FTP) +local function get_reply(c) + local code, current, sep + local line, err = c:receive() + local reply = line + if err then return nil, err end + code, sep = socket.skip(2, string.find(line, "^(%d%d%d)(.?)")) + if not code then return nil, "invalid server reply" end + if sep == "-" then -- reply is multiline + repeat + line, err = c:receive() + if err then return nil, err end + current, sep = socket.skip(2, string.find(line, "^(%d%d%d)(.?)")) + reply = reply .. "\n" .. line + -- reply ends with same code + until code == current and sep == " " + end + return code, reply +end + +-- metatable for sock object +local metat = { __index = {} } + +function metat.__index:check(ok) + local code, reply = get_reply(self.c) + if not code then return nil, reply end + if base.type(ok) ~= "function" then + if base.type(ok) == "table" then + for i, v in base.ipairs(ok) do + if string.find(code, v) then + return base.tonumber(code), reply + end + end + return nil, reply + else + if string.find(code, ok) then return base.tonumber(code), reply + else return nil, reply end + end + else return ok(base.tonumber(code), reply) end +end + +function metat.__index:command(cmd, arg) + cmd = string.upper(cmd) + if arg then + return self.c:send(cmd .. " " .. arg.. "\r\n") + else + return self.c:send(cmd .. "\r\n") + end +end + +function metat.__index:sink(snk, pat) + local chunk, err = c:receive(pat) + return snk(chunk, err) +end + +function metat.__index:send(data) + return self.c:send(data) +end + +function metat.__index:receive(pat) + return self.c:receive(pat) +end + +function metat.__index:getfd() + return self.c:getfd() +end + +function metat.__index:dirty() + return self.c:dirty() +end + +function metat.__index:getcontrol() + return self.c +end + +function metat.__index:source(source, step) + local sink = socket.sink("keep-open", self.c) + local ret, err = ltn12.pump.all(source, sink, step or ltn12.pump.step) + return ret, err +end + +-- closes the underlying c +function metat.__index:close() + self.c:close() + return 1 +end + +-- connect with server and return c object +function _M.connect(host, port, timeout, create) + local c, e = (create or socket.tcp)() + if not c then return nil, e end + c:settimeout(timeout or _M.TIMEOUT) + local r, e = c:connect(host, port) + if not r then + c:close() + return nil, e + end + return base.setmetatable({c = c}, metat) +end + +return _M \ No newline at end of file diff --git a/socket/url.lua b/socket/url.lua index 5a1ea21..7809535 100644 --- a/socket/url.lua +++ b/socket/url.lua @@ -2,7 +2,6 @@ -- URI parsing, composition and relative URL resolution -- LuaSocket toolkit. -- Author: Diego Nehab --- RCS ID: $Id: url.lua,v 1.38 2006/04/03 04:45:42 diego Exp $ ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- @@ -11,12 +10,15 @@ local string = require("string") local base = _G local table = require("table") -module("socket.url") +local socket = require("socket") + +socket.url = {} +local _M = socket.url ----------------------------------------------------------------------------- -- Module version ----------------------------------------------------------------------------- -_VERSION = "URL 1.0.1" +_M._VERSION = "URL 1.0.3" ----------------------------------------------------------------------------- -- Encodes a string into its escaped hexadecimal representation @@ -25,10 +27,10 @@ _VERSION = "URL 1.0.1" -- Returns -- escaped representation of string binary ----------------------------------------------------------------------------- -function escape(s) - return string.gsub(s, "([^A-Za-z0-9_])", function(c) +function _M.escape(s) + return (string.gsub(s, "([^A-Za-z0-9_])", function(c) return string.format("%%%02x", string.byte(c)) - end) + end)) end ----------------------------------------------------------------------------- @@ -40,25 +42,25 @@ end -- escaped representation of string binary ----------------------------------------------------------------------------- local function make_set(t) - local s = {} - for i,v in base.ipairs(t) do - s[t[i]] = 1 - end - return s + local s = {} + for i,v in base.ipairs(t) do + s[t[i]] = 1 + end + return s end -- these are allowed withing a path segment, along with alphanum -- other characters must be escaped local segment_set = make_set { "-", "_", ".", "!", "~", "*", "'", "(", - ")", ":", "@", "&", "=", "+", "$", ",", + ")", ":", "@", "&", "=", "+", "$", ",", } local function protect_segment(s) - return string.gsub(s, "([^A-Za-z0-9_])", function (c) - if segment_set[c] then return c - else return string.format("%%%02x", string.byte(c)) end - end) + return string.gsub(s, "([^A-Za-z0-9_])", function (c) + if segment_set[c] then return c + else return string.format("%%%02x", string.byte(c)) end + end) end ----------------------------------------------------------------------------- @@ -68,10 +70,10 @@ end -- Returns -- escaped representation of string binary ----------------------------------------------------------------------------- -function unescape(s) - return string.gsub(s, "%%(%x%x)", function(hex) +function _M.unescape(s) + return (string.gsub(s, "%%(%x%x)", function(hex) return string.char(base.tonumber(hex, 16)) - end) + end)) end ----------------------------------------------------------------------------- @@ -121,7 +123,7 @@ end -- Obs: -- the leading '/' in {/} is considered part of ----------------------------------------------------------------------------- -function parse(url, default) +function _M.parse(url, default) -- initialize default parameters local parsed = {} for i,v in base.pairs(default or parsed) do parsed[i] = v end @@ -142,7 +144,7 @@ function parse(url, default) parsed.authority = n return "" end) - -- get query stringing + -- get query string url = string.gsub(url, "%?(.*)", function(q) parsed.query = q return "" @@ -158,9 +160,12 @@ function parse(url, default) if not authority then return parsed end authority = string.gsub(authority,"^([^@]*)@", function(u) parsed.userinfo = u; return "" end) - authority = string.gsub(authority, ":([^:]*)$", + authority = string.gsub(authority, ":([^:%]]*)$", function(p) parsed.port = p; return "" end) - if authority ~= "" then parsed.host = authority end + if authority ~= "" then + -- IPv6? + parsed.host = string.match(authority, "^%[(.+)%]$") or authority + end local userinfo = parsed.userinfo if not userinfo then return parsed end userinfo = string.gsub(userinfo, ":([^:]*)$", @@ -177,24 +182,27 @@ end -- Returns -- a stringing with the corresponding URL ----------------------------------------------------------------------------- -function build(parsed) - local ppath = parse_path(parsed.path or "") - local url = build_path(ppath) +function _M.build(parsed) + local ppath = _M.parse_path(parsed.path or "") + local url = _M.build_path(ppath) if parsed.params then url = url .. ";" .. parsed.params end if parsed.query then url = url .. "?" .. parsed.query end - local authority = parsed.authority - if parsed.host then - authority = parsed.host - if parsed.port then authority = authority .. ":" .. parsed.port end - local userinfo = parsed.userinfo - if parsed.user then - userinfo = parsed.user - if parsed.password then - userinfo = userinfo .. ":" .. parsed.password - end - end - if userinfo then authority = userinfo .. "@" .. authority end - end + local authority = parsed.authority + if parsed.host then + authority = parsed.host + if string.find(authority, ":") then -- IPv6? + authority = "[" .. authority .. "]" + end + if parsed.port then authority = authority .. ":" .. parsed.port end + local userinfo = parsed.userinfo + if parsed.user then + userinfo = parsed.user + if parsed.password then + userinfo = userinfo .. ":" .. parsed.password + end + end + if userinfo then authority = userinfo .. "@" .. authority end + end if authority then url = "//" .. authority .. url end if parsed.scheme then url = parsed.scheme .. ":" .. url end if parsed.fragment then url = url .. "#" .. parsed.fragment end @@ -210,14 +218,14 @@ end -- Returns -- corresponding absolute url ----------------------------------------------------------------------------- -function absolute(base_url, relative_url) +function _M.absolute(base_url, relative_url) if base.type(base_url) == "table" then base_parsed = base_url - base_url = build(base_parsed) + base_url = _M.build(base_parsed) else - base_parsed = parse(base_url) + base_parsed = _M.parse(base_url) end - local relative_parsed = parse(relative_url) + local relative_parsed = _M.parse(relative_url) if not base_parsed then return relative_url elseif not relative_parsed then return base_url elseif relative_parsed.scheme then return relative_url @@ -233,12 +241,12 @@ function absolute(base_url, relative_url) relative_parsed.query = base_parsed.query end end - else + else relative_parsed.path = absolute_path(base_parsed.path or "", relative_parsed.path) end end - return build(relative_parsed) + return _M.build(relative_parsed) end end @@ -249,17 +257,17 @@ end -- Returns -- segment: a table with one entry per segment ----------------------------------------------------------------------------- -function parse_path(path) - local parsed = {} - path = path or "" - --path = string.gsub(path, "%s", "") - string.gsub(path, "([^/]+)", function (s) table.insert(parsed, s) end) - for i = 1, table.getn(parsed) do - parsed[i] = unescape(parsed[i]) - end - if string.sub(path, 1, 1) == "/" then parsed.is_absolute = 1 end - if string.sub(path, -1, -1) == "/" then parsed.is_directory = 1 end - return parsed +function _M.parse_path(path) + local parsed = {} + path = path or "" + --path = string.gsub(path, "%s", "") + string.gsub(path, "([^/]+)", function (s) table.insert(parsed, s) end) + for i = 1, #parsed do + parsed[i] = _M.unescape(parsed[i]) + end + if string.sub(path, 1, 1) == "/" then parsed.is_absolute = 1 end + if string.sub(path, -1, -1) == "/" then parsed.is_directory = 1 end + return parsed end ----------------------------------------------------------------------------- @@ -270,28 +278,30 @@ end -- Returns -- path: corresponding path stringing ----------------------------------------------------------------------------- -function build_path(parsed, unsafe) - local path = "" - local n = table.getn(parsed) - if unsafe then - for i = 1, n-1 do - path = path .. parsed[i] - path = path .. "/" - end - if n > 0 then - path = path .. parsed[n] - if parsed.is_directory then path = path .. "/" end - end - else - for i = 1, n-1 do - path = path .. protect_segment(parsed[i]) - path = path .. "/" - end - if n > 0 then - path = path .. protect_segment(parsed[n]) - if parsed.is_directory then path = path .. "/" end - end - end - if parsed.is_absolute then path = "/" .. path end - return path +function _M.build_path(parsed, unsafe) + local path = "" + local n = #parsed + if unsafe then + for i = 1, n-1 do + path = path .. parsed[i] + path = path .. "/" + end + if n > 0 then + path = path .. parsed[n] + if parsed.is_directory then path = path .. "/" end + end + else + for i = 1, n-1 do + path = path .. protect_segment(parsed[i]) + path = path .. "/" + end + if n > 0 then + path = path .. protect_segment(parsed[n]) + if parsed.is_directory then path = path .. "/" end + end + end + if parsed.is_absolute then path = "/" .. path end + return path end + +return _M