ServerSocket.cxx 8.31 KB
Newer Older
1
/*
2
 * Copyright 2003-2016 The Music Player Daemon Project
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
 * http://www.musicpd.org
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */

#include "config.h"
21
#include "ServerSocket.hxx"
22
#include "net/StaticSocketAddress.hxx"
23
#include "net/AllocatedSocketAddress.hxx"
24
#include "net/SocketAddress.hxx"
25 26 27
#include "net/SocketUtil.hxx"
#include "net/SocketError.hxx"
#include "net/Resolver.hxx"
28
#include "net/ToString.hxx"
29
#include "event/SocketMonitor.hxx"
30
#include "system/fd_util.h"
31 32
#include "fs/AllocatedPath.hxx"
#include "fs/FileSystem.hxx"
33
#include "util/RuntimeError.hxx"
34
#include "util/Domain.hxx"
35
#include "util/ScopeExit.hxx"
36
#include "Log.hxx"
37

38
#include <string>
39
#include <algorithm>
40

41 42 43 44 45 46 47 48 49 50 51 52 53 54
#include <sys/stat.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>

#ifdef WIN32
#include <ws2tcpip.h>
#include <winsock.h>
#else
#include <netinet/in.h>
#include <sys/socket.h>
#include <netdb.h>
#endif

55
class OneServerSocket final : private SocketMonitor {
56
	ServerSocket &parent;
57

58
	const unsigned serial;
59

60
#ifdef HAVE_UN
61
	AllocatedPath path;
62
#endif
63

64
	const AllocatedSocketAddress address;
65

66
public:
67
	template<typename A>
68
	OneServerSocket(EventLoop &_loop, ServerSocket &_parent,
69
			unsigned _serial,
70
			A &&_address)
71 72
		:SocketMonitor(_loop),
		 parent(_parent), serial(_serial),
73
#ifdef HAVE_UN
74
		 path(AllocatedPath::Null()),
75
#endif
76
		 address(std::forward<A>(_address))
77 78 79 80 81 82 83
	{
	}

	OneServerSocket(const OneServerSocket &other) = delete;
	OneServerSocket &operator=(const OneServerSocket &other) = delete;

	~OneServerSocket() {
84 85
		if (IsDefined())
			Close();
86 87
	}

88 89 90 91
	unsigned GetSerial() const {
		return serial;
	}

92
#ifdef HAVE_UN
93 94
	void SetPath(AllocatedPath &&_path) {
		assert(path.IsNull());
95

96
		path = std::move(_path);
97
	}
98
#endif
99

100
	void Open();
101

102
	using SocketMonitor::IsDefined;
103
	using SocketMonitor::Close;
104

105 106
	gcc_pure
	std::string ToString() const {
107
		return ::ToString(address);
108
	}
109

110 111 112 113
	void SetFD(int _fd) {
		SocketMonitor::Open(_fd);
		SocketMonitor::ScheduleRead();
	}
114 115

	void Accept();
116 117

private:
118
	virtual bool OnSocketReady(unsigned flags) override;
119 120
};

121
static constexpr Domain server_socket_domain("server_socket");
122 123 124 125 126 127 128 129 130

static int
get_remote_uid(int fd)
{
#ifdef HAVE_STRUCT_UCRED
	struct ucred cred;
	socklen_t len = sizeof (cred);

	if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cred, &len) < 0)
131
		return -1;
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

	return cred.uid;
#else
#ifdef HAVE_GETPEEREID
	uid_t euid;
	gid_t egid;

	if (getpeereid(fd, &euid, &egid) == 0)
		return euid;
#else
	(void)fd;
#endif
	return -1;
#endif
}

