Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summarise function to summarise results -- with more flexibilitythan previous utility function #1457

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 62 additions & 15 deletions src/tlo/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
from typing import Callable, Dict, Iterable, List, Optional, TextIO, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, TextIO, Tuple, Union, Literal

import git
import matplotlib.colors as mcolors
Expand Down Expand Up @@ -306,43 +306,90 @@ def generate_series(dataframe: pd.DataFrame) -> pd.Series:
return _concat


def summarize(results: pd.DataFrame, only_mean: bool = False, collapse_columns: bool = False) -> pd.DataFrame:
def summarise(
results: pd.DataFrame,
central_measure: Literal["mean", "median"] = "median",
width_of_range: float = 0.95,
only_central: bool = False,
collapse_columns: bool = False,
) -> pd.DataFrame:
"""Utility function to compute summary statistics

Finds mean value and 95% interval across the runs for each draw.
Finds a central value and a specified interval across the runs for each draw. By default, this uses a central
measure of the median and a 95% interval range.

:Param: results: The pd.DataFame of results.
:Param: central_measure: The name of the central measure to use - either 'mean' or 'median'.
:Param: width_of_range: The width of the range to compute the statistics (e.g. 0.95 for the 95% interval).
:Param: collapse_columns: Whether to simplify the columnar index if there is only one run (cannot be done otherwise)
:Param: only_central: Whether to only report the central value (dropping the range).
Comment on lines +321 to +325
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:Param: results: The pd.DataFame of results.
:Param: central_measure: The name of the central measure to use - either 'mean' or 'median'.
:Param: width_of_range: The width of the range to compute the statistics (e.g. 0.95 for the 95% interval).
:Param: collapse_columns: Whether to simplify the columnar index if there is only one run (cannot be done otherwise)
:Param: only_central: Whether to only report the central value (dropping the range).
:param results: The dataframe of results to compute summary statistics of.
:param central_measure: The name of the central measure to use - either 'mean' or 'median'.
:param width_of_range: The width of the range to compute the statistics (e.g. 0.95 for the 95% interval).
:param collapse_columns: Whether to simplify the columnar index if there is only one run (cannot be done otherwise).
:param only_central: Whether to only report the central value (dropping the range).
:return: A dataframe with computed summary statistics.

Small update to fix parameter directive syntax in docstring and adding return information.


"""
stats = dict()

if central_measure == 'mean':
stats.update({'central': results.groupby(axis=1, by='draw', sort=False).mean()})
elif central_measure == 'median':
stats.update({'central': results.groupby(axis=1, by='draw', sort=False).median()})
Comment on lines +330 to +333
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if central_measure == 'mean':
stats.update({'central': results.groupby(axis=1, by='draw', sort=False).mean()})
elif central_measure == 'median':
stats.update({'central': results.groupby(axis=1, by='draw', sort=False).median()})
if central_measure == 'mean':
stats['central'] = results.groupby(axis=1, by='draw', sort=False).mean()
elif central_measure == 'median':
stats['central'] = results.groupby(axis=1, by='draw', sort=False).median()

Indexed assignment to a dict is generally preferable over update when just adding a key-value pair as it avoids creating an unecessary intermediate dictionary and is slightly more readable.

We could also avoid the repetition across the different conditions by changing to something like

    if central_measure in ('mean', 'median'):
        grouped_results = results.groupby(axis=1, by='draw', sort=False)
        stats['central'] = getattr(grouped_results, central_measure)()

but I think on balance probably the loss of readability outweighs the slight gain in avoiding code redundancy.

else:
raise ValueError(f"Unknown stat: {central_measure}")

summary = pd.concat(
stats.update(
{
'mean': results.groupby(axis=1, by='draw', sort=False).mean(),
'lower': results.groupby(axis=1, by='draw', sort=False).quantile(0.025),
'upper': results.groupby(axis=1, by='draw', sort=False).quantile(0.975),
},
axis=1
'lower': results.groupby(axis=1, by='draw', sort=False).quantile((1.-width_of_range)/2.),
'upper': results.groupby(axis=1, by='draw', sort=False).quantile(1.-(1.-width_of_range)/2.),
}
)
Comment on lines +337 to 342
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid computing the groupby operation multiple times here and when computing central measure, it would be better to have one line something like

grouped_results = results.groupby(axis=1, by='draw', sort=False)

at beginning of function and then using grouped_results in place of repeated results.groupby(axis=1, by='draw', sort=False) calls.

I would possibly also say writing this as two separate indexed assignments to the stats dict rather than using update would be a bit more readable

lower_quantile = (1 - width_of_range) / 2
stats["lower"] = grouped_results.quantile(lower_quantile)
stats["upper"] = grouped_results.quantile(1 - lower_quantile)

but using dict.update here to update dictionary with two entries is reasonable so I think this is more of a personal preference thing!


summary = pd.concat(stats, axis=1)
summary.columns = summary.columns.swaplevel(1, 0)
summary.columns.names = ['draw', 'stat']
summary = summary.sort_index(axis=1)
summary = summary.sort_index(axis=1).reindex(columns=['lower', 'central', 'upper'], level=1)

if only_mean and (not collapse_columns):
if only_central and (not collapse_columns):
# Remove other metrics and simplify if 'only_mean' across runs for each draw is required:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Remove other metrics and simplify if 'only_mean' across runs for each draw is required:
# Remove other metrics and simplify if 'only_central' across runs for each draw is required:

om: pd.DataFrame = summary.loc[:, (slice(None), "mean")]
om: pd.DataFrame = summary.loc[:, (slice(None), "central")]
om.columns = [c[0] for c in om.columns.to_flat_index()]
om.columns.name = 'draw'
return om
Comment on lines +351 to 354
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if om was short for only_mean then we have to update this to oc for only_central

Suggested change
om: pd.DataFrame = summary.loc[:, (slice(None), "central")]
om.columns = [c[0] for c in om.columns.to_flat_index()]
om.columns.name = 'draw'
return om
oc: pd.DataFrame = summary.loc[:, (slice(None), "central")]
oc.columns = [c[0] for c in oc.columns.to_flat_index()]
oc.columns.name = 'draw'
return oc


elif collapse_columns and (len(summary.columns.levels[0]) == 1):
# With 'collapse_columns', if number of draws is 1, then collapse columns multi-index:
summary_droppedlevel = summary.droplevel('draw', axis=1)
if only_mean:
return summary_droppedlevel['mean']
if only_central:
return summary_droppedlevel['central']
else:
return summary_droppedlevel

else:
return summary


def summarize(
Copy link
Collaborator

@mnjowe mnjowe Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function name summarise and summarize looks confusing. can't we suggest a better name for this function? If we don't want to interfere with how this is used outside utils we can rename summarise?

also, do we really need this function? I think this is just calling/copying summarise function with an argument mean instead of the default median. Can't we have summarise default to mean on central_measure and delete this function? in that case summarise will be renamed to summarize

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @mnjowe that having two different but similar functions that differ only the in whether they use an ize or ise is likely to be confusing.

I suspect the reason you've kept summarize with the original default behaviour of computing mean and same call signature as previously @tbhallett is to ensure this doesn't break code using the previous version?

There are few alternative ways we could deal with this:

  • Just making a breaking change and removing the old function and using summarize for new function (potentially with a helpful error message if users try to call with previous signature).
  • Using a more differentiating name for new function - perhaps something like compute_summary_statistics?
  • Changing default behaviour of new function to replicate old behaviour and using same summarize name, but raising deprecation warning if relying on old defaults. This would require doing somethin like defaulting tocentral_measure="mean", adding a **kwargs argument to new function, checking if only_mean is in kwargs and using to set value of only_central as well as raising a deprecation warning to indicate that this argument is deprecated and users should use use_central instead, and also dealing with the adjustment to column name from central to mean in this case. Overall this is probably a bit complex.

results: pd.DataFrame,
only_mean: bool = False,
collapse_columns: bool = False
):
"""Utility function to compute summary statistics

Finds mean value and 95% interval across the runs for each draw.

NOTE: This provides the legacy functionality of `summarize` that is hard-wired to use `means` (the kwarg is
`only_mean` and the name of the column in the output is `mean`). Please move to using the new and more flexible
version of `summarize` that allows the use of medians and is flexible to allow other forms of summary measure in
the future.
"""
output = summarise(
results=results,
central_measure='mean',
only_central=only_mean,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

passing only_mean to only_central here looks confusing. can't we use one name for consistency?

collapse_columns=collapse_columns,
)
if output.columns.nlevels > 1:
output = output.rename(columns={'central': 'mean'}, level=1) # rename 'central' to 'mean'
return output
Comment on lines +388 to +390
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we do this in summarise function, delete this function and rename summarise to summarize?



def get_grid(params: pd.DataFrame, res: pd.Series):
"""Utility function to create the arrays needed to plot a heatmap.

Expand Down Expand Up @@ -1129,7 +1176,7 @@ def get_parameters_for_status_quo() -> Dict:
"equip_availability": "all", # <--- NB. Existing calibration is assuming all equipment is available
},
}

def get_parameters_for_standard_mode2_runs() -> Dict:
"""
Returns a dictionary of parameters and their updated values to indicate
Expand Down
28 changes: 21 additions & 7 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
order_of_coarse_appt,
order_of_short_treatment_ids,
parse_log_file,
summarise,
summarize,
unflatten_flattened_multi_index_in_logging,
)
Expand Down Expand Up @@ -571,7 +572,7 @@ def check_parameters(self) -> None:
sim.simulate(end_date=Date(year_of_change + 2, 1, 1))


def test_summarize():
def test_summarise():
"""Check that the summarize utility function works as expected."""

results_multiple_draws = pd.DataFrame(
Expand Down Expand Up @@ -602,10 +603,10 @@ def test_summarize():
columns=pd.MultiIndex.from_tuples(
[
("DrawA", "lower"),
("DrawA", "mean"),
("DrawA", "central"),
("DrawA", "upper"),
("DrawB", "lower"),
("DrawB", "mean"),
("DrawB", "central"),
("DrawB", "upper"),
],
names=("draw", "stat"),
Expand All @@ -618,7 +619,7 @@ def test_summarize():
]
),
),
summarize(results_multiple_draws),
summarise(results_multiple_draws, central_measure='mean'),
)

# Without collapsing and only mean
Expand All @@ -628,19 +629,32 @@ def test_summarize():
index=["TimePoint0", "TimePoint1"],
data=np.array([[10.0, 1500.0], [10.0, 1500.0]]),
),
summarize(results_multiple_draws, only_mean=True),
summarise(results_multiple_draws, central_measure='mean', only_central=True),
)

# With collapsing (as only one draw)
pd.testing.assert_frame_equal(
pd.DataFrame(
columns=pd.Index(["lower", "mean", "upper"], name="stat"),
columns=pd.Index(["lower", "central", "upper"], name="stat"),
index=["TimePoint0", "TimePoint1"],
data=np.array([[0.5, 10.0, 19.5], [0.5, 10.0, 19.5], ]),
),
summarize(results_one_draw, collapse_columns=True),
summarise(results_one_draw, central_measure='mean', collapse_columns=True),
)

# Check that summarize() produces legacy behaviour:
pd.testing.assert_frame_equal(
summarise(results_multiple_draws, central_measure='mean').rename(columns={'central': 'mean'}, level=1),
summarize(results_multiple_draws)
)
pd.testing.assert_frame_equal(
summarise(results_multiple_draws, central_measure='mean', only_central=True),
summarize(results_multiple_draws, only_mean=True)
)
pd.testing.assert_frame_equal(
summarise(results_one_draw, central_measure='mean', collapse_columns=True),
summarize(results_one_draw, collapse_columns=True)
)

def test_control_loggers_from_same_module_independently(seed, tmpdir):
"""Check that detailed/summary loggers in the same module can configured independently."""
Expand Down
Loading