WasapiOutputPlugin.cxx 26.7 KB
Newer Older
1
/*
Max Kellermann's avatar
Max Kellermann committed
2
 * Copyright 2020-2021 The Music Player Daemon Project
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
 * http://www.musicpd.org
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */

20 21
#undef NOUSER // COM needs the "MSG" typedef

22
#include "WasapiOutputPlugin.hxx"
23
#include "ForMixer.hxx"
24 25 26
#include "AudioClient.hxx"
#include "Device.hxx"
#include "PropertyStore.hxx"
27
#include "output/OutputAPI.hxx"
28
#include "lib/icu/Win32.hxx"
29
#include "lib/fmt/AudioFormatFormatter.hxx"
30
#include "mixer/MixerList.hxx"
31
#include "output/Error.hxx"
32
#include "pcm/Export.hxx"
33 34
#include "thread/Cond.hxx"
#include "thread/Mutex.hxx"
35
#include "thread/Name.hxx"
36 37
#include "thread/Thread.hxx"
#include "util/AllocatedString.hxx"
38
#include "util/ConstBuffer.hxx"
39 40 41
#include "util/Domain.hxx"
#include "util/RuntimeError.hxx"
#include "util/ScopeExit.hxx"
42
#include "util/StringBuffer.hxx"
43
#include "win32/Com.hxx"
44
#include "win32/ComPtr.hxx"
45
#include "win32/ComWorker.hxx"
46 47
#include "win32/HResult.hxx"
#include "win32/WinEvent.hxx"
48
#include "Log.hxx"
49
#include "config.h"
50 51

#include <boost/lockfree/spsc_queue.hpp>
52 53

#include <algorithm>
54 55 56 57 58
#include <cinttypes>
#include <cmath>
#include <optional>
#include <variant>

59
#include <audioclient.h>
60 61
#include <initguid.h>
#include <functiondiscoverykeys_devpkey.h>
62
#include <mmdeviceapi.h>
63

64 65 66
namespace {
static constexpr Domain wasapi_output_domain("wasapi_output");

67 68 69
constexpr uint32_t
GetChannelMask(const uint8_t channels) noexcept
{
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
	switch (channels) {
	case 1:
		return KSAUDIO_SPEAKER_MONO;
	case 2:
		return KSAUDIO_SPEAKER_STEREO;
	case 3:
		return SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT | SPEAKER_FRONT_CENTER;
	case 4:
		return KSAUDIO_SPEAKER_QUAD;
	case 5:
		return SPEAKER_FRONT_LEFT | SPEAKER_FRONT_RIGHT | SPEAKER_FRONT_CENTER |
		       SPEAKER_BACK_LEFT | SPEAKER_BACK_RIGHT;
	case 6:
		return KSAUDIO_SPEAKER_5POINT1;
	case 7:
		return KSAUDIO_SPEAKER_5POINT1 | SPEAKER_BACK_CENTER;
	case 8:
		return KSAUDIO_SPEAKER_7POINT1_SURROUND;
	default:
		gcc_unreachable();
	}
}

template <typename Functor>
94 95 96
inline bool
SafeSilenceTry(Functor &&functor) noexcept
{
97 98 99
	try {
		functor();
		return true;
100
	} catch (...) {
101 102 103 104
		return false;
	}
}

105 106 107
std::vector<WAVEFORMATEXTENSIBLE>
GetFormats(const AudioFormat &audio_format) noexcept
{
108 109 110 111 112 113 114 115 116 117 118
#ifdef ENABLE_DSD
	if (audio_format.format == SampleFormat::DSD) {
		AudioFormat dop_format = audio_format;
		PcmExport::Params params;
		params.dsd_mode = PcmExport::DsdMode::DOP;
		dop_format.sample_rate =
			params.CalcOutputSampleRate(audio_format.sample_rate);
		dop_format.format = SampleFormat::S24_P32;
		return GetFormats(dop_format);
	}
#endif
119 120 121 122 123 124 125
	std::vector<WAVEFORMATEXTENSIBLE> Result;
	if (audio_format.format == SampleFormat::S24_P32) {
		Result.resize(2);
		Result[0].Format.wBitsPerSample = 24;
		Result[0].Samples.wValidBitsPerSample = 24;
		Result[1].Format.wBitsPerSample = 32;
		Result[1].Samples.wValidBitsPerSample = 24;
126
	} else {
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
		Result.resize(1);
		Result[0].Format.wBitsPerSample = audio_format.GetSampleSize() * 8;
		Result[0].Samples.wValidBitsPerSample = audio_format.GetSampleSize() * 8;
	}
	const DWORD mask = GetChannelMask(audio_format.channels);
	const GUID guid = audio_format.format == SampleFormat::FLOAT
				  ? KSDATAFORMAT_SUBTYPE_IEEE_FLOAT
				  : KSDATAFORMAT_SUBTYPE_PCM;
	for (auto &device_format : Result) {
		device_format.dwChannelMask = mask;
		device_format.Format.wFormatTag = WAVE_FORMAT_EXTENSIBLE;
		device_format.Format.nChannels = audio_format.channels;
		device_format.Format.nSamplesPerSec = audio_format.sample_rate;
		device_format.Format.cbSize =
			sizeof(WAVEFORMATEXTENSIBLE) - sizeof(WAVEFORMATEX);
		device_format.SubFormat = guid;
		device_format.Format.nBlockAlign = device_format.Format.nChannels *
						   device_format.Format.wBitsPerSample /
						   8;
		device_format.Format.nAvgBytesPerSec =
			audio_format.sample_rate * device_format.Format.nBlockAlign;
	}
	return Result;
150 151
}

152
#ifdef ENABLE_DSD
153 154 155
void
SetDSDFallback(AudioFormat &audio_format) noexcept
{
156 157 158 159 160
	audio_format.format = SampleFormat::FLOAT;
	audio_format.sample_rate = 384000;
}
#endif

161 162
} // namespace

163 164
class WasapiOutputThread {
	Thread thread{BIND_THIS_METHOD(Work)};
165 166
	WinEvent event;
	WinEvent data_poped;
167
	IAudioClient &client;
168 169 170
	ComPtr<IAudioRenderClient> render_client;
	const UINT32 frame_size;
	const UINT32 buffer_size_in_frames;
171
	const bool is_exclusive;
172

173 174 175 176 177 178 179 180 181 182
	/**
	 * This flag is only used by the calling thread
	 * (i.e. #OutputThread), and specifies whether the
	 * WasapiOutputThread has been told to play via Play().  This
	 * variable is somewhat redundant because we already have
	 * "state", but using this variable saves some overhead for
	 * atomic operations.
	 */
	bool playing = false;

183 184
	bool started = false;

185 186
	std::atomic_bool cancel = false;

187 188
	std::atomic_bool empty = true;

189
	enum class Status : uint32_t { FINISH, PLAY, PAUSE };
190

191 192 193 194 195 196 197 198 199
	alignas(BOOST_LOCKFREE_CACHELINE_BYTES) std::atomic<Status> status =
		Status::PAUSE;
	alignas(BOOST_LOCKFREE_CACHELINE_BYTES) struct {
		std::atomic_bool occur = false;
		std::exception_ptr ptr = nullptr;
	} error;
	boost::lockfree::spsc_queue<BYTE> spsc_buffer;

public:
200
	WasapiOutputThread(IAudioClient &_client,
201 202
			   ComPtr<IAudioRenderClient> &&_render_client,
			   const UINT32 _frame_size, const UINT32 _buffer_size_in_frames,
203
			   bool _is_exclusive)
204
		:client(_client),
205 206 207 208
		 render_client(std::move(_render_client)), frame_size(_frame_size),
		 buffer_size_in_frames(_buffer_size_in_frames), is_exclusive(_is_exclusive),
		 spsc_buffer(_buffer_size_in_frames * 4 * _frame_size)
	{
209
		SetEventHandle(client, event.handle());
210
		thread.Start();
211 212
	}

213 214
	void Finish() noexcept {
		SetStatus(Status::FINISH);
215
		thread.Join();
216 217
	}

218 219 220 221 222 223 224 225 226 227 228 229
	void Play() noexcept {
		playing = true;
		SetStatus(Status::PLAY);
	}

	void Pause() noexcept {
		if (!playing)
			return;

		playing = false;
		SetStatus(Status::PAUSE);
	}
230

231
	std::size_t Push(ConstBuffer<void> input) noexcept {
232 233
		empty.store(false);

234 235 236 237 238 239 240 241 242 243
		std::size_t consumed =
			spsc_buffer.push(static_cast<const BYTE *>(input.data),
					 input.size);

		if (!playing) {
			playing = true;
			Play();
		}

		return consumed;
244 245
	}

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
	/**
	 * Check if the buffer is empty, and if not, wait a bit.
	 *
	 * Throws on error.
	 *
	 * @return true if the buffer is now empty
	 */
	bool Drain() {
		if (empty)
			return true;

		CheckException();
		Wait();
		CheckException();

		return empty;
	}

264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
	/**
	 * Instruct the thread to discard the buffer (and wait for
	 * completion).  This needs to be done inside this thread,
	 * because only the consumer thread is allowed to do that.
	 */
	void Cancel() noexcept {
		cancel.store(true);
		event.Set();

		while (cancel.load() && !error.occur.load())
			Wait();

		/* not rethrowing the exception here via
		   CheckException() because this method must be
		   "noexcept"; the next WasapiOutput::Play() call will
		   throw */
	}

282 283 284 285 286 287 288 289
	/**
	 * Wait for the thread to finish some work (e.g. until some
	 * buffer space becomes available).
	 */
	void Wait() noexcept {
		data_poped.Wait();
	}

290 291 292 293
	void InterruptWaiter() noexcept {
		data_poped.Set();
	}

294
	void CheckException() {
295
		if (error.occur.load()) {
296
			std::rethrow_exception(error.ptr);
297 298 299 300 301 302
		}
	}

private:
	void SetStatus(Status s) noexcept {
		status.store(s);
303
		event.Set();
304 305 306 307 308
	}
	void Work() noexcept;
};

class WasapiOutput final : public AudioOutput {
309 310
	const bool is_exclusive;
	const bool enumerate_devices;
311
#ifdef ENABLE_DSD
312
	const bool dop_setting;
313
#endif
314

315 316 317 318 319
	/**
	 * Only valid if the output is open.
	 */
	bool paused;

320 321
	std::atomic_flag not_interrupted = true;

322 323
	const std::string device_config;

324 325 326 327 328 329 330 331
	std::shared_ptr<COMWorker> com_worker;
	ComPtr<IMMDevice> device;
	ComPtr<IAudioClient> client;
	WAVEFORMATEXTENSIBLE device_format;
	std::optional<WasapiOutputThread> thread;
	std::size_t watermark;
	std::optional<PcmExport> pcm_export;

332 333 334
public:
	static AudioOutput *Create(EventLoop &, const ConfigBlock &block);
	WasapiOutput(const ConfigBlock &block);
335 336 337 338 339 340

	auto GetComWorker() noexcept {
		// TODO: protect access to the shard_ptr
		return com_worker;
	}

341
	void Enable() override {
342
		com_worker = std::make_shared<COMWorker>();
343 344

		try {
345
			com_worker->Async([&]() { ChooseDevice(); }).get();
346
		} catch (...) {
347
			com_worker.reset();
348 349
			throw;
		}
350 351
	}
	void Disable() noexcept override {
352 353
		com_worker->Async([&]() { DoDisable(); }).get();
		com_worker.reset();
354 355
	}
	void Open(AudioFormat &audio_format) override {
356
		com_worker->Async([&]() { DoOpen(audio_format); }).get();
357
		paused = false;
358
	}
359 360 361 362
	void Close() noexcept override;
	std::chrono::steady_clock::duration Delay() const noexcept override;
	size_t Play(const void *chunk, size_t size) override;
	void Drain() override;
363
	void Cancel() noexcept override;
364
	bool Pause() override;
365
	void Interrupt() noexcept override;
366 367 368 369 370 371 372 373 374 375 376 377

	constexpr bool Exclusive() const { return is_exclusive; }
	constexpr size_t FrameSize() const { return device_format.Format.nBlockAlign; }
	constexpr size_t SampleRate() const {
		return device_format.Format.nSamplesPerSec;
	}

private:
	friend bool wasapi_is_exclusive(WasapiOutput &output) noexcept;
	friend IMMDevice *wasapi_output_get_device(WasapiOutput &output) noexcept;
	friend IAudioClient *wasapi_output_get_client(WasapiOutput &output) noexcept;

378 379 380
	void DoDisable() noexcept;
	void DoOpen(AudioFormat &audio_format);

381
	void ChooseDevice();
382
	bool TryFormatExclusive(const AudioFormat &audio_format);
383 384
	void FindExclusiveFormatSupported(AudioFormat &audio_format);
	void FindSharedFormatSupported(AudioFormat &audio_format);
385 386 387 388 389
	static void EnumerateDevices(IMMDeviceEnumerator &enumerator);
	static ComPtr<IMMDevice> GetDevice(IMMDeviceEnumerator &enumerator,
					   unsigned index);
	static ComPtr<IMMDevice> SearchDevice(IMMDeviceEnumerator &enumerator,
					      std::string_view name);
390 391
};

392 393 394
WasapiOutput &
wasapi_output_downcast(AudioOutput &output) noexcept
{
395 396 397
	return static_cast<WasapiOutput &>(output);
}

398 399 400 401 402
bool
wasapi_is_exclusive(WasapiOutput &output) noexcept
{
	return output.is_exclusive;
}
403

404 405 406 407 408 409
std::shared_ptr<COMWorker>
wasapi_output_get_com_worker(WasapiOutput &output) noexcept
{
	return output.GetComWorker();
}

410 411 412
IMMDevice *
wasapi_output_get_device(WasapiOutput &output) noexcept
{
413 414 415
	return output.device.get();
}

416 417 418
IAudioClient *
wasapi_output_get_client(WasapiOutput &output) noexcept
{
419 420 421
	return output.client.get();
}

422
inline void
423
WasapiOutputThread::Work() noexcept
424
try {
425
	SetThreadName("Wasapi Output Worker");
426
	LogDebug(wasapi_output_domain, "Working thread started");
427
	COM com;
428 429 430 431 432 433 434 435 436 437 438

	AtScopeExit(this) {
		if (started) {
			try {
				Stop(client);
			} catch (...) {
				LogError(std::current_exception());
			}
		}
	};

439
	while (true) {
440
		event.Wait();
441

442 443 444
		if (cancel.load()) {
			spsc_buffer.consume_all([](auto &&) {});
			cancel.store(false);
445
			empty.store(true);
446 447 448
			InterruptWaiter();
		}

449
		Status current_state = status.load();
450 451
		switch (current_state) {
		case Status::FINISH:
452 453
			LogDebug(wasapi_output_domain,
				 "Working thread stopped");
454
			return;
455

456 457 458 459 460 461
		case Status::PAUSE:
			if (!started)
				/* don't bother starting the
				   IAudioClient if we're paused */
				continue;

462 463 464 465 466 467
			/* stop the IAudioClient while paused; it will
			   be restarted as soon as we're asked to
			   resume playback */
			Stop(client);
			started = false;
			continue;
468 469 470 471

		case Status::PLAY:
			break;
		}
472

473 474 475 476 477 478
		UINT32 write_in_frames = buffer_size_in_frames;
		if (!is_exclusive) {
			UINT32 data_in_frames =
				GetCurrentPaddingFrames(client);
			if (data_in_frames >= buffer_size_in_frames) {
				continue;
479
			}
480 481
			write_in_frames -= data_in_frames;
		}
482

483 484
		BYTE *data;
		DWORD mode = 0;
485

486 487 488 489 490
		if (HRESULT result =
		    render_client->GetBuffer(write_in_frames, &data);
		    FAILED(result)) {
			throw MakeHResultError(result, "Failed to get buffer");
		}
491

492 493
		AtScopeExit(&) {
			render_client->ReleaseBuffer(write_in_frames, mode);
494 495 496 497 498

			if (!started) {
				Start(client);
				started = true;
			}
499 500
		};

501 502 503
		const UINT32 write_size = write_in_frames * frame_size;
		UINT32 new_data_size = 0;
		new_data_size = spsc_buffer.pop(data, write_size);
504 505 506
		if (new_data_size == 0)
			empty.store(true);

507 508 509
		std::fill_n(data + new_data_size,
			    write_size - new_data_size, 0);
		InterruptWaiter();
510
	}
511 512 513
} catch (...) {
	error.ptr = std::current_exception();
	error.occur.store(true);
514

515 516
	/* wake up the client thread which may be inside Wait() */
	InterruptWaiter();
517 518
}

519 520 521
AudioOutput *
WasapiOutput::Create(EventLoop &, const ConfigBlock &block)
{
522 523 524 525
	return new WasapiOutput(block);
}

WasapiOutput::WasapiOutput(const ConfigBlock &block)
526 527 528
	:AudioOutput(FLAG_ENABLE_DISABLE | FLAG_PAUSE),
	 is_exclusive(block.GetBlockValue("exclusive", false)),
	 enumerate_devices(block.GetBlockValue("enumerate", false)),
529
#ifdef ENABLE_DSD
530
	 dop_setting(block.GetBlockValue("dop", false)),
531
#endif
532 533
	 device_config(block.GetBlockValue("device", ""))
{
534
}
535

536
/// run inside COMWorkerThread
537 538 539
void
WasapiOutput::DoDisable() noexcept
{
540 541
	assert(!thread);

542 543 544
	device.reset();
}

545
/// run inside COMWorkerThread
546 547 548
void
WasapiOutput::DoOpen(AudioFormat &audio_format)
{
549 550
	client.reset();

551
	if (GetState(*device) != DEVICE_STATE_ACTIVE) {
552
		device.reset();
553
		ChooseDevice();
554 555
	}

556
	client = Activate<IAudioClient>(*device);
557 558 559 560 561

	if (audio_format.channels > 8) {
		audio_format.channels = 8;
	}

562
#ifdef ENABLE_DSD
563
	if (!dop_setting && audio_format.format == SampleFormat::DSD) {
564 565 566
		SetDSDFallback(audio_format);
	}
#endif
567 568 569 570 571
	if (Exclusive()) {
		FindExclusiveFormatSupported(audio_format);
	} else {
		FindSharedFormatSupported(audio_format);
	}
572
	bool require_export = audio_format.format == SampleFormat::S24_P32;
573 574 575
#ifdef ENABLE_DSD
	require_export |= audio_format.format == SampleFormat::DSD;
#endif
576 577
	if (require_export) {
		PcmExport::Params params;
578
#ifdef ENABLE_DSD
579
		params.dsd_mode = PcmExport::DsdMode::NONE;
580 581 582 583
		if (audio_format.format == SampleFormat::DSD) {
			params.dsd_mode = PcmExport::DsdMode::DOP;
		}
#endif
584 585 586 587 588 589 590 591 592
		params.shift8 = false;
		params.pack24 = false;
		if (device_format.Format.wBitsPerSample == 32 &&
		    device_format.Samples.wValidBitsPerSample == 24) {
			params.shift8 = true;
		}
		if (device_format.Format.wBitsPerSample == 24) {
			params.pack24 = true;
		}
593 594
		FmtDebug(wasapi_output_domain, "Packing data: shift8={} pack24={}",
			 params.shift8, params.pack24);
595 596 597
		pcm_export.emplace();
		pcm_export->Open(audio_format.format, audio_format.channels, params);
	}
598 599 600 601 602 603 604

	using s = std::chrono::seconds;
	using ms = std::chrono::milliseconds;
	using ns = std::chrono::nanoseconds;
	using hundred_ns = std::chrono::duration<uint64_t, std::ratio<1, 10000000>>;

	// The unit in REFERENCE_TIME is hundred nanoseconds
605 606 607 608 609
	REFERENCE_TIME default_device_period, min_device_period;

	if (HRESULT result =
		    client->GetDevicePeriod(&default_device_period, &min_device_period);
	    FAILED(result)) {
610
		throw MakeHResultError(result, "Unable to get device period");
611
	}
612 613 614 615 616
	FmtDebug(wasapi_output_domain,
		 "Default device period: {} ns, Minimum device period: "
		 "{} ns",
		 ns(hundred_ns(default_device_period)).count(),
		 ns(hundred_ns(min_device_period)).count());
617

618 619 620 621
	REFERENCE_TIME buffer_duration;
	if (Exclusive()) {
		buffer_duration = default_device_period;
	} else {
622
		const REFERENCE_TIME align = hundred_ns(ms(50)).count();
623
		buffer_duration = (align / default_device_period) * default_device_period;
624
	}
625 626
	FmtDebug(wasapi_output_domain, "Buffer duration: {} ns",
		 ns(hundred_ns(buffer_duration)).count());
627 628

	if (Exclusive()) {
629 630 631 632 633 634 635 636
		if (HRESULT result = client->Initialize(
			    AUDCLNT_SHAREMODE_EXCLUSIVE,
			    AUDCLNT_STREAMFLAGS_EVENTCALLBACK, buffer_duration,
			    buffer_duration,
			    reinterpret_cast<WAVEFORMATEX *>(&device_format), nullptr);
		    FAILED(result)) {
			if (result == AUDCLNT_E_BUFFER_SIZE_NOT_ALIGNED) {
				// https://docs.microsoft.com/en-us/windows/win32/api/audioclient/nf-audioclient-iaudioclient-initialize
637 638
				UINT32 buffer_size_in_frames =
					GetBufferSizeInFrames(*client);
639 640 641 642
				buffer_duration =
					std::ceil(double(buffer_size_in_frames *
							 hundred_ns(s(1)).count()) /
						  SampleRate());
643 644 645
				FmtDebug(wasapi_output_domain,
					 "Aligned buffer duration: {} ns",
					 ns(hundred_ns(buffer_duration)).count());
646
				client.reset();
647
				client = Activate<IAudioClient>(*device);
648 649 650 651 652 653
				result = client->Initialize(
					AUDCLNT_SHAREMODE_EXCLUSIVE,
					AUDCLNT_STREAMFLAGS_EVENTCALLBACK,
					buffer_duration, buffer_duration,
					reinterpret_cast<WAVEFORMATEX *>(&device_format),
					nullptr);
654
			}
655

656
			if (FAILED(result)) {
657
				throw MakeHResultError(result, "Unable to initialize audio client");
658 659 660
			}
		}
	} else {
661 662 663 664 665
		if (HRESULT result = client->Initialize(
			    AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_EVENTCALLBACK,
			    buffer_duration, 0,
			    reinterpret_cast<WAVEFORMATEX *>(&device_format), nullptr);
		    FAILED(result)) {
666 667
			throw MakeHResultError(result,
					       "Unable to initialize audio client");
668
		}
669 670
	}

671
	auto render_client = GetService<IAudioRenderClient>(*client);
672

673
	const UINT32 buffer_size_in_frames = GetBufferSizeInFrames(*client);
674 675

	watermark = buffer_size_in_frames * 3 * FrameSize();
676
	thread.emplace(*client, std::move(render_client), FrameSize(),
677
		       buffer_size_in_frames, is_exclusive);
678

679
	paused = false;
680 681
}

