http.c 30.9 KB
Newer Older
1 2
/*
 * Copyright 2005 Jacek Caban
3
 * Copyright 2007 Misha Koshelev
4 5 6 7 8 9 10 11 12 13 14 15 16
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
17
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
18 19
 */

20 21
#define NONAMELESSUNION

22
#include "urlmon_main.h"
Jacek Caban's avatar
Jacek Caban committed
23
#include "wininet.h"
24

25 26 27
#define NO_SHLWAPI_REG
#include "shlwapi.h"

28 29 30 31 32
#include "wine/debug.h"

WINE_DEFAULT_DEBUG_CHANNEL(urlmon);

typedef struct {
33 34
    Protocol base;

35 36 37
    IInternetProtocolEx IInternetProtocolEx_iface;
    IInternetPriority   IInternetPriority_iface;
    IWinInetHttpInfo    IWinInetHttpInfo_iface;
38

39
    BOOL https;
40
    IHttpNegotiate *http_negotiate;
41
    WCHAR *full_header;
42 43

    LONG ref;
44 45
} HttpProtocol;

46 47 48 49 50 51 52 53 54 55 56 57 58 59
static inline HttpProtocol *impl_from_IInternetProtocolEx(IInternetProtocolEx *iface)
{
    return CONTAINING_RECORD(iface, HttpProtocol, IInternetProtocolEx_iface);
}

static inline HttpProtocol *impl_from_IInternetPriority(IInternetPriority *iface)
{
    return CONTAINING_RECORD(iface, HttpProtocol, IInternetPriority_iface);
}

static inline HttpProtocol *impl_from_IWinInetHttpInfo(IWinInetHttpInfo *iface)
{
    return CONTAINING_RECORD(iface, HttpProtocol, IWinInetHttpInfo_iface);
}
60

61 62
static const WCHAR default_headersW[] = {
    'A','c','c','e','p','t','-','E','n','c','o','d','i','n','g',':',' ','g','z','i','p',',',' ','d','e','f','l','a','t','e',0};
63

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
static LPWSTR query_http_info(HttpProtocol *This, DWORD option)
{
    LPWSTR ret = NULL;
    DWORD len = 0;
    BOOL res;

    res = HttpQueryInfoW(This->base.request, option, NULL, &len, NULL);
    if (!res && GetLastError() == ERROR_INSUFFICIENT_BUFFER) {
        ret = heap_alloc(len);
        res = HttpQueryInfoW(This->base.request, option, ret, &len, NULL);
    }
    if(!res) {
        TRACE("HttpQueryInfoW(%d) failed: %08x\n", option, GetLastError());
        heap_free(ret);
        return NULL;
    }

    return ret;
}

84
static inline BOOL set_security_flag(HttpProtocol *This, DWORD flags)
85 86 87
{
    BOOL res;

88
    res = InternetSetOptionW(This->base.request, INTERNET_OPTION_SECURITY_FLAGS, &flags, sizeof(flags));
89
    if(!res)
90
        ERR("Failed to set security flags: %x\n", flags);
91 92 93 94 95 96 97 98 99 100 101 102

    return res;
}

static inline HRESULT internet_error_to_hres(DWORD error)
{
    switch(error)
    {
    case ERROR_INTERNET_SEC_CERT_DATE_INVALID:
    case ERROR_INTERNET_SEC_CERT_CN_INVALID:
    case ERROR_INTERNET_INVALID_CA:
    case ERROR_INTERNET_CLIENT_AUTH_CERT_NEEDED:
103 104 105 106 107
    case ERROR_INTERNET_SEC_INVALID_CERT:
    case ERROR_INTERNET_SEC_CERT_ERRORS:
    case ERROR_INTERNET_SEC_CERT_REV_FAILED:
    case ERROR_INTERNET_SEC_CERT_NO_REV:
    case ERROR_INTERNET_SEC_CERT_REVOKED:
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        return INET_E_INVALID_CERTIFICATE;
    case ERROR_INTERNET_HTTP_TO_HTTPS_ON_REDIR:
    case ERROR_INTERNET_HTTPS_TO_HTTP_ON_REDIR:
    case ERROR_HTTP_REDIRECT_NEEDS_CONFIRMATION:
        return INET_E_REDIRECT_FAILED;
    default:
        return INET_E_DOWNLOAD_FAILURE;
    }
}

