Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrading repo to TF2 #44

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f0e95fc
some manual attempt to upgrade to tf2 and sonnet 2.0.0
solar464 Apr 4, 2021
0c2a5e2
upgrade with auto upgrade script
solar464 Apr 4, 2021
6419772
fix submodule unit tests
solar464 Apr 25, 2021
0168414
organize tests and rewrite training loop in TF2 with metrics
solar464 May 1, 2021
b4ad854
propagate typing and implement makefile + requirements.txt
solar464 May 22, 2021
6b56cfb
remove tf1 name scopes
solar464 May 22, 2021
af57066
clean up + deterministic tests
solar464 May 23, 2021
71d83f4
update repeat copy train script
solar464 May 23, 2021
cf07bf0
WIP, can't handle complex stats structure
solar464 May 23, 2021
2059de8
migrated to keras.layers.RNN for rnn evaluation
solar464 May 30, 2021
4d99c97
migrate off nametuple rnn states and apply black formatting
solar464 Jun 18, 2021
3fff8b1
update inspection notebook
solar464 Jun 19, 2021
4e74980
pre commit checks
solar464 Jun 21, 2021
f204039
pre commit
solar464 Jun 21, 2021
e476bdb
fix flake8 style
solar464 Jun 21, 2021
68c1eb8
add flake8 configuration
solar464 Jun 21, 2021
1e8e6d4
streamline inspection notebook
solar464 Jun 22, 2021
84aec58
add comments
solar464 Jun 22, 2021
fc0948a
fix comments
solar464 Jun 23, 2021
c69f6ba
Merge branch 'attempt_keras_layer_rnn_migration' into personal_master
solar464 Jun 23, 2021
86e536e
Delete report.txt
solar464 Jun 23, 2021
6a5bc7c
remove interactive.ipynb output for smaller file size
solar464 Jun 23, 2021
a48991a
Merge branch 'master' of https://github.com/solar464/dnc into persona…
solar464 Jun 23, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
ignore = E203, E266, E501, W503, F403, F401
max-line-length = 88
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
repos:
- repo: https://github.com/ambv/black
rev: 21.6b0
hooks:
- id: black
language_version: python3.9
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8

24 changes: 24 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
all: install

install: venv
: # Activate venv and install requirements
mkdir tmp
source venv/bin/activate && TMPDIR=tmp pip install -r requirements.txt
rm -r tmp/
pre-commit install

venv:
: # Create venv if it doesn't exist
: # test -d venv || virtualenv -p python3 --no-site-packages venv
test -d venv || python -m venv venv

test: venv
source venv/bin/activate && python -m pytest

clean:
rm -rf venv/
find -iname "*.pyc" -delete
rm -rf logs/
rm -rf .pytest_cache
rm -rf tmp/

38 changes: 34 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ architecture.

![DNC architecture](images/dnc_model.png)

## Installation
```shell
make install
```

The above command will create a virtual environment and install the dependencies and pre-commit hooks.

Run `source venv/bin/activate` in the root directory of this repository to activate the installed virtual env.

## Testing
```shell
make test
```

Run unit tests in `tests/` using pytest.


## Train
The `DNC` requires an installation of [TensorFlow](https://www.tensorflow.org/)
and [Sonnet](https://github.com/deepmind/sonnet). An example training script is
Expand All @@ -59,13 +76,26 @@ $ ipython train.py -- --memory_size=64 --num_bits=8 --max_length=3
Periodically saving, or 'checkpointing', the model is disabled by default. To
enable, use the `checkpoint_interval` flag. E.g. `--checkpoint_interval=10000`
will ensure a checkpoint is created every `10,000` steps. The model will be
checkpointed to `/tmp/tf/dnc/` by default. From there training can be resumed.
To specify an alternate checkpoint directory, use the `checkpoint_dir` flag.
Note: ensure that `/tmp/tf/dnc/` is deleted before training is resumed with
checkpointed to `./logs/repeat_copy/checkpoint` by default. From there training can be resumed.
To specify an alternate checkpoint directory, use the `log_dir` flag.
Note: ensure that existing checkpoints are deleted or moved before training is resumed with
different model parameters, to avoid shape inconsistency errors.

More generally, the `DNC` class found within `dnc.py` can be used as a standard
TensorFlow rnn core and unrolled with TensorFlow rnn ops, such as
`tf.nn.dynamic_rnn` on any sequential task.
`keras.layers.RNN` on any sequential task.

## Model Inspection
```shell
jupyter notebook interactive.ipynb
```

Jupyter notebook that loads a trained model from checkpoints. It provides helper functions for evaluating arbitrary input bit sequences and visualizing output and intermediate read/write states.

```shell
tensorboard --logdir logs/repeat_copy/
```

Tensorboard visualization of test/train loss and TensorFlow Graph. Test/Train loss is emitted based on `report_interval`.

Disclaimer: This is not an official Google product
Loading