Skip to content

Commit

Permalink
Merge pull request #28 from henrypinkard/main
Browse files Browse the repository at this point in the history
add gh action
  • Loading branch information
henrypinkard authored Sep 9, 2024
2 parents 4200d1f + 319585c commit 9edfe8b
Show file tree
Hide file tree
Showing 18 changed files with 2,137 additions and 1,080 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/publish_pypi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Build and Publish to PyPI using Flit

on:
push:
paths:
- 'src/encoding_information/_version.py'

jobs:
build-and-publish:
runs-on: ubuntu-latest
permissions:
id-token: write # This is required for requesting the JWT
contents: read # This is required for actions/checkout
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flit
- name: Install package dependencies
run: |
pip install -e .
- name: Set PYTHONPATH
run: echo "PYTHONPATH=$PYTHONPATH:${{ github.workspace }}/src" >> $GITHUB_ENV
- name: Debug information
run: |
pwd
ls -R
echo $PYTHONPATH
pip list
- name: Build package
run: flit build
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
21 changes: 21 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,24 @@ imager_experiments/archive_code/*
*.cached_old

*.npy


Icon
venv

build/
dist/
*.egg-info
*.egg-info test
*.pyc
__pycache__/


#jupyter
.ipynb_checkpoints
#vim
*.swp
#vs code
.vscode/
#OS files
.DS_Store
29 changes: 29 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) 2024, Henry Pinkard
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
File renamed without changes.
20 changes: 20 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[build-system]
requires = ["flit_core >=3.2,<4"]
build-backend = "flit_core.buildapi"

[project]
name = "encoding_information"
authors = [{name = "Henry Pinkard"}]
license = {file = "LICENSE"}
classifiers = [
"License :: OSI Approved :: BSD License"
]
dependencies = [
"numpy"
]
dynamic = ["version"]
description = "Estimating encoded information"
readme = "README.md"

[project.urls]
Home = "https://github.com/Waller-Lab/EncodingInformation"
7 changes: 7 additions & 0 deletions src/encoding_information/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
encoding_information package
Information estimators
"""
from ._version import __version__, version_info

2 changes: 2 additions & 0 deletions src/encoding_information/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
version_info = (0, 1, 0)
__version__ = ".".join(map(str, version_info))
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,10 @@ def apply_fn(params, x):
return val_loss_history


def compute_negative_log_likelihood(self, images, seed=None, verbose=True):
def compute_negative_log_likelihood(self, images, data_seed=None, verbose=True, seed=None):
if seed is not None:
warnings.warn('seed argument is deprecated. Use data_seed instead')
data_seed = seed
eig_vals, eig_vecs, mean_vec = self._get_current_params()
cov_mat = eig_vecs @ np.diag(eig_vals) @ eig_vecs.T

Expand All @@ -609,7 +612,7 @@ def compute_negative_log_likelihood(self, images, seed=None, verbose=True):
eig_vals = np.where(eig_vals < floor, floor, eig_vals)
cov_mat = eig_vecs @ np.diag(eig_vals) @ eig_vecs.T

images = match_to_generator_data(images, seed=seed)
images = match_to_generator_data(images, seed=data_seed)

lls = _compute_stationary_log_likelihood(images, cov_mat, mean_vec, verbose=verbose)
return -lls.mean()
Expand Down Expand Up @@ -695,8 +698,11 @@ def fit(self, *args, **kwargs):
warnings.warn('Full Gaussian process does not require fitting. Skipping fit method.')


def compute_negative_log_likelihood(self, images, seed=True, verbose=True):
images = match_to_generator_data(images, seed=seed)
def compute_negative_log_likelihood(self, images, data_seed=None, verbose=True, seed=None):
if seed is not None:
warnings.warn('seed argument is deprecated. Use data_seed instead')
data_seed = seed
images = match_to_generator_data(images, seed=data_seed)
# average nll per pixel
return -gaussian_likelihood(self.cov_mat, self.mean_vec, images).mean() / np.prod(np.array(images.shape[1:]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def fit(self, train_images, learning_rate=1e-2, max_epochs=200, steps_per_epoch=
pass

@abstractmethod
def compute_negative_log_likelihood(self, images, seed=123, verbose=True):
def compute_negative_log_likelihood(self, images, data_seed=123, average=True, verbose=True):
"""
Compute the NLL of the images under the model
Expand All @@ -62,8 +62,10 @@ def compute_negative_log_likelihood(self, images, seed=123, verbose=True):
Array of images, shape (N, H, W)
verbose : bool, optional
Whether to print progress
seed : int, optional
data_seed : int, optional
Random seed for shuffling images (and possibly adding noise)
average : bool, optional
Whether to average the NLL over all images
Returns
-------
Expand Down Expand Up @@ -171,32 +173,43 @@ def make_dataset_generators(images, batch_size, num_val_samples, add_uniform_noi
return train_ds.as_numpy_iterator(), lambda : val_ds.as_numpy_iterator()


def _evaluate_nll(data_iterator, state, eval_step=None, batch_size=16):
def _evaluate_nll(data_iterator, state, eval_step=None, batch_size=16, return_average=True, verbose=False):
"""
Compute negative log likelihood over many batches
batch_size only comes into play if data_iterator is a numpy array
if return_average is False, its up to the user to ensure that the batch size of the data_iterator is 1
"""

if eval_step is None: # default eval step
eval_step = jax.jit(lambda state, imgs: state.apply_fn(state.params, imgs))

total_nll, count = 0, 0
nlls = []
if isinstance(data_iterator, np.ndarray) or isinstance(data_iterator, onp.ndarray):
if not return_average:
batch_size = 1
print('return_average is False but batch_size is not 1. Setting batch_size to 1.')
data_iterator = np.array_split(data_iterator, len(data_iterator) // batch_size + 1) # split into batches of batch_size or less
data_iterator = tqdm(data_iterator, desc='Computing loss')
if verbose:
data_iterator = tqdm(data_iterator, desc='Evaluating NLL')
for batch in data_iterator:
if isinstance(batch, tuple):
images, condition_vector = batch
else:
images = batch
condition_vector = None
batch_nll_per_pixel = eval_step(state, images) if condition_vector is None else eval_step(state, images, condition_vector)
total_nll += images.shape[0] * batch_nll_per_pixel
count += images.shape[0]
# compute average nll per pixel
nll = (total_nll / count).item()
return nll
if return_average:
total_nll += images.shape[0] * batch_nll_per_pixel
count += images.shape[0]
else:
nlls.append(batch_nll_per_pixel)
if return_average:
return (total_nll / count).item()
else:
return np.array(nlls)


def train_model(train_images, state, batch_size, num_val_samples,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def forward_pass(self, x, condition_vectors=None):
# Apply ELU before 1x1 convolution for non-linearity on residual connection
out = self.conv_out(nn.elu(h_stack))

# TODO: maybe append absolute spatial position here to allow for more accurate spatially patterned outputs?

# must be positive and within data range
mu = np.clip(self.mu_dense(out), self.train_data_min, self.train_data_max)

Expand All @@ -274,7 +276,8 @@ def forward_pass(self, x, condition_vectors=None):

class PixelCNN(ProbabilisticImageModel):
"""
Translation layer between the PixelCNNFlaxImpl and the probabilistic image model API
This class handles the training and evaluation of the PixelCNN model, which is implemented in Flax.
It also wraps the model in the ProbabilisticImageModel interface for easy comparison with other models.
"""

Expand Down Expand Up @@ -383,8 +386,11 @@ def train_step(state, imgs, condition_vecs):



def compute_negative_log_likelihood(self, data, conditioning_vecs=None, verbose=True, seed=None):
def compute_negative_log_likelihood(self, data, conditioning_vecs=None, data_seed=None, average=True, verbose=True, seed=None):
# See superclass for docstring
if seed is not None:
warnings.warn("seed argument is deprecated. Use data_seed instead")
data_seed = seed

if data.ndim == 3:
# add a trailing channel dimension if necessary
Expand All @@ -401,15 +407,15 @@ def compute_negative_log_likelihood(self, data, conditioning_vecs=None, verbose=

# get test data generator. Here all data is "validation", because the data passed into this should already be
# (in the typical case) a test set
_, dataset_fn = make_dataset_generators(data, batch_size=32, num_val_samples=data.shape[0],
_, dataset_fn = make_dataset_generators(data, batch_size=32 if average else 1, num_val_samples=data.shape[0],
add_gaussian_noise=self.add_gaussian_noise, add_uniform_noise=self.add_uniform_noise,
condition_vectors=conditioning_vecs, seed=seed)
condition_vectors=conditioning_vecs, seed=data_seed)
@jax.jit
def conditional_eval_step(state, imgs, condition_vecs):
return state.apply_fn(state.params, imgs, condition_vecs)

return _evaluate_nll(dataset_fn(), self._state,
eval_step=conditional_eval_step if conditioning_vecs is not None else None)
return _evaluate_nll(dataset_fn(), self._state, return_average=average,
eval_step=conditional_eval_step if conditioning_vecs is not None else None, verbose=verbose)

def generate_samples(self, num_samples, conditioning_vecs=None, sample_shape=None, ensure_nonnegative=True, seed=None, verbose=True):
if seed is None:
Expand Down
File renamed without changes.
Loading

0 comments on commit 9edfe8b

Please sign in to comment.