Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions bluepyparallel/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,15 @@ def _evaluate_dataframe(


def _evaluate_basic(
to_evaluate, input_cols, evaluation_function, func_args, func_kwargs, mapper, task_ids, db
to_evaluate,
input_cols,
evaluation_function,
func_args,
func_kwargs,
mapper,
task_ids,
db,
progress_bar=True,
):
res = []
# Setup the function to apply to the data
Expand All @@ -109,8 +117,11 @@ def _evaluate_basic(
arg_list = list(to_evaluate.loc[task_ids, input_cols].to_dict("index").items())

try:
tasks = mapper(eval_func, arg_list)
if progress_bar:
tasks = tqdm(tasks, total=len(task_ids))
# Compute and collect the results
for task_id, result, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)):
for task_id, result, exception in tasks:
res.append(dict({"df_index": task_id, "exception": exception}, **result))

# Save the results into the DB
Expand Down Expand Up @@ -163,6 +174,7 @@ def evaluate(
func_args=None,
func_kwargs=None,
shuffle_rows=True,
progress_bar=True,
**mapper_kwargs,
):
"""Evaluate and save results in a sqlite database on the fly and return dataframe.
Expand All @@ -185,12 +197,14 @@ def evaluate(
func_args (list): the arguments to pass to the evaluation_function.
func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
shuffle_rows (bool): if :obj:`True`, it will shuffle the rows before computing the results.
progress_bar (bool): if :obj:`True`, a progress bar will be displayed during computation.
**mapper_kwargs: the keyword arguments are passed to the get_mapper() method of the
:class:`ParallelFactory` instance.

Return:
pandas.DataFrame: dataframe with new columns containing the computed results.
"""
# pylint: disable=too-many-branches
# Initialize the parallel factory
if isinstance(parallel_factory, str) or parallel_factory is None:
parallel_factory = init_parallel_factory(parallel_factory)
Expand Down Expand Up @@ -243,6 +257,8 @@ def evaluate(
return to_evaluate

# Get the factory mapper
if isinstance(parallel_factory, DaskDataFrameFactory):
mapper_kwargs["progress_bar"] = progress_bar
mapper = parallel_factory.get_mapper(**mapper_kwargs)

if isinstance(parallel_factory, DaskDataFrameFactory):
Expand All @@ -267,6 +283,7 @@ def evaluate(
mapper,
task_ids,
db,
progress_bar,
)
to_evaluate.loc[res_df.index, res_df.columns] = res_df

Expand Down
10 changes: 8 additions & 2 deletions bluepyparallel/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,13 @@ def _with_batches(self, *args, **kwargs):
yield tmp

def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
"""Get a Dask mapper."""
"""Get a Dask mapper.

If ``progress_bar=True`` is passed as keyword argument, a progress bar will be displayed
during computation.
"""
self._chunksize_to_kwargs(chunk_size, kwargs, label="chunksize")
progress_bar = kwargs.pop("progress_bar", True)
if not kwargs.get("chunksize"):
kwargs["npartitions"] = self.nb_processes or 1

Expand All @@ -389,7 +394,8 @@ def _dask_df_mapper(func, iterable):
df = pd.DataFrame(iterable)
ddf = dd.from_pandas(df, **kwargs)
future = ddf.apply(func, meta=meta, axis=1).persist()
progress(future)
if progress_bar:
progress(future)
# Put into a list because of the 'yield from' in _with_batches
return [future.compute()]

Expand Down
6 changes: 5 additions & 1 deletion tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,18 @@ class TestEvaluate:
"""Test the ``bluepyparallel.evaluator.evaluate`` function."""

@pytest.mark.parametrize("with_sql", [True, False])
def test_evaluate(self, input_df, new_columns, expected_df, db_url, with_sql, parallel_factory):
@pytest.mark.parametrize("progress_bar", [True, False])
def test_evaluate(
self, input_df, new_columns, expected_df, db_url, with_sql, progress_bar, parallel_factory
):
"""Test evaluator on a trivial example."""
result_df = evaluate(
input_df,
_evaluation_function,
new_columns,
parallel_factory=parallel_factory,
db_url=db_url if with_sql else None,
progress_bar=progress_bar,
)
if not with_sql:
remove_sql_cols(expected_df)
Expand Down