Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip template search when custom_template_path provided #470

Merged
merged 1 commit into from
Jul 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 71 additions & 58 deletions colabfold/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,19 +299,19 @@ def relax_me(pdb_filename=None, pdb_lines=None, pdb_obj=None, use_gpu=False):
from alphafold.common import residue_constants
from alphafold.relax import relax

if pdb_obj is None:
if pdb_obj is None:
if pdb_lines is None:
pdb_lines = Path(pdb_filename).read_text()
pdb_obj = protein.from_pdb_string(pdb_lines)

amber_relaxer = relax.AmberRelaxation(
max_iterations=0,
tolerance=2.39,
stiffness=10.0,
exclude_residues=[],
max_outer_iterations=3,
use_gpu=use_gpu)

relaxed_pdb_lines, _, _ = amber_relaxer.process(prot=pdb_obj)
return relaxed_pdb_lines

Expand All @@ -321,7 +321,7 @@ def __init__(self, prefix: str, result_dir: Path):
self.result_dir = result_dir
self.tag = None
self.files = {}

def get(self, x: str, ext:str) -> Path:
if self.tag not in self.files:
self.files[self.tag] = []
Expand Down Expand Up @@ -366,13 +366,13 @@ def predict_structure(

# iterate through random seeds
for seed_num, seed in enumerate(range(random_seed, random_seed+num_seeds)):

# iterate through models
for model_num, (model_name, model_runner, params) in enumerate(model_runner_and_params):

# swap params to avoid recompiling
model_runner.params = params

#########################
# process input features
#########################
Expand All @@ -383,24 +383,24 @@ def predict_structure(
input_features["asym_id"] = input_features["asym_id"] - input_features["asym_id"][...,0]
else:
if model_num == 0:
input_features = model_runner.process_features(feature_dict, random_seed=seed)
input_features = model_runner.process_features(feature_dict, random_seed=seed)
r = input_features["aatype"].shape[0]
input_features["asym_id"] = np.tile(feature_dict["asym_id"],r).reshape(r,-1)
if seq_len < pad_len:
input_features = pad_input(input_features, model_runner,
input_features = pad_input(input_features, model_runner,
model_name, pad_len, use_templates)
logger.info(f"Padding length to {pad_len}")


tag = f"{model_type}_{model_name}_seed_{seed:03d}"
model_names.append(tag)
files.set_tag(tag)

########################
# predict
########################
start = time.time()

# monitor intermediate results
def callback(result, recycles):
if recycles == 0: result.pop("tol",None)
Expand All @@ -419,12 +419,12 @@ def callback(result, recycles):
result=result, b_factors=b_factors,
remove_leading_feature_dimension=("multimer" not in model_type))
files.get("unrelaxed",f"r{recycles}.pdb").write_text(protein.to_pdb(unrelaxed_protein))

if save_all:
with files.get("all",f"r{recycles}.pickle").open("wb") as handle:
pickle.dump(result, handle)
del unrelaxed_protein

return_representations = save_all or save_single_representations or save_pair_representations

# predict
Expand All @@ -439,9 +439,9 @@ def callback(result, recycles):
########################
# parse results
########################

# summary metrics
mean_scores.append(result["ranking_confidence"])
mean_scores.append(result["ranking_confidence"])
if recycles == 0: result.pop("tol",None)
if not is_complex: result.pop("iptm",None)
print_line = ""
Expand Down Expand Up @@ -469,7 +469,7 @@ def callback(result, recycles):

#########################
# save results
#########################
#########################

# save pdb
protein_lines = protein.to_pdb(unrelaxed_protein)
Expand Down Expand Up @@ -498,12 +498,12 @@ def callback(result, recycles):
del pae
del plddt
json.dump(scores, handle)

del result, unrelaxed_protein

# early stop criteria fulfilled
if mean_scores[-1] > stop_at_score: break

# early stop criteria fulfilled
if mean_scores[-1] > stop_at_score: break

Expand All @@ -514,7 +514,7 @@ def callback(result, recycles):
###################################################
# rerank models based on predicted confidence
###################################################

rank, metric = [],[]
result_files = []
logger.info(f"reranking models by '{rank_by}' metric")
Expand All @@ -527,7 +527,7 @@ def callback(result, recycles):
if n < num_relax:
start = time.time()
pdb_lines = relax_me(pdb_lines=unrelaxed_pdb_lines[key], use_gpu=use_gpu_relax)
files.get("relaxed","pdb").write_text(pdb_lines)
files.get("relaxed","pdb").write_text(pdb_lines)
logger.info(f"Relaxation took {(time.time() - start):.1f}s")

# rename files to include rank
Expand All @@ -538,7 +538,7 @@ def callback(result, recycles):
new_file = result_dir.joinpath(f"{prefix}_{x}_{new_tag}.{ext}")
file.rename(new_file)
result_files.append(new_file)

return {"rank":rank,
"metric":metric,
"result_files":result_files}
Expand Down Expand Up @@ -649,10 +649,10 @@ def get_queries(
# sort by seq. len
if sort_queries_by == "length":
queries.sort(key=lambda t: len("".join(t[1])))

elif sort_queries_by == "random":
random.shuffle(queries)

is_complex = False
for job_number, (raw_jobname, query_sequence, a3m_lines) in enumerate(queries):
if isinstance(query_sequence, list):
Expand Down Expand Up @@ -719,6 +719,7 @@ def pad_sequences(
def get_msa_and_templates(
jobname: str,
query_sequences: Union[str, List[str]],
a3m_lines: Optional[List[str]],
result_dir: Path,
msa_mode: str,
use_templates: bool,
Expand Down Expand Up @@ -749,17 +750,29 @@ def get_msa_and_templates(
# get template features
template_features = []
if use_templates:
a3m_lines_mmseqs2, template_paths = run_mmseqs2(
query_seqs_unique,
str(result_dir.joinpath(jobname)),
use_env,
use_templates=True,
host_url=host_url,
)
# Skip template search when custom_template_path is provided
if custom_template_path is not None:
if a3m_lines is None:
a3m_lines_mmseqs2 = run_mmseqs2(
query_seqs_unique,
str(result_dir.joinpath(jobname)),
use_env,
use_templates=False,
host_url=host_url,
)
else:
a3m_lines_mmseqs2 = a3m_lines
template_paths = {}
for index in range(0, len(query_seqs_unique)):
template_paths[index] = custom_template_path
else:
a3m_lines_mmseqs2, template_paths = run_mmseqs2(
query_seqs_unique,
str(result_dir.joinpath(jobname)),
use_env,
use_templates=True,
host_url=host_url,
)
if template_paths is None:
logger.info("No template detected")
for index in range(0, len(query_seqs_unique)):
Expand Down Expand Up @@ -966,7 +979,7 @@ def generate_input_feature(

# bugfix
a3m_lines = f">0\n{full_sequence}\n"
a3m_lines += pair_msa(query_seqs_unique, query_seqs_cardinality, paired_msa, unpaired_msa)
a3m_lines += pair_msa(query_seqs_unique, query_seqs_cardinality, paired_msa, unpaired_msa)

input_feature = build_monomer_feature(full_sequence, a3m_lines, mk_mock_template(full_sequence))
input_feature["residue_index"] = np.concatenate([np.arange(L) for L in Ls])
Expand All @@ -987,7 +1000,7 @@ def generate_input_feature(
chain_cnt = 0
# for each unique sequence
for sequence_index, sequence in enumerate(query_seqs_unique):

# get unpaired msa
if unpaired_msa is None:
input_msa = f">{101 + sequence_index}\n{sequence}"
Expand Down Expand Up @@ -1243,7 +1256,7 @@ def run(
if max_msa is not None:
max_seq, max_extra_seq = [int(x) for x in max_msa.split(":")]

if kwargs.pop("use_amber", False) and num_relax == 0:
if kwargs.pop("use_amber", False) and num_relax == 0:
num_relax = num_models * num_seeds

if len(kwargs) > 0:
Expand All @@ -1263,22 +1276,22 @@ def run(
L = len("".join(query_sequence))
if L > max_len: max_len = L
if N > max_num: max_num = N

# get max sequences
# 512 5120 = alphafold_ptm (models 1,3,4)
# 512 1024 = alphafold_ptm (models 2,5)
# 508 2048 = alphafold-multimer_v3 (models 1,2,3)
# 508 1152 = alphafold-multimer_v3 (models 4,5)
# 252 1152 = alphafold-multimer_v[1,2]

set_if = lambda x,y: y if x is None else x
if model_type in ["alphafold2_multimer_v1","alphafold2_multimer_v2"]:
(max_seq, max_extra_seq) = (set_if(max_seq,252), set_if(max_extra_seq,1152))
elif model_type == "alphafold2_multimer_v3":
(max_seq, max_extra_seq) = (set_if(max_seq,508), set_if(max_extra_seq,2048))
else:
(max_seq, max_extra_seq) = (set_if(max_seq,512), set_if(max_extra_seq,5120))

if msa_mode == "single_sequence":
num_seqs = 1
if is_complex and "multimer" not in model_type: num_seqs += max_num
Expand Down Expand Up @@ -1337,7 +1350,7 @@ def run(
first_job = True
for job_number, (raw_jobname, query_sequence, a3m_lines) in enumerate(queries):
jobname = safe_filename(raw_jobname)

#######################################
# check if job has already finished
#######################################
Expand All @@ -1359,59 +1372,59 @@ def run(
# generate MSA (a3m_lines) and templates
###########################################
try:
if a3m_lines is None:
if a3m_lines is None:
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features) \
= get_msa_and_templates(jobname, query_sequence, result_dir, msa_mode, use_templates,
= get_msa_and_templates(jobname, query_sequence, a3m_lines, result_dir, msa_mode, use_templates,
custom_template_path, pair_mode, pairing_strategy, host_url)
elif a3m_lines is not None:

elif a3m_lines is not None:
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features) \
= unserialize_msa(a3m_lines, query_sequence)
if use_templates:
if use_templates:
(_, _, _, _, template_features) \
= get_msa_and_templates(jobname, query_seqs_unique, result_dir, 'single_sequence', use_templates,
= get_msa_and_templates(jobname, query_seqs_unique, a3m_lines, result_dir, 'single_sequence', use_templates,
custom_template_path, pair_mode, pairing_strategy, host_url)

# save a3m
msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality)
result_dir.joinpath(f"{jobname}.a3m").write_text(msa)

except Exception as e:
logger.exception(f"Could not get MSA/templates for {jobname}: {e}")
continue

#######################
# generate features
#######################
try:
(feature_dict, domain_names) \
= generate_input_feature(query_seqs_unique, query_seqs_cardinality, unpaired_msa, paired_msa,
template_features, is_complex, model_type, max_seq=max_seq)

# to allow display of MSA info during colab/chimera run (thanks tomgoddard)
if feature_dict_callback is not None:
feature_dict_callback(feature_dict)

except Exception as e:
logger.exception(f"Could not generate input features {jobname}: {e}")
continue

######################
# predict structures
######################
try:
# get list of lengths
query_sequence_len_array = sum([[len(x)] * y
query_sequence_len_array = sum([[len(x)] * y
for x,y in zip(query_seqs_unique, query_seqs_cardinality)],[])

# decide how much to pad (to avoid recompiling)
if seq_len > pad_len:
if isinstance(recompile_padding, float):
pad_len = math.ceil(seq_len * recompile_padding)
else:
pad_len = seq_len + recompile_padding
pad_len = min(pad_len, max_len)

# prep model and params
if first_job:
# if one job input adjust max settings
Expand All @@ -1423,7 +1436,7 @@ def run(
num_seqs = int(len(feature_dict["msa"]))

if use_templates: num_seqs += 4

# adjust max settings
max_seq = min(num_seqs, max_seq)
max_extra_seq = max(min(num_seqs - max_seq, max_extra_seq), 1)
Expand Down Expand Up @@ -1498,7 +1511,7 @@ def run(
scores_file = result_dir.joinpath(f"{jobname}_scores_{r}.json")
with scores_file.open("r") as handle:
scores.append(json.load(handle))

# write alphafold-db format (pAE)
if "pae" in scores[0]:
af_pae_file = result_dir.joinpath(f"{jobname}_predicted_aligned_error_v1.json")
Expand Down Expand Up @@ -1535,7 +1548,7 @@ def run(
with zipfile.ZipFile(result_zip, "w") as result_zip:
for file in result_files:
result_zip.write(file, arcname=file.name)

# Delete only after the zip was successful, and also not the bibtex and config because we need those again
for file in result_files[:-2]:
file.unlink()
Expand Down Expand Up @@ -1737,7 +1750,7 @@ def main():
)

args = parser.parse_args()

# disable unified memory
if args.disable_unified_memory:
for k in ENV.keys():
Expand All @@ -1756,7 +1769,7 @@ def main():

queries, is_complex = get_queries(args.input, args.sort_queries_by)
model_type = set_model_type(is_complex, args.model_type)

download_alphafold_params(model_type, data_dir)

if args.msa_mode != "single_sequence" and not args.templates:
Expand Down
Loading