Commit e653a2e0 authored by Rémi Bernon's avatar Rémi Bernon Committed by Alexandre Julliard

hidclass.sys: Validate report IDs in hid_device_xfer_report.

parent 3d9a6fe6
...@@ -303,7 +303,7 @@ static void hid_device_xfer_report( BASE_DEVICE_EXTENSION *ext, ULONG code, IRP ...@@ -303,7 +303,7 @@ static void hid_device_xfer_report( BASE_DEVICE_EXTENSION *ext, ULONG code, IRP
{ {
const WINE_HIDP_PREPARSED_DATA *preparsed = ext->u.pdo.preparsed_data; const WINE_HIDP_PREPARSED_DATA *preparsed = ext->u.pdo.preparsed_data;
IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation( irp ); IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation( irp );
BYTE report_id = HID_INPUT_VALUE_CAPS( preparsed )->report_id; struct hid_value_caps *caps = NULL, *caps_end = NULL;
ULONG report_len = 0, buffer_len = 0; ULONG report_len = 0, buffer_len = 0;
HID_XFER_PACKET packet; HID_XFER_PACKET packet;
BYTE *buffer = NULL; BYTE *buffer = NULL;
...@@ -326,13 +326,19 @@ static void hid_device_xfer_report( BASE_DEVICE_EXTENSION *ext, ULONG code, IRP ...@@ -326,13 +326,19 @@ static void hid_device_xfer_report( BASE_DEVICE_EXTENSION *ext, ULONG code, IRP
{ {
case IOCTL_HID_GET_INPUT_REPORT: case IOCTL_HID_GET_INPUT_REPORT:
report_len = preparsed->caps.InputReportByteLength; report_len = preparsed->caps.InputReportByteLength;
caps = HID_INPUT_VALUE_CAPS( preparsed );
caps_end = caps + preparsed->value_caps_count[HidP_Input];
break; break;
case IOCTL_HID_SET_OUTPUT_REPORT: case IOCTL_HID_SET_OUTPUT_REPORT:
report_len = preparsed->caps.OutputReportByteLength; report_len = preparsed->caps.OutputReportByteLength;
caps = HID_OUTPUT_VALUE_CAPS( preparsed );
caps_end = caps + preparsed->value_caps_count[HidP_Output];
break; break;
case IOCTL_HID_GET_FEATURE: case IOCTL_HID_GET_FEATURE:
case IOCTL_HID_SET_FEATURE: case IOCTL_HID_SET_FEATURE:
report_len = preparsed->caps.FeatureReportByteLength; report_len = preparsed->caps.FeatureReportByteLength;
caps = HID_FEATURE_VALUE_CAPS( preparsed );
caps_end = caps + preparsed->value_caps_count[HidP_Feature];
break; break;
} }
...@@ -347,11 +353,18 @@ static void hid_device_xfer_report( BASE_DEVICE_EXTENSION *ext, ULONG code, IRP ...@@ -347,11 +353,18 @@ static void hid_device_xfer_report( BASE_DEVICE_EXTENSION *ext, ULONG code, IRP
return; return;
} }
for (; caps != caps_end; ++caps) if (!caps->report_id || caps->report_id == buffer[0]) break;
if (caps == caps_end)
{
irp->IoStatus.Status = STATUS_INVALID_PARAMETER;
return;
}
packet.reportId = buffer[0]; packet.reportId = buffer[0];
packet.reportBuffer = buffer; packet.reportBuffer = buffer;
packet.reportBufferLen = buffer_len; packet.reportBufferLen = buffer_len;
if (!report_id) if (!caps->report_id)
{ {
packet.reportId = 0; packet.reportId = 0;
packet.reportBuffer++; packet.reportBuffer++;
......
...@@ -532,7 +532,6 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) ...@@ -532,7 +532,6 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp)
ok(!in_size, "got input size %u\n", in_size); ok(!in_size, "got input size %u\n", in_size);
ok(out_size == sizeof(*packet), "got output size %u\n", out_size); ok(out_size == sizeof(*packet), "got output size %u\n", out_size);
todo_wine_if(packet->reportId == 0x5a)
ok(packet->reportId == report_id, "got id %u\n", packet->reportId); ok(packet->reportId == report_id, "got id %u\n", packet->reportId);
ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen); ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen);
ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer); ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer);
...@@ -552,7 +551,6 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) ...@@ -552,7 +551,6 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp)
ok(in_size == sizeof(*packet), "got input size %u\n", in_size); ok(in_size == sizeof(*packet), "got input size %u\n", in_size);
ok(!out_size, "got output size %u\n", out_size); ok(!out_size, "got output size %u\n", out_size);
todo_wine_if(packet->reportId == 0x5a)
ok(packet->reportId == report_id, "got id %u\n", packet->reportId); ok(packet->reportId == report_id, "got id %u\n", packet->reportId);
ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen); ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen);
ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer); ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer);
...@@ -569,7 +567,6 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) ...@@ -569,7 +567,6 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp)
ok(!in_size, "got input size %u\n", in_size); ok(!in_size, "got input size %u\n", in_size);
ok(out_size == sizeof(*packet), "got output size %u\n", out_size); ok(out_size == sizeof(*packet), "got output size %u\n", out_size);
todo_wine_if(packet->reportId == 0x5a)
ok(packet->reportId == report_id, "got id %u\n", packet->reportId); ok(packet->reportId == report_id, "got id %u\n", packet->reportId);
ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen); ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen);
ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer); ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer);
...@@ -588,7 +585,6 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) ...@@ -588,7 +585,6 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp)
ok(in_size == sizeof(*packet), "got input size %u\n", in_size); ok(in_size == sizeof(*packet), "got input size %u\n", in_size);
ok(!out_size, "got output size %u\n", out_size); ok(!out_size, "got output size %u\n", out_size);
todo_wine_if(packet->reportId == 0x5a)
ok(packet->reportId == report_id, "got id %u\n", packet->reportId); ok(packet->reportId == report_id, "got id %u\n", packet->reportId);
ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen); ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen);
ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer); ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer);
......
...@@ -2454,9 +2454,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled ...@@ -2454,9 +2454,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
ret = HidD_GetInputReport(file, buffer, caps.InputReportByteLength); ret = HidD_GetInputReport(file, buffer, caps.InputReportByteLength);
if (report_id || broken(!ret) /* w7u */) if (report_id || broken(!ret) /* w7u */)
{ {
todo_wine
ok(!ret, "HidD_GetInputReport succeeded, last error %u\n", GetLastError()); ok(!ret, "HidD_GetInputReport succeeded, last error %u\n", GetLastError());
todo_wine
ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_CRC), ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_CRC),
"HidD_GetInputReport returned error %u\n", GetLastError()); "HidD_GetInputReport returned error %u\n", GetLastError());
} }
...@@ -2499,9 +2497,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled ...@@ -2499,9 +2497,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
ret = HidD_GetFeature(file, buffer, caps.FeatureReportByteLength); ret = HidD_GetFeature(file, buffer, caps.FeatureReportByteLength);
if (report_id || broken(!ret)) if (report_id || broken(!ret))
{ {
todo_wine
ok(!ret, "HidD_GetFeature succeeded, last error %u\n", GetLastError()); ok(!ret, "HidD_GetFeature succeeded, last error %u\n", GetLastError());
todo_wine
ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_CRC), ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_CRC),
"HidD_GetFeature returned error %u\n", GetLastError()); "HidD_GetFeature returned error %u\n", GetLastError());
} }
...@@ -2544,9 +2540,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled ...@@ -2544,9 +2540,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
ret = HidD_SetFeature(file, buffer, caps.FeatureReportByteLength); ret = HidD_SetFeature(file, buffer, caps.FeatureReportByteLength);
if (report_id || broken(!ret)) if (report_id || broken(!ret))
{ {
todo_wine
ok(!ret, "HidD_SetFeature succeeded, last error %u\n", GetLastError()); ok(!ret, "HidD_SetFeature succeeded, last error %u\n", GetLastError());
todo_wine
ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_CRC), ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_CRC),
"HidD_SetFeature returned error %u\n", GetLastError()); "HidD_SetFeature returned error %u\n", GetLastError());
} }
...@@ -2593,9 +2587,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled ...@@ -2593,9 +2587,7 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled
ret = HidD_SetOutputReport(file, buffer, caps.OutputReportByteLength); ret = HidD_SetOutputReport(file, buffer, caps.OutputReportByteLength);
if (report_id || broken(!ret)) if (report_id || broken(!ret))
{ {
todo_wine
ok(!ret, "HidD_SetOutputReport succeeded, last error %u\n", GetLastError()); ok(!ret, "HidD_SetOutputReport succeeded, last error %u\n", GetLastError());
todo_wine
ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_CRC), ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_CRC),
"HidD_SetOutputReport returned error %u\n", GetLastError()); "HidD_SetOutputReport returned error %u\n", GetLastError());
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment