diff --git a/CMakeLists.txt b/CMakeLists.txt index e149d1c..f7e8414 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,6 +84,13 @@ target_sources(${FTPD_TARGET} PRIVATE source/socket.cpp ) +if(NOT NINTENDO_DS) + target_sources(${FTPD_TARGET} PRIVATE + source/mdns.cpp + include/mdns.h + ) +endif() + if(NOT FTPD_CLASSIC AND NOT NINTENDO_DS) target_sources(${FTPD_TARGET} PRIVATE include/imconfig.h diff --git a/include/ftpConfig.h b/include/ftpConfig.h index 7937c5d..57fef91 100644 --- a/include/ftpConfig.h +++ b/include/ftpConfig.h @@ -60,6 +60,9 @@ public: /// \brief Get password std::string const &pass () const; + /// \brief Get hostname + std::string const &hostname () const; + /// \brief Get port std::uint16_t port () const; @@ -88,6 +91,10 @@ public: /// \param pass_ Password void setPass (std::string pass_); + /// \brief Set hostname + /// \param hostname_ Hostname + void setHostname (std::string hostname_); + /// \brief Set listen port /// \param port_ Listen port bool setPort (std::string_view port_); @@ -130,6 +137,9 @@ private: /// \brief Password std::string m_pass; + /// \brief Hostname + std::string m_hostname; + /// \brief Listen port std::uint16_t m_port; diff --git a/include/ftpServer.h b/include/ftpServer.h index 729b090..a866d7e 100644 --- a/include/ftpServer.h +++ b/include/ftpServer.h @@ -110,6 +110,11 @@ private: /// \brief Listen socket UniqueSocket m_socket; +#ifndef __NDS__ + /// \brief mDNS socket + UniqueSocket m_mdnsSocket; +#endif + /// \brief ImGui window name std::string m_name; @@ -151,6 +156,9 @@ private: /// \brief Password setting std::string m_passSetting; + /// \brief Hostname setting + std::string m_hostnameSetting; + /// \brief Port setting std::uint16_t m_portSetting = 0; diff --git a/include/mdns.h b/include/mdns.h new file mode 100644 index 0000000..4fd8ce5 --- /dev/null +++ b/include/mdns.h @@ -0,0 +1,40 @@ + +// ftpd is a server implementation based on the following: +// - RFC 959 (https://datatracker.ietf.org/doc/html/rfc959) +// - RFC 3659 (https://datatracker.ietf.org/doc/html/rfc3659) +// - suggested implementation details from https://cr.yp.to/ftp/filesystem.html +// +// ftpd implements mdns based on the following: +// - RFC 1035 (https://datatracker.ietf.org/doc/html/rfc1035) +// - RFC 6762 (https://datatracker.ietf.org/doc/html/rfc6762) +// +// Copyright (C) 2024 Michael Theall +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#pragma once + +#include "sockAddr.h" +#include "socket.h" + +#include + +namespace mdns +{ +void setHostname (std::string hostname_); + +UniqueSocket createSocket (); + +void handleSocket (Socket *socket_, SockAddr const &addr_); +} diff --git a/include/platform.h b/include/platform.h index e108433..eca7933 100644 --- a/include/platform.h +++ b/include/platform.h @@ -3,7 +3,7 @@ // - RFC 3659 (https://tools.ietf.org/html/rfc3659) // - suggested implementation details from https://cr.yp.to/ftp/filesystem.html // -// Copyright (C) 2023 Michael Theall +// Copyright (C) 2024 Michael Theall // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by @@ -34,6 +34,7 @@ #include #include #include +#include #ifdef CLASSIC extern PrintConsole g_statusConsole; @@ -71,6 +72,9 @@ bool networkVisible (); /// \param[out] addr_ Network address bool networkAddress (SockAddr &addr_); +/// \brief Get hostname +std::string const &hostname (); + /// \brief Platform loop bool loop (); diff --git a/include/sockAddr.h b/include/sockAddr.h index ad15cbd..aa05517 100644 --- a/include/sockAddr.h +++ b/include/sockAddr.h @@ -23,6 +23,7 @@ #include #include +#include #include #ifdef __NDS__ @@ -37,10 +38,48 @@ struct sockaddr_storage class SockAddr { public: + enum class Domain + { + IPv4 = AF_INET, +#ifndef NO_IPV6 + IPv6 = AF_INET6, +#endif + }; + + /// \brief 0.0.0.0 + static SockAddr const AnyIPv4; + +#ifndef NO_IPV6 + /// \brief :: + static SockAddr const AnyIPv6; +#endif + ~SockAddr (); SockAddr (); + /// \brief Parameterized constructor + /// \param domain_ Socket domain + /// \note Initial address is INADDR_ANY/in6addr_any + SockAddr (Domain domain_); + + /// \brief Parameterized constructor + /// \param addr_ Socket address (network byte order) + /// \param port_ Socket port (host byte order) + SockAddr (in_addr_t addr_, std::uint16_t port_ = 0); + + /// \brief Parameterized constructor + /// \param addr_ Socket address (network byte order) + /// \param port_ Socket port (host byte order) + SockAddr (in_addr const &addr_, std::uint16_t port_ = 0); + +#ifndef NO_IPV6 + /// \brief Parameterized constructor + /// \param addr_ Socket address + /// \param port_ Socket port (host byte order) + SockAddr (in6_addr const &addr_, std::uint16_t port_ = 0); +#endif + /// \brief Copy constructor /// \param that_ Object to copy SockAddr (SockAddr const &that_); @@ -88,10 +127,32 @@ public: /// \brief sockaddr const* cast operator (network byte order) operator sockaddr const * () const; + /// \brief Equality operator + bool operator== (SockAddr const &that_) const; + + /// \brief Comparison operator + std::strong_ordering operator<=> (SockAddr const &that_) const; + + /// \brief sockaddr domain + Domain domain () const; /// \brief sockaddr size socklen_t size () const; + /// \brief Set address + /// \param addr_ Address to set (network byte order) + void setAddr (in_addr_t addr_); + + /// \brief Set address + /// \param addr_ Address to set (network byte order) + void setAddr (in_addr const &addr_); + +#ifndef NO_IPV6 + /// \brief Set address + /// \param addr_ Address to set (network byte order) + void setAddr (in6_addr const &addr_); +#endif + /// \brief Address port (host byte order) std::uint16_t port () const; diff --git a/include/socket.h b/include/socket.h index 6d91680..1e9973f 100644 --- a/include/socket.h +++ b/include/socket.h @@ -56,6 +56,12 @@ using SharedSocket = std::shared_ptr; class Socket { public: + enum Type + { + eStream = SOCK_STREAM, ///< Stream socket + eDatagram = SOCK_DGRAM, ///< Datagram socket + }; + /// \brief Poll info struct PollInfo { @@ -114,6 +120,18 @@ public: /// \param size_ Buffer size bool setSendBufferSize (std::size_t size_); +#ifndef __NDS__ + /// \brief Join multicast group + /// \param addr_ Multicast group address + /// \param iface_ Interface address + bool joinMulticastGroup (SockAddr const &addr_, SockAddr const &iface_); + + /// \brief Drop multicast group + /// \param addr_ Multicast group address + /// \param iface_ Interface address + bool dropMulticastGroup (SockAddr const &addr_, SockAddr const &iface_); +#endif + /// \brief Read data /// \param buffer_ Output buffer /// \param size_ Size to read @@ -125,6 +143,12 @@ public: /// \param oob_ Whether to read from out-of-band std::make_signed_t read (IOBuffer &buffer_, bool oob_ = false); + /// \brief Read data + /// \param buffer_ Output buffer + /// \param size_ Size to read + /// \param[out] addr_ Source address + std::make_signed_t readFrom (void *buffer_, std::size_t size_, SockAddr &addr_); + /// \brief Write data /// \param buffer_ Input buffer /// \param size_ Size to write @@ -135,13 +159,21 @@ public: /// \param size_ Size to write std::make_signed_t write (IOBuffer &buffer_); + /// \brief Write data + /// \param buffer_ Input buffer + /// \param size_ Size to write + /// \param[out] addr_ Destination address + std::make_signed_t + writeTo (void const *buffer_, std::size_t size_, SockAddr const &addr_); + /// \brief Local name SockAddr const &sockName () const; /// \brief Peer name SockAddr const &peerName () const; /// \brief Create socket - static UniqueSocket create (); + /// \param type_ Socket type + static UniqueSocket create (Type type_); /// \brief Poll sockets /// \param info_ Poll info diff --git a/source/3ds/platform.cpp b/source/3ds/platform.cpp index c1012af..2a18770 100644 --- a/source/3ds/platform.cpp +++ b/source/3ds/platform.cpp @@ -577,6 +577,12 @@ bool platform::networkAddress (SockAddr &addr_) return true; } +std::string const &platform::hostname () +{ + static std::string const hostname = "3ds-ftpd"; + return hostname; +} + bool platform::loop () { if (!aptMainLoop ()) diff --git a/source/ftpConfig.cpp b/source/ftpConfig.cpp index eca6b89..40ca6de 100644 --- a/source/ftpConfig.cpp +++ b/source/ftpConfig.cpp @@ -232,6 +232,11 @@ std::string const &FtpConfig::pass () const return m_pass; } +std::string const &FtpConfig::hostname () const +{ + return m_hostname; +} + std::uint16_t FtpConfig::port () const { return m_port; @@ -271,6 +276,11 @@ void FtpConfig::setPass (std::string pass_) m_pass = std::move (pass_); } +void FtpConfig::setHostname (std::string hostname_) +{ + m_hostname = std::move (hostname_); +} + bool FtpConfig::setPort (std::string_view const port_) { std::uint16_t parsed{}; diff --git a/source/ftpServer.cpp b/source/ftpServer.cpp index 1c96c7d..f753783 100644 --- a/source/ftpServer.cpp +++ b/source/ftpServer.cpp @@ -29,6 +29,10 @@ #include "sockAddr.h" #include "socket.h" +#ifndef __NDS__ +#include "mdns.h" +#endif + #include "imgui.h" #ifdef __NDS__ @@ -223,8 +227,14 @@ FtpServer::~FtpServer () FtpServer::FtpServer (UniqueFtpConfig config_) : m_config (std::move (config_)) +#ifndef CLASSIC + , + m_hostnameSetting (m_config->hostname ()) +#endif { #ifndef __NDS__ + mdns::setHostname (m_config->hostname ()); + m_thread = platform::Thread (std::bind (&FtpServer::threadFunc, this)); #endif @@ -439,7 +449,7 @@ void FtpServer::handleNetworkFound () addr.setPort (port); - auto socket = Socket::create (); + auto socket = Socket::create (Socket::eStream); if (!socket) return; @@ -461,6 +471,14 @@ void FtpServer::handleNetworkFound () info ("Started server at %s\n", m_name.c_str ()); LOCKED (m_socket = std::move (socket)); + +#ifndef __NDS__ + socket = mdns::createSocket (); + if (!socket) + return; + + LOCKED (m_mdnsSocket = std::move (socket)); +#endif } void FtpServer::handleNetworkLost () @@ -476,6 +494,11 @@ void FtpServer::handleNetworkLost () // destroy command socket LOCKED (sock = std::move (m_socket)); + +#ifndef __NDS__ + // destroy mDNS socket + LOCKED (sock = std::move (m_mdnsSocket)); +#endif } info ("Stopped server at %s\n", m_name.c_str ()); @@ -574,6 +597,9 @@ void FtpServer::showMenu () m_passSetting = m_config->pass (); m_passSetting.resize (32); + m_hostnameSetting = m_config->hostname (); + m_hostnameSetting.resize (32); + m_portSetting = m_config->port (); #ifdef __3DS__ @@ -631,6 +657,11 @@ void FtpServer::showSettings () m_passSetting.size (), ImGuiInputTextFlags_AutoSelectAll | ImGuiInputTextFlags_Password); + ImGui::InputText ("Hostname", + m_hostnameSetting.data (), + m_hostnameSetting.size (), + ImGuiInputTextFlags_AutoSelectAll); + ImGui::InputScalar ("Port", ImGuiDataType_U16, &m_portSetting, @@ -703,6 +734,7 @@ void FtpServer::showSettings () m_config->setUser (m_userSetting); m_config->setPass (m_passSetting); + m_config->setHostname (m_hostnameSetting); m_config->setPort (m_portSetting); #ifdef __3DS__ @@ -718,6 +750,8 @@ void FtpServer::showSettings () UniqueSocket socket; LOCKED (socket = std::move (m_socket)); + + mdns::setHostname (m_hostnameSetting); } if (save) @@ -733,9 +767,10 @@ void FtpServer::showSettings () { static auto const defaults = FtpConfig::create (); - m_userSetting = defaults->user (); - m_passSetting = defaults->pass (); - m_portSetting = defaults->port (); + m_userSetting = defaults->user (); + m_passSetting = defaults->pass (); + m_hostnameSetting = defaults->hostname (); + m_portSetting = defaults->port (); #ifdef __3DS__ m_getMTimeSetting = defaults->getMTime (); #endif @@ -966,6 +1001,12 @@ void FtpServer::loop () } } +#ifndef __NDS__ + // poll mDNS socket + if (m_socket && m_mdnsSocket) + mdns::handleSocket (m_mdnsSocket.get (), m_socket->sockName ()); +#endif + { std::vector deadSessions; { diff --git a/source/ftpSession.cpp b/source/ftpSession.cpp index b8bfab3..5f769a1 100644 --- a/source/ftpSession.cpp +++ b/source/ftpSession.cpp @@ -22,6 +22,7 @@ #include "ftpServer.h" #include "log.h" +#include "mdns.h" #include "platform.h" #include "imgui.h" @@ -837,7 +838,7 @@ bool FtpSession::dataConnect () m_port = false; - auto data = Socket::create (); + auto data = Socket::create (Socket::eStream); LOCKED (m_dataSocket = std::move (data)); if (!m_dataSocket) return false; @@ -2384,7 +2385,7 @@ void FtpSession::PASV (char const *args_) m_port = false; // create a socket to listen on - auto pasv = Socket::create (); + auto pasv = Socket::create (Socket::eStream); LOCKED (m_pasvSocket = std::move (pasv)); if (!m_pasvSocket) { @@ -2726,6 +2727,9 @@ void FtpSession::SITE (char const *args_) " Set username: SITE USER \r\n" " Set password: SITE PASS \r\n" " Set port: SITE PORT \r\n" +#ifndef __NDS__ + " Set hostname: SITE HOST \r\n" +#endif #ifdef __3DS__ " Set getMTime: SITE MTIME [0|1]\r\n" #endif @@ -2784,6 +2788,16 @@ void FtpSession::SITE (char const *args_) sendResponse ("200 OK\r\n"); return; } +#ifndef __NDS__ + else if (compare (command, "HOST") == 0) + { + { + auto const lock = m_config.lockGuard (); + m_config.setHostname (std::string (arg)); + mdns::setHostname (std::string (arg)); + } + } +#endif #ifdef __3DS__ else if (compare (command, "MTIME") == 0) { diff --git a/source/linux/platform.cpp b/source/linux/platform.cpp index b2d72ce..fc62ed3 100644 --- a/source/linux/platform.cpp +++ b/source/linux/platform.cpp @@ -204,6 +204,24 @@ bool platform::networkAddress (SockAddr &addr_) return true; } +std::string const &platform::hostname () +{ + static std::string hostname = "switch-ftpd.local"; + if (hostname.empty ()) + { + std::string buffer (256, '\0'); + gethostname (buffer.data (), buffer.size ()); + + if (buffer.back () == 0) // check for truncation + { + hostname = std::move (buffer); + hostname.resize (std::strlen (hostname.data ())); + } + } + + return hostname; +} + bool platform::loop () { bool inactive; diff --git a/source/mdns.cpp b/source/mdns.cpp new file mode 100644 index 0000000..10cd111 --- /dev/null +++ b/source/mdns.cpp @@ -0,0 +1,599 @@ +// ftpd is a server implementation based on the following: +// - RFC 959 (https://datatracker.ietf.org/doc/html/rfc959) +// - RFC 3659 (https://datatracker.ietf.org/doc/html/rfc3659) +// - suggested implementation details from https://cr.yp.to/ftp/filesystem.html +// +// ftpd implements mdns based on the following: +// - RFC 1035 (https://datatracker.ietf.org/doc/html/rfc1035) +// - RFC 6762 (https://datatracker.ietf.org/doc/html/rfc6762) +// +// Copyright (C) 2024 Michael Theall +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#include "mdns.h" + +#include "log.h" +#include "platform.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace std::chrono_literals; + +static_assert ( + std::endian::native == std::endian::big || std::endian::native == std::endian::little); + +static_assert (sizeof (in_addr_t) == 4); + +namespace +{ +constexpr auto MDNS_TTL = 120; + +SockAddr const s_multicastAddress{inet_addr ("224.0.0.251"), 5353}; + +platform::steady_clock::time_point s_lastAnnounce{}; +platform::steady_clock::time_point s_lastProbe{}; + +std::string s_hostname = platform::hostname (); +std::string s_hostnameLocal = s_hostname + ".local"; + +enum class State +{ + Probe1, + Probe2, + Probe3, + Announce1, + Announce2, + Complete, + Conflict, +}; + +auto s_state = State::Probe1; + +#if __has_cpp_attribute(__cpp_lib_byteswap) +template +using byteswap = std::byteswap; +#else +template +constexpr T byteswap (T const value_) noexcept +{ + static_assert (std::has_unique_object_representations_v, "T may not have padding bits"); + auto buffer = std::bit_cast> (value_); + std::ranges::reverse (buffer); + return std::bit_cast (buffer); +} +#endif + +template +constexpr T hton (T const value_) noexcept +{ + if constexpr (std::endian::native == std::endian::big) + return value_; + else + return byteswap (value_); +} + +template +constexpr T ntoh (T const value_) noexcept +{ + if constexpr (std::endian::native == std::endian::big) + return value_; + else + return byteswap (value_); +} + +template +void const *decode (void const *const buffer_, U &size_, T &out_, bool networkToHost_ = true) +{ + if (!buffer_) + return nullptr; + + if (size_ < 0 || static_cast> (size_) < sizeof (T)) + return nullptr; + + std::memcpy (&out_, buffer_, sizeof (T)); + + if (networkToHost_) + out_ = ntoh (out_); + + size_ -= sizeof (T); + return static_cast (buffer_) + sizeof (T); +} + +template +void const *decode (void const *buffer_, T &size_, std::string &out_) +{ + auto p = static_cast (buffer_); + auto const end = p + size_; + + std::string result; + result.reserve (size_); + + while (p < end && *p) + { + auto const len = *p++; + + // punt on compressed labels + if (len & 0xC0) + return nullptr; + + if (p + len >= end) + return nullptr; + + if (!result.empty ()) + result.push_back ('.'); + + result.insert (std::end (result), p, p + len); + p += len; + } + + ++p; + + out_ = std::move (result); + + size_ = end - p; + return p; +} + +template +void *encode (void *const buffer_, U &size_, T in_, bool hostToNetwork_ = true) +{ + if (!buffer_) + return nullptr; + + if (size_ < sizeof (T)) + return nullptr; + + if (hostToNetwork_) + in_ = hton (in_); + + std::memcpy (buffer_, &in_, sizeof (T)); + + size_ -= sizeof (T); + return static_cast (buffer_) + sizeof (T); +} + +template +void *encode (void *const buffer_, T &size_, std::string const &in_) +{ + // names are limited to 255 bytes + if (in_.size () > 0xFF) + return nullptr; + + auto p = static_cast (buffer_); + auto const end = p + size_; + + std::string::size_type prev = 0; + std::string::size_type pos = 0; + while (p < end && pos != std::string::npos) + { + pos = in_.find ('.', prev); + + auto const label = std::string_view (in_).substr (prev, pos); + + // labels are limited to 63 bytes + if (label.size () >= size_ || label.size () > 0x3F) + return nullptr; + + p = static_cast (encode (p, size_, label.size ())); + if (!p) + return nullptr; + + std::memcpy (p, label.data (), label.size ()); + + p += label.size (); + + if (pos != std::string::npos) + prev = pos + 1; + } + + if (p == end) + return nullptr; + + *p++ = 0; + + size_ = end - p; + return p; +} + +struct DNSHeader +{ + std::uint16_t id{}; + std::uint16_t flags{}; + std::uint16_t qdCount{}; + std::uint16_t anCount{}; + std::uint16_t nsCount{}; + std::uint16_t arCount{}; + + template + void const *decode (void const *const buffer_, T &size_) + { + auto in = ::decode (buffer_, size_, id); + in = ::decode (buffer_, size_, flags); + in = ::decode (buffer_, size_, qdCount); + in = ::decode (buffer_, size_, anCount); + in = ::decode (buffer_, size_, nsCount); + in = ::decode (buffer_, size_, arCount); + + return buffer_; + } + + template + void *encode (void *buffer_, T &size_) + { + buffer_ = ::encode (buffer_, size_, id); + buffer_ = ::encode (buffer_, size_, flags); + buffer_ = ::encode (buffer_, size_, qdCount); + buffer_ = ::encode (buffer_, size_, anCount); + buffer_ = ::encode (buffer_, size_, nsCount); + buffer_ = ::encode (buffer_, size_, arCount); + + return buffer_; + } +}; + +struct QueryRecord +{ + std::string qname{}; + std::uint16_t qtype{}; + std::uint16_t qclass{}; + + template + void const *decode (void const *buffer_, T &size_) + { + buffer_ = ::decode (buffer_, size_, qname); + buffer_ = ::decode (buffer_, size_, qtype); + buffer_ = ::decode (buffer_, size_, qclass); + + return buffer_; + } + + template + void *encode (void *buffer_, T &size_) + { + buffer_ = ::encode (buffer_, size_, qname); + buffer_ = ::encode (buffer_, size_, qtype); + buffer_ = ::encode (buffer_, size_, qclass); + + return buffer_; + } +}; + +struct ResourceRecord +{ + std::string rname{}; + std::uint16_t rtype{}; + std::uint16_t rclass{}; + std::uint32_t rttl{}; + std::uint16_t rlen{}; + std::vector rdata{}; + + template + void const *decode (void const *buffer_, T &size_) + { + buffer_ = ::decode (buffer_, size_, rname); + buffer_ = ::decode (buffer_, size_, rtype); + buffer_ = ::decode (buffer_, size_, rclass); + buffer_ = ::decode (buffer_, size_, rttl); + buffer_ = ::decode (buffer_, size_, rlen); + + return buffer_; + } + + template + void *encode (void *buffer_, T &size_) + { + if (rttl > std::numeric_limits::max ()) + return nullptr; + + buffer_ = ::encode (buffer_, size_, rname); + buffer_ = ::encode (buffer_, size_, rtype); + buffer_ = ::encode (buffer_, size_, rclass); + buffer_ = ::encode (buffer_, size_, rttl); + buffer_ = ::encode (buffer_, size_, rlen); + + if (rlen > size_) + return nullptr; + + rdata.resize (rlen); + std::memcpy (rdata.data (), buffer_, rlen); + + size_ -= rlen; + return static_cast (buffer_) + rlen; + } +}; + +void probe (Socket *const socket_, std::string const &qname_) +{ + std::vector response (65536); + auto available = response.size (); + + auto out = DNSHeader{.qdCount = 1}.encode (response.data (), available); + out = QueryRecord{.qname = qname_, .qtype = 255, .qclass = 1}.encode (out, available); + + if (!out) + return; + + info ("Probe mDNS %s\n", qname_.c_str ()); + + socket_->writeTo (response.data (), response.size () - available, s_multicastAddress); + s_lastProbe = platform::steady_clock::now (); +} + +void announce (Socket *const socket_, + SockAddr const *srcAddr_, + std::uint16_t const id_, + std::uint16_t const flags_, + QueryRecord const &record_, + SockAddr const &addr_) +{ + std::vector response (65536); + auto available = response.size (); + + // header + auto out = encode (response.data (), available, id_); + out = + encode (out, available, flags_ | (1 << 15) | (1 << 10)); // mark response/AA + out = encode (out, available, 0); + out = encode (out, available, 1); + out = encode (out, available, 0); + out = encode (out, available, 0); + + // answer section + out = encode (out, available, record_.qname); + out = encode (out, available, record_.qtype); + out = encode (out, available, record_.qclass | (1 << 15)); // mark unique/flush + out = encode (out, available, MDNS_TTL); + out = encode (out, available, sizeof (in_addr_t)); + out = encode ( + out, available, static_cast (addr_).sin_addr.s_addr, false); + + if (!out) + return; + + auto const preferUnicast = srcAddr_ && ((record_.qclass >> 15) & 0x1); + + if (preferUnicast) + { + auto const name = std::string (addr_.name ()); + info ( + "Respond mDNS %s %s to %s\n", record_.qname.c_str (), name.c_str (), srcAddr_->name ()); + socket_->writeTo (response.data (), response.size () - available, *srcAddr_); + } + + auto const now = platform::steady_clock::now (); + if (!preferUnicast || now - s_lastAnnounce > std::chrono::seconds (MDNS_TTL / 4)) + { + info ("Announce mDNS %s %s\n", record_.qname.c_str (), addr_.name ()); + socket_->writeTo (response.data (), response.size () - available, s_multicastAddress); + s_lastAnnounce = now; + } +} +} + +void mdns::setHostname (std::string hostname_) +{ + if (hostname_.empty ()) + hostname_ = platform::hostname (); + + if (s_hostname == hostname_) + return; + + s_hostname = std::move (hostname_); + s_hostnameLocal = s_hostname + ".local"; + + s_state = State::Probe1; + s_lastProbe = platform::steady_clock::now (); +} + +UniqueSocket mdns::createSocket () +{ + auto socket = Socket::create (Socket::eDatagram); + if (!socket) + return nullptr; + + if (!socket->setReuseAddress ()) + return nullptr; + + auto iface = SockAddr::AnyIPv4; + iface.setPort (s_multicastAddress.port ()); + if (!socket->bind (iface)) + return nullptr; + + if (!socket->joinMulticastGroup (s_multicastAddress, iface)) + return nullptr; + + s_state = State::Probe1; + s_lastProbe = platform::steady_clock::now (); + + return socket; +} + +void mdns::handleSocket (Socket *socket_, SockAddr const &addr_) +{ + if (!socket_) + return; + + // only support IPv4 for now + if (addr_.domain () != SockAddr::Domain::IPv4) + return; + + auto const now = platform::steady_clock::now (); + + switch (s_state) + { + case State::Probe1: + case State::Probe2: + case State::Probe3: + if (now - s_lastProbe > 250ms) + { + probe (socket_, s_hostname); + s_state = static_cast (static_cast (s_state) + 1); + } + break; + + case State::Announce1: + case State::Announce2: + if (now - s_lastAnnounce > 1s) + { + announce (socket_, + nullptr, + 0, + 0, + QueryRecord{.qname = s_hostname, .qtype = 1, .qclass = 1}, + addr_); + s_state = static_cast (static_cast (s_state) + 1); + } + + default: + break; + } + + Socket::PollInfo pollInfo{*socket_, POLLIN, 0}; + auto const rc = Socket::poll (&pollInfo, 1, 0ms); + if (rc <= 0 || !(pollInfo.revents & POLLIN)) + return; + + SockAddr srcAddr; + std::vector buffer (65536); + auto bytes = socket_->readFrom (buffer.data (), buffer.size (), srcAddr); + if (bytes <= 0) + return; + + // only support IPv4 for now + if (srcAddr.domain () != SockAddr::Domain::IPv4) + return; + + // ignore loopback + if (std::memcmp (&reinterpret_cast (srcAddr).sin_addr.s_addr, + &reinterpret_cast (addr_).sin_addr.s_addr, + sizeof (in_addr_t)) == 0) + return; + + std::uint16_t id; + std::uint16_t flags; + std::uint16_t qdCount; + std::uint16_t anCount; + std::uint16_t nsCount; + std::uint16_t arCount; + + // parse header + auto in = decode (buffer.data (), bytes, id); + in = decode (in, bytes, flags); + in = decode (in, bytes, qdCount); + in = decode (in, bytes, anCount); + in = decode (in, bytes, nsCount); + in = decode (in, bytes, arCount); + + if (!in) + return; + + auto const qr = (flags >> 15) & 0x1; + + // ill-formed on queries and responses + auto const opcode = (flags >> 11) & 0xF; + if (opcode != 0) + return; + + // ill-formed on queries + if (!qr && ((flags >> 10) & 0x1)) + return; + + // punt on truncated messages + if ((flags >> 9) & 0x1) + return; + + // ill-formed on queries + if (!qr && ((flags >> 7) & 0x1)) + return; + + // must be zero + if ((flags >> 4) & 0x7) + return; + + // ill-formed on queries and responses + if ((flags >> 0) & 0xF) + return; + + // std::vector response (65536); + // void *out = response.data (); + // auto available = response.size (); + + std::vector answers; + + bool announced = false; + for (unsigned i = 0; i < qdCount; ++i) + { + QueryRecord record; + in = record.decode (in, bytes); + + if (!in) + return; + + // only respond to queries + if (qr) + continue; + + // only accept A or ANY type + if (record.qtype != 1 && record.qtype != 255) + continue; + + // only accept IN or ANY class + if ((record.qclass & 0x7FFF) != 1 && (record.qclass & 0x7FFF) != 255) + continue; + + if (record.qname != s_hostname && record.qname != s_hostnameLocal) + continue; + + if (!announced) + { + std::vector data (sizeof (in_addr_t)); + auto n = data.size (); + encode (data.data (), + n, + static_cast (addr_).sin_addr.s_addr, + false); + + answers.emplace_back (ResourceRecord{// answer + .rname = record.qname, + .rtype = 1, + .rclass = static_cast (1 | (1 << 15)), + .rttl = MDNS_TTL, + .rlen = sizeof (in_addr_t), + .rdata = std::move (data)}); + + announce (socket_, &srcAddr, id, flags, record, addr_); + announced = true; + } + } + + for (unsigned i = 0; i < anCount; ++i) + { + ResourceRecord record; + in = record.decode (in, bytes); + + if (!in) + return; + } +} diff --git a/source/sockAddr.cpp b/source/sockAddr.cpp index bf56679..165f847 100644 --- a/source/sockAddr.cpp +++ b/source/sockAddr.cpp @@ -26,11 +26,79 @@ #include #include +#ifdef __3DS__ +static_assert (sizeof (sockaddr_storage) == 0x1c); +#endif + +namespace +{ +in_addr inaddr_any = {.s_addr = htonl (INADDR_ANY)}; + +std::strong_ordering + strongMemCompare (void const *const a_, void const *const b_, std::size_t const size_) +{ + auto const cmp = std::memcmp (a_, b_, size_); + if (cmp < 0) + return std::strong_ordering::less; + if (cmp > 0) + return std::strong_ordering::greater; + return std::strong_ordering::equal; +} +} + /////////////////////////////////////////////////////////////////////////// +SockAddr const SockAddr::AnyIPv4{inaddr_any}; + +#ifndef NO_IPV6 +SockAddr const SockAddr::AnyIPv6{in6addr_any}; +#endif + SockAddr::~SockAddr () = default; SockAddr::SockAddr () = default; +SockAddr::SockAddr (Domain const domain_) +{ + switch (domain_) + { + case Domain::IPv4: + *this = AnyIPv4; + break; + +#ifndef NO_IPV6 + case Domain::IPv6: + *this = AnyIPv6; + break; +#endif + + default: + std::abort (); + } +} + +SockAddr::SockAddr (in_addr_t const addr_, std::uint16_t const port_) + : SockAddr (in_addr{.s_addr = addr_}, port_) +{ +} + +SockAddr::SockAddr (in_addr const &addr_, std::uint16_t const port_) +{ + std::memset (&m_addr, 0, sizeof (m_addr)); + m_addr.ss_family = AF_INET; + setAddr (addr_); + setPort (port_); +} + +#ifndef NO_IPV6 +SockAddr::SockAddr (in6_addr const &addr_, std::uint16_t const port_) +{ + std::memset (&m_addr, 0, sizeof (m_addr)); + m_addr.ss_family = AF_INET6; + setAddr (addr_); + setPort (port_); +} +#endif + SockAddr::SockAddr (SockAddr const &that_) = default; SockAddr::SockAddr (SockAddr &&that_) = default; @@ -101,36 +169,24 @@ SockAddr::operator sockaddr const * () const return reinterpret_cast (&m_addr); } -void SockAddr::setPort (std::uint16_t const port_) +bool SockAddr::operator== (SockAddr const &that_) const { + if (m_addr.ss_family != that_.m_addr.ss_family) + return false; + switch (m_addr.ss_family) { case AF_INET: - reinterpret_cast (&m_addr)->sin_port = htons (port_); - break; + if (port () != that_.port ()) + return false; + + // ignore sin_zero + return static_cast (*this).sin_addr.s_addr == + static_cast (that_).sin_addr.s_addr; #ifndef NO_IPV6 case AF_INET6: - reinterpret_cast (&m_addr)->sin6_port = htons (port_); - break; -#endif - - default: - std::abort (); - break; - } -} - -socklen_t SockAddr::size () const -{ - switch (m_addr.ss_family) - { - case AF_INET: - return sizeof (struct sockaddr_in); - -#ifndef NO_IPV6 - case AF_INET6: - return sizeof (struct sockaddr_in6); + return std::memcmp (&m_addr, &that_.m_addr, sizeof (sockaddr_in6)) == 0; #endif default: @@ -138,6 +194,85 @@ socklen_t SockAddr::size () const } } +std::strong_ordering SockAddr::operator<=> (SockAddr const &that_) const +{ + if (m_addr.ss_family != that_.m_addr.ss_family) + return m_addr.ss_family <=> that_.m_addr.ss_family; + + switch (m_addr.ss_family) + { + case AF_INET: + { + auto const cmp = + strongMemCompare (&static_cast (*this).sin_addr.s_addr, + &static_cast (that_).sin_addr.s_addr, + sizeof (in_addr_t)); + + if (cmp != std::strong_ordering::equal) + return cmp; + + return port () <=> that_.port (); + } + +#ifndef NO_IPV6 + case AF_INET6: + { + auto const &addr1 = static_cast (*this); + auto const &addr2 = static_cast (that_); + + if (auto const cmp = + strongMemCompare (&addr1.sin6_addr, &addr2.sin6_addr, sizeof (in6_addr)); + cmp != std::strong_ordering::equal) + return cmp; + + auto const p1 = port (); + auto const p2 = that_.port (); + + if (p1 < p2) + return std::strong_ordering::less; + else if (p1 > p2) + return std::strong_ordering::greater; + + if (auto const cmp = strongMemCompare ( + &addr1.sin6_flowinfo, &addr2.sin6_flowinfo, sizeof (std::uint32_t)); + cmp != std::strong_ordering::equal) + return cmp; + + return strongMemCompare ( + &addr1.sin6_flowinfo, &addr2.sin6_flowinfo, sizeof (std::uint32_t)); + } +#endif + + default: + std::abort (); + } +} + +void SockAddr::setAddr (in_addr_t const addr_) +{ + setAddr (in_addr{.s_addr = addr_}); +} + +void SockAddr::setAddr (in_addr const &addr_) +{ + if (m_addr.ss_family != AF_INET) + std::abort (); + + std::memcpy (&reinterpret_cast (m_addr).sin_addr, &addr_, sizeof (addr_)); + ; +} + +#ifndef NO_IPV6 +void SockAddr::setAddr (in6_addr const &addr_) +{ + if (m_addr.ss_family != AF_INET6) + std::abort (); + + std::memcpy (&reinterpret_cast (m_addr).sin6_addr, &addr_, sizeof (addr_)); + ; +} +#endif + std::uint16_t SockAddr::port () const { switch (m_addr.ss_family) @@ -152,7 +287,57 @@ std::uint16_t SockAddr::port () const default: std::abort (); + } +} + +void SockAddr::setPort (std::uint16_t const port_) +{ + switch (m_addr.ss_family) + { + case AF_INET: + reinterpret_cast (&m_addr)->sin_port = htons (port_); break; + +#ifndef NO_IPV6 + case AF_INET6: + reinterpret_cast (&m_addr)->sin6_port = htons (port_); + break; +#endif + + default: + std::abort (); + } +} + +SockAddr::Domain SockAddr::domain () const +{ + switch (m_addr.ss_family) + { + case AF_INET: +#ifndef NO_IPV6 + case AF_INET6: +#endif + return static_cast (m_addr.ss_family); + + default: + std::abort (); + } +} + +socklen_t SockAddr::size () const +{ + switch (m_addr.ss_family) + { + case AF_INET: + return sizeof (sockaddr_in); + +#ifndef NO_IPV6 + case AF_INET6: + return sizeof (sockaddr_in6); +#endif + + default: + std::abort (); } } diff --git a/source/socket.cpp b/source/socket.cpp index 14a6cad..36df5ab 100644 --- a/source/socket.cpp +++ b/source/socket.cpp @@ -255,6 +255,38 @@ bool Socket::setSendBufferSize (std::size_t const size_) return true; } +#ifndef __NDS__ +bool Socket::joinMulticastGroup (SockAddr const &addr_, SockAddr const &iface_) +{ + ip_mreq group; + group.imr_multiaddr = static_cast (addr_).sin_addr; + group.imr_interface = static_cast (iface_).sin_addr; + + if (::setsockopt (m_fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof (group)) != 0) + { + error ("setsockopt(IP_ADD_MEMBERSHIP, %s): %s\n", addr_.name (), std::strerror (errno)); + return false; + } + + return true; +} + +bool Socket::dropMulticastGroup (SockAddr const &addr_, SockAddr const &iface_) +{ + ip_mreq group; + group.imr_multiaddr = static_cast (addr_).sin_addr; + group.imr_interface = static_cast (iface_).sin_addr; + + if (::setsockopt (m_fd, IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, sizeof (group)) != 0) + { + error ("setsockopt(IP_DROP_MEMBERSHIP, %s): %s\n", addr_.name (), std::strerror (errno)); + return false; + } + + return true; +} +#endif + std::make_signed_t Socket::read (void *const buffer_, std::size_t const size_, bool const oob_) { @@ -279,6 +311,21 @@ std::make_signed_t Socket::read (IOBuffer &buffer_, bool const oob_ return rc; } +std::make_signed_t + Socket::readFrom (void *const buffer_, std::size_t const size_, SockAddr &addr_) +{ + assert (buffer_); + assert (size_); + + socklen_t addrLen = sizeof (sockaddr_storage); + + auto const rc = ::recvfrom (m_fd, buffer_, size_, 0, addr_, &addrLen); + if (rc < 0 && errno != EWOULDBLOCK) + error ("recvfrom: %s\n", std::strerror (errno)); + + return rc; +} + std::make_signed_t Socket::write (void const *const buffer_, std::size_t const size_) { assert (buffer_); @@ -302,6 +349,19 @@ std::make_signed_t Socket::write (IOBuffer &buffer_) return rc; } +std::make_signed_t + Socket::writeTo (void const *buffer_, std::size_t size_, SockAddr const &addr_) +{ + assert (buffer_); + assert (size_ > 0); + + auto const rc = ::sendto (m_fd, buffer_, size_, 0, addr_, addr_.size ()); + if (rc < 0 && errno != EWOULDBLOCK) + error ("sendto: %s\n", std::strerror (errno)); + + return rc; +} + SockAddr const &Socket::sockName () const { return m_sockName; @@ -312,9 +372,9 @@ SockAddr const &Socket::peerName () const return m_peerName; } -UniqueSocket Socket::create () +UniqueSocket Socket::create (Type const type_) { - auto const fd = ::socket (AF_INET, SOCK_STREAM, 0); + auto const fd = ::socket (AF_INET, static_cast (type_), 0); if (fd < 0) { error ("socket: %s\n", std::strerror (errno)); diff --git a/source/switch/platform.cpp b/source/switch/platform.cpp index 2d8f6b9..ccf1fd1 100644 --- a/source/switch/platform.cpp +++ b/source/switch/platform.cpp @@ -793,6 +793,12 @@ bool platform::networkAddress (SockAddr &addr_) return true; } +std::string const &platform::hostname () +{ + static std::string const hostname = "switch-ftpd"; + return hostname; +} + bool platform::loop () { if (!appletMainLoop ())