Skip to content

Commit

Permalink
reverting RoseTTAFold2 to default params
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton committed Mar 20, 2024
1 parent 6d2969f commit 0614fdf
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions RoseTTAFold2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"source": [
"%%time\n",
"#@title setup **RoseTTAFold2** (~1m)\n",
"params = \"RF2_jan24\" # @param [\"RF2_apr23\",\"RF2_jan24\"]\n",
"params = \"RF2_apr23\" # @param [\"RF2_apr23\",\"RF2_jan24\"]\n",
"\n",
"import os, time, sys\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:512\"\n",
Expand Down Expand Up @@ -92,7 +92,12 @@
" while os.path.isfile(f\"{params}.tgz.aria2\"):\n",
" time.sleep(5)\n",
"\n",
"if not os.path.isfile(f\"{params}.pt\"):\n",
"if params == \"RF2_jan24\":\n",
" model_params = f\"{params}.pt\"\n",
"if params == \"RF2_apr23\":\n",
" model_params = f\"weights/{params}.pt\"\n",
"\n",
"if not os.path.isfile(model_params):\n",
" os.system(f\"tar -zxvf {params}.tgz\")\n",
"\n",
"if not \"IMPORTED\" in dir():\n",
Expand Down Expand Up @@ -120,15 +125,16 @@
"\n",
" IMPORTED = True\n",
"\n",
"if not \"pred\" in dir():\n",
"if not \"pred\" in dir() or model_params_sele != model_params:\n",
" from predict import Predictor\n",
" print(\"compile RoseTTAFold2\")\n",
" model_params = f\"{params}.pt\"\n",
"\n",
" if (torch.cuda.is_available()):\n",
" pred = Predictor(model_params, torch.device(\"cuda:0\"))\n",
" else:\n",
" print (\"WARNING: using CPU\")\n",
" pred = Predictor(model_params, torch.device(\"cpu\"))\n",
" model_params_sele = model_params\n",
"\n",
"def get_unique_sequences(seq_list):\n",
" unique_seqs = list(OrderedDict.fromkeys(seq_list))\n",
Expand Down Expand Up @@ -253,7 +259,7 @@
"use_dropout = False #@param {type:\"boolean\"}\n",
"max_msa = 256 #@param [16, 32, 64, 128, 256, 512] {type:\"raw\"}\n",
"random_seed = 0 #@param {type:\"integer\"}\n",
"num_models = 1 #@param [\"1\", \"2\", \"4\", \"8\", \"16\", \"32\"] {type:\"raw\"}\n",
"num_models = 1 #@param [\"1\", \"5\", \"10\", \"15\", \"20\", \"25\"] {type:\"raw\"}\n",
"\n",
"# process\n",
"max_extra_msa = max_msa * 8\n",
Expand Down Expand Up @@ -341,7 +347,7 @@
],
"metadata": {
"cellView": "form",
"id": "_oJTZGgdeKkO"
"id": "Eh48KV70rQ03"
},
"execution_count": null,
"outputs": []
Expand Down Expand Up @@ -391,7 +397,7 @@
],
"metadata": {
"cellView": "form",
"id": "53wdd2WX70o_"
"id": "3m0H-yCIrpc4"
},
"execution_count": null,
"outputs": []
Expand All @@ -409,6 +415,7 @@
"settings_path = f\"{jobname}/settings.txt\"\n",
"with open(settings_path, \"w\") as text_file:\n",
" text_file.write(f\"method=RoseTTAFold2\\n\")\n",
" text_file.write(f\"params={params}\\n\")\n",
" text_file.write(f\"sequence={sequence}\\n\")\n",
" text_file.write(f\"sym={sym}\\n\")\n",
" text_file.write(f\"order={order}\\n\")\n",
Expand Down

0 comments on commit 0614fdf

Please sign in to comment.