diff --git a/CMakeLists.txt b/CMakeLists.txt
index e149d1c..f7e8414 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -84,6 +84,13 @@ target_sources(${FTPD_TARGET} PRIVATE
source/socket.cpp
)
+if(NOT NINTENDO_DS)
+ target_sources(${FTPD_TARGET} PRIVATE
+ source/mdns.cpp
+ include/mdns.h
+ )
+endif()
+
if(NOT FTPD_CLASSIC AND NOT NINTENDO_DS)
target_sources(${FTPD_TARGET} PRIVATE
include/imconfig.h
diff --git a/include/ftpConfig.h b/include/ftpConfig.h
index 7937c5d..57fef91 100644
--- a/include/ftpConfig.h
+++ b/include/ftpConfig.h
@@ -60,6 +60,9 @@ public:
/// \brief Get password
std::string const &pass () const;
+ /// \brief Get hostname
+ std::string const &hostname () const;
+
/// \brief Get port
std::uint16_t port () const;
@@ -88,6 +91,10 @@ public:
/// \param pass_ Password
void setPass (std::string pass_);
+ /// \brief Set hostname
+ /// \param hostname_ Hostname
+ void setHostname (std::string hostname_);
+
/// \brief Set listen port
/// \param port_ Listen port
bool setPort (std::string_view port_);
@@ -130,6 +137,9 @@ private:
/// \brief Password
std::string m_pass;
+ /// \brief Hostname
+ std::string m_hostname;
+
/// \brief Listen port
std::uint16_t m_port;
diff --git a/include/ftpServer.h b/include/ftpServer.h
index 729b090..a866d7e 100644
--- a/include/ftpServer.h
+++ b/include/ftpServer.h
@@ -110,6 +110,11 @@ private:
/// \brief Listen socket
UniqueSocket m_socket;
+#ifndef __NDS__
+ /// \brief mDNS socket
+ UniqueSocket m_mdnsSocket;
+#endif
+
/// \brief ImGui window name
std::string m_name;
@@ -151,6 +156,9 @@ private:
/// \brief Password setting
std::string m_passSetting;
+ /// \brief Hostname setting
+ std::string m_hostnameSetting;
+
/// \brief Port setting
std::uint16_t m_portSetting = 0;
diff --git a/include/mdns.h b/include/mdns.h
new file mode 100644
index 0000000..4fd8ce5
--- /dev/null
+++ b/include/mdns.h
@@ -0,0 +1,40 @@
+
+// ftpd is a server implementation based on the following:
+// - RFC 959 (https://datatracker.ietf.org/doc/html/rfc959)
+// - RFC 3659 (https://datatracker.ietf.org/doc/html/rfc3659)
+// - suggested implementation details from https://cr.yp.to/ftp/filesystem.html
+//
+// ftpd implements mdns based on the following:
+// - RFC 1035 (https://datatracker.ietf.org/doc/html/rfc1035)
+// - RFC 6762 (https://datatracker.ietf.org/doc/html/rfc6762)
+//
+// Copyright (C) 2024 Michael Theall
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+#pragma once
+
+#include "sockAddr.h"
+#include "socket.h"
+
+#include
+
+namespace mdns
+{
+void setHostname (std::string hostname_);
+
+UniqueSocket createSocket ();
+
+void handleSocket (Socket *socket_, SockAddr const &addr_);
+}
diff --git a/include/platform.h b/include/platform.h
index e108433..eca7933 100644
--- a/include/platform.h
+++ b/include/platform.h
@@ -3,7 +3,7 @@
// - RFC 3659 (https://tools.ietf.org/html/rfc3659)
// - suggested implementation details from https://cr.yp.to/ftp/filesystem.html
//
-// Copyright (C) 2023 Michael Theall
+// Copyright (C) 2024 Michael Theall
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
@@ -34,6 +34,7 @@
#include
#include
#include
+#include
#ifdef CLASSIC
extern PrintConsole g_statusConsole;
@@ -71,6 +72,9 @@ bool networkVisible ();
/// \param[out] addr_ Network address
bool networkAddress (SockAddr &addr_);
+/// \brief Get hostname
+std::string const &hostname ();
+
/// \brief Platform loop
bool loop ();
diff --git a/include/sockAddr.h b/include/sockAddr.h
index ad15cbd..aa05517 100644
--- a/include/sockAddr.h
+++ b/include/sockAddr.h
@@ -23,6 +23,7 @@
#include
#include
+#include
#include
#ifdef __NDS__
@@ -37,10 +38,48 @@ struct sockaddr_storage
class SockAddr
{
public:
+ enum class Domain
+ {
+ IPv4 = AF_INET,
+#ifndef NO_IPV6
+ IPv6 = AF_INET6,
+#endif
+ };
+
+ /// \brief 0.0.0.0
+ static SockAddr const AnyIPv4;
+
+#ifndef NO_IPV6
+ /// \brief ::
+ static SockAddr const AnyIPv6;
+#endif
+
~SockAddr ();
SockAddr ();
+ /// \brief Parameterized constructor
+ /// \param domain_ Socket domain
+ /// \note Initial address is INADDR_ANY/in6addr_any
+ SockAddr (Domain domain_);
+
+ /// \brief Parameterized constructor
+ /// \param addr_ Socket address (network byte order)
+ /// \param port_ Socket port (host byte order)
+ SockAddr (in_addr_t addr_, std::uint16_t port_ = 0);
+
+ /// \brief Parameterized constructor
+ /// \param addr_ Socket address (network byte order)
+ /// \param port_ Socket port (host byte order)
+ SockAddr (in_addr const &addr_, std::uint16_t port_ = 0);
+
+#ifndef NO_IPV6
+ /// \brief Parameterized constructor
+ /// \param addr_ Socket address
+ /// \param port_ Socket port (host byte order)
+ SockAddr (in6_addr const &addr_, std::uint16_t port_ = 0);
+#endif
+
/// \brief Copy constructor
/// \param that_ Object to copy
SockAddr (SockAddr const &that_);
@@ -88,10 +127,32 @@ public:
/// \brief sockaddr const* cast operator (network byte order)
operator sockaddr const * () const;
+ /// \brief Equality operator
+ bool operator== (SockAddr const &that_) const;
+
+ /// \brief Comparison operator
+ std::strong_ordering operator<=> (SockAddr const &that_) const;
+
+ /// \brief sockaddr domain
+ Domain domain () const;
/// \brief sockaddr size
socklen_t size () const;
+ /// \brief Set address
+ /// \param addr_ Address to set (network byte order)
+ void setAddr (in_addr_t addr_);
+
+ /// \brief Set address
+ /// \param addr_ Address to set (network byte order)
+ void setAddr (in_addr const &addr_);
+
+#ifndef NO_IPV6
+ /// \brief Set address
+ /// \param addr_ Address to set (network byte order)
+ void setAddr (in6_addr const &addr_);
+#endif
+
/// \brief Address port (host byte order)
std::uint16_t port () const;
diff --git a/include/socket.h b/include/socket.h
index 6d91680..1e9973f 100644
--- a/include/socket.h
+++ b/include/socket.h
@@ -56,6 +56,12 @@ using SharedSocket = std::shared_ptr;
class Socket
{
public:
+ enum Type
+ {
+ eStream = SOCK_STREAM, ///< Stream socket
+ eDatagram = SOCK_DGRAM, ///< Datagram socket
+ };
+
/// \brief Poll info
struct PollInfo
{
@@ -114,6 +120,18 @@ public:
/// \param size_ Buffer size
bool setSendBufferSize (std::size_t size_);
+#ifndef __NDS__
+ /// \brief Join multicast group
+ /// \param addr_ Multicast group address
+ /// \param iface_ Interface address
+ bool joinMulticastGroup (SockAddr const &addr_, SockAddr const &iface_);
+
+ /// \brief Drop multicast group
+ /// \param addr_ Multicast group address
+ /// \param iface_ Interface address
+ bool dropMulticastGroup (SockAddr const &addr_, SockAddr const &iface_);
+#endif
+
/// \brief Read data
/// \param buffer_ Output buffer
/// \param size_ Size to read
@@ -125,6 +143,12 @@ public:
/// \param oob_ Whether to read from out-of-band
std::make_signed_t read (IOBuffer &buffer_, bool oob_ = false);
+ /// \brief Read data
+ /// \param buffer_ Output buffer
+ /// \param size_ Size to read
+ /// \param[out] addr_ Source address
+ std::make_signed_t readFrom (void *buffer_, std::size_t size_, SockAddr &addr_);
+
/// \brief Write data
/// \param buffer_ Input buffer
/// \param size_ Size to write
@@ -135,13 +159,21 @@ public:
/// \param size_ Size to write
std::make_signed_t write (IOBuffer &buffer_);
+ /// \brief Write data
+ /// \param buffer_ Input buffer
+ /// \param size_ Size to write
+ /// \param[out] addr_ Destination address
+ std::make_signed_t
+ writeTo (void const *buffer_, std::size_t size_, SockAddr const &addr_);
+
/// \brief Local name
SockAddr const &sockName () const;
/// \brief Peer name
SockAddr const &peerName () const;
/// \brief Create socket
- static UniqueSocket create ();
+ /// \param type_ Socket type
+ static UniqueSocket create (Type type_);
/// \brief Poll sockets
/// \param info_ Poll info
diff --git a/source/3ds/platform.cpp b/source/3ds/platform.cpp
index c1012af..2a18770 100644
--- a/source/3ds/platform.cpp
+++ b/source/3ds/platform.cpp
@@ -577,6 +577,12 @@ bool platform::networkAddress (SockAddr &addr_)
return true;
}
+std::string const &platform::hostname ()
+{
+ static std::string const hostname = "3ds-ftpd";
+ return hostname;
+}
+
bool platform::loop ()
{
if (!aptMainLoop ())
diff --git a/source/ftpConfig.cpp b/source/ftpConfig.cpp
index eca6b89..40ca6de 100644
--- a/source/ftpConfig.cpp
+++ b/source/ftpConfig.cpp
@@ -232,6 +232,11 @@ std::string const &FtpConfig::pass () const
return m_pass;
}
+std::string const &FtpConfig::hostname () const
+{
+ return m_hostname;
+}
+
std::uint16_t FtpConfig::port () const
{
return m_port;
@@ -271,6 +276,11 @@ void FtpConfig::setPass (std::string pass_)
m_pass = std::move (pass_);
}
+void FtpConfig::setHostname (std::string hostname_)
+{
+ m_hostname = std::move (hostname_);
+}
+
bool FtpConfig::setPort (std::string_view const port_)
{
std::uint16_t parsed{};
diff --git a/source/ftpServer.cpp b/source/ftpServer.cpp
index 1c96c7d..f753783 100644
--- a/source/ftpServer.cpp
+++ b/source/ftpServer.cpp
@@ -29,6 +29,10 @@
#include "sockAddr.h"
#include "socket.h"
+#ifndef __NDS__
+#include "mdns.h"
+#endif
+
#include "imgui.h"
#ifdef __NDS__
@@ -223,8 +227,14 @@ FtpServer::~FtpServer ()
FtpServer::FtpServer (UniqueFtpConfig config_)
: m_config (std::move (config_))
+#ifndef CLASSIC
+ ,
+ m_hostnameSetting (m_config->hostname ())
+#endif
{
#ifndef __NDS__
+ mdns::setHostname (m_config->hostname ());
+
m_thread = platform::Thread (std::bind (&FtpServer::threadFunc, this));
#endif
@@ -439,7 +449,7 @@ void FtpServer::handleNetworkFound ()
addr.setPort (port);
- auto socket = Socket::create ();
+ auto socket = Socket::create (Socket::eStream);
if (!socket)
return;
@@ -461,6 +471,14 @@ void FtpServer::handleNetworkFound ()
info ("Started server at %s\n", m_name.c_str ());
LOCKED (m_socket = std::move (socket));
+
+#ifndef __NDS__
+ socket = mdns::createSocket ();
+ if (!socket)
+ return;
+
+ LOCKED (m_mdnsSocket = std::move (socket));
+#endif
}
void FtpServer::handleNetworkLost ()
@@ -476,6 +494,11 @@ void FtpServer::handleNetworkLost ()
// destroy command socket
LOCKED (sock = std::move (m_socket));
+
+#ifndef __NDS__
+ // destroy mDNS socket
+ LOCKED (sock = std::move (m_mdnsSocket));
+#endif
}
info ("Stopped server at %s\n", m_name.c_str ());
@@ -574,6 +597,9 @@ void FtpServer::showMenu ()
m_passSetting = m_config->pass ();
m_passSetting.resize (32);
+ m_hostnameSetting = m_config->hostname ();
+ m_hostnameSetting.resize (32);
+
m_portSetting = m_config->port ();
#ifdef __3DS__
@@ -631,6 +657,11 @@ void FtpServer::showSettings ()
m_passSetting.size (),
ImGuiInputTextFlags_AutoSelectAll | ImGuiInputTextFlags_Password);
+ ImGui::InputText ("Hostname",
+ m_hostnameSetting.data (),
+ m_hostnameSetting.size (),
+ ImGuiInputTextFlags_AutoSelectAll);
+
ImGui::InputScalar ("Port",
ImGuiDataType_U16,
&m_portSetting,
@@ -703,6 +734,7 @@ void FtpServer::showSettings ()
m_config->setUser (m_userSetting);
m_config->setPass (m_passSetting);
+ m_config->setHostname (m_hostnameSetting);
m_config->setPort (m_portSetting);
#ifdef __3DS__
@@ -718,6 +750,8 @@ void FtpServer::showSettings ()
UniqueSocket socket;
LOCKED (socket = std::move (m_socket));
+
+ mdns::setHostname (m_hostnameSetting);
}
if (save)
@@ -733,9 +767,10 @@ void FtpServer::showSettings ()
{
static auto const defaults = FtpConfig::create ();
- m_userSetting = defaults->user ();
- m_passSetting = defaults->pass ();
- m_portSetting = defaults->port ();
+ m_userSetting = defaults->user ();
+ m_passSetting = defaults->pass ();
+ m_hostnameSetting = defaults->hostname ();
+ m_portSetting = defaults->port ();
#ifdef __3DS__
m_getMTimeSetting = defaults->getMTime ();
#endif
@@ -966,6 +1001,12 @@ void FtpServer::loop ()
}
}
+#ifndef __NDS__
+ // poll mDNS socket
+ if (m_socket && m_mdnsSocket)
+ mdns::handleSocket (m_mdnsSocket.get (), m_socket->sockName ());
+#endif
+
{
std::vector deadSessions;
{
diff --git a/source/ftpSession.cpp b/source/ftpSession.cpp
index b8bfab3..5f769a1 100644
--- a/source/ftpSession.cpp
+++ b/source/ftpSession.cpp
@@ -22,6 +22,7 @@
#include "ftpServer.h"
#include "log.h"
+#include "mdns.h"
#include "platform.h"
#include "imgui.h"
@@ -837,7 +838,7 @@ bool FtpSession::dataConnect ()
m_port = false;
- auto data = Socket::create ();
+ auto data = Socket::create (Socket::eStream);
LOCKED (m_dataSocket = std::move (data));
if (!m_dataSocket)
return false;
@@ -2384,7 +2385,7 @@ void FtpSession::PASV (char const *args_)
m_port = false;
// create a socket to listen on
- auto pasv = Socket::create ();
+ auto pasv = Socket::create (Socket::eStream);
LOCKED (m_pasvSocket = std::move (pasv));
if (!m_pasvSocket)
{
@@ -2726,6 +2727,9 @@ void FtpSession::SITE (char const *args_)
" Set username: SITE USER \r\n"
" Set password: SITE PASS \r\n"
" Set port: SITE PORT \r\n"
+#ifndef __NDS__
+ " Set hostname: SITE HOST \r\n"
+#endif
#ifdef __3DS__
" Set getMTime: SITE MTIME [0|1]\r\n"
#endif
@@ -2784,6 +2788,16 @@ void FtpSession::SITE (char const *args_)
sendResponse ("200 OK\r\n");
return;
}
+#ifndef __NDS__
+ else if (compare (command, "HOST") == 0)
+ {
+ {
+ auto const lock = m_config.lockGuard ();
+ m_config.setHostname (std::string (arg));
+ mdns::setHostname (std::string (arg));
+ }
+ }
+#endif
#ifdef __3DS__
else if (compare (command, "MTIME") == 0)
{
diff --git a/source/linux/platform.cpp b/source/linux/platform.cpp
index b2d72ce..fc62ed3 100644
--- a/source/linux/platform.cpp
+++ b/source/linux/platform.cpp
@@ -204,6 +204,24 @@ bool platform::networkAddress (SockAddr &addr_)
return true;
}
+std::string const &platform::hostname ()
+{
+ static std::string hostname = "switch-ftpd.local";
+ if (hostname.empty ())
+ {
+ std::string buffer (256, '\0');
+ gethostname (buffer.data (), buffer.size ());
+
+ if (buffer.back () == 0) // check for truncation
+ {
+ hostname = std::move (buffer);
+ hostname.resize (std::strlen (hostname.data ()));
+ }
+ }
+
+ return hostname;
+}
+
bool platform::loop ()
{
bool inactive;
diff --git a/source/mdns.cpp b/source/mdns.cpp
new file mode 100644
index 0000000..10cd111
--- /dev/null
+++ b/source/mdns.cpp
@@ -0,0 +1,599 @@
+// ftpd is a server implementation based on the following:
+// - RFC 959 (https://datatracker.ietf.org/doc/html/rfc959)
+// - RFC 3659 (https://datatracker.ietf.org/doc/html/rfc3659)
+// - suggested implementation details from https://cr.yp.to/ftp/filesystem.html
+//
+// ftpd implements mdns based on the following:
+// - RFC 1035 (https://datatracker.ietf.org/doc/html/rfc1035)
+// - RFC 6762 (https://datatracker.ietf.org/doc/html/rfc6762)
+//
+// Copyright (C) 2024 Michael Theall
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see .
+
+#include "mdns.h"
+
+#include "log.h"
+#include "platform.h"
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+using namespace std::chrono_literals;
+
+static_assert (
+ std::endian::native == std::endian::big || std::endian::native == std::endian::little);
+
+static_assert (sizeof (in_addr_t) == 4);
+
+namespace
+{
+constexpr auto MDNS_TTL = 120;
+
+SockAddr const s_multicastAddress{inet_addr ("224.0.0.251"), 5353};
+
+platform::steady_clock::time_point s_lastAnnounce{};
+platform::steady_clock::time_point s_lastProbe{};
+
+std::string s_hostname = platform::hostname ();
+std::string s_hostnameLocal = s_hostname + ".local";
+
+enum class State
+{
+ Probe1,
+ Probe2,
+ Probe3,
+ Announce1,
+ Announce2,
+ Complete,
+ Conflict,
+};
+
+auto s_state = State::Probe1;
+
+#if __has_cpp_attribute(__cpp_lib_byteswap)
+template
+using byteswap = std::byteswap;
+#else
+template
+constexpr T byteswap (T const value_) noexcept
+{
+ static_assert (std::has_unique_object_representations_v, "T may not have padding bits");
+ auto buffer = std::bit_cast> (value_);
+ std::ranges::reverse (buffer);
+ return std::bit_cast (buffer);
+}
+#endif
+
+template
+constexpr T hton (T const value_) noexcept
+{
+ if constexpr (std::endian::native == std::endian::big)
+ return value_;
+ else
+ return byteswap (value_);
+}
+
+template
+constexpr T ntoh (T const value_) noexcept
+{
+ if constexpr (std::endian::native == std::endian::big)
+ return value_;
+ else
+ return byteswap (value_);
+}
+
+template
+void const *decode (void const *const buffer_, U &size_, T &out_, bool networkToHost_ = true)
+{
+ if (!buffer_)
+ return nullptr;
+
+ if (size_ < 0 || static_cast> (size_) < sizeof (T))
+ return nullptr;
+
+ std::memcpy (&out_, buffer_, sizeof (T));
+
+ if (networkToHost_)
+ out_ = ntoh (out_);
+
+ size_ -= sizeof (T);
+ return static_cast (buffer_) + sizeof (T);
+}
+
+template
+void const *decode (void const *buffer_, T &size_, std::string &out_)
+{
+ auto p = static_cast (buffer_);
+ auto const end = p + size_;
+
+ std::string result;
+ result.reserve (size_);
+
+ while (p < end && *p)
+ {
+ auto const len = *p++;
+
+ // punt on compressed labels
+ if (len & 0xC0)
+ return nullptr;
+
+ if (p + len >= end)
+ return nullptr;
+
+ if (!result.empty ())
+ result.push_back ('.');
+
+ result.insert (std::end (result), p, p + len);
+ p += len;
+ }
+
+ ++p;
+
+ out_ = std::move (result);
+
+ size_ = end - p;
+ return p;
+}
+
+template
+void *encode (void *const buffer_, U &size_, T in_, bool hostToNetwork_ = true)
+{
+ if (!buffer_)
+ return nullptr;
+
+ if (size_ < sizeof (T))
+ return nullptr;
+
+ if (hostToNetwork_)
+ in_ = hton (in_);
+
+ std::memcpy (buffer_, &in_, sizeof (T));
+
+ size_ -= sizeof (T);
+ return static_cast (buffer_) + sizeof (T);
+}
+
+template
+void *encode (void *const buffer_, T &size_, std::string const &in_)
+{
+ // names are limited to 255 bytes
+ if (in_.size () > 0xFF)
+ return nullptr;
+
+ auto p = static_cast (buffer_);
+ auto const end = p + size_;
+
+ std::string::size_type prev = 0;
+ std::string::size_type pos = 0;
+ while (p < end && pos != std::string::npos)
+ {
+ pos = in_.find ('.', prev);
+
+ auto const label = std::string_view (in_).substr (prev, pos);
+
+ // labels are limited to 63 bytes
+ if (label.size () >= size_ || label.size () > 0x3F)
+ return nullptr;
+
+ p = static_cast (encode (p, size_, label.size ()));
+ if (!p)
+ return nullptr;
+
+ std::memcpy (p, label.data (), label.size ());
+
+ p += label.size ();
+
+ if (pos != std::string::npos)
+ prev = pos + 1;
+ }
+
+ if (p == end)
+ return nullptr;
+
+ *p++ = 0;
+
+ size_ = end - p;
+ return p;
+}
+
+struct DNSHeader
+{
+ std::uint16_t id{};
+ std::uint16_t flags{};
+ std::uint16_t qdCount{};
+ std::uint16_t anCount{};
+ std::uint16_t nsCount{};
+ std::uint16_t arCount{};
+
+ template
+ void const *decode (void const *const buffer_, T &size_)
+ {
+ auto in = ::decode (buffer_, size_, id);
+ in = ::decode (buffer_, size_, flags);
+ in = ::decode (buffer_, size_, qdCount);
+ in = ::decode (buffer_, size_, anCount);
+ in = ::decode (buffer_, size_, nsCount);
+ in = ::decode (buffer_, size_, arCount);
+
+ return buffer_;
+ }
+
+ template
+ void *encode (void *buffer_, T &size_)
+ {
+ buffer_ = ::encode (buffer_, size_, id);
+ buffer_ = ::encode (buffer_, size_, flags);
+ buffer_ = ::encode (buffer_, size_, qdCount);
+ buffer_ = ::encode (buffer_, size_, anCount);
+ buffer_ = ::encode (buffer_, size_, nsCount);
+ buffer_ = ::encode (buffer_, size_, arCount);
+
+ return buffer_;
+ }
+};
+
+struct QueryRecord
+{
+ std::string qname{};
+ std::uint16_t qtype{};
+ std::uint16_t qclass{};
+
+ template
+ void const *decode (void const *buffer_, T &size_)
+ {
+ buffer_ = ::decode (buffer_, size_, qname);
+ buffer_ = ::decode (buffer_, size_, qtype);
+ buffer_ = ::decode (buffer_, size_, qclass);
+
+ return buffer_;
+ }
+
+ template
+ void *encode (void *buffer_, T &size_)
+ {
+ buffer_ = ::encode (buffer_, size_, qname);
+ buffer_ = ::encode (buffer_, size_, qtype);
+ buffer_ = ::encode (buffer_, size_, qclass);
+
+ return buffer_;
+ }
+};
+
+struct ResourceRecord
+{
+ std::string rname{};
+ std::uint16_t rtype{};
+ std::uint16_t rclass{};
+ std::uint32_t rttl{};
+ std::uint16_t rlen{};
+ std::vector rdata{};
+
+ template
+ void const *decode (void const *buffer_, T &size_)
+ {
+ buffer_ = ::decode (buffer_, size_, rname);
+ buffer_ = ::decode (buffer_, size_, rtype);
+ buffer_ = ::decode (buffer_, size_, rclass);
+ buffer_ = ::decode (buffer_, size_, rttl);
+ buffer_ = ::decode (buffer_, size_, rlen);
+
+ return buffer_;
+ }
+
+ template
+ void *encode (void *buffer_, T &size_)
+ {
+ if (rttl > std::numeric_limits::max ())
+ return nullptr;
+
+ buffer_ = ::encode (buffer_, size_, rname);
+ buffer_ = ::encode (buffer_, size_, rtype);
+ buffer_ = ::encode (buffer_, size_, rclass);
+ buffer_ = ::encode (buffer_, size_, rttl);
+ buffer_ = ::encode (buffer_, size_, rlen);
+
+ if (rlen > size_)
+ return nullptr;
+
+ rdata.resize (rlen);
+ std::memcpy (rdata.data (), buffer_, rlen);
+
+ size_ -= rlen;
+ return static_cast (buffer_) + rlen;
+ }
+};
+
+void probe (Socket *const socket_, std::string const &qname_)
+{
+ std::vector response (65536);
+ auto available = response.size ();
+
+ auto out = DNSHeader{.qdCount = 1}.encode (response.data (), available);
+ out = QueryRecord{.qname = qname_, .qtype = 255, .qclass = 1}.encode (out, available);
+
+ if (!out)
+ return;
+
+ info ("Probe mDNS %s\n", qname_.c_str ());
+
+ socket_->writeTo (response.data (), response.size () - available, s_multicastAddress);
+ s_lastProbe = platform::steady_clock::now ();
+}
+
+void announce (Socket *const socket_,
+ SockAddr const *srcAddr_,
+ std::uint16_t const id_,
+ std::uint16_t const flags_,
+ QueryRecord const &record_,
+ SockAddr const &addr_)
+{
+ std::vector response (65536);
+ auto available = response.size ();
+
+ // header
+ auto out = encode (response.data (), available, id_);
+ out =
+ encode (out, available, flags_ | (1 << 15) | (1 << 10)); // mark response/AA
+ out = encode (out, available, 0);
+ out = encode (out, available, 1);
+ out = encode (out, available, 0);
+ out = encode (out, available, 0);
+
+ // answer section
+ out = encode (out, available, record_.qname);
+ out = encode (out, available, record_.qtype);
+ out = encode (out, available, record_.qclass | (1 << 15)); // mark unique/flush
+ out = encode (out, available, MDNS_TTL);
+ out = encode (out, available, sizeof (in_addr_t));
+ out = encode (
+ out, available, static_cast (addr_).sin_addr.s_addr, false);
+
+ if (!out)
+ return;
+
+ auto const preferUnicast = srcAddr_ && ((record_.qclass >> 15) & 0x1);
+
+ if (preferUnicast)
+ {
+ auto const name = std::string (addr_.name ());
+ info (
+ "Respond mDNS %s %s to %s\n", record_.qname.c_str (), name.c_str (), srcAddr_->name ());
+ socket_->writeTo (response.data (), response.size () - available, *srcAddr_);
+ }
+
+ auto const now = platform::steady_clock::now ();
+ if (!preferUnicast || now - s_lastAnnounce > std::chrono::seconds (MDNS_TTL / 4))
+ {
+ info ("Announce mDNS %s %s\n", record_.qname.c_str (), addr_.name ());
+ socket_->writeTo (response.data (), response.size () - available, s_multicastAddress);
+ s_lastAnnounce = now;
+ }
+}
+}
+
+void mdns::setHostname (std::string hostname_)
+{
+ if (hostname_.empty ())
+ hostname_ = platform::hostname ();
+
+ if (s_hostname == hostname_)
+ return;
+
+ s_hostname = std::move (hostname_);
+ s_hostnameLocal = s_hostname + ".local";
+
+ s_state = State::Probe1;
+ s_lastProbe = platform::steady_clock::now ();
+}
+
+UniqueSocket mdns::createSocket ()
+{
+ auto socket = Socket::create (Socket::eDatagram);
+ if (!socket)
+ return nullptr;
+
+ if (!socket->setReuseAddress ())
+ return nullptr;
+
+ auto iface = SockAddr::AnyIPv4;
+ iface.setPort (s_multicastAddress.port ());
+ if (!socket->bind (iface))
+ return nullptr;
+
+ if (!socket->joinMulticastGroup (s_multicastAddress, iface))
+ return nullptr;
+
+ s_state = State::Probe1;
+ s_lastProbe = platform::steady_clock::now ();
+
+ return socket;
+}
+
+void mdns::handleSocket (Socket *socket_, SockAddr const &addr_)
+{
+ if (!socket_)
+ return;
+
+ // only support IPv4 for now
+ if (addr_.domain () != SockAddr::Domain::IPv4)
+ return;
+
+ auto const now = platform::steady_clock::now ();
+
+ switch (s_state)
+ {
+ case State::Probe1:
+ case State::Probe2:
+ case State::Probe3:
+ if (now - s_lastProbe > 250ms)
+ {
+ probe (socket_, s_hostname);
+ s_state = static_cast (static_cast (s_state) + 1);
+ }
+ break;
+
+ case State::Announce1:
+ case State::Announce2:
+ if (now - s_lastAnnounce > 1s)
+ {
+ announce (socket_,
+ nullptr,
+ 0,
+ 0,
+ QueryRecord{.qname = s_hostname, .qtype = 1, .qclass = 1},
+ addr_);
+ s_state = static_cast (static_cast (s_state) + 1);
+ }
+
+ default:
+ break;
+ }
+
+ Socket::PollInfo pollInfo{*socket_, POLLIN, 0};
+ auto const rc = Socket::poll (&pollInfo, 1, 0ms);
+ if (rc <= 0 || !(pollInfo.revents & POLLIN))
+ return;
+
+ SockAddr srcAddr;
+ std::vector buffer (65536);
+ auto bytes = socket_->readFrom (buffer.data (), buffer.size (), srcAddr);
+ if (bytes <= 0)
+ return;
+
+ // only support IPv4 for now
+ if (srcAddr.domain () != SockAddr::Domain::IPv4)
+ return;
+
+ // ignore loopback
+ if (std::memcmp (&reinterpret_cast (srcAddr).sin_addr.s_addr,
+ &reinterpret_cast (addr_).sin_addr.s_addr,
+ sizeof (in_addr_t)) == 0)
+ return;
+
+ std::uint16_t id;
+ std::uint16_t flags;
+ std::uint16_t qdCount;
+ std::uint16_t anCount;
+ std::uint16_t nsCount;
+ std::uint16_t arCount;
+
+ // parse header
+ auto in = decode (buffer.data (), bytes, id);
+ in = decode (in, bytes, flags);
+ in = decode (in, bytes, qdCount);
+ in = decode (in, bytes, anCount);
+ in = decode (in, bytes, nsCount);
+ in = decode (in, bytes, arCount);
+
+ if (!in)
+ return;
+
+ auto const qr = (flags >> 15) & 0x1;
+
+ // ill-formed on queries and responses
+ auto const opcode = (flags >> 11) & 0xF;
+ if (opcode != 0)
+ return;
+
+ // ill-formed on queries
+ if (!qr && ((flags >> 10) & 0x1))
+ return;
+
+ // punt on truncated messages
+ if ((flags >> 9) & 0x1)
+ return;
+
+ // ill-formed on queries
+ if (!qr && ((flags >> 7) & 0x1))
+ return;
+
+ // must be zero
+ if ((flags >> 4) & 0x7)
+ return;
+
+ // ill-formed on queries and responses
+ if ((flags >> 0) & 0xF)
+ return;
+
+ // std::vector response (65536);
+ // void *out = response.data ();
+ // auto available = response.size ();
+
+ std::vector answers;
+
+ bool announced = false;
+ for (unsigned i = 0; i < qdCount; ++i)
+ {
+ QueryRecord record;
+ in = record.decode (in, bytes);
+
+ if (!in)
+ return;
+
+ // only respond to queries
+ if (qr)
+ continue;
+
+ // only accept A or ANY type
+ if (record.qtype != 1 && record.qtype != 255)
+ continue;
+
+ // only accept IN or ANY class
+ if ((record.qclass & 0x7FFF) != 1 && (record.qclass & 0x7FFF) != 255)
+ continue;
+
+ if (record.qname != s_hostname && record.qname != s_hostnameLocal)
+ continue;
+
+ if (!announced)
+ {
+ std::vector data (sizeof (in_addr_t));
+ auto n = data.size ();
+ encode (data.data (),
+ n,
+ static_cast (addr_).sin_addr.s_addr,
+ false);
+
+ answers.emplace_back (ResourceRecord{// answer
+ .rname = record.qname,
+ .rtype = 1,
+ .rclass = static_cast (1 | (1 << 15)),
+ .rttl = MDNS_TTL,
+ .rlen = sizeof (in_addr_t),
+ .rdata = std::move (data)});
+
+ announce (socket_, &srcAddr, id, flags, record, addr_);
+ announced = true;
+ }
+ }
+
+ for (unsigned i = 0; i < anCount; ++i)
+ {
+ ResourceRecord record;
+ in = record.decode (in, bytes);
+
+ if (!in)
+ return;
+ }
+}
diff --git a/source/sockAddr.cpp b/source/sockAddr.cpp
index bf56679..165f847 100644
--- a/source/sockAddr.cpp
+++ b/source/sockAddr.cpp
@@ -26,11 +26,79 @@
#include
#include
+#ifdef __3DS__
+static_assert (sizeof (sockaddr_storage) == 0x1c);
+#endif
+
+namespace
+{
+in_addr inaddr_any = {.s_addr = htonl (INADDR_ANY)};
+
+std::strong_ordering
+ strongMemCompare (void const *const a_, void const *const b_, std::size_t const size_)
+{
+ auto const cmp = std::memcmp (a_, b_, size_);
+ if (cmp < 0)
+ return std::strong_ordering::less;
+ if (cmp > 0)
+ return std::strong_ordering::greater;
+ return std::strong_ordering::equal;
+}
+}
+
///////////////////////////////////////////////////////////////////////////
+SockAddr const SockAddr::AnyIPv4{inaddr_any};
+
+#ifndef NO_IPV6
+SockAddr const SockAddr::AnyIPv6{in6addr_any};
+#endif
+
SockAddr::~SockAddr () = default;
SockAddr::SockAddr () = default;
+SockAddr::SockAddr (Domain const domain_)
+{
+ switch (domain_)
+ {
+ case Domain::IPv4:
+ *this = AnyIPv4;
+ break;
+
+#ifndef NO_IPV6
+ case Domain::IPv6:
+ *this = AnyIPv6;
+ break;
+#endif
+
+ default:
+ std::abort ();
+ }
+}
+
+SockAddr::SockAddr (in_addr_t const addr_, std::uint16_t const port_)
+ : SockAddr (in_addr{.s_addr = addr_}, port_)
+{
+}
+
+SockAddr::SockAddr (in_addr const &addr_, std::uint16_t const port_)
+{
+ std::memset (&m_addr, 0, sizeof (m_addr));
+ m_addr.ss_family = AF_INET;
+ setAddr (addr_);
+ setPort (port_);
+}
+
+#ifndef NO_IPV6
+SockAddr::SockAddr (in6_addr const &addr_, std::uint16_t const port_)
+{
+ std::memset (&m_addr, 0, sizeof (m_addr));
+ m_addr.ss_family = AF_INET6;
+ setAddr (addr_);
+ setPort (port_);
+}
+#endif
+
SockAddr::SockAddr (SockAddr const &that_) = default;
SockAddr::SockAddr (SockAddr &&that_) = default;
@@ -101,36 +169,24 @@ SockAddr::operator sockaddr const * () const
return reinterpret_cast (&m_addr);
}
-void SockAddr::setPort (std::uint16_t const port_)
+bool SockAddr::operator== (SockAddr const &that_) const
{
+ if (m_addr.ss_family != that_.m_addr.ss_family)
+ return false;
+
switch (m_addr.ss_family)
{
case AF_INET:
- reinterpret_cast (&m_addr)->sin_port = htons (port_);
- break;
+ if (port () != that_.port ())
+ return false;
+
+ // ignore sin_zero
+ return static_cast (*this).sin_addr.s_addr ==
+ static_cast (that_).sin_addr.s_addr;
#ifndef NO_IPV6
case AF_INET6:
- reinterpret_cast (&m_addr)->sin6_port = htons (port_);
- break;
-#endif
-
- default:
- std::abort ();
- break;
- }
-}
-
-socklen_t SockAddr::size () const
-{
- switch (m_addr.ss_family)
- {
- case AF_INET:
- return sizeof (struct sockaddr_in);
-
-#ifndef NO_IPV6
- case AF_INET6:
- return sizeof (struct sockaddr_in6);
+ return std::memcmp (&m_addr, &that_.m_addr, sizeof (sockaddr_in6)) == 0;
#endif
default:
@@ -138,6 +194,85 @@ socklen_t SockAddr::size () const
}
}
+std::strong_ordering SockAddr::operator<=> (SockAddr const &that_) const
+{
+ if (m_addr.ss_family != that_.m_addr.ss_family)
+ return m_addr.ss_family <=> that_.m_addr.ss_family;
+
+ switch (m_addr.ss_family)
+ {
+ case AF_INET:
+ {
+ auto const cmp =
+ strongMemCompare (&static_cast (*this).sin_addr.s_addr,
+ &static_cast (that_).sin_addr.s_addr,
+ sizeof (in_addr_t));
+
+ if (cmp != std::strong_ordering::equal)
+ return cmp;
+
+ return port () <=> that_.port ();
+ }
+
+#ifndef NO_IPV6
+ case AF_INET6:
+ {
+ auto const &addr1 = static_cast (*this);
+ auto const &addr2 = static_cast (that_);
+
+ if (auto const cmp =
+ strongMemCompare (&addr1.sin6_addr, &addr2.sin6_addr, sizeof (in6_addr));
+ cmp != std::strong_ordering::equal)
+ return cmp;
+
+ auto const p1 = port ();
+ auto const p2 = that_.port ();
+
+ if (p1 < p2)
+ return std::strong_ordering::less;
+ else if (p1 > p2)
+ return std::strong_ordering::greater;
+
+ if (auto const cmp = strongMemCompare (
+ &addr1.sin6_flowinfo, &addr2.sin6_flowinfo, sizeof (std::uint32_t));
+ cmp != std::strong_ordering::equal)
+ return cmp;
+
+ return strongMemCompare (
+ &addr1.sin6_flowinfo, &addr2.sin6_flowinfo, sizeof (std::uint32_t));
+ }
+#endif
+
+ default:
+ std::abort ();
+ }
+}
+
+void SockAddr::setAddr (in_addr_t const addr_)
+{
+ setAddr (in_addr{.s_addr = addr_});
+}
+
+void SockAddr::setAddr (in_addr const &addr_)
+{
+ if (m_addr.ss_family != AF_INET)
+ std::abort ();
+
+ std::memcpy (&reinterpret_cast (m_addr).sin_addr, &addr_, sizeof (addr_));
+ ;
+}
+
+#ifndef NO_IPV6
+void SockAddr::setAddr (in6_addr const &addr_)
+{
+ if (m_addr.ss_family != AF_INET6)
+ std::abort ();
+
+ std::memcpy (&reinterpret_cast (m_addr).sin6_addr, &addr_, sizeof (addr_));
+ ;
+}
+#endif
+
std::uint16_t SockAddr::port () const
{
switch (m_addr.ss_family)
@@ -152,7 +287,57 @@ std::uint16_t SockAddr::port () const
default:
std::abort ();
+ }
+}
+
+void SockAddr::setPort (std::uint16_t const port_)
+{
+ switch (m_addr.ss_family)
+ {
+ case AF_INET:
+ reinterpret_cast (&m_addr)->sin_port = htons (port_);
break;
+
+#ifndef NO_IPV6
+ case AF_INET6:
+ reinterpret_cast (&m_addr)->sin6_port = htons (port_);
+ break;
+#endif
+
+ default:
+ std::abort ();
+ }
+}
+
+SockAddr::Domain SockAddr::domain () const
+{
+ switch (m_addr.ss_family)
+ {
+ case AF_INET:
+#ifndef NO_IPV6
+ case AF_INET6:
+#endif
+ return static_cast (m_addr.ss_family);
+
+ default:
+ std::abort ();
+ }
+}
+
+socklen_t SockAddr::size () const
+{
+ switch (m_addr.ss_family)
+ {
+ case AF_INET:
+ return sizeof (sockaddr_in);
+
+#ifndef NO_IPV6
+ case AF_INET6:
+ return sizeof (sockaddr_in6);
+#endif
+
+ default:
+ std::abort ();
}
}
diff --git a/source/socket.cpp b/source/socket.cpp
index 14a6cad..36df5ab 100644
--- a/source/socket.cpp
+++ b/source/socket.cpp
@@ -255,6 +255,38 @@ bool Socket::setSendBufferSize (std::size_t const size_)
return true;
}
+#ifndef __NDS__
+bool Socket::joinMulticastGroup (SockAddr const &addr_, SockAddr const &iface_)
+{
+ ip_mreq group;
+ group.imr_multiaddr = static_cast (addr_).sin_addr;
+ group.imr_interface = static_cast (iface_).sin_addr;
+
+ if (::setsockopt (m_fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof (group)) != 0)
+ {
+ error ("setsockopt(IP_ADD_MEMBERSHIP, %s): %s\n", addr_.name (), std::strerror (errno));
+ return false;
+ }
+
+ return true;
+}
+
+bool Socket::dropMulticastGroup (SockAddr const &addr_, SockAddr const &iface_)
+{
+ ip_mreq group;
+ group.imr_multiaddr = static_cast (addr_).sin_addr;
+ group.imr_interface = static_cast (iface_).sin_addr;
+
+ if (::setsockopt (m_fd, IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, sizeof (group)) != 0)
+ {
+ error ("setsockopt(IP_DROP_MEMBERSHIP, %s): %s\n", addr_.name (), std::strerror (errno));
+ return false;
+ }
+
+ return true;
+}
+#endif
+
std::make_signed_t
Socket::read (void *const buffer_, std::size_t const size_, bool const oob_)
{
@@ -279,6 +311,21 @@ std::make_signed_t Socket::read (IOBuffer &buffer_, bool const oob_
return rc;
}
+std::make_signed_t
+ Socket::readFrom (void *const buffer_, std::size_t const size_, SockAddr &addr_)
+{
+ assert (buffer_);
+ assert (size_);
+
+ socklen_t addrLen = sizeof (sockaddr_storage);
+
+ auto const rc = ::recvfrom (m_fd, buffer_, size_, 0, addr_, &addrLen);
+ if (rc < 0 && errno != EWOULDBLOCK)
+ error ("recvfrom: %s\n", std::strerror (errno));
+
+ return rc;
+}
+
std::make_signed_t Socket::write (void const *const buffer_, std::size_t const size_)
{
assert (buffer_);
@@ -302,6 +349,19 @@ std::make_signed_t Socket::write (IOBuffer &buffer_)
return rc;
}
+std::make_signed_t
+ Socket::writeTo (void const *buffer_, std::size_t size_, SockAddr const &addr_)
+{
+ assert (buffer_);
+ assert (size_ > 0);
+
+ auto const rc = ::sendto (m_fd, buffer_, size_, 0, addr_, addr_.size ());
+ if (rc < 0 && errno != EWOULDBLOCK)
+ error ("sendto: %s\n", std::strerror (errno));
+
+ return rc;
+}
+
SockAddr const &Socket::sockName () const
{
return m_sockName;
@@ -312,9 +372,9 @@ SockAddr const &Socket::peerName () const
return m_peerName;
}
-UniqueSocket Socket::create ()
+UniqueSocket Socket::create (Type const type_)
{
- auto const fd = ::socket (AF_INET, SOCK_STREAM, 0);
+ auto const fd = ::socket (AF_INET, static_cast (type_), 0);
if (fd < 0)
{
error ("socket: %s\n", std::strerror (errno));
diff --git a/source/switch/platform.cpp b/source/switch/platform.cpp
index 2d8f6b9..ccf1fd1 100644
--- a/source/switch/platform.cpp
+++ b/source/switch/platform.cpp
@@ -793,6 +793,12 @@ bool platform::networkAddress (SockAddr &addr_)
return true;
}
+std::string const &platform::hostname ()
+{
+ static std::string const hostname = "switch-ftpd";
+ return hostname;
+}
+
bool platform::loop ()
{
if (!appletMainLoop ())