Skip to content

Commit

Permalink
Merge pull request #27 from henrypinkard/main
Browse files Browse the repository at this point in the history
Expose random seeds
  • Loading branch information
henrypinkard authored Sep 6, 2024
2 parents 3fc6204 + c82a296 commit 4200d1f
Show file tree
Hide file tree
Showing 40 changed files with 5,872 additions and 1,699 deletions.
1,183 changes: 924 additions & 259 deletions 1d_simulations/6_rayleigh_2_point_1_point.ipynb

Large diffs are not rendered by default.

240 changes: 240 additions & 0 deletions 1d_simulations/demo_of_optimizing_psf.ipynb

Large diffs are not rendered by default.

1,090 changes: 120 additions & 970 deletions 1d_simulations/mi_vs_sampling_density_objects.ipynb

Large diffs are not rendered by default.

296 changes: 206 additions & 90 deletions 1d_simulations/resolution1d_utils.py

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions 1d_simulations/signal_utils_1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def make_convolutional_encoder(conv_kernel):

def run_optimzation(loss_fn, prox_fn, parameters, learning_rate=1e0, verbose=False,
tolerance=1e-6, momentum=0.9, loss_improvement_patience=800, max_epochs=100000,
learning_rate_decay=0.999, transition_begin=500,
learning_rate_decay=0.999, transition_begin=500, return_param_history=False,
key=None):
"""
Run optimization with optax, return optimized parameters
Expand Down Expand Up @@ -252,6 +252,7 @@ def tolerance_check(loss, loss_history):
best_params = parameters

loss_history = []
param_history = []
for i in range(max_epochs):
if key is not None:
key, subkey = jax.random.split(key)
Expand Down Expand Up @@ -282,11 +283,16 @@ def tolerance_check(loss, loss_history):
# Apply proximal function if provided.
parameters = prox_fn(parameters)

if return_param_history:
param_history.append(parameters)

if verbose == 'very':
print(f'{i}: {loss:.7f}')
elif verbose:
print(f'{i}: {loss:.7f}\r', end='')

if return_param_history:
return best_params, param_history
return best_params

def make_convolutional_forward_model_and_target_signal_MSE_loss_fn(object, target_integrated_signal, sampling_indices=None):
Expand Down Expand Up @@ -596,4 +602,5 @@ def generate_uniform_random_bandlimited_signals(num_nyquist_samples, num_signals
signals.extend(valid_signals)
pbar.update(valid_signals.shape[0])
pbar.close()
return np.array(signals)
print('concatenating...')
return np.array(signals)[:num_signals]
2 changes: 1 addition & 1 deletion encoding_information/bsccm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def add_shot_noise_to_experimenal_data(image_stack, photon_fraction, seed=None):
image_stack: stack of images to add noise to
photon_fraction: fraction of photons to keep
seed: random seed
"""
if seed is None:
seed = onp.random.randint(0, 100000)
Expand Down
Empty file.
1 change: 1 addition & 0 deletions encoding_information/information_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def estimate_mutual_information(noisy_images, clean_images=None, entropy_model=
return_entropy_model : bool, whether to return the noisy image entropy model
verbose : bool, whether to print out the estimated values
"""
warnings.warn("This function is deprecated. Use estimate_information() instead.")
clean_images_if_available = clean_images if clean_images is not None else noisy_images
if np.any(clean_images_if_available < 0):
warnings.warn(f"{np.sum(clean_images_if_available < 0) / clean_images_if_available.size:.2%} of pixels are negative.")
Expand Down
44 changes: 29 additions & 15 deletions encoding_information/models/gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
from flax.training.train_state import TrainState
from encoding_information.models.image_distribution_models import ProbabilisticImageModel, train_model, make_dataset_generators


def match_to_generator_data(images, seed=None):
"""
Important: during the training process, noise is added to the pixel values to account for the
fact that discrete pixel values are used with a continuous density in the model. This is handled by the
make_dataset_generator function in the image_distribution_models module. So we call this here on a the images
to ensure that the same noise is added to the images here as was added during training, and then convert back
to a jax array
"""
_, dataset_fn = make_dataset_generators(images, batch_size=images.shape[0], num_val_samples=images.shape[0], seed=seed)
return next(dataset_fn())

def estimate_full_cov_mat(patches):
"""
Take an NxWxH stack of patches, and compute the covariance matrix of the vectorized patches
Expand Down Expand Up @@ -474,7 +486,7 @@ def compute_loss(self, mean_vec, cov_mat, images):

class StationaryGaussianProcess(ProbabilisticImageModel):

def __init__(self, images, eigenvalue_floor=1e-3, verbose=False):
def __init__(self, images, eigenvalue_floor=1e-3, seed=None, verbose=False):
"""
Create a StationaryGaussianProcess model and initialize it to the plugin estimate of the stationary covariance matrix
"""
Expand All @@ -484,8 +496,10 @@ def __init__(self, images, eigenvalue_floor=1e-3, verbose=False):
self.initial_params = self._flax_model.init(jax.random.PRNGKey(0)) # Note: this RNG doesnt actually matter because there's no random initialization

# initialize parameters
initial_cov_mat = plugin_estimate_stationary_cov_mat(images, eigenvalue_floor=eigenvalue_floor, suppress_warning=True, verbose=verbose)
mean_vec = np.ones(self.image_shape[0]**2) * np.mean(images)
self.images = images
data_generator_matched = match_to_generator_data(images, seed=seed)
initial_cov_mat = plugin_estimate_stationary_cov_mat(data_generator_matched, eigenvalue_floor=eigenvalue_floor, suppress_warning=True, verbose=verbose)
mean_vec = np.ones(self.image_shape[0]**2) * np.mean(data_generator_matched)

eig_vals, eig_vecs = np.linalg.eigh(initial_cov_mat)
self.initial_params['params']['eig_vals'] = eig_vals
Expand All @@ -496,13 +510,16 @@ def __init__(self, images, eigenvalue_floor=1e-3, verbose=False):



def fit(self, train_images, learning_rate=1e2, max_epochs=60, steps_per_epoch=1, patience=15,
def fit(self, train_images=None, data_seed=None,
learning_rate=1e2, max_epochs=60, steps_per_epoch=1, patience=15,
batch_size=12, num_val_samples=None, percent_samples_for_validation=0.1,
eigenvalue_floor=1e-3, gradient_clip=1, momentum=0.9,
precondition_gradient=False, verbose=True):

num_val_samples = int(train_images.shape[0] * percent_samples_for_validation) if num_val_samples is None else num_val_samples
if train_images is None:
train_images = self.images

num_val_samples = int(train_images.shape[0] * percent_samples_for_validation) if num_val_samples is None else num_val_samples


self._optimizer = optax.chain(
Expand Down Expand Up @@ -558,6 +575,7 @@ def apply_fn(params, x):

best_params, val_loss_history = train_model(train_images=train_images, state=self._state, batch_size=batch_size, num_val_samples=int(num_val_samples),
steps_per_epoch=steps_per_epoch, num_epochs=max_epochs, patience=patience, train_step=_train_step,
seed=data_seed,
verbose=verbose)
# ensure that eigenvalues are positive definite
if best_params['params']['eig_vals'].min() < 0:
Expand All @@ -578,7 +596,7 @@ def apply_fn(params, x):
return val_loss_history


def compute_negative_log_likelihood(self, images, verbose=True):
def compute_negative_log_likelihood(self, images, seed=None, verbose=True):
eig_vals, eig_vecs, mean_vec = self._get_current_params()
cov_mat = eig_vecs @ np.diag(eig_vals) @ eig_vecs.T

Expand All @@ -591,13 +609,7 @@ def compute_negative_log_likelihood(self, images, verbose=True):
eig_vals = np.where(eig_vals < floor, floor, eig_vals)
cov_mat = eig_vecs @ np.diag(eig_vals) @ eig_vecs.T

# Important: during the training process, noise is added to the pixel values to account for the
# fact that discrete pixel values are used with a continuous density in the model. This is handled by the
# make_dataset_generator function in the image_distribution_models module. So we call this here on a the images
# to ensure that the same noise is added to the images here as was added during training, and then convert back
# to a jax array
_, dataset_fn = make_dataset_generators(images, batch_size=images.shape[0], num_val_samples=images.shape[0])
images = next(dataset_fn())
images = match_to_generator_data(images, seed=seed)

lls = _compute_stationary_log_likelihood(images, cov_mat, mean_vec, verbose=verbose)
return -lls.mean()
Expand Down Expand Up @@ -651,12 +663,13 @@ def _get_current_params(self):

class FullGaussianProcess(ProbabilisticImageModel):

def __init__(self, images, eigenvalue_floor=1e-3, verbose=False):
def __init__(self, images, eigenvalue_floor=1e-3, seed=None, verbose=False):
"""
Estiamte mean and covariance matrix of a full Gaussian process from images
"""
self.image_shape = images.shape[1:]

images = match_to_generator_data(images, seed=seed)
# initialize parameters
if verbose:
print('computing full covariance matrix')
Expand All @@ -682,7 +695,8 @@ def fit(self, *args, **kwargs):
warnings.warn('Full Gaussian process does not require fitting. Skipping fit method.')


def compute_negative_log_likelihood(self, images, verbose=True):
def compute_negative_log_likelihood(self, images, seed=True, verbose=True):
images = match_to_generator_data(images, seed=seed)
# average nll per pixel
return -gaussian_likelihood(self.cov_mat, self.mean_vec, images).mean() / np.prod(np.array(images.shape[1:]))

Expand Down
43 changes: 26 additions & 17 deletions encoding_information/models/image_distribution_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ProbabilisticImageModel(ABC):
@abstractmethod
def fit(self, train_images, learning_rate=1e-2, max_epochs=200, steps_per_epoch=100, patience=10,
batch_size=64, num_val_samples=None, percent_samples_for_validation=0.1,
seed=0, verbose=True):
data_seed=None, model_seed=None, verbose=True):
"""
Fit the model to the images
Expand All @@ -37,8 +37,10 @@ def fit(self, train_images, learning_rate=1e-2, max_epochs=200, steps_per_epoch=
Number of validation samples to use. If None, use percent_samples_for_validation
percent_samples_for_validation : float, optional
Percentage of samples to use for validation
seed : int, optional
Random seed to initialize the model
data_seed : int, optional
Random seed that controls shuffling and adding noise to data
model_seed : int, optional
Random seed that controls initialization of weights
verbose : bool, optional
Whether to print training progress
Expand All @@ -50,7 +52,7 @@ def fit(self, train_images, learning_rate=1e-2, max_epochs=200, steps_per_epoch=
pass

@abstractmethod
def compute_negative_log_likelihood(self, images, verbose=True):
def compute_negative_log_likelihood(self, images, seed=123, verbose=True):
"""
Compute the NLL of the images under the model
Expand All @@ -60,6 +62,8 @@ def compute_negative_log_likelihood(self, images, verbose=True):
Array of images, shape (N, H, W)
verbose : bool, optional
Whether to print progress
seed : int, optional
Random seed for shuffling images (and possibly adding noise)
Returns
-------
Expand All @@ -84,7 +88,7 @@ def generate_samples(self, num_samples, sample_shape=None, verbose=True):
"""
pass

def add_gaussian_noise_fn(images, condition_vectors=None):
def _add_gaussian_noise_fn(images, condition_vectors=None):
"""
Add gaussian noise to images
"""
Expand All @@ -94,7 +98,7 @@ def add_gaussian_noise_fn(images, condition_vectors=None):
else:
return noisy_images

def add_uniform_noise_fn(images, condition_vectors=None):
def _add_uniform_noise_fn(images, condition_vectors=None):
"""
Add uniform noise to images
"""
Expand All @@ -104,10 +108,14 @@ def add_uniform_noise_fn(images, condition_vectors=None):
else:
return noisy_images

def make_dataset_generators(images, batch_size, num_val_samples, add_uniform_noise=True, add_gaussian_noise=False, condition_vectors=None):
def make_dataset_generators(images, batch_size, num_val_samples, add_uniform_noise=True,
add_gaussian_noise=False, condition_vectors=None, seed=None):
"""
Use tensorflow datasets to make fast data pipelines
"""
if seed is not None:
tf.random.set_seed(seed)

if num_val_samples > images.shape[0]:
raise ValueError("Number of validation samples must be less than the number of training samples")

Expand Down Expand Up @@ -150,20 +158,20 @@ def make_dataset_generators(images, batch_size, num_val_samples, add_uniform_noi
raise ValueError("Cannot add both gaussian and uniform noise")

if add_gaussian_noise:
train_ds = train_ds.map(add_gaussian_noise_fn)
val_ds = val_ds.map(add_gaussian_noise_fn)
train_ds = train_ds.map(_add_gaussian_noise_fn)
val_ds = val_ds.map(_add_gaussian_noise_fn)

if add_uniform_noise:
train_ds = train_ds.map(add_uniform_noise_fn)
val_ds = val_ds.map(add_uniform_noise_fn)
train_ds = train_ds.map(_add_uniform_noise_fn)
val_ds = val_ds.map(_add_uniform_noise_fn)

train_ds = train_ds.repeat().shuffle(1024).batch(batch_size, drop_remainder=False).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.shuffle(1024).batch(batch_size, drop_remainder=False).prefetch(tf.data.AUTOTUNE)

return train_ds.as_numpy_iterator(), lambda : val_ds.as_numpy_iterator()


def _evaluate_nll(data_iterator, state, eval_step=None, seed=0, batch_size=16, verbose=True):
def _evaluate_nll(data_iterator, state, eval_step=None, batch_size=16):
"""
Compute negative log likelihood over many batches
Expand All @@ -173,7 +181,6 @@ def _evaluate_nll(data_iterator, state, eval_step=None, seed=0, batch_size=16, v
if eval_step is None: # default eval step
eval_step = jax.jit(lambda state, imgs: state.apply_fn(state.params, imgs))

key = jax.random.PRNGKey(seed)
total_nll, count = 0, 0
if isinstance(data_iterator, np.ndarray) or isinstance(data_iterator, onp.ndarray):
data_iterator = np.array_split(data_iterator, len(data_iterator) // batch_size + 1) # split into batches of batch_size or less
Expand All @@ -194,7 +201,7 @@ def _evaluate_nll(data_iterator, state, eval_step=None, seed=0, batch_size=16, v

def train_model(train_images, state, batch_size, num_val_samples,
steps_per_epoch, num_epochs, patience, train_step, condition_vectors=None,
add_gaussian_noise=False, add_uniform_noise=True,
add_gaussian_noise=False, add_uniform_noise=True, seed=None,
verbose=True):
"""
Training loop with early stopping. Returns a callable with
Expand All @@ -204,7 +211,7 @@ def train_model(train_images, state, batch_size, num_val_samples,
warnings.warn(f'Number of validation samples must be less than the number of training samples. Using {num_val_samples} validation samples instead.')
train_ds_iterator, val_loader_maker_fn = make_dataset_generators(train_images,
batch_size=batch_size, num_val_samples=num_val_samples, condition_vectors=condition_vectors,
add_gaussian_noise=add_gaussian_noise, add_uniform_noise=add_uniform_noise
add_gaussian_noise=add_gaussian_noise, add_uniform_noise=add_uniform_noise, seed=seed
)

if condition_vectors is not None:
Expand All @@ -213,7 +220,7 @@ def train_model(train_images, state, batch_size, num_val_samples,
eval_step = jax.jit(lambda state, imgs: state.apply_fn(state.params, imgs))

best_params = state.params
eval_nll = _evaluate_nll(val_loader_maker_fn(), state, eval_step=eval_step, verbose=verbose)
eval_nll = _evaluate_nll(val_loader_maker_fn(), state, eval_step=eval_step)
if verbose:
print(f'Initial validation NLL: {eval_nll:.2f}')

Expand All @@ -224,15 +231,17 @@ def train_model(train_images, state, batch_size, num_val_samples,
avg_loss = 0
iter = range(steps_per_epoch)
for _ in iter if not verbose else tqdm(iter, desc=f'Epoch {epoch_idx}'):

batch = next(train_ds_iterator)

if condition_vectors is None:
state, loss = train_step(state, batch)
else:
state, loss = train_step(state, batch[0], batch[1])

avg_loss += loss / steps_per_epoch

eval_nll = _evaluate_nll(val_loader_maker_fn(), state, eval_step=eval_step, verbose=verbose)
eval_nll = _evaluate_nll(val_loader_maker_fn(), state, eval_step=eval_step)
if np.isnan(eval_nll):
warnings.warn('NaN encountered in validation loss. Stopping early.')
break
Expand Down
Loading

0 comments on commit 4200d1f

Please sign in to comment.