682 683 684
void
WasapiOutput::Close() noexcept
{
685 686 687 688
	assert(thread);

	try {
		thread->CheckException();
689
	} catch (...) {
690
		LogError(wasapi_output_domain, "exception while stopping");
691
	}
692
	thread->Finish();
693
	com_worker->Async([&]() {
694 695 696
		thread.reset();
		client.reset();
	}).get();
697
	pcm_export.reset();
698 699
}

700 701 702
std::chrono::steady_clock::duration
WasapiOutput::Delay() const noexcept
{
703
	if (paused) {
704 705
		// idle while paused
		return std::chrono::seconds(1);
706 707
	}

708
	return std::chrono::steady_clock::duration::zero();
709 710
}

711 712 713
size_t
WasapiOutput::Play(const void *chunk, size_t size)
{
714
	assert(thread);
715

716 717
	paused = false;

718 719
	not_interrupted.test_and_set();

720 721 722 723 724 725 726
	ConstBuffer<void> input(chunk, size);
	if (pcm_export) {
		input = pcm_export->Export(input);
	}
	if (input.empty())
		return size;

727
	do {
728
		const size_t consumed_size = thread->Push({input.data, input.size});
729

730
		if (consumed_size == 0) {
731
			thread->Wait();
732
			thread->CheckException();
733 734 735
			if (!not_interrupted.test_and_set()) {
				throw AudioOutputInterrupted{};
			}
736 737 738 739 740
			continue;
		}

		thread->CheckException();

741 742 743
		if (pcm_export) {
			return pcm_export->CalcInputSize(consumed_size);
		}
744 745 746 747
		return consumed_size;
	} while (true);
}

748 749 750
bool
WasapiOutput::Pause()
{
751
	paused = true;
752
	thread->Pause();
753 754 755 756
	thread->CheckException();
	return true;
}

757 758 759
void
WasapiOutput::Interrupt() noexcept
{
760 761
	if (thread) {
		not_interrupted.clear();
762
		thread->InterruptWaiter();
763 764 765
	}
}

766 767 768
void
WasapiOutput::Drain()
{
769 770
	assert(thread);

771 772 773 774 775 776
	not_interrupted.test_and_set();

	while (!thread->Drain()) {
		if (!not_interrupted.test_and_set())
			throw AudioOutputInterrupted{};
	}
777

778 779
	/* TODO: this needs to wait until the hardware has really
	   finished playing */
780 781
}

782 783 784 785 786
void
WasapiOutput::Cancel() noexcept
{
	assert(thread);

787
	thread->Cancel();
788 789
}

790
/// run inside COMWorkerThread
791
void
792
WasapiOutput::ChooseDevice()
793
{
794
	ComPtr<IMMDeviceEnumerator> enumerator;
795 796 797
	enumerator.CoCreateInstance(__uuidof(MMDeviceEnumerator), nullptr,
				    CLSCTX_INPROC_SERVER);

798 799
	if (enumerate_devices) {
		try {
800
			EnumerateDevices(*enumerator);
801 802
		} catch (...) {
			LogError(std::current_exception());
803 804 805 806
		}
	}

	if (!device_config.empty()) {
807
		unsigned int id;
808
		if (!SafeSilenceTry([this, &id]() { id = std::stoul(device_config); })) {
809
			device = SearchDevice(*enumerator, device_config);
810
			if (!device)
811 812
				throw FormatRuntimeError("Device '%s' not found",
							 device_config.c_str());
813
		} else
814
			device = GetDevice(*enumerator, id);
815
	} else {
816
		device = GetDefaultAudioEndpoint(*enumerator);
817 818 819 820
	}
}

/// run inside COMWorkerThread
821 822 823
bool
WasapiOutput::TryFormatExclusive(const AudioFormat &audio_format)
{
824 825
	for (auto test_format : GetFormats(audio_format)) {
		HRESULT result = client->IsFormatSupported(
826
			AUDCLNT_SHAREMODE_EXCLUSIVE,
827 828
			reinterpret_cast<WAVEFORMATEX *>(&test_format), nullptr);
		const auto result_string = std::string(HRESULTToString(result));
829 830 831 832 833
		FmtDebug(wasapi_output_domain, "Trying {} {} {}-{} (exclusive) -> {}",
			 audio_format, test_format.Format.nSamplesPerSec,
			 test_format.Format.wBitsPerSample,
			 test_format.Samples.wValidBitsPerSample,
			 result_string);
834 835 836
		if (SUCCEEDED(result)) {
			device_format = test_format;
			return true;
837
		}
838 839 840

		if (result == AUDCLNT_E_EXCLUSIVE_MODE_NOT_ALLOWED)
			throw std::runtime_error("Exclusive mode not allowed");
841 842 843
	}
	return false;
}
844

845
/// run inside COMWorkerThread
846 847 848
void
WasapiOutput::FindExclusiveFormatSupported(AudioFormat &audio_format)
{
849 850
	for (uint8_t channels : {0, 2, 6, 8, 7, 1, 4, 5, 3}) {
		if (audio_format.channels == channels) {
851 852
			continue;
		}
853 854
		if (channels == 0) {
			channels = audio_format.channels;
855
		}
856
		auto old_channels = std::exchange(audio_format.channels, channels);
857 858 859 860 861 862 863 864 865 866
#ifdef ENABLE_DSD
		bool was_dsd = false;
		if (audio_format.format == SampleFormat::DSD) {
			if (dop_setting && TryFormatExclusive(audio_format)) {
				return;
			}
			was_dsd = true;
			SetDSDFallback(audio_format);
		}
#endif
867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896
		for (uint32_t rate : {0, 384000, 352800, 192000, 176400, 96000, 88200,
				      48000, 44100, 32000, 22050, 16000, 11025, 8000}) {
			if (audio_format.sample_rate <= rate) {
				continue;
			}
			if (rate == 0) {
				rate = audio_format.sample_rate;
			}
			auto old_rate = std::exchange(audio_format.sample_rate, rate);
			for (SampleFormat format : {
				     SampleFormat::UNDEFINED,
				     SampleFormat::S32,
				     SampleFormat::S24_P32,
				     SampleFormat::S16,
				     SampleFormat::S8,
			     }) {
				if (audio_format.format == format) {
					continue;
				}
				if (format == SampleFormat::UNDEFINED) {
					format = audio_format.format;
				}
				auto old_format =
					std::exchange(audio_format.format, format);
				if (TryFormatExclusive(audio_format)) {
					return;
				}
				audio_format.format = old_format;
			}
			audio_format.sample_rate = old_rate;
897
		}
898 899 900 901 902
#ifdef ENABLE_DSD
		if (was_dsd) {
			audio_format.format = SampleFormat::DSD;
		}
#endif
903 904
		audio_format.channels = old_channels;
	}
905 906
}

907
/// run inside COMWorkerThread
908 909 910
void
WasapiOutput::FindSharedFormatSupported(AudioFormat &audio_format)
{
911 912 913
	HRESULT result;

	// In shared mode, different sample rate is always unsupported.
914 915
	auto mixer_format = GetMixFormat(*client);

916 917
	audio_format.sample_rate = mixer_format->nSamplesPerSec;
	device_format = GetFormats(audio_format).front();
918 919 920 921 922 923

	ComHeapPtr<WAVEFORMATEXTENSIBLE> closest_format;
	result = client->IsFormatSupported(
		AUDCLNT_SHAREMODE_SHARED,
		reinterpret_cast<WAVEFORMATEX *>(&device_format),
		closest_format.AddressCast<WAVEFORMATEX>());
924 925
	{
		const auto result_string = std::string(HRESULTToString(result));
926 927 928 929 930
		FmtDebug(wasapi_output_domain, "Trying {} {} {}-{} (shared) -> {}",
			 audio_format, device_format.Format.nSamplesPerSec,
			 device_format.Format.wBitsPerSample,
			 device_format.Samples.wValidBitsPerSample,
			 result_string);
931
	}
932 933

	if (FAILED(result) && result != AUDCLNT_E_UNSUPPORTED_FORMAT) {
934
		throw MakeHResultError(result, "IsFormatSupported failed");
935 936 937 938 939 940 941 942 943 944
	}

	switch (result) {
	case S_OK:
		break;
	case AUDCLNT_E_UNSUPPORTED_FORMAT:
	default:
		// Trying channels fallback.
		audio_format.channels = mixer_format->nChannels;

945
		device_format = GetFormats(audio_format).front();
946 947 948 949 950

		result = client->IsFormatSupported(
			AUDCLNT_SHAREMODE_SHARED,
			reinterpret_cast<WAVEFORMATEX *>(&device_format),
			closest_format.AddressCast<WAVEFORMATEX>());
951 952
		{
			const auto result_string = std::string(HRESULTToString(result));
953 954 955 956 957 958 959
			FmtDebug(wasapi_output_domain,
				 "Trying {} {} {}-{} (shared) -> {}",
				 audio_format,
				 device_format.Format.nSamplesPerSec,
				 device_format.Format.wBitsPerSample,
				 device_format.Samples.wValidBitsPerSample,
				 result_string);
960
		}
961
		if (FAILED(result)) {
962
			throw MakeHResultError(result, "Format is not supported");
963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998
		}
		break;
	case S_FALSE:
		if (closest_format->Format.wFormatTag == WAVE_FORMAT_EXTENSIBLE) {
			device_format = *closest_format;
		} else {
			device_format.Samples.wValidBitsPerSample =
				device_format.Format.wBitsPerSample;
			device_format.Format = closest_format->Format;
			switch (std::exchange(device_format.Format.wFormatTag,
					      WAVE_FORMAT_EXTENSIBLE)) {
			case WAVE_FORMAT_PCM:
				device_format.SubFormat = KSDATAFORMAT_SUBTYPE_PCM;
				break;
			case WAVE_FORMAT_IEEE_FLOAT:
				device_format.SubFormat = KSDATAFORMAT_SUBTYPE_IEEE_FLOAT;
				break;
			default:
				gcc_unreachable();
			}
		}
		break;
	}

	// Copy closest match back to audio_format.
	audio_format.channels = device_format.Format.nChannels;
	audio_format.sample_rate = device_format.Format.nSamplesPerSec;
	if (device_format.SubFormat == KSDATAFORMAT_SUBTYPE_PCM) {
		switch (device_format.Format.wBitsPerSample) {
		case 8:
			audio_format.format = SampleFormat::S8;
			break;
		case 16:
			audio_format.format = SampleFormat::S16;
			break;
		case 32:
999 1000 1001 1002
			audio_format.format =
				device_format.Samples.wValidBitsPerSample == 32
					? SampleFormat::S32
					: SampleFormat::S24_P32;
1003 1004 1005 1006 1007 1008 1009
			break;
		}
	} else if (device_format.SubFormat == KSDATAFORMAT_SUBTYPE_IEEE_FLOAT) {
		audio_format.format = SampleFormat::FLOAT;
	}
}

1010
/// run inside COMWorkerThread
1011
void
1012
WasapiOutput::EnumerateDevices(IMMDeviceEnumerator &enumerator)
1013
{
1014
	const auto device_collection = EnumAudioEndpoints(enumerator);
1015

1016
	const UINT count = GetCount(*device_collection);
1017
	for (UINT i = 0; i < count; ++i) {
1018
		const auto enumerated_device = Item(*device_collection, i);
1019

1020 1021
		const auto property_store =
			OpenPropertyStore(*enumerated_device);
1022

1023 1024 1025 1026
		auto name = GetString(*property_store,
				      PKEY_Device_FriendlyName);
		if (name == nullptr)
			continue;
1027

1028 1029
		FmtNotice(wasapi_output_domain,
			  "Device \"{}\" \"{}\"", i, name);
1030 1031 1032
	}
}

1033
/// run inside COMWorkerThread
1034
ComPtr<IMMDevice>
1035
WasapiOutput::GetDevice(IMMDeviceEnumerator &enumerator, unsigned index)
1036
{
1037
	const auto device_collection = EnumAudioEndpoints(enumerator);
1038
	return Item(*device_collection, index);
1039 1040
}

1041
/// run inside COMWorkerThread
1042
ComPtr<IMMDevice>
1043 1044
WasapiOutput::SearchDevice(IMMDeviceEnumerator &enumerator,
			   std::string_view name)
1045
{
1046
	const auto device_collection = EnumAudioEndpoints(enumerator);
1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058

	const UINT count = GetCount(*device_collection);
	for (UINT i = 0; i < count; ++i) {
		auto d = Item(*device_collection, i);

		const auto property_store = OpenPropertyStore(*d);
		auto n = GetString(*property_store, PKEY_Device_FriendlyName);
		if (n != nullptr && name.compare(n) == 0)
			return d;
	}

	return nullptr;
1059 1060
}

1061 1062 1063 1064 1065
static bool
wasapi_output_test_default_device()
{
	return true;
}
1066 1067 1068 1069 1070 1071 1072

const struct AudioOutputPlugin wasapi_output_plugin = {
	"wasapi",
	wasapi_output_test_default_device,
	WasapiOutput::Create,
	&wasapi_mixer_plugin,
};