Commit 0b69c706 authored by Hans Leidekker's avatar Hans Leidekker Committed by Alexandre Julliard

wininet: Reuse cached basic authorization across sessions.

parent f7e09276
......@@ -863,7 +863,7 @@ static void destroy_authinfo( struct HttpAuthInfo *authinfo )
heap_free(authinfo);
}
static UINT retrieve_cached_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR *auth_data)
static UINT retrieve_cached_basic_authorization(const WCHAR *host, const WCHAR *realm, char **auth_data)
{
basicAuthorizationData *ad;
UINT rc = 0;
......@@ -873,7 +873,7 @@ static UINT retrieve_cached_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR
EnterCriticalSection(&authcache_cs);
LIST_FOR_EACH_ENTRY(ad, &basicAuthorizationCache, basicAuthorizationData, entry)
{
if (!strcmpiW(host,ad->host) && !strcmpW(realm,ad->realm))
if (!strcmpiW(host, ad->host) && (!realm || !strcmpW(realm, ad->realm)))
{
TRACE("Authorization found in cache\n");
*auth_data = heap_alloc(ad->authorizationLen);
......@@ -1620,6 +1620,21 @@ static UINT HTTP_DecodeBase64( LPCWSTR base64, LPSTR bin )
return n;
}
static WCHAR *encode_auth_data( const WCHAR *scheme, const char *data, UINT data_len )
{
WCHAR *ret;
UINT len, scheme_len = strlenW( scheme );
/* scheme + space + base64 encoded data (3/2/1 bytes data -> 4 bytes of characters) */
len = scheme_len + 1 + ((data_len + 2) * 4) / 3;
if (!(ret = heap_alloc( (len + 1) * sizeof(WCHAR) ))) return NULL;
memcpy( ret, scheme, scheme_len * sizeof(WCHAR) );
ret[scheme_len] = ' ';
HTTP_EncodeBase64( data, data_len, ret + scheme_len + 1 );
return ret;
}
/***********************************************************************
* HTTP_InsertAuthorization
*
......@@ -1627,27 +1642,16 @@ static UINT HTTP_DecodeBase64( LPCWSTR base64, LPSTR bin )
*/
static BOOL HTTP_InsertAuthorization( http_request_t *request, struct HttpAuthInfo *pAuthInfo, LPCWSTR header )
{
if (pAuthInfo)
{
static const WCHAR wszSpace[] = {' ',0};
static const WCHAR wszBasic[] = {'B','a','s','i','c',0};
unsigned int len;
WCHAR *authorization = NULL;
WCHAR *host, *authorization = NULL;
if (pAuthInfo)
{
if (pAuthInfo->auth_data_len)
{
/* scheme + space + base64 encoded data (3/2/1 bytes data -> 4 bytes of characters) */
len = strlenW(pAuthInfo->scheme)+1+((pAuthInfo->auth_data_len+2)*4)/3;
authorization = heap_alloc((len+1)*sizeof(WCHAR));
if (!authorization)
if (!(authorization = encode_auth_data(pAuthInfo->scheme, pAuthInfo->auth_data, pAuthInfo->auth_data_len)))
return FALSE;
strcpyW(authorization, pAuthInfo->scheme);
strcatW(authorization, wszSpace);
HTTP_EncodeBase64(pAuthInfo->auth_data,
pAuthInfo->auth_data_len,
authorization+strlenW(authorization));
/* clear the data as it isn't valid now that it has been sent to the
* server, unless it's Basic authentication which doesn't do
* connection tracking */
......@@ -1664,6 +1668,30 @@ static BOOL HTTP_InsertAuthorization( http_request_t *request, struct HttpAuthIn
HTTP_ProcessHeader(request, header, authorization, HTTP_ADDHDR_FLAG_REQ | HTTP_ADDHDR_FLAG_REPLACE);
heap_free(authorization);
}
else if (!strcmpW(header, szAuthorization) && (host = get_host_header(request)))
{
UINT data_len;
char *data;
if ((data_len = retrieve_cached_basic_authorization(host, NULL, &data)))
{
TRACE("Found cached basic authorization for %s\n", debugstr_w(host));
if (!(authorization = encode_auth_data(wszBasic, data, data_len)))
{
heap_free(data);
heap_free(host);
return FALSE;
}
TRACE("Inserting authorization: %s\n", debugstr_w(authorization));
HTTP_ProcessHeader(request, header, authorization, HTTP_ADDHDR_FLAG_REQ | HTTP_ADDHDR_FLAG_REPLACE);
heap_free(data);
heap_free(authorization);
}
heap_free(host);
}
return TRUE;
}
......
......@@ -2321,6 +2321,20 @@ static DWORD CALLBACK server_thread(LPVOID param)
else
send(c, notokmsg, sizeof notokmsg-1, 0);
}
if (strstr(buffer, "HEAD /upload.txt"))
{
if (strstr(buffer, "Authorization: Basic dXNlcjpwd2Q="))
send(c, okmsg, sizeof okmsg-1, 0);
else
send(c, noauthmsg, sizeof noauthmsg-1, 0);
}
if (strstr(buffer, "PUT /upload2.txt"))
{
if (strstr(buffer, "Authorization: Basic dXNlcjpwd2Q="))
send(c, okmsg, sizeof okmsg-1, 0);
else
send(c, notokmsg, sizeof notokmsg-1, 0);
}
shutdown(c, 2);
closesocket(c);
c = -1;
......@@ -4212,6 +4226,59 @@ static void test_accept_encoding(int port)
InternetCloseHandle(ses);
}
static void test_basic_auth_credentials_reuse(int port)
{
HINTERNET ses, con, req;
DWORD status, size;
BOOL ret;
ses = InternetOpenA( "winetest", 0, NULL, NULL, 0 );
ok( ses != NULL, "InternetOpenA failed\n" );
con = InternetConnectA( ses, "localhost", port, "user", "pwd",
INTERNET_SERVICE_HTTP, 0, 0 );
ok( con != NULL, "InternetConnectA failed %u\n", GetLastError() );
req = HttpOpenRequestA( con, "HEAD", "/upload.txt", NULL, NULL, NULL, 0, 0 );
ok( req != NULL, "HttpOpenRequestA failed %u\n", GetLastError() );
ret = HttpSendRequestA( req, NULL, 0, NULL, 0 );
ok( ret, "HttpSendRequestA failed %u\n", GetLastError() );
status = 0xdeadbeef;
size = sizeof(status);
ret = HttpQueryInfoA( req, HTTP_QUERY_STATUS_CODE|HTTP_QUERY_FLAG_NUMBER, &status, &size, NULL );
ok( ret, "HttpQueryInfoA failed %u\n", GetLastError() );
ok( status == 200, "got %u\n", status );
InternetCloseHandle( req );
InternetCloseHandle( con );
InternetCloseHandle( ses );
ses = InternetOpenA( "winetest", 0, NULL, NULL, 0 );
ok( ses != NULL, "InternetOpenA failed\n" );
con = InternetConnectA( ses, "localhost", port, NULL, NULL,
INTERNET_SERVICE_HTTP, 0, 0 );
ok( con != NULL, "InternetConnectA failed %u\n", GetLastError() );
req = HttpOpenRequestA( con, "PUT", "/upload2.txt", NULL, NULL, NULL, 0, 0 );
ok( req != NULL, "HttpOpenRequestA failed %u\n", GetLastError() );
ret = HttpSendRequestA( req, NULL, 0, NULL, 0 );
ok( ret, "HttpSendRequestA failed %u\n", GetLastError() );
status = 0xdeadbeef;
size = sizeof(status);
ret = HttpQueryInfoA( req, HTTP_QUERY_STATUS_CODE|HTTP_QUERY_FLAG_NUMBER, &status, &size, NULL );
ok( ret, "HttpQueryInfoA failed %u\n", GetLastError() );
ok( status == 200, "got %u\n", status );
InternetCloseHandle( req );
InternetCloseHandle( con );
InternetCloseHandle( ses );
}
static void test_http_connection(void)
{
struct server_info si;
......@@ -4259,6 +4326,7 @@ static void test_http_connection(void)
test_head_request(si.port);
test_request_content_length(si.port);
test_accept_encoding(si.port);
test_basic_auth_credentials_reuse(si.port);
/* send the basic request again to shutdown the server thread */
test_basic_request(si.port, "GET", "/quit");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment