Commit abd11749 authored by Bernhard Loos's avatar Bernhard Loos Committed by Alexandre Julliard

msi: Protected primary keys against modification.

parent a52c2bf9
......@@ -809,6 +809,7 @@ extern UINT MSI_RecordSetStreamFromFileW( MSIRECORD *, UINT, LPCWSTR ) DECLSPEC_
extern UINT MSI_RecordCopyField( MSIRECORD *, UINT, MSIRECORD *, UINT ) DECLSPEC_HIDDEN;
extern MSIRECORD *MSI_CloneRecord( MSIRECORD * ) DECLSPEC_HIDDEN;
extern BOOL MSI_RecordsAreEqual( MSIRECORD *, MSIRECORD * ) DECLSPEC_HIDDEN;
extern BOOL MSI_RecordsAreFieldsEqual(MSIRECORD *a, MSIRECORD *b, UINT field) DECLSPEC_HIDDEN;
/* stream internals */
extern void enum_stream_names( IStorage *stg ) DECLSPEC_HIDDEN;
......
......@@ -994,6 +994,34 @@ MSIRECORD *MSI_CloneRecord(MSIRECORD *rec)
return clone;
}
BOOL MSI_RecordsAreFieldsEqual(MSIRECORD *a, MSIRECORD *b, UINT field)
{
if (a->fields[field].type != b->fields[field].type)
return FALSE;
switch (a->fields[field].type)
{
case MSIFIELD_NULL:
break;
case MSIFIELD_INT:
if (a->fields[field].u.iVal != b->fields[field].u.iVal)
return FALSE;
break;
case MSIFIELD_WSTR:
if (strcmpW(a->fields[field].u.szwVal, b->fields[field].u.szwVal))
return FALSE;
break;
case MSIFIELD_STREAM:
default:
return FALSE;
}
return TRUE;
}
BOOL MSI_RecordsAreEqual(MSIRECORD *a, MSIRECORD *b)
{
UINT i;
......@@ -1003,28 +1031,8 @@ BOOL MSI_RecordsAreEqual(MSIRECORD *a, MSIRECORD *b)
for (i = 0; i <= a->count; i++)
{
if (a->fields[i].type != b->fields[i].type)
if (!MSI_RecordsAreFieldsEqual( a, b, i ))
return FALSE;
switch (a->fields[i].type)
{
case MSIFIELD_NULL:
break;
case MSIFIELD_INT:
if (a->fields[i].u.iVal != b->fields[i].u.iVal)
return FALSE;
break;
case MSIFIELD_WSTR:
if (strcmpW(a->fields[i].u.szwVal, b->fields[i].u.szwVal))
return FALSE;
break;
case MSIFIELD_STREAM:
default:
return FALSE;
}
}
return TRUE;
......
......@@ -262,9 +262,10 @@ static UINT WHERE_get_row( struct tagMSIVIEW *view, UINT row, MSIRECORD **rec )
static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask )
{
MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view;
UINT r, offset = 0, reduced_mask = 0;
UINT i, r, offset = 0;
JOINTABLE *table = wv->tables;
UINT *rows;
UINT mask_copy = mask;
TRACE("%p %d %p %08x\n", wv, row, rec, mask );
......@@ -275,28 +276,54 @@ static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UI
if (r != ERROR_SUCCESS)
return r;
if(mask >= 1 << wv->col_count)
if (mask >= 1 << wv->col_count)
return ERROR_INVALID_PARAMETER;
do
{
for (i = 0; i < table->col_count; i++) {
UINT type;
if (!(mask_copy & (1 << i)))
continue;
r = table->view->ops->get_column_info(table->view, i + 1, NULL,
&type, NULL, NULL );
if (r != ERROR_SUCCESS)
return r;
if (type & MSITYPE_KEY)
return ERROR_FUNCTION_FAILED;
}
mask_copy >>= table->col_count;
}
while (mask_copy && (table = table->next));
table = wv->tables;
do
{
const UINT col_count = table->col_count;
UINT i;
MSIRECORD *reduced;
UINT reduced_mask = (mask >> offset) & ((1 << col_count) - 1);
if (!reduced_mask)
{
offset += col_count;
continue;
}
reduced = MSI_CreateRecord(col_count);
if (!reduced)
return ERROR_FUNCTION_FAILED;
for (i = 0; i < col_count; i++)
for (i = 1; i <= col_count; i++)
{
r = MSI_RecordCopyField(rec, i + offset + 1, reduced, i + 1);
r = MSI_RecordCopyField(rec, i + offset, reduced, i);
if (r != ERROR_SUCCESS)
break;
}
offset += col_count;
reduced_mask = mask >> (wv->col_count - offset) & ((1 << col_count) - 1);
if (r == ERROR_SUCCESS)
r = table->view->ops->set_row(table->view, rows[table->table_index], reduced, reduced_mask);
......@@ -644,13 +671,28 @@ static UINT join_find_row( MSIWHEREVIEW *wv, MSIRECORD *rec, UINT *row )
static UINT join_modify_update( struct tagMSIVIEW *view, MSIRECORD *rec )
{
MSIWHEREVIEW *wv = (MSIWHEREVIEW *)view;
UINT r, row;
UINT r, row, i, mask = 0;
MSIRECORD *current;
r = join_find_row( wv, rec, &row );
if (r != ERROR_SUCCESS)
return r;
return WHERE_set_row( view, row, rec, (1 << wv->col_count) - 1 );
r = msi_view_get_row( wv->db, view, row, &current );
if (r != ERROR_SUCCESS)
return r;
assert(MSI_RecordGetFieldCount(rec) == MSI_RecordGetFieldCount(current));
for (i = MSI_RecordGetFieldCount(rec); i > 0; i--)
{
if (!MSI_RecordsAreFieldsEqual(rec, current, i))
mask |= 1 << (i - 1);
}
msiobj_release(&current->hdr);
return WHERE_set_row( view, row, rec, mask );
}
static UINT WHERE_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment