Commit c06167b1 authored by Sebastian Lackner's avatar Sebastian Lackner Committed by Alexandre Julliard

vcomp: Fix handling of _vcomp_fork with ifval == FALSE.

Forks with ifval == FALSE do not count as "parallel", so nested forks are still allowed. Please note that calling _vcomp_fork(FALSE, ...) or directly calling the callback is still different in some aspects, the synchronization functions run in a different context for example.
parent 889eba36
...@@ -54,6 +54,7 @@ struct vcomp_thread_data ...@@ -54,6 +54,7 @@ struct vcomp_thread_data
struct vcomp_team_data *team; struct vcomp_team_data *team;
struct vcomp_task_data *task; struct vcomp_task_data *task;
int thread_num; int thread_num;
BOOL parallel;
int fork_threads; int fork_threads;
/* only used for concurrent tasks */ /* only used for concurrent tasks */
...@@ -203,6 +204,7 @@ static struct vcomp_thread_data *vcomp_init_thread_data(void) ...@@ -203,6 +204,7 @@ static struct vcomp_thread_data *vcomp_init_thread_data(void)
thread_data->team = NULL; thread_data->team = NULL;
thread_data->task = &data->task; thread_data->task = &data->task;
thread_data->thread_num = 0; thread_data->thread_num = 0;
thread_data->parallel = FALSE;
thread_data->fork_threads = 0; thread_data->fork_threads = 0;
thread_data->section = 1; thread_data->section = 1;
...@@ -410,10 +412,11 @@ void WINAPIV _vcomp_fork(BOOL ifval, int nargs, void *wrapper, ...) ...@@ -410,10 +412,11 @@ void WINAPIV _vcomp_fork(BOOL ifval, int nargs, void *wrapper, ...)
TRACE("(%d, %d, %p, ...)\n", ifval, nargs, wrapper); TRACE("(%d, %d, %p, ...)\n", ifval, nargs, wrapper);
if (prev_thread_data->parallel && !vcomp_nested_fork)
ifval = FALSE;
if (!ifval) if (!ifval)
num_threads = 1; num_threads = 1;
else if (prev_thread_data->team && !vcomp_nested_fork)
num_threads = 1;
else if (prev_thread_data->fork_threads) else if (prev_thread_data->fork_threads)
num_threads = prev_thread_data->fork_threads; num_threads = prev_thread_data->fork_threads;
else else
...@@ -433,6 +436,7 @@ void WINAPIV _vcomp_fork(BOOL ifval, int nargs, void *wrapper, ...) ...@@ -433,6 +436,7 @@ void WINAPIV _vcomp_fork(BOOL ifval, int nargs, void *wrapper, ...)
thread_data.team = &team_data; thread_data.team = &team_data;
thread_data.task = &task_data; thread_data.task = &task_data;
thread_data.thread_num = 0; thread_data.thread_num = 0;
thread_data.parallel = ifval || prev_thread_data->parallel;
thread_data.fork_threads = 0; thread_data.fork_threads = 0;
thread_data.section = 1; thread_data.section = 1;
list_init(&thread_data.entry); list_init(&thread_data.entry);
...@@ -450,6 +454,7 @@ void WINAPIV _vcomp_fork(BOOL ifval, int nargs, void *wrapper, ...) ...@@ -450,6 +454,7 @@ void WINAPIV _vcomp_fork(BOOL ifval, int nargs, void *wrapper, ...)
data->team = &team_data; data->team = &team_data;
data->task = &task_data; data->task = &task_data;
data->thread_num = team_data.num_threads++; data->thread_num = team_data.num_threads++;
data->parallel = thread_data.parallel;
data->fork_threads = 0; data->fork_threads = 0;
data->section = 1; data->section = 1;
list_remove(&data->entry); list_remove(&data->entry);
...@@ -470,6 +475,7 @@ void WINAPIV _vcomp_fork(BOOL ifval, int nargs, void *wrapper, ...) ...@@ -470,6 +475,7 @@ void WINAPIV _vcomp_fork(BOOL ifval, int nargs, void *wrapper, ...)
data->team = &team_data; data->team = &team_data;
data->task = &task_data; data->task = &task_data;
data->thread_num = team_data.num_threads; data->thread_num = team_data.num_threads;
data->parallel = thread_data.parallel;
data->fork_threads = 0; data->fork_threads = 0;
data->section = 1; data->section = 1;
InitializeConditionVariable(&data->cond); InitializeConditionVariable(&data->cond);
......
...@@ -208,17 +208,21 @@ static void CDECL num_threads_cb(BOOL nested, int nested_threads, LONG *count) ...@@ -208,17 +208,21 @@ static void CDECL num_threads_cb(BOOL nested, int nested_threads, LONG *count)
thread_count = 0; thread_count = 0;
p_vcomp_fork(TRUE, 1, num_threads_cb2, &thread_count); p_vcomp_fork(TRUE, 1, num_threads_cb2, &thread_count);
if (nested) if (nested)
ok(thread_count == nested_threads, "expected %d thread, got %d\n", nested_threads, thread_count); ok(thread_count == nested_threads, "expected %d threads, got %d\n", nested_threads, thread_count);
else else
ok(thread_count == 1, "expected 1 thread, got %d\n", thread_count); ok(thread_count == 1, "expected 1 thread, got %d\n", thread_count);
thread_count = 0;
p_vcomp_fork(FALSE, 1, num_threads_cb2, &thread_count);
ok(thread_count == 1, "expected 1 thread, got %d\n", thread_count);
p_vcomp_set_num_threads(4); p_vcomp_set_num_threads(4);
thread_count = 0; thread_count = 0;
p_vcomp_fork(TRUE, 1, num_threads_cb2, &thread_count); p_vcomp_fork(TRUE, 1, num_threads_cb2, &thread_count);
if (nested) if (nested)
ok(thread_count == 4 , "expected 4 thread, got %d\n", thread_count); ok(thread_count == 4, "expected 4 threads, got %d\n", thread_count);
else else
ok(thread_count == 1 , "expected 1 thread, got %d\n", thread_count); ok(thread_count == 1, "expected 1 thread, got %d\n", thread_count);
} }
static void test_omp_get_num_threads(BOOL nested) static void test_omp_get_num_threads(BOOL nested)
...@@ -241,6 +245,19 @@ static void test_omp_get_num_threads(BOOL nested) ...@@ -241,6 +245,19 @@ static void test_omp_get_num_threads(BOOL nested)
p_vcomp_fork(TRUE, 3, num_threads_cb, nested, max_threads, &thread_count); p_vcomp_fork(TRUE, 3, num_threads_cb, nested, max_threads, &thread_count);
ok(thread_count == max_threads, "expected %d threads, got %d\n", max_threads, thread_count); ok(thread_count == max_threads, "expected %d threads, got %d\n", max_threads, thread_count);
num_threads = pomp_get_num_threads();
ok(num_threads == 1, "expected num_threads == 1, got %d\n", num_threads);
thread_count = 0;
p_vcomp_fork(FALSE, 3, num_threads_cb, TRUE, max_threads, &thread_count);
ok(thread_count == 1, "expected 1 thread, got %d\n", thread_count);
pomp_set_num_threads(1);
num_threads = pomp_get_num_threads();
ok(num_threads == 1, "expected num_threads == 1, got %d\n", num_threads);
thread_count = 0;
p_vcomp_fork(TRUE, 3, num_threads_cb, nested, 1, &thread_count);
ok(thread_count == 1, "expected 1 thread, got %d\n", thread_count);
pomp_set_num_threads(2); pomp_set_num_threads(2);
num_threads = pomp_get_num_threads(); num_threads = pomp_get_num_threads();
ok(num_threads == 1, "expected num_threads == 1, got %d\n", num_threads); ok(num_threads == 1, "expected num_threads == 1, got %d\n", num_threads);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment