Commit af045a52 authored by Alexandre Julliard's avatar Alexandre Julliard

ntdll: Add helper functions for verifying write access to a memory range.

parent 8bc95aa7
......@@ -838,9 +838,10 @@ static inline int mprotect_exec( void *base, size_t size, int unix_prot )
static void mprotect_range( void *base, size_t size, BYTE set, BYTE clear )
{
size_t i, count;
char *addr = base;
char *addr = ROUND_ADDR( base, page_mask );
int prot, next;
size = ROUND_SIZE( base, size );
prot = VIRTUAL_GetUnixProt( (get_page_vprot( addr ) & ~clear ) | set );
for (count = i = 1; i < size >> page_shift; i++, count++)
{
......@@ -925,6 +926,19 @@ static NTSTATUS set_protection( struct file_view *view, void *base, SIZE_T size,
/***********************************************************************
* update_write_watches
*/
static void update_write_watches( void *base, size_t size, size_t accessed_size )
{
TRACE( "updating watch %p-%p-%p\n", base, (char *)base + accessed_size, (char *)base + size );
/* clear write watch flag on accessed pages */
set_page_vprot_bits( base, accessed_size, 0, VPROT_WRITEWATCH );
/* restore page protections on the entire range */
mprotect_range( base, size, 0, 0 );
}
/***********************************************************************
* reset_write_watches
*
* Reset write watches in a memory range.
......@@ -1800,6 +1814,30 @@ NTSTATUS virtual_handle_fault( LPCVOID addr, DWORD err, BOOL on_signal_stack )
}
/***********************************************************************
* check_write_access
*
* Check if the memory range is writable, temporarily disabling write watches if necessary.
*/
static NTSTATUS check_write_access( void *base, size_t size, BOOL *has_write_watch )
{
size_t i;
char *addr = ROUND_ADDR( base, page_mask );
size = ROUND_SIZE( base, size );
for (i = 0; i < size; i += page_size)
{
BYTE vprot = get_page_vprot( addr + i );
if (vprot & VPROT_WRITEWATCH) *has_write_watch = TRUE;
if (!(VIRTUAL_GetUnixProt( vprot & ~VPROT_WRITEWATCH ) & PROT_WRITE))
return STATUS_INVALID_USER_BUFFER;
}
if (*has_write_watch)
mprotect_range( addr, size, 0, VPROT_WRITEWATCH ); /* temporarily enable write access */
return STATUS_SUCCESS;
}
/***********************************************************************
* virtual_is_valid_code_address
......@@ -1963,32 +2001,18 @@ SIZE_T virtual_uninterrupted_read_memory( const void *addr, void *buffer, SIZE_T
*/
NTSTATUS virtual_uninterrupted_write_memory( void *addr, const void *buffer, SIZE_T size )
{
struct file_view *view;
BOOL has_write_watch = FALSE;
sigset_t sigset;
NTSTATUS ret = STATUS_ACCESS_VIOLATION;
NTSTATUS ret;
if (!size) return STATUS_SUCCESS;
server_enter_uninterrupted_section( &csVirtual, &sigset );
if ((view = VIRTUAL_FindView( addr, size )) && !(view->protect & VPROT_SYSTEM))
if (!(ret = check_write_access( addr, size, &has_write_watch )))
{
char *page = ROUND_ADDR( addr, page_mask );
size_t i, total = ROUND_SIZE( addr, size );
for (i = 0; i < total; i += page_size)
{
int prot = VIRTUAL_GetUnixProt( get_page_vprot( page + i ) & ~VPROT_WRITEWATCH );
if (!(prot & PROT_WRITE)) goto done;
}
if (view->protect & VPROT_WRITEWATCH) /* enable write access by clearing write watches */
{
set_page_vprot_bits( addr, size, 0, VPROT_WRITEWATCH );
mprotect_range( addr, size, 0, 0 );
}
memcpy( addr, buffer, size );
ret = STATUS_SUCCESS;
if (has_write_watch) update_write_watches( addr, size, size );
}
done:
server_leave_uninterrupted_section( &csVirtual, &sigset );
return ret;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment