This commit is contained in:
Michael Theall 2024-11-03 17:55:56 -06:00
parent 41e467f47b
commit 663c02172e
16 changed files with 1134 additions and 33 deletions

View File

@ -84,6 +84,13 @@ target_sources(${FTPD_TARGET} PRIVATE
source/socket.cpp
)
if(NOT NINTENDO_DS)
target_sources(${FTPD_TARGET} PRIVATE
source/mdns.cpp
include/mdns.h
)
endif()
if(NOT FTPD_CLASSIC AND NOT NINTENDO_DS)
target_sources(${FTPD_TARGET} PRIVATE
include/imconfig.h

View File

@ -60,6 +60,9 @@ public:
/// \brief Get password
std::string const &pass () const;
/// \brief Get hostname
std::string const &hostname () const;
/// \brief Get port
std::uint16_t port () const;
@ -88,6 +91,10 @@ public:
/// \param pass_ Password
void setPass (std::string pass_);
/// \brief Set hostname
/// \param hostname_ Hostname
void setHostname (std::string hostname_);
/// \brief Set listen port
/// \param port_ Listen port
bool setPort (std::string_view port_);
@ -130,6 +137,9 @@ private:
/// \brief Password
std::string m_pass;
/// \brief Hostname
std::string m_hostname;
/// \brief Listen port
std::uint16_t m_port;

View File

@ -110,6 +110,11 @@ private:
/// \brief Listen socket
UniqueSocket m_socket;
#ifndef __NDS__
/// \brief mDNS socket
UniqueSocket m_mdnsSocket;
#endif
/// \brief ImGui window name
std::string m_name;
@ -151,6 +156,9 @@ private:
/// \brief Password setting
std::string m_passSetting;
/// \brief Hostname setting
std::string m_hostnameSetting;
/// \brief Port setting
std::uint16_t m_portSetting = 0;

40
include/mdns.h Normal file
View File

@ -0,0 +1,40 @@
// ftpd is a server implementation based on the following:
// - RFC 959 (https://datatracker.ietf.org/doc/html/rfc959)
// - RFC 3659 (https://datatracker.ietf.org/doc/html/rfc3659)
// - suggested implementation details from https://cr.yp.to/ftp/filesystem.html
//
// ftpd implements mdns based on the following:
// - RFC 1035 (https://datatracker.ietf.org/doc/html/rfc1035)
// - RFC 6762 (https://datatracker.ietf.org/doc/html/rfc6762)
//
// Copyright (C) 2024 Michael Theall
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
#pragma once
#include "sockAddr.h"
#include "socket.h"
#include <cstddef>
namespace mdns
{
void setHostname (std::string hostname_);
UniqueSocket createSocket ();
void handleSocket (Socket *socket_, SockAddr const &addr_);
}

View File

@ -3,7 +3,7 @@
// - RFC 3659 (https://tools.ietf.org/html/rfc3659)
// - suggested implementation details from https://cr.yp.to/ftp/filesystem.html
//
// Copyright (C) 2023 Michael Theall
// Copyright (C) 2024 Michael Theall
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
@ -34,6 +34,7 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#ifdef CLASSIC
extern PrintConsole g_statusConsole;
@ -71,6 +72,9 @@ bool networkVisible ();
/// \param[out] addr_ Network address
bool networkAddress (SockAddr &addr_);
/// \brief Get hostname
std::string const &hostname ();
/// \brief Platform loop
bool loop ();

View File

@ -23,6 +23,7 @@
#include <netinet/in.h>
#include <sys/socket.h>
#include <compare>
#include <cstdint>
#ifdef __NDS__
@ -37,10 +38,48 @@ struct sockaddr_storage
class SockAddr
{
public:
enum class Domain
{
IPv4 = AF_INET,
#ifndef NO_IPV6
IPv6 = AF_INET6,
#endif
};
/// \brief 0.0.0.0
static SockAddr const AnyIPv4;
#ifndef NO_IPV6
/// \brief ::
static SockAddr const AnyIPv6;
#endif
~SockAddr ();
SockAddr ();
/// \brief Parameterized constructor
/// \param domain_ Socket domain
/// \note Initial address is INADDR_ANY/in6addr_any
SockAddr (Domain domain_);
/// \brief Parameterized constructor
/// \param addr_ Socket address (network byte order)
/// \param port_ Socket port (host byte order)
SockAddr (in_addr_t addr_, std::uint16_t port_ = 0);
/// \brief Parameterized constructor
/// \param addr_ Socket address (network byte order)
/// \param port_ Socket port (host byte order)
SockAddr (in_addr const &addr_, std::uint16_t port_ = 0);
#ifndef NO_IPV6
/// \brief Parameterized constructor
/// \param addr_ Socket address
/// \param port_ Socket port (host byte order)
SockAddr (in6_addr const &addr_, std::uint16_t port_ = 0);
#endif
/// \brief Copy constructor
/// \param that_ Object to copy
SockAddr (SockAddr const &that_);
@ -88,10 +127,32 @@ public:
/// \brief sockaddr const* cast operator (network byte order)
operator sockaddr const * () const;
/// \brief Equality operator
bool operator== (SockAddr const &that_) const;
/// \brief Comparison operator
std::strong_ordering operator<=> (SockAddr const &that_) const;
/// \brief sockaddr domain
Domain domain () const;
/// \brief sockaddr size
socklen_t size () const;
/// \brief Set address
/// \param addr_ Address to set (network byte order)
void setAddr (in_addr_t addr_);
/// \brief Set address
/// \param addr_ Address to set (network byte order)
void setAddr (in_addr const &addr_);
#ifndef NO_IPV6
/// \brief Set address
/// \param addr_ Address to set (network byte order)
void setAddr (in6_addr const &addr_);
#endif
/// \brief Address port (host byte order)
std::uint16_t port () const;

View File

@ -56,6 +56,12 @@ using SharedSocket = std::shared_ptr<Socket>;
class Socket
{
public:
enum Type
{
eStream = SOCK_STREAM, ///< Stream socket
eDatagram = SOCK_DGRAM, ///< Datagram socket
};
/// \brief Poll info
struct PollInfo
{
@ -114,6 +120,18 @@ public:
/// \param size_ Buffer size
bool setSendBufferSize (std::size_t size_);
#ifndef __NDS__
/// \brief Join multicast group
/// \param addr_ Multicast group address
/// \param iface_ Interface address
bool joinMulticastGroup (SockAddr const &addr_, SockAddr const &iface_);
/// \brief Drop multicast group
/// \param addr_ Multicast group address
/// \param iface_ Interface address
bool dropMulticastGroup (SockAddr const &addr_, SockAddr const &iface_);
#endif
/// \brief Read data
/// \param buffer_ Output buffer
/// \param size_ Size to read
@ -125,6 +143,12 @@ public:
/// \param oob_ Whether to read from out-of-band
std::make_signed_t<std::size_t> read (IOBuffer &buffer_, bool oob_ = false);
/// \brief Read data
/// \param buffer_ Output buffer
/// \param size_ Size to read
/// \param[out] addr_ Source address
std::make_signed_t<std::size_t> readFrom (void *buffer_, std::size_t size_, SockAddr &addr_);
/// \brief Write data
/// \param buffer_ Input buffer
/// \param size_ Size to write
@ -135,13 +159,21 @@ public:
/// \param size_ Size to write
std::make_signed_t<std::size_t> write (IOBuffer &buffer_);
/// \brief Write data
/// \param buffer_ Input buffer
/// \param size_ Size to write
/// \param[out] addr_ Destination address
std::make_signed_t<std::size_t>
writeTo (void const *buffer_, std::size_t size_, SockAddr const &addr_);
/// \brief Local name
SockAddr const &sockName () const;
/// \brief Peer name
SockAddr const &peerName () const;
/// \brief Create socket
static UniqueSocket create ();
/// \param type_ Socket type
static UniqueSocket create (Type type_);
/// \brief Poll sockets
/// \param info_ Poll info

View File

@ -577,6 +577,12 @@ bool platform::networkAddress (SockAddr &addr_)
return true;
}
std::string const &platform::hostname ()
{
static std::string const hostname = "3ds-ftpd";
return hostname;
}
bool platform::loop ()
{
if (!aptMainLoop ())

View File

@ -232,6 +232,11 @@ std::string const &FtpConfig::pass () const
return m_pass;
}
std::string const &FtpConfig::hostname () const
{
return m_hostname;
}
std::uint16_t FtpConfig::port () const
{
return m_port;
@ -271,6 +276,11 @@ void FtpConfig::setPass (std::string pass_)
m_pass = std::move (pass_);
}
void FtpConfig::setHostname (std::string hostname_)
{
m_hostname = std::move (hostname_);
}
bool FtpConfig::setPort (std::string_view const port_)
{
std::uint16_t parsed{};

View File

@ -29,6 +29,10 @@
#include "sockAddr.h"
#include "socket.h"
#ifndef __NDS__
#include "mdns.h"
#endif
#include "imgui.h"
#ifdef __NDS__
@ -223,8 +227,14 @@ FtpServer::~FtpServer ()
FtpServer::FtpServer (UniqueFtpConfig config_)
: m_config (std::move (config_))
#ifndef CLASSIC
,
m_hostnameSetting (m_config->hostname ())
#endif
{
#ifndef __NDS__
mdns::setHostname (m_config->hostname ());
m_thread = platform::Thread (std::bind (&FtpServer::threadFunc, this));
#endif
@ -439,7 +449,7 @@ void FtpServer::handleNetworkFound ()
addr.setPort (port);
auto socket = Socket::create ();
auto socket = Socket::create (Socket::eStream);
if (!socket)
return;
@ -461,6 +471,14 @@ void FtpServer::handleNetworkFound ()
info ("Started server at %s\n", m_name.c_str ());
LOCKED (m_socket = std::move (socket));
#ifndef __NDS__
socket = mdns::createSocket ();
if (!socket)
return;
LOCKED (m_mdnsSocket = std::move (socket));
#endif
}
void FtpServer::handleNetworkLost ()
@ -476,6 +494,11 @@ void FtpServer::handleNetworkLost ()
// destroy command socket
LOCKED (sock = std::move (m_socket));
#ifndef __NDS__
// destroy mDNS socket
LOCKED (sock = std::move (m_mdnsSocket));
#endif
}
info ("Stopped server at %s\n", m_name.c_str ());
@ -574,6 +597,9 @@ void FtpServer::showMenu ()
m_passSetting = m_config->pass ();
m_passSetting.resize (32);
m_hostnameSetting = m_config->hostname ();
m_hostnameSetting.resize (32);
m_portSetting = m_config->port ();
#ifdef __3DS__
@ -631,6 +657,11 @@ void FtpServer::showSettings ()
m_passSetting.size (),
ImGuiInputTextFlags_AutoSelectAll | ImGuiInputTextFlags_Password);
ImGui::InputText ("Hostname",
m_hostnameSetting.data (),
m_hostnameSetting.size (),
ImGuiInputTextFlags_AutoSelectAll);
ImGui::InputScalar ("Port",
ImGuiDataType_U16,
&m_portSetting,
@ -703,6 +734,7 @@ void FtpServer::showSettings ()
m_config->setUser (m_userSetting);
m_config->setPass (m_passSetting);
m_config->setHostname (m_hostnameSetting);
m_config->setPort (m_portSetting);
#ifdef __3DS__
@ -718,6 +750,8 @@ void FtpServer::showSettings ()
UniqueSocket socket;
LOCKED (socket = std::move (m_socket));
mdns::setHostname (m_hostnameSetting);
}
if (save)
@ -733,9 +767,10 @@ void FtpServer::showSettings ()
{
static auto const defaults = FtpConfig::create ();
m_userSetting = defaults->user ();
m_passSetting = defaults->pass ();
m_portSetting = defaults->port ();
m_userSetting = defaults->user ();
m_passSetting = defaults->pass ();
m_hostnameSetting = defaults->hostname ();
m_portSetting = defaults->port ();
#ifdef __3DS__
m_getMTimeSetting = defaults->getMTime ();
#endif
@ -966,6 +1001,12 @@ void FtpServer::loop ()
}
}
#ifndef __NDS__
// poll mDNS socket
if (m_socket && m_mdnsSocket)
mdns::handleSocket (m_mdnsSocket.get (), m_socket->sockName ());
#endif
{
std::vector<UniqueFtpSession> deadSessions;
{

View File

@ -22,6 +22,7 @@
#include "ftpServer.h"
#include "log.h"
#include "mdns.h"
#include "platform.h"
#include "imgui.h"
@ -837,7 +838,7 @@ bool FtpSession::dataConnect ()
m_port = false;
auto data = Socket::create ();
auto data = Socket::create (Socket::eStream);
LOCKED (m_dataSocket = std::move (data));
if (!m_dataSocket)
return false;
@ -2384,7 +2385,7 @@ void FtpSession::PASV (char const *args_)
m_port = false;
// create a socket to listen on
auto pasv = Socket::create ();
auto pasv = Socket::create (Socket::eStream);
LOCKED (m_pasvSocket = std::move (pasv));
if (!m_pasvSocket)
{
@ -2726,6 +2727,9 @@ void FtpSession::SITE (char const *args_)
" Set username: SITE USER <NAME>\r\n"
" Set password: SITE PASS <PASS>\r\n"
" Set port: SITE PORT <PORT>\r\n"
#ifndef __NDS__
" Set hostname: SITE HOST <HOSTNAME>\r\n"
#endif
#ifdef __3DS__
" Set getMTime: SITE MTIME [0|1]\r\n"
#endif
@ -2784,6 +2788,16 @@ void FtpSession::SITE (char const *args_)
sendResponse ("200 OK\r\n");
return;
}
#ifndef __NDS__
else if (compare (command, "HOST") == 0)
{
{
auto const lock = m_config.lockGuard ();
m_config.setHostname (std::string (arg));
mdns::setHostname (std::string (arg));
}
}
#endif
#ifdef __3DS__
else if (compare (command, "MTIME") == 0)
{

View File

@ -204,6 +204,24 @@ bool platform::networkAddress (SockAddr &addr_)
return true;
}
std::string const &platform::hostname ()
{
static std::string hostname = "switch-ftpd.local";
if (hostname.empty ())
{
std::string buffer (256, '\0');
gethostname (buffer.data (), buffer.size ());
if (buffer.back () == 0) // check for truncation
{
hostname = std::move (buffer);
hostname.resize (std::strlen (hostname.data ()));
}
}
return hostname;
}
bool platform::loop ()
{
bool inactive;

599
source/mdns.cpp Normal file
View File

@ -0,0 +1,599 @@
// ftpd is a server implementation based on the following:
// - RFC 959 (https://datatracker.ietf.org/doc/html/rfc959)
// - RFC 3659 (https://datatracker.ietf.org/doc/html/rfc3659)
// - suggested implementation details from https://cr.yp.to/ftp/filesystem.html
//
// ftpd implements mdns based on the following:
// - RFC 1035 (https://datatracker.ietf.org/doc/html/rfc1035)
// - RFC 6762 (https://datatracker.ietf.org/doc/html/rfc6762)
//
// Copyright (C) 2024 Michael Theall
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
#include "mdns.h"
#include "log.h"
#include "platform.h"
#include <arpa/inet.h>
#include <algorithm>
#include <array>
#include <bit>
#include <chrono>
#include <concepts>
#include <cstdlib>
#include <cstring>
#include <string>
#include <type_traits>
#include <vector>
using namespace std::chrono_literals;
static_assert (
std::endian::native == std::endian::big || std::endian::native == std::endian::little);
static_assert (sizeof (in_addr_t) == 4);
namespace
{
constexpr auto MDNS_TTL = 120;
SockAddr const s_multicastAddress{inet_addr ("224.0.0.251"), 5353};
platform::steady_clock::time_point s_lastAnnounce{};
platform::steady_clock::time_point s_lastProbe{};
std::string s_hostname = platform::hostname ();
std::string s_hostnameLocal = s_hostname + ".local";
enum class State
{
Probe1,
Probe2,
Probe3,
Announce1,
Announce2,
Complete,
Conflict,
};
auto s_state = State::Probe1;
#if __has_cpp_attribute(__cpp_lib_byteswap)
template <std::integral T>
using byteswap = std::byteswap<T>;
#else
template <std::integral T>
constexpr T byteswap (T const value_) noexcept
{
static_assert (std::has_unique_object_representations_v<T>, "T may not have padding bits");
auto buffer = std::bit_cast<std::array<std::byte, sizeof (T)>> (value_);
std::ranges::reverse (buffer);
return std::bit_cast<T> (buffer);
}
#endif
template <std::integral T>
constexpr T hton (T const value_) noexcept
{
if constexpr (std::endian::native == std::endian::big)
return value_;
else
return byteswap (value_);
}
template <std::integral T>
constexpr T ntoh (T const value_) noexcept
{
if constexpr (std::endian::native == std::endian::big)
return value_;
else
return byteswap (value_);
}
template <std::integral T, std::integral U>
void const *decode (void const *const buffer_, U &size_, T &out_, bool networkToHost_ = true)
{
if (!buffer_)
return nullptr;
if (size_ < 0 || static_cast<std::make_unsigned_t<T>> (size_) < sizeof (T))
return nullptr;
std::memcpy (&out_, buffer_, sizeof (T));
if (networkToHost_)
out_ = ntoh (out_);
size_ -= sizeof (T);
return static_cast<std::uint8_t const *> (buffer_) + sizeof (T);
}
template <std::integral T>
void const *decode (void const *buffer_, T &size_, std::string &out_)
{
auto p = static_cast<char const *> (buffer_);
auto const end = p + size_;
std::string result;
result.reserve (size_);
while (p < end && *p)
{
auto const len = *p++;
// punt on compressed labels
if (len & 0xC0)
return nullptr;
if (p + len >= end)
return nullptr;
if (!result.empty ())
result.push_back ('.');
result.insert (std::end (result), p, p + len);
p += len;
}
++p;
out_ = std::move (result);
size_ = end - p;
return p;
}
template <std::integral T, std::integral U>
void *encode (void *const buffer_, U &size_, T in_, bool hostToNetwork_ = true)
{
if (!buffer_)
return nullptr;
if (size_ < sizeof (T))
return nullptr;
if (hostToNetwork_)
in_ = hton (in_);
std::memcpy (buffer_, &in_, sizeof (T));
size_ -= sizeof (T);
return static_cast<std::uint8_t *> (buffer_) + sizeof (T);
}
template <std::integral T>
void *encode (void *const buffer_, T &size_, std::string const &in_)
{
// names are limited to 255 bytes
if (in_.size () > 0xFF)
return nullptr;
auto p = static_cast<char *> (buffer_);
auto const end = p + size_;
std::string::size_type prev = 0;
std::string::size_type pos = 0;
while (p < end && pos != std::string::npos)
{
pos = in_.find ('.', prev);
auto const label = std::string_view (in_).substr (prev, pos);
// labels are limited to 63 bytes
if (label.size () >= size_ || label.size () > 0x3F)
return nullptr;
p = static_cast<char *> (encode<std::uint8_t> (p, size_, label.size ()));
if (!p)
return nullptr;
std::memcpy (p, label.data (), label.size ());
p += label.size ();
if (pos != std::string::npos)
prev = pos + 1;
}
if (p == end)
return nullptr;
*p++ = 0;
size_ = end - p;
return p;
}
struct DNSHeader
{
std::uint16_t id{};
std::uint16_t flags{};
std::uint16_t qdCount{};
std::uint16_t anCount{};
std::uint16_t nsCount{};
std::uint16_t arCount{};
template <std::integral T>
void const *decode (void const *const buffer_, T &size_)
{
auto in = ::decode (buffer_, size_, id);
in = ::decode (buffer_, size_, flags);
in = ::decode (buffer_, size_, qdCount);
in = ::decode (buffer_, size_, anCount);
in = ::decode (buffer_, size_, nsCount);
in = ::decode (buffer_, size_, arCount);
return buffer_;
}
template <std::integral T>
void *encode (void *buffer_, T &size_)
{
buffer_ = ::encode (buffer_, size_, id);
buffer_ = ::encode (buffer_, size_, flags);
buffer_ = ::encode (buffer_, size_, qdCount);
buffer_ = ::encode (buffer_, size_, anCount);
buffer_ = ::encode (buffer_, size_, nsCount);
buffer_ = ::encode (buffer_, size_, arCount);
return buffer_;
}
};
struct QueryRecord
{
std::string qname{};
std::uint16_t qtype{};
std::uint16_t qclass{};
template <std::integral T>
void const *decode (void const *buffer_, T &size_)
{
buffer_ = ::decode (buffer_, size_, qname);
buffer_ = ::decode (buffer_, size_, qtype);
buffer_ = ::decode (buffer_, size_, qclass);
return buffer_;
}
template <std::integral T>
void *encode (void *buffer_, T &size_)
{
buffer_ = ::encode (buffer_, size_, qname);
buffer_ = ::encode (buffer_, size_, qtype);
buffer_ = ::encode (buffer_, size_, qclass);
return buffer_;
}
};
struct ResourceRecord
{
std::string rname{};
std::uint16_t rtype{};
std::uint16_t rclass{};
std::uint32_t rttl{};
std::uint16_t rlen{};
std::vector<std::uint8_t> rdata{};
template <std::integral T>
void const *decode (void const *buffer_, T &size_)
{
buffer_ = ::decode (buffer_, size_, rname);
buffer_ = ::decode (buffer_, size_, rtype);
buffer_ = ::decode (buffer_, size_, rclass);
buffer_ = ::decode (buffer_, size_, rttl);
buffer_ = ::decode (buffer_, size_, rlen);
return buffer_;
}
template <std::integral T>
void *encode (void *buffer_, T &size_)
{
if (rttl > std::numeric_limits<std::int32_t>::max ())
return nullptr;
buffer_ = ::encode (buffer_, size_, rname);
buffer_ = ::encode (buffer_, size_, rtype);
buffer_ = ::encode (buffer_, size_, rclass);
buffer_ = ::encode (buffer_, size_, rttl);
buffer_ = ::encode (buffer_, size_, rlen);
if (rlen > size_)
return nullptr;
rdata.resize (rlen);
std::memcpy (rdata.data (), buffer_, rlen);
size_ -= rlen;
return static_cast<std::uint8_t *> (buffer_) + rlen;
}
};
void probe (Socket *const socket_, std::string const &qname_)
{
std::vector<std::uint8_t> response (65536);
auto available = response.size ();
auto out = DNSHeader{.qdCount = 1}.encode (response.data (), available);
out = QueryRecord{.qname = qname_, .qtype = 255, .qclass = 1}.encode (out, available);
if (!out)
return;
info ("Probe mDNS %s\n", qname_.c_str ());
socket_->writeTo (response.data (), response.size () - available, s_multicastAddress);
s_lastProbe = platform::steady_clock::now ();
}
void announce (Socket *const socket_,
SockAddr const *srcAddr_,
std::uint16_t const id_,
std::uint16_t const flags_,
QueryRecord const &record_,
SockAddr const &addr_)
{
std::vector<std::uint8_t> response (65536);
auto available = response.size ();
// header
auto out = encode<std::uint16_t> (response.data (), available, id_);
out =
encode<std::uint16_t> (out, available, flags_ | (1 << 15) | (1 << 10)); // mark response/AA
out = encode<std::uint16_t> (out, available, 0);
out = encode<std::uint16_t> (out, available, 1);
out = encode<std::uint16_t> (out, available, 0);
out = encode<std::uint16_t> (out, available, 0);
// answer section
out = encode (out, available, record_.qname);
out = encode<std::uint16_t> (out, available, record_.qtype);
out = encode<std::uint16_t> (out, available, record_.qclass | (1 << 15)); // mark unique/flush
out = encode<std::uint32_t> (out, available, MDNS_TTL);
out = encode<std::uint16_t> (out, available, sizeof (in_addr_t));
out = encode<in_addr_t> (
out, available, static_cast<sockaddr_in const &> (addr_).sin_addr.s_addr, false);
if (!out)
return;
auto const preferUnicast = srcAddr_ && ((record_.qclass >> 15) & 0x1);
if (preferUnicast)
{
auto const name = std::string (addr_.name ());
info (
"Respond mDNS %s %s to %s\n", record_.qname.c_str (), name.c_str (), srcAddr_->name ());
socket_->writeTo (response.data (), response.size () - available, *srcAddr_);
}
auto const now = platform::steady_clock::now ();
if (!preferUnicast || now - s_lastAnnounce > std::chrono::seconds (MDNS_TTL / 4))
{
info ("Announce mDNS %s %s\n", record_.qname.c_str (), addr_.name ());
socket_->writeTo (response.data (), response.size () - available, s_multicastAddress);
s_lastAnnounce = now;
}
}
}
void mdns::setHostname (std::string hostname_)
{
if (hostname_.empty ())
hostname_ = platform::hostname ();
if (s_hostname == hostname_)
return;
s_hostname = std::move (hostname_);
s_hostnameLocal = s_hostname + ".local";
s_state = State::Probe1;
s_lastProbe = platform::steady_clock::now ();
}
UniqueSocket mdns::createSocket ()
{
auto socket = Socket::create (Socket::eDatagram);
if (!socket)
return nullptr;
if (!socket->setReuseAddress ())
return nullptr;
auto iface = SockAddr::AnyIPv4;
iface.setPort (s_multicastAddress.port ());
if (!socket->bind (iface))
return nullptr;
if (!socket->joinMulticastGroup (s_multicastAddress, iface))
return nullptr;
s_state = State::Probe1;
s_lastProbe = platform::steady_clock::now ();
return socket;
}
void mdns::handleSocket (Socket *socket_, SockAddr const &addr_)
{
if (!socket_)
return;
// only support IPv4 for now
if (addr_.domain () != SockAddr::Domain::IPv4)
return;
auto const now = platform::steady_clock::now ();
switch (s_state)
{
case State::Probe1:
case State::Probe2:
case State::Probe3:
if (now - s_lastProbe > 250ms)
{
probe (socket_, s_hostname);
s_state = static_cast<State> (static_cast<int> (s_state) + 1);
}
break;
case State::Announce1:
case State::Announce2:
if (now - s_lastAnnounce > 1s)
{
announce (socket_,
nullptr,
0,
0,
QueryRecord{.qname = s_hostname, .qtype = 1, .qclass = 1},
addr_);
s_state = static_cast<State> (static_cast<int> (s_state) + 1);
}
default:
break;
}
Socket::PollInfo pollInfo{*socket_, POLLIN, 0};
auto const rc = Socket::poll (&pollInfo, 1, 0ms);
if (rc <= 0 || !(pollInfo.revents & POLLIN))
return;
SockAddr srcAddr;
std::vector<std::uint8_t> buffer (65536);
auto bytes = socket_->readFrom (buffer.data (), buffer.size (), srcAddr);
if (bytes <= 0)
return;
// only support IPv4 for now
if (srcAddr.domain () != SockAddr::Domain::IPv4)
return;
// ignore loopback
if (std::memcmp (&reinterpret_cast<sockaddr_in const &> (srcAddr).sin_addr.s_addr,
&reinterpret_cast<sockaddr_in const &> (addr_).sin_addr.s_addr,
sizeof (in_addr_t)) == 0)
return;
std::uint16_t id;
std::uint16_t flags;
std::uint16_t qdCount;
std::uint16_t anCount;
std::uint16_t nsCount;
std::uint16_t arCount;
// parse header
auto in = decode (buffer.data (), bytes, id);
in = decode (in, bytes, flags);
in = decode (in, bytes, qdCount);
in = decode (in, bytes, anCount);
in = decode (in, bytes, nsCount);
in = decode (in, bytes, arCount);
if (!in)
return;
auto const qr = (flags >> 15) & 0x1;
// ill-formed on queries and responses
auto const opcode = (flags >> 11) & 0xF;
if (opcode != 0)
return;
// ill-formed on queries
if (!qr && ((flags >> 10) & 0x1))
return;
// punt on truncated messages
if ((flags >> 9) & 0x1)
return;
// ill-formed on queries
if (!qr && ((flags >> 7) & 0x1))
return;
// must be zero
if ((flags >> 4) & 0x7)
return;
// ill-formed on queries and responses
if ((flags >> 0) & 0xF)
return;
// std::vector<std::uint8_t> response (65536);
// void *out = response.data ();
// auto available = response.size ();
std::vector<ResourceRecord> answers;
bool announced = false;
for (unsigned i = 0; i < qdCount; ++i)
{
QueryRecord record;
in = record.decode (in, bytes);
if (!in)
return;
// only respond to queries
if (qr)
continue;
// only accept A or ANY type
if (record.qtype != 1 && record.qtype != 255)
continue;
// only accept IN or ANY class
if ((record.qclass & 0x7FFF) != 1 && (record.qclass & 0x7FFF) != 255)
continue;
if (record.qname != s_hostname && record.qname != s_hostnameLocal)
continue;
if (!announced)
{
std::vector<std::uint8_t> data (sizeof (in_addr_t));
auto n = data.size ();
encode (data.data (),
n,
static_cast<sockaddr_in const &> (addr_).sin_addr.s_addr,
false);
answers.emplace_back (ResourceRecord{// answer
.rname = record.qname,
.rtype = 1,
.rclass = static_cast<std::uint16_t> (1 | (1 << 15)),
.rttl = MDNS_TTL,
.rlen = sizeof (in_addr_t),
.rdata = std::move (data)});
announce (socket_, &srcAddr, id, flags, record, addr_);
announced = true;
}
}
for (unsigned i = 0; i < anCount; ++i)
{
ResourceRecord record;
in = record.decode (in, bytes);
if (!in)
return;
}
}

View File

@ -26,11 +26,79 @@
#include <cstdlib>
#include <cstring>
#ifdef __3DS__
static_assert (sizeof (sockaddr_storage) == 0x1c);
#endif
namespace
{
in_addr inaddr_any = {.s_addr = htonl (INADDR_ANY)};
std::strong_ordering
strongMemCompare (void const *const a_, void const *const b_, std::size_t const size_)
{
auto const cmp = std::memcmp (a_, b_, size_);
if (cmp < 0)
return std::strong_ordering::less;
if (cmp > 0)
return std::strong_ordering::greater;
return std::strong_ordering::equal;
}
}
///////////////////////////////////////////////////////////////////////////
SockAddr const SockAddr::AnyIPv4{inaddr_any};
#ifndef NO_IPV6
SockAddr const SockAddr::AnyIPv6{in6addr_any};
#endif
SockAddr::~SockAddr () = default;
SockAddr::SockAddr () = default;
SockAddr::SockAddr (Domain const domain_)
{
switch (domain_)
{
case Domain::IPv4:
*this = AnyIPv4;
break;
#ifndef NO_IPV6
case Domain::IPv6:
*this = AnyIPv6;
break;
#endif
default:
std::abort ();
}
}
SockAddr::SockAddr (in_addr_t const addr_, std::uint16_t const port_)
: SockAddr (in_addr{.s_addr = addr_}, port_)
{
}
SockAddr::SockAddr (in_addr const &addr_, std::uint16_t const port_)
{
std::memset (&m_addr, 0, sizeof (m_addr));
m_addr.ss_family = AF_INET;
setAddr (addr_);
setPort (port_);
}
#ifndef NO_IPV6
SockAddr::SockAddr (in6_addr const &addr_, std::uint16_t const port_)
{
std::memset (&m_addr, 0, sizeof (m_addr));
m_addr.ss_family = AF_INET6;
setAddr (addr_);
setPort (port_);
}
#endif
SockAddr::SockAddr (SockAddr const &that_) = default;
SockAddr::SockAddr (SockAddr &&that_) = default;
@ -101,36 +169,24 @@ SockAddr::operator sockaddr const * () const
return reinterpret_cast<sockaddr const *> (&m_addr);
}
void SockAddr::setPort (std::uint16_t const port_)
bool SockAddr::operator== (SockAddr const &that_) const
{
if (m_addr.ss_family != that_.m_addr.ss_family)
return false;
switch (m_addr.ss_family)
{
case AF_INET:
reinterpret_cast<struct sockaddr_in *> (&m_addr)->sin_port = htons (port_);
break;
if (port () != that_.port ())
return false;
// ignore sin_zero
return static_cast<sockaddr_in const &> (*this).sin_addr.s_addr ==
static_cast<sockaddr_in const &> (that_).sin_addr.s_addr;
#ifndef NO_IPV6
case AF_INET6:
reinterpret_cast<struct sockaddr_in6 *> (&m_addr)->sin6_port = htons (port_);
break;
#endif
default:
std::abort ();
break;
}
}
socklen_t SockAddr::size () const
{
switch (m_addr.ss_family)
{
case AF_INET:
return sizeof (struct sockaddr_in);
#ifndef NO_IPV6
case AF_INET6:
return sizeof (struct sockaddr_in6);
return std::memcmp (&m_addr, &that_.m_addr, sizeof (sockaddr_in6)) == 0;
#endif
default:
@ -138,6 +194,85 @@ socklen_t SockAddr::size () const
}
}
std::strong_ordering SockAddr::operator<=> (SockAddr const &that_) const
{
if (m_addr.ss_family != that_.m_addr.ss_family)
return m_addr.ss_family <=> that_.m_addr.ss_family;
switch (m_addr.ss_family)
{
case AF_INET:
{
auto const cmp =
strongMemCompare (&static_cast<sockaddr_in const &> (*this).sin_addr.s_addr,
&static_cast<sockaddr_in const &> (that_).sin_addr.s_addr,
sizeof (in_addr_t));
if (cmp != std::strong_ordering::equal)
return cmp;
return port () <=> that_.port ();
}
#ifndef NO_IPV6
case AF_INET6:
{
auto const &addr1 = static_cast<sockaddr_in6 const &> (*this);
auto const &addr2 = static_cast<sockaddr_in6 const &> (that_);
if (auto const cmp =
strongMemCompare (&addr1.sin6_addr, &addr2.sin6_addr, sizeof (in6_addr));
cmp != std::strong_ordering::equal)
return cmp;
auto const p1 = port ();
auto const p2 = that_.port ();
if (p1 < p2)
return std::strong_ordering::less;
else if (p1 > p2)
return std::strong_ordering::greater;
if (auto const cmp = strongMemCompare (
&addr1.sin6_flowinfo, &addr2.sin6_flowinfo, sizeof (std::uint32_t));
cmp != std::strong_ordering::equal)
return cmp;
return strongMemCompare (
&addr1.sin6_flowinfo, &addr2.sin6_flowinfo, sizeof (std::uint32_t));
}
#endif
default:
std::abort ();
}
}
void SockAddr::setAddr (in_addr_t const addr_)
{
setAddr (in_addr{.s_addr = addr_});
}
void SockAddr::setAddr (in_addr const &addr_)
{
if (m_addr.ss_family != AF_INET)
std::abort ();
std::memcpy (&reinterpret_cast<sockaddr_in &> (m_addr).sin_addr, &addr_, sizeof (addr_));
;
}
#ifndef NO_IPV6
void SockAddr::setAddr (in6_addr const &addr_)
{
if (m_addr.ss_family != AF_INET6)
std::abort ();
std::memcpy (&reinterpret_cast<sockaddr_in6 &> (m_addr).sin6_addr, &addr_, sizeof (addr_));
;
}
#endif
std::uint16_t SockAddr::port () const
{
switch (m_addr.ss_family)
@ -152,7 +287,57 @@ std::uint16_t SockAddr::port () const
default:
std::abort ();
}
}
void SockAddr::setPort (std::uint16_t const port_)
{
switch (m_addr.ss_family)
{
case AF_INET:
reinterpret_cast<sockaddr_in *> (&m_addr)->sin_port = htons (port_);
break;
#ifndef NO_IPV6
case AF_INET6:
reinterpret_cast<sockaddr_in6 *> (&m_addr)->sin6_port = htons (port_);
break;
#endif
default:
std::abort ();
}
}
SockAddr::Domain SockAddr::domain () const
{
switch (m_addr.ss_family)
{
case AF_INET:
#ifndef NO_IPV6
case AF_INET6:
#endif
return static_cast<Domain> (m_addr.ss_family);
default:
std::abort ();
}
}
socklen_t SockAddr::size () const
{
switch (m_addr.ss_family)
{
case AF_INET:
return sizeof (sockaddr_in);
#ifndef NO_IPV6
case AF_INET6:
return sizeof (sockaddr_in6);
#endif
default:
std::abort ();
}
}

View File

@ -255,6 +255,38 @@ bool Socket::setSendBufferSize (std::size_t const size_)
return true;
}
#ifndef __NDS__
bool Socket::joinMulticastGroup (SockAddr const &addr_, SockAddr const &iface_)
{
ip_mreq group;
group.imr_multiaddr = static_cast<sockaddr_in const &> (addr_).sin_addr;
group.imr_interface = static_cast<sockaddr_in const &> (iface_).sin_addr;
if (::setsockopt (m_fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof (group)) != 0)
{
error ("setsockopt(IP_ADD_MEMBERSHIP, %s): %s\n", addr_.name (), std::strerror (errno));
return false;
}
return true;
}
bool Socket::dropMulticastGroup (SockAddr const &addr_, SockAddr const &iface_)
{
ip_mreq group;
group.imr_multiaddr = static_cast<sockaddr_in const &> (addr_).sin_addr;
group.imr_interface = static_cast<sockaddr_in const &> (iface_).sin_addr;
if (::setsockopt (m_fd, IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, sizeof (group)) != 0)
{
error ("setsockopt(IP_DROP_MEMBERSHIP, %s): %s\n", addr_.name (), std::strerror (errno));
return false;
}
return true;
}
#endif
std::make_signed_t<std::size_t>
Socket::read (void *const buffer_, std::size_t const size_, bool const oob_)
{
@ -279,6 +311,21 @@ std::make_signed_t<std::size_t> Socket::read (IOBuffer &buffer_, bool const oob_
return rc;
}
std::make_signed_t<std::size_t>
Socket::readFrom (void *const buffer_, std::size_t const size_, SockAddr &addr_)
{
assert (buffer_);
assert (size_);
socklen_t addrLen = sizeof (sockaddr_storage);
auto const rc = ::recvfrom (m_fd, buffer_, size_, 0, addr_, &addrLen);
if (rc < 0 && errno != EWOULDBLOCK)
error ("recvfrom: %s\n", std::strerror (errno));
return rc;
}
std::make_signed_t<std::size_t> Socket::write (void const *const buffer_, std::size_t const size_)
{
assert (buffer_);
@ -302,6 +349,19 @@ std::make_signed_t<std::size_t> Socket::write (IOBuffer &buffer_)
return rc;
}
std::make_signed_t<std::size_t>
Socket::writeTo (void const *buffer_, std::size_t size_, SockAddr const &addr_)
{
assert (buffer_);
assert (size_ > 0);
auto const rc = ::sendto (m_fd, buffer_, size_, 0, addr_, addr_.size ());
if (rc < 0 && errno != EWOULDBLOCK)
error ("sendto: %s\n", std::strerror (errno));
return rc;
}
SockAddr const &Socket::sockName () const
{
return m_sockName;
@ -312,9 +372,9 @@ SockAddr const &Socket::peerName () const
return m_peerName;
}
UniqueSocket Socket::create ()
UniqueSocket Socket::create (Type const type_)
{
auto const fd = ::socket (AF_INET, SOCK_STREAM, 0);
auto const fd = ::socket (AF_INET, static_cast<int> (type_), 0);
if (fd < 0)
{
error ("socket: %s\n", std::strerror (errno));

View File

@ -793,6 +793,12 @@ bool platform::networkAddress (SockAddr &addr_)
return true;
}
std::string const &platform::hostname ()
{
static std::string const hostname = "switch-ftpd";
return hostname;
}
bool platform::loop ()
{
if (!appletMainLoop ())