static HRESULT handle_http_error(HttpProtocol *This, DWORD error)
{
    IServiceProvider *serv_prov;
    IWindowForBindingUI *wfb_ui;
    IHttpSecurity *http_security;
    BOOL security_problem;
124 125 126
    DWORD dlg_flags;
    HWND hwnd;
    DWORD res;
127 128
    HRESULT hres;

129 130
    TRACE("(%p %u)\n", This, error);

131 132 133 134 135 136 137
    switch(error) {
    case ERROR_INTERNET_SEC_CERT_DATE_INVALID:
    case ERROR_INTERNET_SEC_CERT_CN_INVALID:
    case ERROR_INTERNET_HTTP_TO_HTTPS_ON_REDIR:
    case ERROR_INTERNET_HTTPS_TO_HTTP_ON_REDIR:
    case ERROR_INTERNET_INVALID_CA:
    case ERROR_INTERNET_CLIENT_AUTH_CERT_NEEDED:
138 139 140 141 142 143
    case ERROR_INTERNET_SEC_INVALID_CERT:
    case ERROR_INTERNET_SEC_CERT_ERRORS:
    case ERROR_INTERNET_SEC_CERT_REV_FAILED:
    case ERROR_INTERNET_SEC_CERT_NO_REV:
    case ERROR_INTERNET_SEC_CERT_REVOKED:
    case ERROR_HTTP_REDIRECT_NEEDS_CONFIRMATION:
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        security_problem = TRUE;
        break;
    default:
        security_problem = FALSE;
    }

    hres = IInternetProtocolSink_QueryInterface(This->base.protocol_sink, &IID_IServiceProvider,
                                                (void**)&serv_prov);
    if(FAILED(hres)) {
        ERR("Failed to get IServiceProvider.\n");
        return E_ABORT;
    }

    if(security_problem) {
        hres = IServiceProvider_QueryService(serv_prov, &IID_IHttpSecurity, &IID_IHttpSecurity,
                                             (void**)&http_security);
        if(SUCCEEDED(hres)) {
            hres = IHttpSecurity_OnSecurityProblem(http_security, error);
            IHttpSecurity_Release(http_security);

164 165
            TRACE("OnSecurityProblem returned %08x\n", hres);

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
            if(hres != S_FALSE)
            {
                BOOL res = FALSE;

                IServiceProvider_Release(serv_prov);

                if(hres == S_OK) {
                    if(error == ERROR_INTERNET_SEC_CERT_DATE_INVALID)
                        res = set_security_flag(This, SECURITY_FLAG_IGNORE_CERT_DATE_INVALID);
                    else if(error == ERROR_INTERNET_SEC_CERT_CN_INVALID)
                        res = set_security_flag(This, SECURITY_FLAG_IGNORE_CERT_CN_INVALID);
                    else if(error == ERROR_INTERNET_INVALID_CA)
                        res = set_security_flag(This, SECURITY_FLAG_IGNORE_UNKNOWN_CA);

                    if(res)
                        return RPC_E_RETRY;

                    FIXME("Don't know how to ignore error %d\n", error);
                    return E_ABORT;
                }

                if(hres == E_ABORT)
                    return E_ABORT;
                if(hres == RPC_E_RETRY)
                    return RPC_E_RETRY;

                return internet_error_to_hres(error);
            }
        }
    }

197 198
    switch(error) {
    case ERROR_INTERNET_SEC_CERT_REV_FAILED:
199
        if(hres != S_FALSE) {
200 201 202
            /* Silently ignore the error. We will get more detailed error from wininet anyway. */
            set_security_flag(This, SECURITY_FLAG_IGNORE_REVOCATION);
            hres = RPC_E_RETRY;
203
            break;
204
        }
205
        /* fallthrough */
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    default:
        hres = IServiceProvider_QueryService(serv_prov, &IID_IWindowForBindingUI, &IID_IWindowForBindingUI, (void**)&wfb_ui);
        if(SUCCEEDED(hres)) {
            const IID *iid_reason;

            if(security_problem)
                iid_reason = &IID_IHttpSecurity;
            else if(error == ERROR_INTERNET_INCORRECT_PASSWORD)
                iid_reason = &IID_IAuthenticate;
            else
                iid_reason = &IID_IWindowForBindingUI;

            hres = IWindowForBindingUI_GetWindow(wfb_ui, iid_reason, &hwnd);
            IWindowForBindingUI_Release(wfb_ui);
        }
221

222
        if(FAILED(hres)) hwnd = NULL;
223

224 225 226
        dlg_flags = FLAGS_ERROR_UI_FLAGS_CHANGE_OPTIONS | FLAGS_ERROR_UI_FLAGS_GENERATE_DATA;
        if(This->base.bindf & BINDF_NO_UI)
            dlg_flags |= FLAGS_ERROR_UI_FLAGS_NO_UI;
227

228 229 230 231 232 233
        res = InternetErrorDlg(hwnd, This->base.request, error, dlg_flags, NULL);
        hres = res == ERROR_INTERNET_FORCE_RETRY || res == ERROR_SUCCESS ? RPC_E_RETRY : internet_error_to_hres(error);
    }

    IServiceProvider_Release(serv_prov);
    return hres;
234 235
}

236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
static ULONG send_http_request(HttpProtocol *This)
{
    INTERNET_BUFFERSW send_buffer = {sizeof(INTERNET_BUFFERSW)};
    BOOL res;

    send_buffer.lpcszHeader = This->full_header;
    send_buffer.dwHeadersLength = send_buffer.dwHeadersTotal = strlenW(This->full_header);

    if(This->base.bind_info.dwBindVerb != BINDVERB_GET) {
        switch(This->base.bind_info.stgmedData.tymed) {
        case TYMED_HGLOBAL:
            /* Native does not use GlobalLock/GlobalUnlock, so we won't either */
            send_buffer.lpvBuffer = This->base.bind_info.stgmedData.u.hGlobal;
            send_buffer.dwBufferLength = send_buffer.dwBufferTotal = This->base.bind_info.cbstgmedData;
            break;
        case TYMED_ISTREAM: {
            LARGE_INTEGER offset;

            send_buffer.dwBufferTotal = This->base.bind_info.cbstgmedData;
            if(!This->base.post_stream) {
                This->base.post_stream = This->base.bind_info.stgmedData.u.pstm;
                IStream_AddRef(This->base.post_stream);
            }

            offset.QuadPart = 0;
            IStream_Seek(This->base.post_stream, offset, STREAM_SEEK_SET, NULL);
            break;
        }
        default:
            FIXME("Unsupported This->base.bind_info.stgmedData.tymed %d\n", This->base.bind_info.stgmedData.tymed);
        }
    }

    if(This->base.post_stream)
        res = HttpSendRequestExW(This->base.request, &send_buffer, NULL, 0, 0);
    else
        res = HttpSendRequestW(This->base.request, send_buffer.lpcszHeader, send_buffer.dwHeadersLength,
                send_buffer.lpvBuffer, send_buffer.dwBufferLength);

    return res ? 0 : GetLastError();
}

278 279 280 281
static inline HttpProtocol *impl_from_Protocol(Protocol *prot)
{
    return CONTAINING_RECORD(prot, HttpProtocol, base);
}
282

283
static HRESULT HttpProtocol_open_request(Protocol *prot, IUri *uri, DWORD request_flags,
284
        HINTERNET internet_session, IInternetBindInfo *bind_info)
285
{
286
    HttpProtocol *This = impl_from_Protocol(prot);
287
    WCHAR *addl_header = NULL, *post_cookie = NULL, *rootdoc_url = NULL;
288 289
    IServiceProvider *service_provider = NULL;
    IHttpNegotiate2 *http_negotiate2 = NULL;
290
    BSTR url, host, user, pass, path;
291
    LPOLESTR accept_mimes[257];
292
    const WCHAR **accept_types;
293
    BYTE security_id[512];
294
    DWORD len, port, flags;
295
    ULONG num, error;
296
    BOOL res, b;
297 298 299 300 301 302 303
    HRESULT hres;

    static const WCHAR wszBindVerb[BINDVERB_CUSTOM][5] =
        {{'G','E','T',0},
         {'P','O','S','T',0},
         {'P','U','T',0}};

304 305 306 307 308 309 310
    hres = IUri_GetPort(uri, &port);
    if(FAILED(hres))
        return hres;

    hres = IUri_GetHost(uri, &host);
    if(FAILED(hres))
        return hres;
311

312 313 314 315 316 317 318 319 320 321 322 323 324 325
    hres = IUri_GetUserName(uri, &user);
    if(SUCCEEDED(hres)) {
        hres = IUri_GetPassword(uri, &pass);

        if(SUCCEEDED(hres)) {
            This->base.connection = InternetConnectW(internet_session, host, port, user, pass,
                    INTERNET_SERVICE_HTTP, This->https ? INTERNET_FLAG_SECURE : 0, (DWORD_PTR)&This->base);
            SysFreeString(pass);
        }
        SysFreeString(user);
    }
    SysFreeString(host);
    if(FAILED(hres))
        return hres;
326 327 328 329 330
    if(!This->base.connection) {
        WARN("InternetConnect failed: %d\n", GetLastError());
        return INET_E_CANNOT_CONNECT;
    }

331 332 333 334 335 336 337
    num = 0;
    hres = IInternetBindInfo_GetBindString(bind_info, BINDSTRING_ROOTDOC_URL, &rootdoc_url, 1, &num);
    if(hres == S_OK && num) {
        FIXME("Use root doc URL %s\n", debugstr_w(rootdoc_url));
        CoTaskMemFree(rootdoc_url);
    }

338 339
    num = sizeof(accept_mimes)/sizeof(accept_mimes[0])-1;
    hres = IInternetBindInfo_GetBindString(bind_info, BINDSTRING_ACCEPT_MIMES, accept_mimes, num, &num);
340 341 342 343 344 345 346 347 348
    if(hres == INET_E_USE_DEFAULT_SETTING) {
        static const WCHAR default_accept_mimeW[] = {'*','/','*',0};
        static const WCHAR *default_accept_mimes[] = {default_accept_mimeW, NULL};

        accept_types = default_accept_mimes;
        num = 0;
    }else if(hres == S_OK) {
        accept_types = (const WCHAR**)accept_mimes;
    }else {
349 350 351 352 353 354 355
        WARN("GetBindString BINDSTRING_ACCEPT_MIMES failed: %08x\n", hres);
        return INET_E_NO_VALID_MEDIA;
    }
    accept_mimes[num] = 0;

    if(This->https)
        request_flags |= INTERNET_FLAG_SECURE;
356 357 358 359 360 361 362 363 364

    hres = IUri_GetPathAndQuery(uri, &path);
    if(SUCCEEDED(hres)) {
        This->base.request = HttpOpenRequestW(This->base.connection,
                This->base.bind_info.dwBindVerb < BINDVERB_CUSTOM
                    ? wszBindVerb[This->base.bind_info.dwBindVerb] : This->base.bind_info.szCustomVerb,
                path, NULL, NULL, accept_types, request_flags, (DWORD_PTR)&This->base);
        SysFreeString(path);
    }
365 366
    while(num--)
        CoTaskMemFree(accept_mimes[num]);
367 368
    if(FAILED(hres))
        return hres;
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
    if (!This->base.request) {
        WARN("HttpOpenRequest failed: %d\n", GetLastError());
        return INET_E_RESOURCE_NOT_FOUND;
    }

    hres = IInternetProtocolSink_QueryInterface(This->base.protocol_sink, &IID_IServiceProvider,
            (void **)&service_provider);
    if (hres != S_OK) {
        WARN("IInternetProtocolSink_QueryInterface IID_IServiceProvider failed: %08x\n", hres);
        return hres;
    }

    hres = IServiceProvider_QueryService(service_provider, &IID_IHttpNegotiate,
            &IID_IHttpNegotiate, (void **)&This->http_negotiate);
    if (hres != S_OK) {
        WARN("IServiceProvider_QueryService IID_IHttpNegotiate failed: %08x\n", hres);
385
        IServiceProvider_Release(service_provider);
386 387 388
        return hres;
    }

389 390 391 392 393 394
    hres = IUri_GetAbsoluteUri(uri, &url);
    if(FAILED(hres)) {
        IServiceProvider_Release(service_provider);
        return hres;
    }

395
    hres = IHttpNegotiate_BeginningTransaction(This->http_negotiate, url, default_headersW,
396
            0, &addl_header);
397
    SysFreeString(url);
398 399 400 401 402 403
    if(hres != S_OK) {
        WARN("IHttpNegotiate_BeginningTransaction failed: %08x\n", hres);
        IServiceProvider_Release(service_provider);
        return hres;
    }

404
    len = addl_header ? strlenW(addl_header) : 0;
405

406 407 408 409
    This->full_header = heap_alloc(len*sizeof(WCHAR)+sizeof(default_headersW));
    if(!This->full_header) {
        IServiceProvider_Release(service_provider);
        return E_OUTOFMEMORY;
410 411
    }

412 413 414 415 416
    if(len)
        memcpy(This->full_header, addl_header, len*sizeof(WCHAR));
    CoTaskMemFree(addl_header);
    memcpy(This->full_header+len, default_headersW, sizeof(default_headersW));

417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
    hres = IServiceProvider_QueryService(service_provider, &IID_IHttpNegotiate2,
            &IID_IHttpNegotiate2, (void **)&http_negotiate2);
    IServiceProvider_Release(service_provider);
    if(hres != S_OK) {
        WARN("IServiceProvider_QueryService IID_IHttpNegotiate2 failed: %08x\n", hres);
        /* No goto done as per native */
    }else {
        len = sizeof(security_id)/sizeof(security_id[0]);
        hres = IHttpNegotiate2_GetRootSecurityId(http_negotiate2, security_id, &len, 0);
        IHttpNegotiate2_Release(http_negotiate2);
        if (hres != S_OK)
            WARN("IHttpNegotiate2_GetRootSecurityId failed: %08x\n", hres);
    }

    /* FIXME: Handle security_id. Native calls undocumented function IsHostInProxyBypassList. */

    if(This->base.bind_info.dwBindVerb == BINDVERB_POST) {
        num = 0;
        hres = IInternetBindInfo_GetBindString(bind_info, BINDSTRING_POST_COOKIE, &post_cookie, 1, &num);
        if(hres == S_OK && num) {
            if(!InternetSetOptionW(This->base.request, INTERNET_OPTION_SECONDARY_CACHE_KEY,
                                   post_cookie, lstrlenW(post_cookie)))
                WARN("InternetSetOption INTERNET_OPTION_SECONDARY_CACHE_KEY failed: %d\n", GetLastError());
            CoTaskMemFree(post_cookie);
        }
    }

444 445 446 447 448
    flags = INTERNET_ERROR_MASK_COMBINED_SEC_CERT;
    res = InternetSetOptionW(This->base.request, INTERNET_OPTION_ERROR_MASK, &flags, sizeof(flags));
    if(!res)
        WARN("InternetSetOption(INTERNET_OPTION_ERROR_MASK) failed: %u\n", GetLastError());

449 450 451
    b = TRUE;
    res = InternetSetOptionW(This->base.request, INTERNET_OPTION_HTTP_DECODING, &b, sizeof(b));
    if(!res)
452
        WARN("InternetSetOption(INTERNET_OPTION_HTTP_DECODING) failed: %u\n", GetLastError());
453

454 455
    do {
        error = send_http_request(This);
456

457 458
        switch(error) {
        case ERROR_IO_PENDING:
459
            return S_OK;
460 461 462 463 464 465 466 467 468
        case ERROR_SUCCESS:
            /*
             * If sending response ended synchronously, it means that we have the whole data
             * available locally (most likely in cache).
             */
            return protocol_syncbinding(&This->base);
        default:
            hres = handle_http_error(This, error);
        }
469
    } while(hres == RPC_E_RETRY);
470

471
    WARN("HttpSendRequest failed: %d\n", error);
472
    return hres;
473 474
}

475 476 477 478 479 480 481 482 483 484 485 486 487
static HRESULT HttpProtocol_end_request(Protocol *protocol)
{
    BOOL res;

    res = HttpEndRequestW(protocol->request, NULL, 0, 0);
    if(!res && GetLastError() != ERROR_IO_PENDING) {
        FIXME("HttpEndRequest failed: %u\n", GetLastError());
        return E_FAIL;
    }

    return S_OK;
}

488 489 490 491 492 493 494 495 496 497 498 499
static BOOL is_redirect_response(DWORD status_code)
{
    switch(status_code) {
    case HTTP_STATUS_REDIRECT:
    case HTTP_STATUS_MOVED:
    case HTTP_STATUS_REDIRECT_KEEP_VERB:
    case HTTP_STATUS_REDIRECT_METHOD:
        return TRUE;
    }
    return FALSE;
}

500 501
static HRESULT HttpProtocol_start_downloading(Protocol *prot)
{
502
    HttpProtocol *This = impl_from_Protocol(prot);
503
    LPWSTR content_type, content_length, ranges;
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
    DWORD len = sizeof(DWORD);
    DWORD status_code;
    BOOL res;
    HRESULT hres;

    static const WCHAR wszDefaultContentType[] =
        {'t','e','x','t','/','h','t','m','l',0};

    if(!This->http_negotiate) {
        WARN("Expected IHttpNegotiate pointer to be non-NULL\n");
        return S_OK;
    }

    res = HttpQueryInfoW(This->base.request, HTTP_QUERY_STATUS_CODE | HTTP_QUERY_FLAG_NUMBER,
            &status_code, &len, NULL);
    if(res) {
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
        WCHAR *response_headers;

        if((This->base.bind_info.dwOptions & BINDINFO_OPTIONS_DISABLEAUTOREDIRECTS) && is_redirect_response(status_code)) {
            WCHAR *location;

            TRACE("Got redirect with disabled auto redirects\n");

            location = query_http_info(This, HTTP_QUERY_LOCATION);
            This->base.flags |= FLAG_RESULT_REPORTED | FLAG_LAST_DATA_REPORTED;
            IInternetProtocolSink_ReportResult(This->base.protocol_sink, INET_E_REDIRECT_FAILED, 0, location);
            heap_free(location);
            return INET_E_REDIRECT_FAILED;
        }

        response_headers = query_http_info(This, HTTP_QUERY_RAW_HEADERS_CRLF);
535 536 537 538 539 540 541 542 543 544 545 546 547
        if(response_headers) {
            hres = IHttpNegotiate_OnResponse(This->http_negotiate, status_code, response_headers,
                    NULL, NULL);
            heap_free(response_headers);
            if (hres != S_OK) {
                WARN("IHttpNegotiate_OnResponse failed: %08x\n", hres);
                return S_OK;
            }
        }
    }else {
        WARN("HttpQueryInfo failed: %d\n", GetLastError());
    }

548 549
    ranges = query_http_info(This, HTTP_QUERY_ACCEPT_RANGES);
    if(ranges) {
550
        IInternetProtocolSink_ReportProgress(This->base.protocol_sink, BINDSTATUS_ACCEPTRANGES, NULL);
551 552
        heap_free(ranges);
    }
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581

    content_type = query_http_info(This, HTTP_QUERY_CONTENT_TYPE);
    if(content_type) {
        /* remove the charset, if present */
        LPWSTR p = strchrW(content_type, ';');
        if (p) *p = '\0';

        IInternetProtocolSink_ReportProgress(This->base.protocol_sink,
                (This->base.bindf & BINDF_FROMURLMON)
                 ? BINDSTATUS_MIMETYPEAVAILABLE : BINDSTATUS_RAWMIMETYPE,
                 content_type);
        heap_free(content_type);
    }else {
        WARN("HttpQueryInfo failed: %d\n", GetLastError());
        IInternetProtocolSink_ReportProgress(This->base.protocol_sink,
                 (This->base.bindf & BINDF_FROMURLMON)
                  ? BINDSTATUS_MIMETYPEAVAILABLE : BINDSTATUS_RAWMIMETYPE,
                  wszDefaultContentType);
    }

    content_length = query_http_info(This, HTTP_QUERY_CONTENT_LENGTH);
    if(content_length) {
        This->base.content_length = atoiW(content_length);
        heap_free(content_length);
    }

    return S_OK;
}

582 583
static void HttpProtocol_close_connection(Protocol *prot)
{
584
    HttpProtocol *This = impl_from_Protocol(prot);
585 586 587

    if(This->http_negotiate) {
        IHttpNegotiate_Release(This->http_negotiate);
588
        This->http_negotiate = NULL;
589 590
    }

591 592
    heap_free(This->full_header);
    This->full_header = NULL;
593 594
}

595 596
static void HttpProtocol_on_error(Protocol *prot, DWORD error)
{
597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616
    HttpProtocol *This = impl_from_Protocol(prot);
    HRESULT hres;

    TRACE("(%p) %d\n", prot, error);

    if(prot->flags & FLAG_FIRST_CONTINUE_COMPLETE) {
        FIXME("Not handling error %d\n", error);
        return;
    }

    while((hres = handle_http_error(This, error)) == RPC_E_RETRY) {
        error = send_http_request(This);

        if(error == ERROR_IO_PENDING || error == ERROR_SUCCESS)
            return;
    }

    protocol_abort(prot, hres);
    protocol_close_connection(prot);
    return;
617 618
}

619
static const ProtocolVtbl AsyncProtocolVtbl = {
620
    HttpProtocol_open_request,
621
    HttpProtocol_end_request,
622
    HttpProtocol_start_downloading,
623 624
    HttpProtocol_close_connection,
    HttpProtocol_on_error
625
};
626

627
static HRESULT WINAPI HttpProtocol_QueryInterface(IInternetProtocolEx *iface, REFIID riid, void **ppv)
628
{
629
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
630 631 632 633

    *ppv = NULL;
    if(IsEqualGUID(&IID_IUnknown, riid)) {
        TRACE("(%p)->(IID_IUnknown %p)\n", This, ppv);
634
        *ppv = &This->IInternetProtocolEx_iface;
635 636
    }else if(IsEqualGUID(&IID_IInternetProtocolRoot, riid)) {
        TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", This, ppv);
637
        *ppv = &This->IInternetProtocolEx_iface;
638 639
    }else if(IsEqualGUID(&IID_IInternetProtocol, riid)) {
        TRACE("(%p)->(IID_IInternetProtocol %p)\n", This, ppv);
640
        *ppv = &This->IInternetProtocolEx_iface;
641 642
    }else if(IsEqualGUID(&IID_IInternetProtocolEx, riid)) {
        TRACE("(%p)->(IID_IInternetProtocolEx %p)\n", This, ppv);
643
        *ppv = &This->IInternetProtocolEx_iface;
644 645
    }else if(IsEqualGUID(&IID_IInternetPriority, riid)) {
        TRACE("(%p)->(IID_IInternetPriority %p)\n", This, ppv);
646
        *ppv = &This->IInternetPriority_iface;
647 648
    }else if(IsEqualGUID(&IID_IWinInetInfo, riid)) {
        TRACE("(%p)->(IID_IWinInetInfo %p)\n", This, ppv);
649
        *ppv = &This->IWinInetHttpInfo_iface;
650 651
    }else if(IsEqualGUID(&IID_IWinInetHttpInfo, riid)) {
        TRACE("(%p)->(IID_IWinInetHttpInfo %p)\n", This, ppv);
652
        *ppv = &This->IWinInetHttpInfo_iface;
653 654 655
    }

    if(*ppv) {
656
        IInternetProtocolEx_AddRef(iface);
657 658 659 660 661 662 663
        return S_OK;
    }

    WARN("not supported interface %s\n", debugstr_guid(riid));
    return E_NOINTERFACE;
}

664
static ULONG WINAPI HttpProtocol_AddRef(IInternetProtocolEx *iface)
665
{
666
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
667
    LONG ref = InterlockedIncrement(&This->ref);
668
    TRACE("(%p) ref=%d\n", This, ref);
669 670 671
    return ref;
}

672
static ULONG WINAPI HttpProtocol_Release(IInternetProtocolEx *iface)
673
{
674
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
675 676
    LONG ref = InterlockedDecrement(&This->ref);

677
    TRACE("(%p) ref=%d\n", This, ref);
678 679

    if(!ref) {
680
        protocol_close_connection(&This->base);
681
        heap_free(This);
682 683 684 685 686 687 688

        URLMON_UnlockModule();
    }

    return ref;
}

689
static HRESULT WINAPI HttpProtocol_Start(IInternetProtocolEx *iface, LPCWSTR szUrl,
690
        IInternetProtocolSink *pOIProtSink, IInternetBindInfo *pOIBindInfo,
691
        DWORD grfPI, HANDLE_PTR dwReserved)
692
{
693
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
694 695
    IUri *uri;
    HRESULT hres;
696

697
    TRACE("(%p)->(%s %p %p %08x %lx)\n", This, debugstr_w(szUrl), pOIProtSink,
698
            pOIBindInfo, grfPI, dwReserved);
699

700 701 702 703
    hres = CreateUri(szUrl, 0, 0, &uri);
    if(FAILED(hres))
        return hres;

704 705
    hres = IInternetProtocolEx_StartEx(&This->IInternetProtocolEx_iface, uri, pOIProtSink,
            pOIBindInfo, grfPI, (HANDLE*)dwReserved);
706 707 708

    IUri_Release(uri);
    return hres;
709 710
}

711
static HRESULT WINAPI HttpProtocol_Continue(IInternetProtocolEx *iface, PROTOCOLDATA *pProtocolData)
712
{
713
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
714 715 716

    TRACE("(%p)->(%p)\n", This, pProtocolData);

717
    return protocol_continue(&This->base, pProtocolData);
718 719
}

720
static HRESULT WINAPI HttpProtocol_Abort(IInternetProtocolEx *iface, HRESULT hrReason,
721 722
        DWORD dwOptions)
{
723
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
724 725 726 727

    TRACE("(%p)->(%08x %08x)\n", This, hrReason, dwOptions);

    return protocol_abort(&This->base, hrReason);
728 729
}

730
static HRESULT WINAPI HttpProtocol_Terminate(IInternetProtocolEx *iface, DWORD dwOptions)
731
{
732
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
733 734 735

    TRACE("(%p)->(%08x)\n", This, dwOptions);

736
    protocol_close_connection(&This->base);
737
    return S_OK;
738 739
}

740
static HRESULT WINAPI HttpProtocol_Suspend(IInternetProtocolEx *iface)
741
{
742
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
743 744 745 746
    FIXME("(%p)\n", This);
    return E_NOTIMPL;
}

747
static HRESULT WINAPI HttpProtocol_Resume(IInternetProtocolEx *iface)
748
{
749
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
750 751 752 753
    FIXME("(%p)\n", This);
    return E_NOTIMPL;
}

754
static HRESULT WINAPI HttpProtocol_Read(IInternetProtocolEx *iface, void *pv,
755 756
        ULONG cb, ULONG *pcbRead)
{
757
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
758 759 760

    TRACE("(%p)->(%p %u %p)\n", This, pv, cb, pcbRead);

761
    return protocol_read(&This->base, pv, cb, pcbRead);
762 763
}

764
static HRESULT WINAPI HttpProtocol_Seek(IInternetProtocolEx *iface, LARGE_INTEGER dlibMove,
765
        DWORD dwOrigin, ULARGE_INTEGER *plibNewPosition)
766
{
767
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
768
    FIXME("(%p)->(%d %d %p)\n", This, dlibMove.u.LowPart, dwOrigin, plibNewPosition);
769 770 771
    return E_NOTIMPL;
}

772
static HRESULT WINAPI HttpProtocol_LockRequest(IInternetProtocolEx *iface, DWORD dwOptions)
773
{
774
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
775 776 777

    TRACE("(%p)->(%08x)\n", This, dwOptions);

778
    return protocol_lock_request(&This->base);
779 780
}

781
static HRESULT WINAPI HttpProtocol_UnlockRequest(IInternetProtocolEx *iface)
782
{
783
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
784 785 786

    TRACE("(%p)\n", This);

787
    return protocol_unlock_request(&This->base);
788 789
}

790 791 792 793
static HRESULT WINAPI HttpProtocol_StartEx(IInternetProtocolEx *iface, IUri *pUri,
        IInternetProtocolSink *pOIProtSink, IInternetBindInfo *pOIBindInfo,
        DWORD grfPI, HANDLE *dwReserved)
{
794
    HttpProtocol *This = impl_from_IInternetProtocolEx(iface);
795 796 797 798 799 800 801 802 803 804 805 806
    DWORD scheme = 0;
    HRESULT hres;

    TRACE("(%p)->(%p %p %p %08x %p)\n", This, pUri, pOIProtSink,
            pOIBindInfo, grfPI, dwReserved);

    hres = IUri_GetScheme(pUri, &scheme);
    if(FAILED(hres))
        return hres;
    if(scheme != (This->https ? URL_SCHEME_HTTPS : URL_SCHEME_HTTP))
        return MK_E_SYNTAX;

807 808
    return protocol_start(&This->base, (IInternetProtocol*)&This->IInternetProtocolEx_iface, pUri,
                          pOIProtSink, pOIBindInfo);
809 810 811
}

static const IInternetProtocolExVtbl HttpProtocolVtbl = {
812 813 814 815 816 817 818 819 820 821 822 823
    HttpProtocol_QueryInterface,
    HttpProtocol_AddRef,
    HttpProtocol_Release,
    HttpProtocol_Start,
    HttpProtocol_Continue,
    HttpProtocol_Abort,
    HttpProtocol_Terminate,
    HttpProtocol_Suspend,
    HttpProtocol_Resume,
    HttpProtocol_Read,
    HttpProtocol_Seek,
    HttpProtocol_LockRequest,
824 825
    HttpProtocol_UnlockRequest,
    HttpProtocol_StartEx
826 827
};

828 829
static HRESULT WINAPI HttpPriority_QueryInterface(IInternetPriority *iface, REFIID riid, void **ppv)
{
830 831
    HttpProtocol *This = impl_from_IInternetPriority(iface);
    return IInternetProtocolEx_QueryInterface(&This->IInternetProtocolEx_iface, riid, ppv);
832 833 834 835
}

static ULONG WINAPI HttpPriority_AddRef(IInternetPriority *iface)
{
836 837
    HttpProtocol *This = impl_from_IInternetPriority(iface);
    return IInternetProtocolEx_AddRef(&This->IInternetProtocolEx_iface);
838 839 840 841
}

static ULONG WINAPI HttpPriority_Release(IInternetPriority *iface)
{
842 843
    HttpProtocol *This = impl_from_IInternetPriority(iface);
    return IInternetProtocolEx_Release(&This->IInternetProtocolEx_iface);
844 845 846 847
}

static HRESULT WINAPI HttpPriority_SetPriority(IInternetPriority *iface, LONG nPriority)
{
848
    HttpProtocol *This = impl_from_IInternetPriority(iface);
849

850
    TRACE("(%p)->(%d)\n", This, nPriority);
851

852
    This->base.priority = nPriority;
853 854 855 856 857
    return S_OK;
}

static HRESULT WINAPI HttpPriority_GetPriority(IInternetPriority *iface, LONG *pnPriority)
{
858
    HttpProtocol *This = impl_from_IInternetPriority(iface);
859 860 861

    TRACE("(%p)->(%p)\n", This, pnPriority);

862
    *pnPriority = This->base.priority;
863 864 865 866 867 868 869 870 871 872 873
    return S_OK;
}

static const IInternetPriorityVtbl HttpPriorityVtbl = {
    HttpPriority_QueryInterface,
    HttpPriority_AddRef,
    HttpPriority_Release,
    HttpPriority_SetPriority,
    HttpPriority_GetPriority
};

874 875
static HRESULT WINAPI HttpInfo_QueryInterface(IWinInetHttpInfo *iface, REFIID riid, void **ppv)
{
876 877
    HttpProtocol *This = impl_from_IWinInetHttpInfo(iface);
    return IInternetProtocolEx_QueryInterface(&This->IInternetProtocolEx_iface, riid, ppv);
878 879 880 881
}

static ULONG WINAPI HttpInfo_AddRef(IWinInetHttpInfo *iface)
{
882 883
    HttpProtocol *This = impl_from_IWinInetHttpInfo(iface);
    return IInternetProtocolEx_AddRef(&This->IInternetProtocolEx_iface);
884 885 886 887
}

static ULONG WINAPI HttpInfo_Release(IWinInetHttpInfo *iface)
{
888 889
    HttpProtocol *This = impl_from_IWinInetHttpInfo(iface);
    return IInternetProtocolEx_Release(&This->IInternetProtocolEx_iface);
890 891 892 893 894
}

static HRESULT WINAPI HttpInfo_QueryOption(IWinInetHttpInfo *iface, DWORD dwOption,
        void *pBuffer, DWORD *pcbBuffer)
{
895
    HttpProtocol *This = impl_from_IWinInetHttpInfo(iface);
896 897 898 899 900 901 902 903
    TRACE("(%p)->(%x %p %p)\n", This, dwOption, pBuffer, pcbBuffer);

    if(!This->base.request)
        return E_FAIL;

    if(!InternetQueryOptionW(This->base.request, dwOption, pBuffer, pcbBuffer))
        return S_FALSE;
    return S_OK;
904 905 906 907 908
}

static HRESULT WINAPI HttpInfo_QueryInfo(IWinInetHttpInfo *iface, DWORD dwOption,
        void *pBuffer, DWORD *pcbBuffer, DWORD *pdwFlags, DWORD *pdwReserved)
{
909
    HttpProtocol *This = impl_from_IWinInetHttpInfo(iface);
910 911 912 913 914 915 916 917 918 919 920
    TRACE("(%p)->(%x %p %p %p %p)\n", This, dwOption, pBuffer, pcbBuffer, pdwFlags, pdwReserved);

    if(!This->base.request)
        return E_FAIL;

    if(!HttpQueryInfoW(This->base.request, dwOption, pBuffer, pcbBuffer, pdwFlags)) {
        if(pBuffer)
            memset(pBuffer, 0, *pcbBuffer);
        return S_OK;
    }
    return S_OK;
921 922 923 924 925 926 927 928 929 930
}

static const IWinInetHttpInfoVtbl WinInetHttpInfoVtbl = {
    HttpInfo_QueryInterface,
    HttpInfo_AddRef,
    HttpInfo_Release,
    HttpInfo_QueryOption,
    HttpInfo_QueryInfo
};

931
static HRESULT create_http_protocol(BOOL https, void **ppobj)
932 933 934
{
    HttpProtocol *ret;

935 936 937
    ret = heap_alloc_zero(sizeof(HttpProtocol));
    if(!ret)
        return E_OUTOFMEMORY;
938

939
    ret->base.vtbl = &AsyncProtocolVtbl;
940 941 942
    ret->IInternetProtocolEx_iface.lpVtbl = &HttpProtocolVtbl;
    ret->IInternetPriority_iface.lpVtbl   = &HttpPriorityVtbl;
    ret->IWinInetHttpInfo_iface.lpVtbl    = &WinInetHttpInfoVtbl;
943 944

    ret->https = https;
945
    ret->ref = 1;
946

947 948
    *ppobj = &ret->IInternetProtocolEx_iface;

949
    URLMON_LockModule();
950 951
    return S_OK;
}
952

953 954 955 956 957 958 959
HRESULT HttpProtocol_Construct(IUnknown *pUnkOuter, LPVOID *ppobj)
{
    TRACE("(%p %p)\n", pUnkOuter, ppobj);

    return create_http_protocol(FALSE, ppobj);
}

960 961
HRESULT HttpSProtocol_Construct(IUnknown *pUnkOuter, LPVOID *ppobj)
{
962 963 964
    TRACE("(%p %p)\n", pUnkOuter, ppobj);

    return create_http_protocol(TRUE, ppobj);
965
}