Commit 67a213fc authored by Rob Shearman's avatar Rob Shearman Committed by Alexandre Julliard

oleaut32: Try to re-use existing memory when unmarshalling variants with byref types.

parent baccba31
......@@ -607,11 +607,15 @@ static void test_marshal_VARIANT(void)
ok(*(short*)wirev == s, "wv[6] %08x\n", *wirev);
if (VARIANT_UNMARSHAL_WORKS)
{
void *mem;
VariantInit(&v2);
V_VT(&v2) = VT_I2 | VT_BYREF;
V_BYREF(&v2) = mem = CoTaskMemAlloc(sizeof(V_I2(&v2)));
stubMsg.Buffer = buffer;
next = VARIANT_UserUnmarshal(&umcb.Flags, buffer, &v2);
ok(next == buffer + stubMsg.BufferLength, "got %p expect %p\n", next, buffer + stubMsg.BufferLength);
ok(V_VT(&v) == V_VT(&v2), "got vt %d expect %d\n", V_VT(&v), V_VT(&v2));
ok(V_BYREF(&v2) == mem, "didn't reuse existing memory\n");
ok(*V_I2REF(&v) == *V_I2REF(&v2), "got i2 ref %x expect ui4 ref %x\n", *V_I2REF(&v), *V_I2REF(&v2));
VARIANT_UserFree(&umcb.Flags, &v2);
......
......@@ -220,11 +220,11 @@ typedef struct
DWORD switch_is;
} variant_wire_t;
static unsigned int get_type_size(ULONG *pFlags, const VARIANT *pvar)
static unsigned int get_type_size(ULONG *pFlags, VARTYPE vt)
{
if (V_VT(pvar) & VT_ARRAY) return 4;
if (vt & VT_ARRAY) return 4;
switch (V_VT(pvar) & ~VT_BYREF) {
switch (vt & ~VT_BYREF) {
case VT_EMPTY:
case VT_NULL:
return 0;
......@@ -263,15 +263,15 @@ static unsigned int get_type_size(ULONG *pFlags, const VARIANT *pvar)
case VT_RECORD:
return 0;
default:
FIXME("unhandled VT %d\n", V_VT(pvar));
FIXME("unhandled VT %d\n", vt);
return 0;
}
}
static unsigned int get_type_alignment(ULONG *pFlags, const VARIANT *pvar)
static unsigned int get_type_alignment(ULONG *pFlags, VARTYPE vt)
{
unsigned int size = get_type_size(pFlags, pvar);
if(V_VT(pvar) & VT_BYREF) return 3;
unsigned int size = get_type_size(pFlags, vt);
if(vt & VT_BYREF) return 3;
if(size == 0) return 0;
if(size <= 4) return size - 1;
return 7;
......@@ -441,12 +441,12 @@ ULONG WINAPI VARIANT_UserSize(ULONG *pFlags, ULONG Start, VARIANT *pvar)
if(V_VT(pvar) & VT_BYREF)
Start += 4;
align = get_type_alignment(pFlags, pvar);
align = get_type_alignment(pFlags, V_VT(pvar));
ALIGN_LENGTH(Start, align);
if(V_VT(pvar) == (VT_VARIANT | VT_BYREF))
Start += 4;
else
Start += get_type_size(pFlags, pvar);
Start += get_type_size(pFlags, V_VT(pvar));
Start = wire_extra_user_size(pFlags, Start, pvar);
TRACE("returning %d\n", Start);
......@@ -478,8 +478,8 @@ unsigned char * WINAPI VARIANT_UserMarshal(ULONG *pFlags, unsigned char *Buffer,
header->switch_is &= ~VT_TYPEMASK;
Pos = (unsigned char*)(header + 1);
type_size = get_type_size(pFlags, pvar);
align = get_type_alignment(pFlags, pvar);
type_size = get_type_size(pFlags, V_VT(pvar));
align = get_type_alignment(pFlags, V_VT(pvar));
ALIGN_POINTER(Pos, align);
if(header->vt & VT_BYREF)
......@@ -565,32 +565,33 @@ unsigned char * WINAPI VARIANT_UserUnmarshal(ULONG *pFlags, unsigned char *Buffe
TRACE("(%x,%p,%p)\n", *pFlags, Buffer, pvar);
ALIGN_POINTER(Buffer, 7);
VariantClear(pvar);
header = (variant_wire_t *)Buffer;
pvar->n1.n2.vt = header->vt;
pvar->n1.n2.wReserved1 = header->wReserved1;
pvar->n1.n2.wReserved2 = header->wReserved2;
pvar->n1.n2.wReserved3 = header->wReserved3;
Pos = (unsigned char*)(header + 1);
type_size = get_type_size(pFlags, pvar);
align = get_type_alignment(pFlags, pvar);
type_size = get_type_size(pFlags, header->vt);
align = get_type_alignment(pFlags, header->vt);
ALIGN_POINTER(Pos, align);
if(header->vt & VT_BYREF)
{
Pos += 4;
pvar->n1.n2.n3.byref = CoTaskMemAlloc(type_size);
memcpy(pvar->n1.n2.n3.byref, Pos, type_size);
if (V_VT(pvar) != header->vt)
{
VariantClear(pvar);
V_BYREF(pvar) = CoTaskMemAlloc(type_size);
}
else if (!V_BYREF(pvar))
V_BYREF(pvar) = CoTaskMemAlloc(type_size);
memcpy(V_BYREF(pvar), Pos, type_size);
if((header->vt & VT_TYPEMASK) != VT_VARIANT)
Pos += type_size;
Pos += type_size;
else
Pos += 4;
}
else
{
VariantClear(pvar);
if((header->vt & VT_TYPEMASK) == VT_DECIMAL)
memcpy(pvar, Pos, type_size);
else
......@@ -598,6 +599,11 @@ unsigned char * WINAPI VARIANT_UserUnmarshal(ULONG *pFlags, unsigned char *Buffe
Pos += type_size;
}
pvar->n1.n2.vt = header->vt;
pvar->n1.n2.wReserved1 = header->wReserved1;
pvar->n1.n2.wReserved2 = header->wReserved2;
pvar->n1.n2.wReserved3 = header->wReserved3;
if(header->vt & VT_ARRAY)
{
if(header->vt & VT_BYREF)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment