Commit 40986cd7 authored by Huw Davies's avatar Huw Davies Committed by Alexandre Julliard

oleaut32: Correctly handle the case when the number of bytes in a BSTR is odd.

parent 86d9457c
...@@ -227,20 +227,20 @@ static void test_marshal_LPSAFEARRAY(void) ...@@ -227,20 +227,20 @@ static void test_marshal_LPSAFEARRAY(void)
static void check_bstr(void *buffer, BSTR b) static void check_bstr(void *buffer, BSTR b)
{ {
DWORD *wireb = buffer; DWORD *wireb = buffer;
DWORD len = SysStringLen(b); DWORD len = SysStringByteLen(b);
ok(*wireb == len, "wv[0] %08lx\n", *wireb); ok(*wireb == (len + 1) / 2, "wv[0] %08lx\n", *wireb);
wireb++; wireb++;
if(len) if(b)
ok(*wireb == len * 2, "wv[1] %08lx\n", *wireb); ok(*wireb == len, "wv[1] %08lx\n", *wireb);
else else
ok(*wireb == 0xffffffff, "wv[1] %08lx\n", *wireb); ok(*wireb == 0xffffffff, "wv[1] %08lx\n", *wireb);
wireb++; wireb++;
ok(*wireb == len, "wv[2] %08lx\n", *wireb); ok(*wireb == (len + 1) / 2, "wv[2] %08lx\n", *wireb);
if(len) if(len)
{ {
wireb++; wireb++;
ok(!memcmp(wireb, b, len * 2), "strings differ\n"); ok(!memcmp(wireb, b, (len + 1) & ~1), "strings differ\n");
} }
return; return;
} }
...@@ -250,7 +250,7 @@ static void test_marshal_BSTR(void) ...@@ -250,7 +250,7 @@ static void test_marshal_BSTR(void)
unsigned long size; unsigned long size;
MIDL_STUB_MESSAGE stubMsg = { 0 }; MIDL_STUB_MESSAGE stubMsg = { 0 };
USER_MARSHAL_CB umcb = { 0 }; USER_MARSHAL_CB umcb = { 0 };
unsigned char *buffer; unsigned char *buffer, *next;
BSTR b, b2; BSTR b, b2;
WCHAR str[] = {'m','a','r','s','h','a','l',' ','t','e','s','t','1',0}; WCHAR str[] = {'m','a','r','s','h','a','l',' ','t','e','s','t','1',0};
DWORD len; DWORD len;
...@@ -271,14 +271,16 @@ static void test_marshal_BSTR(void) ...@@ -271,14 +271,16 @@ static void test_marshal_BSTR(void)
ok(size == 38, "size %ld\n", size); ok(size == 38, "size %ld\n", size);
buffer = HeapAlloc(GetProcessHeap(), 0, size); buffer = HeapAlloc(GetProcessHeap(), 0, size);
BSTR_UserMarshal(&umcb.Flags, buffer, &b); next = BSTR_UserMarshal(&umcb.Flags, buffer, &b);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
check_bstr(buffer, b); check_bstr(buffer, b);
if (BSTR_UNMARSHAL_WORKS) if (BSTR_UNMARSHAL_WORKS)
{ {
b2 = NULL; b2 = NULL;
BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2); next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2);
ok(b2 != NULL, "NULL LPSAFEARRAY didn't unmarshal\n"); ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
ok(b2 != NULL, "BSTR didn't unmarshal\n");
ok(!memcmp(b, b2, (len + 1) * 2), "strings differ\n"); ok(!memcmp(b, b2, (len + 1) * 2), "strings differ\n");
BSTR_UserFree(&umcb.Flags, &b2); BSTR_UserFree(&umcb.Flags, &b2);
} }
...@@ -291,11 +293,75 @@ static void test_marshal_BSTR(void) ...@@ -291,11 +293,75 @@ static void test_marshal_BSTR(void)
ok(size == 12, "size %ld\n", size); ok(size == 12, "size %ld\n", size);
buffer = HeapAlloc(GetProcessHeap(), 0, size); buffer = HeapAlloc(GetProcessHeap(), 0, size);
BSTR_UserMarshal(&umcb.Flags, buffer, &b); next = BSTR_UserMarshal(&umcb.Flags, buffer, &b);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
check_bstr(buffer, b); check_bstr(buffer, b);
if (BSTR_UNMARSHAL_WORKS)
{
b2 = NULL;
next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
ok(b2 == NULL, "NULL BSTR didn't unmarshal\n");
BSTR_UserFree(&umcb.Flags, &b2);
}
HeapFree(GetProcessHeap(), 0, buffer); HeapFree(GetProcessHeap(), 0, buffer);
b = SysAllocStringByteLen("abc", 3);
*(((char*)b) + 3) = 'd';
len = SysStringLen(b);
ok(len == 1, "get %ld\n", len);
len = SysStringByteLen(b);
ok(len == 3, "get %ld\n", len);
size = BSTR_UserSize(&umcb.Flags, 0, &b);
ok(size == 16, "size %ld\n", size);
buffer = HeapAlloc(GetProcessHeap(), 0, size);
memset(buffer, 0xcc, size);
next = BSTR_UserMarshal(&umcb.Flags, buffer, &b);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
check_bstr(buffer, b);
ok(buffer[15] == 'd', "buffer[15] %02x\n", buffer[15]);
if (BSTR_UNMARSHAL_WORKS)
{
b2 = NULL;
next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
ok(b2 != NULL, "BSTR didn't unmarshal\n");
ok(!memcmp(b, b2, len), "strings differ\n");
BSTR_UserFree(&umcb.Flags, &b2);
}
HeapFree(GetProcessHeap(), 0, buffer);
SysFreeString(b);
b = SysAllocStringByteLen("", 0);
len = SysStringLen(b);
ok(len == 0, "get %ld\n", len);
len = SysStringByteLen(b);
ok(len == 0, "get %ld\n", len);
size = BSTR_UserSize(&umcb.Flags, 0, &b);
ok(size == 12, "size %ld\n", size);
buffer = HeapAlloc(GetProcessHeap(), 0, size);
next = BSTR_UserMarshal(&umcb.Flags, buffer, &b);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
check_bstr(buffer, b);
if (BSTR_UNMARSHAL_WORKS)
{
b2 = NULL;
next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
ok(b2 != NULL, "NULL LPSAFEARRAY didn't unmarshal\n");
len = SysStringByteLen(b2);
ok(len == 0, "byte len %ld\n", len);
BSTR_UserFree(&umcb.Flags, &b2);
}
HeapFree(GetProcessHeap(), 0, buffer);
SysFreeString(b);
} }
static void check_variant_header(DWORD *wirev, VARIANT *v, unsigned long size) static void check_variant_header(DWORD *wirev, VARIANT *v, unsigned long size)
......
...@@ -152,7 +152,7 @@ unsigned long WINAPI BSTR_UserSize(unsigned long *pFlags, unsigned long Start, B ...@@ -152,7 +152,7 @@ unsigned long WINAPI BSTR_UserSize(unsigned long *pFlags, unsigned long Start, B
TRACE("(%lx,%ld,%p) => %p\n", *pFlags, Start, pstr, *pstr); TRACE("(%lx,%ld,%p) => %p\n", *pFlags, Start, pstr, *pstr);
if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr));
ALIGN_LENGTH(Start, 3); ALIGN_LENGTH(Start, 3);
Start += sizeof(bstr_wire_t) + sizeof(OLECHAR) * (SysStringLen(*pstr)); Start += sizeof(bstr_wire_t) + ((SysStringByteLen(*pstr) + 1) & ~1);
TRACE("returning %ld\n", Start); TRACE("returning %ld\n", Start);
return Start; return Start;
} }
...@@ -160,19 +160,21 @@ unsigned long WINAPI BSTR_UserSize(unsigned long *pFlags, unsigned long Start, B ...@@ -160,19 +160,21 @@ unsigned long WINAPI BSTR_UserSize(unsigned long *pFlags, unsigned long Start, B
unsigned char * WINAPI BSTR_UserMarshal(unsigned long *pFlags, unsigned char *Buffer, BSTR *pstr) unsigned char * WINAPI BSTR_UserMarshal(unsigned long *pFlags, unsigned char *Buffer, BSTR *pstr)
{ {
bstr_wire_t *header; bstr_wire_t *header;
DWORD len = SysStringByteLen(*pstr);
TRACE("(%lx,%p,%p) => %p\n", *pFlags, Buffer, pstr, *pstr); TRACE("(%lx,%p,%p) => %p\n", *pFlags, Buffer, pstr, *pstr);
if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr));
ALIGN_POINTER(Buffer, 3); ALIGN_POINTER(Buffer, 3);
header = (bstr_wire_t*)Buffer; header = (bstr_wire_t*)Buffer;
header->len = header->len2 = SysStringLen(*pstr); header->len = header->len2 = (len + 1) / 2;
if (header->len) if (*pstr)
{ {
header->byte_len = header->len * sizeof(OLECHAR); header->byte_len = len;
memcpy(header + 1, *pstr, header->byte_len); memcpy(header + 1, *pstr, header->len * 2);
} }
else else
header->byte_len = 0xffffffff; /* special case for an empty string */ header->byte_len = 0xffffffff; /* special case for a null bstr */
return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len; return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len;
} }
...@@ -187,14 +189,15 @@ unsigned char * WINAPI BSTR_UserUnmarshal(unsigned long *pFlags, unsigned char * ...@@ -187,14 +189,15 @@ unsigned char * WINAPI BSTR_UserUnmarshal(unsigned long *pFlags, unsigned char *
if(header->len != header->len2) if(header->len != header->len2)
FIXME("len %08lx != len2 %08lx\n", header->len, header->len2); FIXME("len %08lx != len2 %08lx\n", header->len, header->len2);
if(header->len) if(*pstr)
SysReAllocStringLen(pstr, (OLECHAR*)(header + 1), header->len);
else if (*pstr)
{ {
SysFreeString(*pstr); SysFreeString(*pstr);
*pstr = NULL; *pstr = NULL;
} }
if(header->byte_len != 0xffffffff)
*pstr = SysAllocStringByteLen((char*)(header + 1), header->byte_len);
if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr));
return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len; return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment