Skip to content

idata to plotting data

Ari Hartikainen edited this page Aug 21, 2020 · 2 revisions

Here are some notes about plotting interface

To have a consistent interface for (realtime) plotting, we should have a good idea of how to do the following

  • common data structure for data --> plots
  • possibility to update data "easily"

This can be done by translating idata --> pandas.DataFrame (wide format)

  • On the Bokeh side, dataframe can be inserted into ColumnDataSource (CDS) object.
  • On the matplotlib side, dataframe can be used as a data source.

Some plots enable us to update plots with raw data, some need a bit more processing and some need a lot of processing. DataFrame contains GroupBy functionality which can be used for grouped plotting (e.g. by-chain).

InferenceData to (wide) DataFrame

Here are some code how to translate idata to dataframe. It is still a bit brittle and needs some idea how to access coordinates information. (Maybe a class could be used)

import numpy as np
import pandas as pd

def idata_to_dataframe(idata):
    """InferenceData to DataFrame by groups.

    Skips data groups.
    """
    idict = idata.to_dict()
    dims = idict["dims"]
    coords = idict["coords"]
    idfs = {}
    coords_info = {}
    for group, values in idict.items():
        if group in ("dims", "coords"):
            continue
        if group in ("observed_data", "constant_data", "attrs"):
            continue
        group_dict = {}
        chain = draw = None
        for variable, array in values.items():
            if chain is None:
                chain, draw = array.shape[:2]
                group_dict["chain"] = (np.ones((chain, draw), dtype=int) * np.arange(chain)[:, None]).ravel()
                group_dict["draw"] = (np.ones((chain, draw), dtype=int) * np.arange(draw)).ravel()
            if not array.shape[2:]:
                group_dict[variable] = array.ravel()
                continue
            for idx in np.ndindex(array.shape[2:]):
                key = f"{variable}[{','.join(map(str, idx))}]"
                idx_tuple = (slice(None), slice(None)) + idx
                group_dict[key] = array[idx_tuple].ravel()
                
                if variable in dims:
                    var_dims = dims[variable]
                    for i, dim in enumerate(var_dims):
                        if dim in coords:
                            coord_value = coords[dim][idx[i]]
                            coords_info[(group, key)] = (dim, coord_value)
        group_df = pd.DataFrame.from_dict(group_dict, orient="columns")
        idfs[group] = group_df
    return idfs, coords_info

Plot groups

Here is a rough idea what different parts different plots have

Raw (no processing; only grouping if needed)

pair
parallel

Density (kde/hist)

density
distcomparison
dist
energy
forest
hdi
joint
kde
posterior
ppc
trace
violin

autocorr (autocorr func called --> update for each value)

autocorr
ess
mcse
rank

loo / stats (needs recalculation for each new value)

compare
bpv
elpd
khat
loopit

Plot parts

And categorized in a different way:

  • raw --> scatter, rug
  • density --> kde / hist
  • autocorr
  • loo

pair

raw
density
density[contour]

parallel

raw

density

density

distcomparison

density

dist

density

energy

density

forest

density
density[boxplot]

hdi

density

joint

raw
density

kde

density

posterior

density

ppc

density

trace

raw
density
autocorr

violin

density

autocorr

autocorr

ess

autocorr

mcse

autocorr

rank

autocorr

compare

loo

bpv

stats

elpd

loo

khat

loo

loopit

loo