netio.c 31.2 KB
Newer Older
1 2 3
/*
 * WSK (Winsock Kernel) driver library.
 *
4
 * Copyright 2020 Paul Gofman <pgofman@codeweavers.com> for CodeWeavers
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
 */

#include <stdarg.h>

#define NONAMELESSUNION
#define NONAMELESSSTRUCT

#include "ntstatus.h"
#define WIN32_NO_STATUS
#include "windef.h"
#include "winioctl.h"
#include "winternl.h"
#include "ddk/wdm.h"
32
#include "ddk/wsk.h"
33
#include "wine/debug.h"
34 35
#include "winsock2.h"
#include "ws2tcpip.h"
36

37 38
#include "wine/heap.h"

39 40
WINE_DEFAULT_DEBUG_CHANNEL(netio);

41 42 43 44 45 46
struct _WSK_CLIENT
{
    WSK_REGISTRATION *registration;
    WSK_CLIENT_NPI *client_npi;
};

47 48
struct listen_socket_callback_context
{
49
    SOCKADDR *local_address;
50 51 52 53 54 55 56
    SOCKADDR *remote_address;
    const void *client_dispatch;
    void *client_context;
    char addr_buffer[2 * (sizeof(SOCKADDR) + 16)];
    SOCKET acceptor;
};

57 58 59 60 61 62
#define MAX_PENDING_IO 10

struct wsk_pending_io
{
    OVERLAPPED ovr;
    TP_WAIT *tp_wait;
63
    void *callback;
64 65 66
    IRP *irp;
};

67 68 69 70 71 72
struct wsk_socket_internal
{
    WSK_SOCKET wsk_socket;
    SOCKET s;
    const void *client_dispatch;
    void *client_context;
73 74 75 76
    ULONG flags;
    ADDRESS_FAMILY address_family;
    USHORT socket_type;
    ULONG protocol;
77
    BOOL bound;
78 79 80

    CRITICAL_SECTION cs_socket;

81 82
    struct wsk_pending_io pending_io[MAX_PENDING_IO];

83 84 85 86 87
    union
    {
        struct listen_socket_callback_context listen_socket_callback_context;
    }
    callback_context;
88 89
};

90
static LPFN_ACCEPTEX pAcceptEx;
91
static LPFN_GETACCEPTEXSOCKADDRS pGetAcceptExSockaddrs;
92 93
static LPFN_CONNECTEX pConnectEx;

94 95
static const WSK_PROVIDER_CONNECTION_DISPATCH wsk_provider_connection_dispatch;

96 97 98 99 100
static inline struct wsk_socket_internal *wsk_socket_internal_from_wsk_socket(WSK_SOCKET *wsk_socket)
{
    return CONTAINING_RECORD(wsk_socket, struct wsk_socket_internal, wsk_socket);
}

101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
static NTSTATUS sock_error_to_ntstatus(DWORD err)
{
    switch (err)
    {
        case 0:                    return STATUS_SUCCESS;
        case WSAEBADF:             return STATUS_INVALID_HANDLE;
        case WSAEACCES:            return STATUS_ACCESS_DENIED;
        case WSAEFAULT:            return STATUS_NO_MEMORY;
        case WSAEINVAL:            return STATUS_INVALID_PARAMETER;
        case WSAEMFILE:            return STATUS_TOO_MANY_OPENED_FILES;
        case WSAEWOULDBLOCK:       return STATUS_CANT_WAIT;
        case WSAEINPROGRESS:       return STATUS_PENDING;
        case WSAEALREADY:          return STATUS_NETWORK_BUSY;
        case WSAENOTSOCK:          return STATUS_OBJECT_TYPE_MISMATCH;
        case WSAEDESTADDRREQ:      return STATUS_INVALID_PARAMETER;
        case WSAEMSGSIZE:          return STATUS_BUFFER_OVERFLOW;
        case WSAEPROTONOSUPPORT:
        case WSAESOCKTNOSUPPORT:
        case WSAEPFNOSUPPORT:
        case WSAEAFNOSUPPORT:
        case WSAEPROTOTYPE:        return STATUS_NOT_SUPPORTED;
        case WSAENOPROTOOPT:       return STATUS_INVALID_PARAMETER;
123
        case WSAEOPNOTSUPP:        return STATUS_NOT_IMPLEMENTED;
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        case WSAEADDRINUSE:        return STATUS_ADDRESS_ALREADY_ASSOCIATED;
        case WSAEADDRNOTAVAIL:     return STATUS_INVALID_PARAMETER;
        case WSAECONNREFUSED:      return STATUS_CONNECTION_REFUSED;
        case WSAESHUTDOWN:         return STATUS_PIPE_DISCONNECTED;
        case WSAENOTCONN:          return STATUS_CONNECTION_DISCONNECTED;
        case WSAETIMEDOUT:         return STATUS_IO_TIMEOUT;
        case WSAENETUNREACH:       return STATUS_NETWORK_UNREACHABLE;
        case WSAENETDOWN:          return STATUS_NETWORK_BUSY;
        case WSAECONNRESET:        return STATUS_CONNECTION_RESET;
        case WSAECONNABORTED:      return STATUS_CONNECTION_ABORTED;
        case WSAHOST_NOT_FOUND:    return STATUS_NOT_FOUND;
        default:
            FIXME("Unmapped error %u.\n", err);
            return STATUS_UNSUCCESSFUL;
    }
}

141 142 143 144 145 146 147 148 149 150
static inline void lock_socket(struct wsk_socket_internal *socket)
{
    EnterCriticalSection(&socket->cs_socket);
}

static inline void unlock_socket(struct wsk_socket_internal *socket)
{
    LeaveCriticalSection(&socket->cs_socket);
}

151
static void socket_init(struct wsk_socket_internal *socket)
152 153 154 155
{
    InitializeCriticalSection(&socket->cs_socket);
}

156 157 158 159 160 161 162 163
static void dispatch_irp(IRP *irp, NTSTATUS status)
{
    irp->IoStatus.u.Status = status;
    --irp->CurrentLocation;
    --irp->Tail.Overlay.s.u2.CurrentStackLocation;
    IoCompleteRequest(irp, IO_NO_INCREMENT);
}

164 165 166 167
static struct wsk_pending_io *allocate_pending_io(struct wsk_socket_internal *socket,
        PTP_WAIT_CALLBACK socket_async_callback, IRP *irp)
{
    struct wsk_pending_io *io = socket->pending_io;
168
    unsigned int i, io_index;
169

170
    io_index = ~0u;
171
    for (i = 0; i < ARRAY_SIZE(socket->pending_io); ++i)
172
    {
173
        if (!io[i].irp)
174 175 176 177 178 179 180 181 182 183 184
        {
            if (io[i].callback == socket_async_callback)
            {
                io[i].irp = irp;
                return &io[i];
            }

            if (io_index == ~0u)
                io_index = i;
        }
    }
185

186
    if (io_index == ~0u)
187 188 189 190 191
    {
        FIXME("Pending io requests count exceeds limit.\n");
        return NULL;
    }

192
    io[io_index].irp = irp;
193

194 195 196 197
    if (io[io_index].tp_wait)
        CloseThreadpoolWait(io[io_index].tp_wait);
    else
        io[io_index].ovr.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL);
198

199 200
    io[io_index].tp_wait = CreateThreadpoolWait(socket_async_callback, socket, NULL);
    io[io_index].callback = socket_async_callback;
201

202
    return &io[io_index];
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
}

static struct wsk_pending_io *find_pending_io(struct wsk_socket_internal *socket, TP_WAIT *tp_wait)
{
    unsigned int i;

    for (i = 0; i < ARRAY_SIZE(socket->pending_io); ++i)
    {
        if (socket->pending_io[i].tp_wait == tp_wait)
            return &socket->pending_io[i];
    }

    FIXME("Pending io not found for tp_wait %p.\n", tp_wait);
    return NULL;
}

static void dispatch_pending_io(struct wsk_pending_io *io, NTSTATUS status, ULONG_PTR information)
{
    TRACE("io %p, status %#x, information %#lx.\n", io, status, information);

    io->irp->IoStatus.Information = information;
    dispatch_irp(io->irp, status);
    io->irp = NULL;
}

228 229 230 231 232 233 234 235 236 237 238 239 240 241
static NTSTATUS WINAPI wsk_control_socket(WSK_SOCKET *socket, WSK_CONTROL_SOCKET_TYPE request_type,
        ULONG control_code, ULONG level, SIZE_T input_size, void *input_buffer, SIZE_T output_size,
        void *output_buffer, SIZE_T *output_size_returned, IRP *irp)
{
    FIXME("socket %p, request_type %u, control_code %#x, level %u, input_size %lu, input_buffer %p, "
            "output_size %lu, output_buffer %p, output_size_returned %p, irp %p stub.\n",
            socket, request_type, control_code, level, input_size, input_buffer, output_size,
            output_buffer, output_size_returned, irp);

    return STATUS_NOT_IMPLEMENTED;
}

static NTSTATUS WINAPI wsk_close_socket(WSK_SOCKET *socket, IRP *irp)
{
242 243
    struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(socket);
    NTSTATUS status;
244
    unsigned int i;
245

246 247
    TRACE("socket %p, irp %p.\n", socket, irp);

248 249
    lock_socket(s);

250
    for (i = 0; i < ARRAY_SIZE(s->pending_io); ++i)
251
    {
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
        struct wsk_pending_io *io = &s->pending_io[i];

        if (io->tp_wait)
        {
            CancelIoEx((HANDLE)s->s, &io->ovr);
            SetThreadpoolWait(io->tp_wait, NULL, NULL);
            unlock_socket(s);
            WaitForThreadpoolWaitCallbacks(io->tp_wait, FALSE);
            lock_socket(s);
            CloseThreadpoolWait(io->tp_wait);
            CloseHandle(io->ovr.hEvent);
        }

        if (io->irp)
            dispatch_pending_io(io, STATUS_CANCELLED, 0);
267 268 269 270 271
    }

    if (s->flags & WSK_FLAG_LISTEN_SOCKET && s->callback_context.listen_socket_callback_context.acceptor)
        closesocket(s->callback_context.listen_socket_callback_context.acceptor);

272
    status = closesocket(s->s) ? sock_error_to_ntstatus(WSAGetLastError()) : STATUS_SUCCESS;
273 274 275

    unlock_socket(s);
    DeleteCriticalSection(&s->cs_socket);
276 277 278 279 280 281
    heap_free(socket);

    irp->IoStatus.Information = 0;
    dispatch_irp(irp, status);

    return status ? status : STATUS_PENDING;
282 283 284 285
}

static NTSTATUS WINAPI wsk_bind(WSK_SOCKET *socket, SOCKADDR *local_address, ULONG flags, IRP *irp)
{
286 287 288 289
    struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(socket);
    NTSTATUS status;

    TRACE("socket %p, local_address %p, flags %#x, irp %p.\n",
290 291
            socket, local_address, flags, irp);

292 293 294 295 296
    if (!irp)
        return STATUS_INVALID_PARAMETER;

    if (bind(s->s, local_address, sizeof(*local_address)))
        status = sock_error_to_ntstatus(WSAGetLastError());
297 298
    else if (s->flags & WSK_FLAG_LISTEN_SOCKET && listen(s->s, SOMAXCONN))
        status = sock_error_to_ntstatus(WSAGetLastError());
299 300 301
    else
        status = STATUS_SUCCESS;

302 303 304
    if (status == STATUS_SUCCESS)
        s->bound = TRUE;

305 306 307 308
    TRACE("status %#x.\n", status);
    irp->IoStatus.Information = 0;
    dispatch_irp(irp, status);
    return STATUS_PENDING;
309 310
}

311
static void create_accept_socket(struct wsk_socket_internal *socket, struct wsk_pending_io *io)
312 313 314
{
    struct listen_socket_callback_context *context
            = &socket->callback_context.listen_socket_callback_context;
315 316
    INT local_address_len, remote_address_len;
    SOCKADDR *local_address, *remote_address;
317 318 319 320 321
    struct wsk_socket_internal *accept_socket;

    if (!(accept_socket = heap_alloc_zero(sizeof(*accept_socket))))
    {
        ERR("No memory.\n");
322
        dispatch_pending_io(io, STATUS_NO_MEMORY, 0);
323 324 325 326 327 328 329 330 331 332 333 334
    }
    else
    {
        TRACE("accept_socket %p.\n", accept_socket);
        accept_socket->wsk_socket.Dispatch = &wsk_provider_connection_dispatch;
        accept_socket->s = context->acceptor;
        accept_socket->client_dispatch = context->client_dispatch;
        accept_socket->client_context = context->client_context;
        accept_socket->socket_type = socket->socket_type;
        accept_socket->address_family = socket->address_family;
        accept_socket->protocol = socket->protocol;
        accept_socket->flags = WSK_FLAG_CONNECTION_SOCKET;
335
        socket_init(accept_socket);
336 337 338 339 340 341 342 343 344 345 346

        pGetAcceptExSockaddrs(context->addr_buffer, 0, sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16,
                &local_address, &local_address_len, &remote_address, &remote_address_len);

        if (context->local_address)
            memcpy(context->local_address, local_address,
                    min(sizeof(*context->local_address), local_address_len));

        if (context->remote_address)
            memcpy(context->remote_address, remote_address,
                    min(sizeof(*context->remote_address), remote_address_len));
347

348
        dispatch_pending_io(io, STATUS_SUCCESS, (ULONG_PTR)&accept_socket->wsk_socket);
349 350 351 352 353 354 355 356
    }
}

static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_, TP_WAIT *wait,
        TP_WAIT_RESULT wait_result)
{
    struct listen_socket_callback_context *context;
    struct wsk_socket_internal *socket = socket_;
357
    struct wsk_pending_io *io;
358 359 360 361 362 363
    DWORD size;

    TRACE("instance %p, socket %p, wait %p, wait_result %#x.\n", instance, socket, wait, wait_result);

    lock_socket(socket);
    context = &socket->callback_context.listen_socket_callback_context;
364
    io = find_pending_io(socket, wait);
365

366
    if (GetOverlappedResult((HANDLE)socket->s, &io->ovr, &size, FALSE))
367
    {
368
        create_accept_socket(socket, io);
369 370 371 372 373
    }
    else
    {
        closesocket(context->acceptor);
        context->acceptor = 0;
374
        dispatch_pending_io(io, io->ovr.Internal, 0);
375 376 377 378 379 380
    }
    unlock_socket(socket);
}

static BOOL WINAPI init_accept_functions(INIT_ONCE *once, void *param, void **context)
{
381
    GUID get_acceptex_guid = WSAID_GETACCEPTEXSOCKADDRS;
382 383 384 385 386 387 388 389 390 391
    GUID acceptex_guid = WSAID_ACCEPTEX;
    SOCKET s = (SOCKET)param;
    DWORD size;

    if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &acceptex_guid, sizeof(acceptex_guid),
            &pAcceptEx, sizeof(pAcceptEx), &size, NULL, NULL))
    {
        ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError());
        return FALSE;
    }
392 393 394 395 396 397 398 399

    if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &get_acceptex_guid, sizeof(get_acceptex_guid),
            &pGetAcceptExSockaddrs, sizeof(pGetAcceptExSockaddrs), &size, NULL, NULL))
    {
        ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError());
        return FALSE;
    }

400 401 402
    return TRUE;
}

403 404 405 406
static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void *accept_socket_context,
        const WSK_CLIENT_CONNECTION_DISPATCH *accept_socket_dispatch, SOCKADDR *local_address,
        SOCKADDR *remote_address, IRP *irp)
{
407 408 409
    struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(listen_socket);
    static INIT_ONCE init_once = INIT_ONCE_STATIC_INIT;
    struct listen_socket_callback_context *context;
410
    struct wsk_pending_io *io;
411 412 413 414 415 416
    SOCKET acceptor;
    DWORD size;
    int error;

    TRACE("listen_socket %p, flags %#x, accept_socket_context %p, accept_socket_dispatch %p, "
            "local_address %p, remote_address %p, irp %p.\n",
417 418 419
            listen_socket, flags, accept_socket_context, accept_socket_dispatch, local_address,
            remote_address, irp);

420 421 422 423 424
    if (!irp)
        return STATUS_INVALID_PARAMETER;

    if (!InitOnceExecuteOnce(&init_once, init_accept_functions, (void *)s->s, NULL))
    {
425 426
        dispatch_irp(irp, STATUS_UNSUCCESSFUL);
        return STATUS_PENDING;
427 428 429
    }

    lock_socket(s);
430 431 432 433 434 435 436 437
    if (!(io = allocate_pending_io(s, accept_callback, irp)))
    {
        irp->IoStatus.Information = 0;
        dispatch_irp(irp, STATUS_UNSUCCESSFUL);
        unlock_socket(s);
        return STATUS_PENDING;
    }

438 439 440 441
    context = &s->callback_context.listen_socket_callback_context;
    if ((acceptor = WSASocketW(s->address_family, s->socket_type, s->protocol, NULL, 0, WSA_FLAG_OVERLAPPED))
            == INVALID_SOCKET)
    {
442
        dispatch_pending_io(io, sock_error_to_ntstatus(WSAGetLastError()), 0);
443
        unlock_socket(s);
444
        return STATUS_PENDING;
445 446
    }

447
    context->local_address = local_address;
448 449 450 451 452 453
    context->remote_address = remote_address;
    context->client_dispatch = accept_socket_dispatch;
    context->client_context = accept_socket_context;
    context->acceptor = acceptor;

    if (pAcceptEx(s->s, acceptor, context->addr_buffer, 0,
454
            sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16, &size, &io->ovr))
455
    {
456
        create_accept_socket(s, io);
457 458 459
    }
    else if ((error = WSAGetLastError()) == ERROR_IO_PENDING)
    {
460
        SetThreadpoolWait(io->tp_wait, io->ovr.hEvent, NULL);
461 462 463 464 465
    }
    else
    {
        closesocket(acceptor);
        context->acceptor = 0;
466
        dispatch_pending_io(io, sock_error_to_ntstatus(error), 0);
467 468 469 470
    }
    unlock_socket(s);

    return STATUS_PENDING;
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500
}

static NTSTATUS WINAPI wsk_inspect_complete(WSK_SOCKET *listen_socket, WSK_INSPECT_ID *inspect_id,
        WSK_INSPECT_ACTION action, IRP *irp)
{
    FIXME("listen_socket %p, inspect_id %p, action %u, irp %p stub.\n",
            listen_socket, inspect_id, action, irp);

    return STATUS_NOT_IMPLEMENTED;
}

static NTSTATUS WINAPI wsk_get_local_address(WSK_SOCKET *socket, SOCKADDR *local_address, IRP *irp)
{
    FIXME("socket %p, local_address %p, irp %p stub.\n", socket, local_address, irp);

    return STATUS_NOT_IMPLEMENTED;
}

static const WSK_PROVIDER_LISTEN_DISPATCH wsk_provider_listen_dispatch =
{
    {
        wsk_control_socket,
        wsk_close_socket,
    },
    wsk_bind,
    wsk_accept,
    wsk_inspect_complete,
    wsk_get_local_address,
};

501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
static void WINAPI connect_callback(TP_CALLBACK_INSTANCE *instance, void *socket_, TP_WAIT *wait,
        TP_WAIT_RESULT wait_result)
{
    struct wsk_socket_internal *socket = socket_;
    struct wsk_pending_io *io;
    DWORD size;

    TRACE("instance %p, socket %p, wait %p, wait_result %#x.\n", instance, socket, wait, wait_result);

    lock_socket(socket);
    io = find_pending_io(socket, wait);

    GetOverlappedResult((HANDLE)socket->s, &io->ovr, &size, FALSE);
    dispatch_pending_io(io, io->ovr.Internal, 0);
    unlock_socket(socket);
}

static BOOL WINAPI init_connect_functions(INIT_ONCE *once, void *param, void **context)
{
    GUID connectex_guid = WSAID_CONNECTEX;
    SOCKET s = (SOCKET)param;
    DWORD size;

    if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &connectex_guid, sizeof(connectex_guid),
            &pConnectEx, sizeof(pConnectEx), &size, NULL, NULL))
    {
        ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError());
        return FALSE;
    }
    return TRUE;
}

533 534
static NTSTATUS WINAPI wsk_connect(WSK_SOCKET *socket, SOCKADDR *remote_address, ULONG flags, IRP *irp)
{
535 536 537 538
    struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(socket);
    static INIT_ONCE init_once = INIT_ONCE_STATIC_INIT;
    struct wsk_pending_io *io;
    int error;
539

540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
    TRACE("socket %p, remote_address %p, flags %#x, irp %p.\n",
            socket, remote_address, flags, irp);

    if (!irp)
        return STATUS_INVALID_PARAMETER;

    if (!InitOnceExecuteOnce(&init_once, init_connect_functions, (void *)s->s, NULL))
    {
        dispatch_irp(irp, STATUS_UNSUCCESSFUL);
        return STATUS_PENDING;
    }

    lock_socket(s);

    if (!(io = allocate_pending_io(s, connect_callback, irp)))
    {
        irp->IoStatus.Information = 0;
        dispatch_irp(irp, STATUS_UNSUCCESSFUL);
        unlock_socket(s);
        return STATUS_PENDING;
    }

    if (!s->bound)
    {
        dispatch_pending_io(io, STATUS_INVALID_DEVICE_STATE, 0);
        unlock_socket(s);
        return STATUS_INVALID_DEVICE_STATE;
    }

    if (pConnectEx(s->s, remote_address, sizeof(*remote_address), NULL, 0, NULL, &io->ovr))
        dispatch_pending_io(io, STATUS_SUCCESS, 0);
    else if ((error = WSAGetLastError()) == ERROR_IO_PENDING)
        SetThreadpoolWait(io->tp_wait, io->ovr.hEvent, NULL);
    else
        dispatch_pending_io(io, sock_error_to_ntstatus(error), 0);

    unlock_socket(s);

    return STATUS_PENDING;
579 580 581 582 583 584 585 586 587
}

static NTSTATUS WINAPI wsk_get_remote_address(WSK_SOCKET *socket, SOCKADDR *remote_address, IRP *irp)
{
    FIXME("socket %p, remote_address %p, irp %p stub.\n", socket, remote_address, irp);

    return STATUS_NOT_IMPLEMENTED;
}

588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
static void WINAPI send_receive_callback(TP_CALLBACK_INSTANCE *instance, void *socket_, TP_WAIT *wait,
        TP_WAIT_RESULT wait_result)
{
    struct wsk_socket_internal *socket = socket_;
    struct wsk_pending_io *io;
    DWORD length, flags;

    TRACE("instance %p, socket %p, wait %p, wait_result %#x.\n", instance, socket, wait, wait_result);

    lock_socket(socket);
    io = find_pending_io(socket, wait);

    if (WSAGetOverlappedResult(socket->s, &io->ovr, &length, FALSE, &flags))
        dispatch_pending_io(io, STATUS_SUCCESS, length);
    else
        dispatch_pending_io(io, io->ovr.Internal, 0);

    unlock_socket(socket);
}

608
static NTSTATUS do_send_receive(WSK_SOCKET *socket, WSK_BUF *wsk_buf, ULONG flags, IRP *irp, BOOL is_send)
609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669
{
    struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(socket);
    struct wsk_pending_io *io;
    DWORD wsa_flags;
    WSABUF wsa_buf;
    DWORD length;
    int error;

    TRACE("socket %p, buffer %p, flags %#x, irp %p, is_send %#x.\n",
            socket, wsk_buf, flags, irp, is_send);

    if (!irp)
        return STATUS_INVALID_PARAMETER;

    if (!wsk_buf->Mdl && wsk_buf->Length)
        return STATUS_INVALID_PARAMETER;

    if (wsk_buf->Mdl && wsk_buf->Mdl->Next)
    {
        FIXME("Chained MDLs are not supported.\n");
        irp->IoStatus.Information = 0;
        dispatch_irp(irp, STATUS_UNSUCCESSFUL);
        return STATUS_PENDING;
    }

    if (flags)
        FIXME("flags %#x not implemented.\n", flags);

    lock_socket(s);
    if (!(io = allocate_pending_io(s, send_receive_callback, irp)))
    {
        irp->IoStatus.Information = 0;
        dispatch_irp(irp, STATUS_UNSUCCESSFUL);
        unlock_socket(s);
        return STATUS_PENDING;
    }

    wsa_buf.len = wsk_buf->Length;
    wsa_buf.buf = wsk_buf->Mdl ? (CHAR *)wsk_buf->Mdl->StartVa
            + wsk_buf->Mdl->ByteOffset + wsk_buf->Offset : NULL;

    wsa_flags = 0;

    if (!(is_send ? WSASend(s->s, &wsa_buf, 1, &length, wsa_flags, &io->ovr, NULL)
            : WSARecv(s->s, &wsa_buf, 1, &length, &wsa_flags, &io->ovr, NULL)))
    {
        dispatch_pending_io(io, STATUS_SUCCESS, length);
    }
    else if ((error = WSAGetLastError()) == WSA_IO_PENDING)
    {
        SetThreadpoolWait(io->tp_wait, io->ovr.hEvent, NULL);
    }
    else
    {
        dispatch_pending_io(io, sock_error_to_ntstatus(error), 0);
    }
    unlock_socket(s);

    return STATUS_PENDING;
}

670 671
static NTSTATUS WINAPI wsk_send(WSK_SOCKET *socket, WSK_BUF *buffer, ULONG flags, IRP *irp)
{
672
    TRACE("socket %p, buffer %p, flags %#x, irp %p.\n", socket, buffer, flags, irp);
673

674
    return do_send_receive(socket, buffer, flags, irp, TRUE);
675 676 677 678
}

static NTSTATUS WINAPI wsk_receive(WSK_SOCKET *socket, WSK_BUF *buffer, ULONG flags, IRP *irp)
{
679
    TRACE("socket %p, buffer %p, flags %#x, irp %p.\n", socket, buffer, flags, irp);
680

681
    return do_send_receive(socket, buffer, flags, irp, FALSE);
682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
}

static NTSTATUS WINAPI wsk_disconnect(WSK_SOCKET *socket, WSK_BUF *buffer, ULONG flags, IRP *irp)
{
    FIXME("socket %p, buffer %p, flags %#x, irp %p stub.\n", socket, buffer, flags, irp);

    return STATUS_NOT_IMPLEMENTED;
}

static NTSTATUS WINAPI wsk_release(WSK_SOCKET *socket, WSK_DATA_INDICATION *data_indication)
{
    FIXME("socket %p, data_indication %p stub.\n", socket, data_indication);

    return STATUS_NOT_IMPLEMENTED;
}

static NTSTATUS WINAPI wsk_connext_ex(WSK_SOCKET *socket, SOCKADDR *remote_address, WSK_BUF *buffer,
        ULONG flags, IRP *irp)
{
    FIXME("socket %p, remote_address %p, buffer %p, flags %#x, irp %p stub.\n",
            socket, remote_address, buffer, flags, irp);

    return STATUS_NOT_IMPLEMENTED;
}

static NTSTATUS WINAPI wsk_send_ex(void)
{
    FIXME("stub (no prototype, will crash).\n");

    return STATUS_NOT_IMPLEMENTED;
}

static NTSTATUS WINAPI wsk_receive_ex(void)
{
    FIXME("stub (no prototype, will crash).\n");

    return STATUS_NOT_IMPLEMENTED;
}

static const WSK_PROVIDER_CONNECTION_DISPATCH wsk_provider_connection_dispatch =
{
    {
        wsk_control_socket,
        wsk_close_socket,
    },
    wsk_bind,
    wsk_connect,
    wsk_get_local_address,
    wsk_get_remote_address,
    wsk_send,
    wsk_receive,
    wsk_disconnect,
    wsk_release,
    wsk_connext_ex,
    wsk_send_ex,
    wsk_receive_ex,
};

740
static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_family, USHORT socket_type,
741
        ULONG protocol, ULONG flags, void *socket_context, const void *dispatch, PEPROCESS owning_process,
742 743
        PETHREAD owning_thread, SECURITY_DESCRIPTOR *security_descriptor, IRP *irp)
{
744 745 746 747 748 749 750
    struct wsk_socket_internal *socket;
    NTSTATUS status;
    SOCKET s;

    TRACE("client %p, address_family %#x, socket_type %#x, protocol %#x, flags %#x, socket_context %p, dispatch %p, "
            "owning_process %p, owning_thread %p, security_descriptor %p, irp %p.\n",
            client, address_family, socket_type, protocol, flags, socket_context, dispatch, owning_process,
751 752
            owning_thread, security_descriptor, irp);

753 754 755 756 757 758 759 760
    if (!irp)
        return STATUS_INVALID_PARAMETER;

    if (!client)
        return STATUS_INVALID_HANDLE;

    irp->IoStatus.Information = 0;

761
    if ((s = WSASocketW(address_family, socket_type, protocol, NULL, 0, WSA_FLAG_OVERLAPPED)) == INVALID_SOCKET)
762 763 764 765 766
    {
        status = sock_error_to_ntstatus(WSAGetLastError());
        goto done;
    }

767
    if (!(socket = heap_alloc_zero(sizeof(*socket))))
768 769 770 771 772 773 774 775 776
    {
        status = STATUS_NO_MEMORY;
        closesocket(s);
        goto done;
    }

    socket->s = s;
    socket->client_dispatch = dispatch;
    socket->client_context = socket_context;
777 778 779 780
    socket->socket_type = socket_type;
    socket->flags = flags;
    socket->address_family = address_family;
    socket->protocol = protocol;
781 782 783 784 785 786 787

    switch (flags)
    {
        case WSK_FLAG_LISTEN_SOCKET:
            socket->wsk_socket.Dispatch = &wsk_provider_listen_dispatch;
            break;

788 789 790 791
        case WSK_FLAG_CONNECTION_SOCKET:
            socket->wsk_socket.Dispatch = &wsk_provider_connection_dispatch;
            break;

792 793 794 795 796 797 798 799
        default:
            FIXME("Flags %#x not implemented.\n", flags);
            closesocket(s);
            heap_free(socket);
            status = STATUS_NOT_IMPLEMENTED;
            goto done;
    }

800
    socket_init(socket);
801

802 803 804 805 806 807
    irp->IoStatus.Information = (ULONG_PTR)&socket->wsk_socket;
    status = STATUS_SUCCESS;

done:
    dispatch_irp(irp, status);
    return status ? status : STATUS_PENDING;
808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836
}

static NTSTATUS WINAPI wsk_socket_connect(WSK_CLIENT *client, USHORT socket_type, ULONG protocol,
        SOCKADDR *local_address, SOCKADDR *remote_address, ULONG flags, void *socket_context,
        const WSK_CLIENT_CONNECTION_DISPATCH *dispatch, PEPROCESS owning_process, PETHREAD owning_thread,
        SECURITY_DESCRIPTOR *security_descriptor, IRP *irp)
{
    FIXME("client %p, socket_type %#x, protocol %#x, local_address %p, remote_address %p, "
            "flags %#x, socket_context %p, dispatch %p, owning_process %p, owning_thread %p, "
            "security_descriptor %p, irp %p stub.\n",
            client, socket_type, protocol, local_address, remote_address, flags, socket_context,
            dispatch, owning_process, owning_thread, security_descriptor, irp);

    return STATUS_NOT_IMPLEMENTED;
}

static NTSTATUS WINAPI wsk_control_client(WSK_CLIENT *client, ULONG control_code, SIZE_T input_size,
        void *input_buffer, SIZE_T output_size, void *output_buffer, SIZE_T *output_size_returned,
        IRP *irp
)
{
    FIXME("client %p, control_code %#x, input_size %lu, input_buffer %p, output_size %lu, "
            "output_buffer %p, output_size_returned %p, irp %p, stub.\n",
            client, control_code, input_size, input_buffer, output_size, output_buffer,
            output_size_returned, irp);

    return STATUS_NOT_IMPLEMENTED;
}

837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863
struct wsk_get_address_info_context
{
    UNICODE_STRING *node_name;
    UNICODE_STRING *service_name;
    ULONG namespace;
    GUID *provider;
    ADDRINFOEXW *hints;
    ADDRINFOEXW **result;
    IRP *irp;
};

static void WINAPI get_address_info_callback(TP_CALLBACK_INSTANCE *instance, void *context_)
{
    struct wsk_get_address_info_context *context = context_;
    INT ret;

    TRACE("instance %p, context %p.\n", instance, context);

    ret = GetAddrInfoExW( context->node_name ? context->node_name->Buffer : NULL,
            context->service_name ? context->service_name->Buffer : NULL, context->namespace,
            context->provider, context->hints, context->result, NULL, NULL, NULL, NULL);

    context->irp->IoStatus.Information = 0;
    dispatch_irp(context->irp, sock_error_to_ntstatus(ret));
    heap_free(context);
}

864 865 866 867
static NTSTATUS WINAPI wsk_get_address_info(WSK_CLIENT *client, UNICODE_STRING *node_name,
        UNICODE_STRING *service_name, ULONG name_space, GUID *provider, ADDRINFOEXW *hints,
        ADDRINFOEXW **result, PEPROCESS owning_process, PETHREAD owning_thread, IRP *irp)
{
868 869 870 871 872
    struct wsk_get_address_info_context *context;
    NTSTATUS status;

    TRACE("client %p, node_name %p, service_name %p, name_space %#x, provider %p, hints %p, "
            "result %p, owning_process %p, owning_thread %p, irp %p.\n",
873 874 875
            client, node_name, service_name, name_space, provider, hints, result,
            owning_process, owning_thread, irp);

876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904
    if (!irp)
        return STATUS_INVALID_PARAMETER;

    if (!(context = heap_alloc(sizeof(*context))))
    {
        ERR("No memory.\n");
        status = STATUS_NO_MEMORY;
        dispatch_irp(irp, status);
        return status;
    }

    context->node_name = node_name;
    context->service_name = service_name;
    context->namespace = name_space;
    context->provider = provider;
    context->hints = hints;
    context->result = result;
    context->irp = irp;

    if (!TrySubmitThreadpoolCallback(get_address_info_callback, context, NULL))
    {
        ERR("Could not submit thread pool callback.\n");
        status = STATUS_UNSUCCESSFUL;
        dispatch_irp(irp, status);
        heap_free(context);
        return status;
    }
    TRACE("Submitted threadpool callback, context %p.\n", context);
    return STATUS_PENDING;
905 906 907 908
}

static void WINAPI wsk_free_address_info(WSK_CLIENT *client, ADDRINFOEXW *addr_info)
{
909 910 911
    TRACE("client %p, addr_info %p.\n", client, addr_info);

    FreeAddrInfoExW(addr_info);
912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936
}

static NTSTATUS WINAPI wsk_get_name_info(WSK_CLIENT *client, SOCKADDR *sock_addr, ULONG sock_addr_length,
        UNICODE_STRING *node_name, UNICODE_STRING *service_name, ULONG flags, PEPROCESS owning_process,
        PETHREAD owning_thread, IRP *irp)
{
    FIXME("client %p, sock_addr %p, sock_addr_length %u, node_name %p, service_name %p, "
            "flags %#x, owning_process %p, owning_thread %p, irp %p stub.\n",
            client, sock_addr, sock_addr_length, node_name, service_name, flags,
            owning_process, owning_thread, irp);

    return STATUS_NOT_IMPLEMENTED;
}

static const WSK_PROVIDER_DISPATCH wsk_dispatch =
{
    MAKE_WSK_VERSION(1, 0), 0,
    wsk_socket,
    wsk_socket_connect,
    wsk_control_client,
    wsk_get_address_info,
    wsk_free_address_info,
    wsk_get_name_info,
};

937 938 939
NTSTATUS WINAPI WskCaptureProviderNPI(WSK_REGISTRATION *wsk_registration, ULONG wait_timeout,
        WSK_PROVIDER_NPI *wsk_provider_npi)
{
940 941 942
    WSK_CLIENT *client = wsk_registration->ReservedRegistrationContext;

    TRACE("wsk_registration %p, wait_timeout %u, wsk_provider_npi %p.\n",
943 944
            wsk_registration, wait_timeout, wsk_provider_npi);

945 946 947
    wsk_provider_npi->Client = client;
    wsk_provider_npi->Dispatch = &wsk_dispatch;
    return STATUS_SUCCESS;
948 949
}

950 951
void WINAPI WskReleaseProviderNPI(WSK_REGISTRATION *wsk_registration)
{
952 953
    TRACE("wsk_registration %p.\n", wsk_registration);

954 955
}

956 957
NTSTATUS WINAPI WskRegister(WSK_CLIENT_NPI *wsk_client_npi, WSK_REGISTRATION *wsk_registration)
{
958 959 960
    static const WORD version = MAKEWORD( 2, 2 );
    WSADATA data;

961
    WSK_CLIENT *client;
962

963 964 965 966 967 968 969 970 971 972 973 974
    TRACE("wsk_client_npi %p, wsk_registration %p.\n", wsk_client_npi, wsk_registration);

    if (!(client = heap_alloc(sizeof(*client))))
    {
        ERR("No memory.\n");
        return STATUS_NO_MEMORY;
    }

    client->registration = wsk_registration;
    client->client_npi = wsk_client_npi;
    wsk_registration->ReservedRegistrationContext = client;

975 976 977
    if (WSAStartup(version, &data))
        return STATUS_INTERNAL_ERROR;

978
    return STATUS_SUCCESS;
979 980
}

981 982
void WINAPI WskDeregister(WSK_REGISTRATION *wsk_registration)
{
983 984 985
    TRACE("wsk_registration %p.\n", wsk_registration);

    heap_free(wsk_registration->ReservedRegistrationContext);
986 987
}

988 989 990 991 992 993 994 995 996 997 998 999
static void WINAPI driver_unload(DRIVER_OBJECT *driver)
{
    TRACE("driver %p.\n", driver);
}

NTSTATUS WINAPI DriverEntry(DRIVER_OBJECT *driver, UNICODE_STRING *path)
{
    TRACE("driver %p, path %s.\n", driver, debugstr_w(path->Buffer));

    driver->DriverUnload = driver_unload;
    return STATUS_SUCCESS;
}