Skip to content

Commit

Permalink
Add validation dataset (#1719)
Browse files Browse the repository at this point in the history
  • Loading branch information
sachinprasadhs authored Jan 10, 2024
1 parent 1d8d323 commit 5aff074
Show file tree
Hide file tree
Showing 5 changed files with 16,102 additions and 635 deletions.
Binary file modified examples/vision/img/pointnet/pointnet_10_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/vision/img/pointnet/pointnet_28_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 11 additions & 6 deletions examples/vision/ipynb/pointnet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [David Griffiths](https://dgriffiths3.github.io)<br>\n",
"**Date created:** 2020/05/25<br>\n",
"**Last modified:** 2020/05/26<br>\n",
"**Last modified:** 2024/01/09<br>\n",
"**Description:** Implementation of PointNet for ModelNet10 classification."
]
},
Expand Down Expand Up @@ -68,7 +68,7 @@
"from keras import layers\n",
"from matplotlib import pyplot as plt\n",
"\n",
"keras.random.SeedGenerator(seed=42)"
"keras.utils.set_random_seed(seed=42)"
]
},
{
Expand Down Expand Up @@ -259,11 +259,16 @@
" return points, label\n",
"\n",
"\n",
"train_dataset = tf_data.Dataset.from_tensor_slices((train_points, train_labels))\n",
"train_size = 0.8\n",
"dataset = tf_data.Dataset.from_tensor_slices((train_points, train_labels))\n",
"test_dataset = tf_data.Dataset.from_tensor_slices((test_points, test_labels))\n",
"train_dataset_size = int(len(dataset) * train_size)\n",
"\n",
"train_dataset = train_dataset.shuffle(len(train_points)).map(augment).batch(BATCH_SIZE)\n",
"test_dataset = test_dataset.shuffle(len(test_points)).batch(BATCH_SIZE)"
"dataset = dataset.shuffle(len(train_points)).map(augment)\n",
"test_dataset = test_dataset.shuffle(len(test_points)).batch(BATCH_SIZE)\n",
"\n",
"train_dataset = dataset.take(train_dataset_size).batch(BATCH_SIZE)\n",
"validation_dataset = dataset.skip(train_dataset_size).batch(BATCH_SIZE)"
]
},
{
Expand Down Expand Up @@ -445,7 +450,7 @@
" metrics=[\"sparse_categorical_accuracy\"],\n",
")\n",
"\n",
"model.fit(train_dataset, epochs=20, validation_data=test_dataset)"
"model.fit(train_dataset, epochs=20, validation_data=validation_dataset)"
]
},
{
Expand Down
Loading

0 comments on commit 5aff074

Please sign in to comment.