148 149 150
inline void
OneServerSocket::Accept()
{
151
	StaticSocketAddress peer_address;
152 153
	size_t peer_address_length = sizeof(peer_address);
	int peer_fd =
154
		accept_cloexec_nonblock(Get(), peer_address.GetAddress(),
155 156 157
					&peer_address_length);
	if (peer_fd < 0) {
		const SocketErrorMessage msg;
158 159
		FormatError(server_socket_domain,
			    "accept() failed: %s", (const char *)msg);
160 161 162
		return;
	}

163 164
	peer_address.SetSize(peer_address_length);

165 166
	if (socket_keepalive(peer_fd)) {
		const SocketErrorMessage msg;
167 168 169
		FormatError(server_socket_domain,
			    "Could not set TCP keepalive option: %s",
			    (const char *)msg);
170 171
	}

172
	parent.OnAccept(peer_fd, peer_address,
173
			get_remote_uid(peer_fd));
174 175
}

176
bool
177
OneServerSocket::OnSocketReady(gcc_unused unsigned flags)
178
{
179
	Accept();
180
	return true;
181 182
}

183 184
inline void
OneServerSocket::Open()
185
{
186
	assert(!IsDefined());
187

188
	int _fd = socket_bind_listen(address.GetFamily(),
189
				     SOCK_STREAM, 0,
190
				     address, 5);
191

192
#ifdef HAVE_UN
193 194
	/* allow everybody to connect */

195 196
	if (!path.IsNull())
		chmod(path.c_str(), 0666);
197
#endif
198

199
	/* register in the EventLoop */
200 201 202 203

	SetFD(_fd);
}

204 205
ServerSocket::ServerSocket(EventLoop &_loop)
	:loop(_loop), next_serial(1) {}
206 207 208 209 210

/* this is just here to allow the OneServerSocket forward
   declaration */
ServerSocket::~ServerSocket() {}

211 212
void
ServerSocket::Open()
213
{
214
	OneServerSocket *good = nullptr, *bad = nullptr;
215
	std::exception_ptr last_error;
216

217 218
	for (auto &i : sockets) {
		assert(i.GetSerial() > 0);
219
		assert(good == nullptr || i.GetSerial() >= good->GetSerial());
220

221 222
		if (bad != nullptr && i.GetSerial() != bad->GetSerial()) {
			Close();
223
			std::rethrow_exception(last_error);
224 225
		}

226 227 228
		try {
			i.Open();
		} catch (const std::runtime_error &e) {
229
			if (good != nullptr && good->GetSerial() == i.GetSerial()) {
230 231
				const auto address_string = i.ToString();
				const auto good_string = good->ToString();
232 233 234 235 236 237
				FormatError(e,
					    "bind to '%s' failed "
					    "(continuing anyway, because "
					    "binding to '%s' succeeded)",
					    address_string.c_str(),
					    good_string.c_str());
238 239
			} else if (bad == nullptr) {
				bad = &i;
240

241
				const auto address_string = i.ToString();
242

243 244 245 246 247 248
				try {
					std::throw_with_nested(FormatRuntimeError("Failed to bind to '%s'",
										  address_string.c_str()));
				} catch (...) {
					last_error = std::current_exception();
				}
249 250
			}

251 252 253 254 255 256
			continue;
		}

		/* mark this socket as "good", and clear previous
		   errors */

257
		good = &i;
258

259 260
		if (bad != nullptr) {
			bad = nullptr;
261
			last_error = nullptr;
262 263 264
		}
	}

265
	if (bad != nullptr) {
266
		Close();
267
		std::rethrow_exception(last_error);
268 269 270
	}
}

271
void
272
ServerSocket::Close()
273 274
{
	for (auto &i : sockets)
275 276
		if (i.IsDefined())
			i.Close();
277 278
}

279
OneServerSocket &
280
ServerSocket::AddAddress(SocketAddress address)
281
{
282
	sockets.emplace_back(loop, *this, next_serial,
283
			     address);
284

285
	return sockets.back();
286 287
}

288 289 290 291 292 293 294 295 296
OneServerSocket &
ServerSocket::AddAddress(AllocatedSocketAddress &&address)
{
	sockets.emplace_back(loop, *this, next_serial,
			     std::move(address));

	return sockets.back();
}

297 298
void
ServerSocket::AddFD(int fd)
299 300 301
{
	assert(fd >= 0);

302
	StaticSocketAddress address;
303
	socklen_t address_length = sizeof(address);
304
	if (getsockname(fd, address.GetAddress(),
305 306
			&address_length) < 0)
		throw MakeSocketError("Failed to get socket address");
307

308 309 310
	address.SetSize(address_length);

	OneServerSocket &s = AddAddress(address);
311
	s.SetFD(fd);
312 313
}

314 315
#ifdef HAVE_TCP

316 317
inline void
ServerSocket::AddPortIPv4(unsigned port)
318 319 320 321 322 323 324
{
	struct sockaddr_in sin;
	memset(&sin, 0, sizeof(sin));
	sin.sin_port = htons(port);
	sin.sin_family = AF_INET;
	sin.sin_addr.s_addr = INADDR_ANY;

325
	AddAddress({(const sockaddr *)&sin, sizeof(sin)});
326 327 328
}

#ifdef HAVE_IPV6
329

330 331
inline void
ServerSocket::AddPortIPv6(unsigned port)
332 333 334 335 336 337
{
	struct sockaddr_in6 sin;
	memset(&sin, 0, sizeof(sin));
	sin.sin6_port = htons(port);
	sin.sin6_family = AF_INET6;

338
	AddAddress({(const sockaddr *)&sin, sizeof(sin)});
339
}
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355

/**
 * Is IPv6 supported by the kernel?
 */
gcc_pure
static bool
SupportsIPv6()
{
	int fd = socket(AF_INET6, SOCK_STREAM, 0);
	if (fd < 0)
		return false;

	close(fd);
	return true;
}

356 357 358 359
#endif /* HAVE_IPV6 */

#endif /* HAVE_TCP */

360 361
void
ServerSocket::AddPort(unsigned port)
362 363
{
#ifdef HAVE_TCP
364 365
	if (port == 0 || port > 0xffff)
		throw std::runtime_error("Invalid TCP port");
366 367

#ifdef HAVE_IPV6
368 369
	if (SupportsIPv6())
		AddPortIPv6(port);
370
#endif
371
	AddPortIPv4(port);
372

373
	++next_serial;
374 375 376
#else /* HAVE_TCP */
	(void)port;

377
	throw std::runtime_error("TCP support is disabled");
378 379 380
#endif /* HAVE_TCP */
}

381 382
void
ServerSocket::AddHost(const char *hostname, unsigned port)
383 384
{
#ifdef HAVE_TCP
385
	struct addrinfo *ai = resolve_host_port(hostname, port,
386 387
						AI_PASSIVE, SOCK_STREAM);
	AtScopeExit(ai) { freeaddrinfo(ai); };
388

389
	for (const struct addrinfo *i = ai; i != nullptr; i = i->ai_next)
390
		AddAddress(SocketAddress(i->ai_addr, i->ai_addrlen));
391

392
	++next_serial;
393 394 395 396
#else /* HAVE_TCP */
	(void)hostname;
	(void)port;

397
	throw std::runtime_error("TCP support is disabled");
398 399 400
#endif /* HAVE_TCP */
}

401 402
void
ServerSocket::AddPath(AllocatedPath &&path)
403 404
{
#ifdef HAVE_UN
405
	unlink(path.c_str());
406

407 408
	AllocatedSocketAddress address;
	address.SetLocal(path.c_str());
409

410
	OneServerSocket &s = AddAddress(std::move(address));
411
	s.SetPath(std::move(path));
412 413 414
#else /* !HAVE_UN */
	(void)path;

415
	throw std::runtime_error("UNIX domain socket support is disabled");
416 417 418
#endif /* !HAVE_UN */
}