import copy
import pairk.backend.tools.sequence_utils as tools
import pairk.backend.tools.pairwise_tools as pairwise_tools
import pairk.backend.exceptions as _exceptions
import pairk.backend.tools.esm_tools as esm_tools
import torch
from collections import defaultdict
import torch
def get_subsequences(indices, input_string, k):
"""Get subsequences of ortho string based on best indices"""
char_list = list(input_string)
start_indices = indices.view(-1, 1)
end_indices = start_indices + k
max_length = len(char_list)
subsequences = [
"".join(char_list[start:end]) for start, end in zip(start_indices, end_indices)
]
return subsequences
def run_pairwise_kmer_emb_aln(
query_id: str,
embedding_dict: dict,
k: int,
):
"""
This function was originally too slow but was vectorized/optimized by
Foster Birnbaum. Thanks Foster!
"""
# get the query sequence and remove it from the embedding_dict
ref_seq_str, ref_seq_embedding = embedding_dict.pop(query_id)
kmers = tools.gen_kmers(ref_seq_str, k)
positions = list(range(len(ref_seq_str) - (k - 1)))
score_df = pairwise_tools.make_empty_kmer_ortho_df(
positions, list(embedding_dict.keys())
)
orthokmer_df = pairwise_tools.make_empty_kmer_ortho_df(
positions, list(embedding_dict.keys())
)
pos_df = pairwise_tools.make_empty_kmer_ortho_df(
positions, list(embedding_dict.keys())
)
score_df.loc[positions, "query_kmer"] = kmers
orthokmer_df.loc[positions, "query_kmer"] = kmers
pos_df.loc[positions, "query_kmer"] = kmers
# Make expanded ref tensor
expand_inds_ref = torch.arange(k).view(1, -1) + torch.arange(
ref_seq_embedding.shape[0]
).view(-1, 1)
expand_inds_ref[ref_seq_embedding.shape[0] - (k - 1) :] = 0
expand_inds_ref = (
expand_inds_ref.unsqueeze(-1)
.expand(-1, -1, ref_seq_embedding.shape[1])
.to(dtype=torch.int64)
)
expand_ref = ref_seq_embedding.unsqueeze(1).expand(-1, expand_inds_ref.shape[1], -1)
expand_ref = torch.gather(expand_ref, 0, expand_inds_ref)
expand_ref = expand_ref[: ref_seq_embedding.shape[0] - (k - 1)].reshape(
-1, k * expand_ref.shape[2]
)
# for each ortholog sequence
for ortholog_id, v in embedding_dict.items():
ortholog_seq = v[0]
ortholog_embedding = v[1]
if ortholog_seq == "no idr" or ortholog_embedding == "no idr":
orthokmer_df.loc[positions, ortholog_id] = "-" * k
continue
if len(ortholog_seq) < k:
orthokmer_df.loc[positions, ortholog_id] = "-" * k
continue
# Make expanded ortholog tensor
expand_inds_ortho = torch.arange(k).view(1, -1) + torch.arange(
ortholog_embedding.shape[0]
).view(-1, 1)
expand_inds_ortho[ortholog_embedding.shape[0] - (k - 1) :] = 0
expand_inds_ortho = (
expand_inds_ortho.unsqueeze(-1)
.expand(-1, -1, ortholog_embedding.shape[1])
.to(dtype=torch.int64)
)
expand_ortho = ortholog_embedding.unsqueeze(1).expand(
-1, expand_inds_ortho.shape[1], -1
)
expand_ortho = torch.gather(expand_ortho, 0, expand_inds_ortho)
expand_ortho = expand_ortho[: ortholog_embedding.shape[0] - (k - 1)].reshape(
-1, k * expand_ortho.shape[2]
)
# Calculate pairwise distances and get stats
pairwise_dists = torch.cdist(
expand_ref, expand_ortho, p=2
) # Optional: compute_mode='donot_use_mm_for_euclid_dist'
min_dists, min_dists_pos = torch.min(pairwise_dists, dim=-1)
score_df.loc[positions, ortholog_id] = min_dists.cpu().numpy()
pos_df.loc[positions, ortholog_id] = min_dists_pos.cpu().numpy()
orthokmer_df.loc[positions, ortholog_id] = get_subsequences(
min_dists_pos, ortholog_seq, k
)
return score_df, orthokmer_df, pos_df
def get_idr_embeddings(
seq_str: str,
idr_start: int,
idr_end: int,
mod: esm_tools.ESM_Model,
device="cuda",
):
idr_str = seq_str[idr_start : idr_end + 1]
if len(idr_str) == 0:
return "no idr", "no idr"
# print("no idr")
# return None, None
orth_tensor = mod.encode(seq_str, device=device)
idr_ortho_tensor = orth_tensor[idr_start + 1 : idr_end + 2, :] # type: ignore # +1 to account for the start token
return idr_str, idr_ortho_tensor
# def get_idr_embedding_dict(
# full_length_sequence_dict: dict[str, str],
# idr_position_map: dict[str, list[int]],
# mod: esm_tools.ESM_Model,
# device="cuda",
# ):
# embedding_dict = defaultdict(list)
# # get dictionary keys and values as 2 separate lists
# # fl_seqs = [(i, seq) for i, seq in full_length_sequence_dict.items()]
# seq_ids, seqs = zip(*full_length_sequence_dict.items())
# seq_embeddings = mod.encode_multiple_seqs(seqs, device=device)
# for i, seq, seq_embedding in zip(seq_ids, seqs, seq_embeddings):
# idrst, idrend = idr_position_map[i][0], idr_position_map[i][1]
# idr_str = seq[idrst : idrend + 1]
# if len(idr_str) == 0:
# embedding_dict[i].append("no idr")
# embedding_dict[i].append("no idr")
# continue
# idr_ortho_tensor = seq_embedding[idrst + 1 : idrend + 2, :] # type: ignore # +1 to account for the start token
# embedding_dict[i].append(idr_str)
# embedding_dict[i].append(idr_ortho_tensor)
# return embedding_dict
def get_idr_embedding_dict(
full_length_sequence_dict: dict[str, str],
idr_position_map: dict[str, list[int]],
mod: esm_tools.ESM_Model,
device="cuda",
):
embedding_dict = defaultdict(list)
for i, seq in full_length_sequence_dict.items():
idr_str, idr_tensor = get_idr_embeddings(
seq,
idr_position_map[i][0],
idr_position_map[i][1],
mod,
device,
)
embedding_dict[i].append(idr_str)
embedding_dict[i].append(idr_tensor)
return embedding_dict
def slice_idr_embedding_dict(
full_length_sequence_dict: dict[str, str],
idr_position_map: dict[str, list[int]],
precomputed_embeddings: dict[str, torch.Tensor],
):
embedding_dict = defaultdict(list)
for i, seq in full_length_sequence_dict.items():
idr_start = idr_position_map[i][0]
idr_end = idr_position_map[i][1]
idr_str = seq[idr_start : idr_end + 1]
if len(idr_str) == 0:
embedding_dict[i].append("no idr")
embedding_dict[i].append("no idr")
continue
idr_ortho_tensor = precomputed_embeddings[i][idr_start + 1 : idr_end + 2, :] # type: ignore # +1 to account for the start token
embedding_dict[i].append(idr_str)
embedding_dict[i].append(idr_ortho_tensor)
assert idr_ortho_tensor.shape[0] == len(
idr_str
), f"IDR tensor and string length mismatch for {i}"
return embedding_dict
[docs]
def pairk_alignment_embedding_distance(
full_length_sequence_dict: dict[str, str],
idr_position_map: dict[str, list[int]],
query_id: str,
k: int,
mod: esm_tools.ESM_Model,
device: str = "cuda",
precomputed_embeddings: None | dict[str, torch.Tensor] = None,
):
"""run pairwise k-mer alignment method using residue embeddings from a large
language model to find the best k-mer matches from each homolog. By default,
the ESM2 protein large language model is used to generate residue embeddings.
Other residue embeddings (e.g. from different LLMs) can be used by providing
the embeddings directly to the `precomputed_embeddings` argument.
If an ortholog IDR is shorter than the k-mer, a string of "-" characters
("-"\\*k) is assigned as the best matching ortholog k-mer for that ortholog
**Note**: if there are multiple top-scoring matches, only one is returned.
If `precomputed_embeddings` is not provided,
Sequence embeddings are calculated for each full length sequence in the
input dictionary. The `idr_position_map` dictionary is used to extract the
IDR and the IDR embeddings from each sequence.
If `precomputed_embeddings` is provided, the function will use these embeddings.
the provided embeddings must have an extra dimension for the start and end
tokens. The start token is at index 0 and the end token is at index -1. These
are stripped out before calculating the pairwise distances.
The Euclidean distance is calculated between each query k-mer embedding
slice and each ortholog k-mer embedding slice to find the best matching
ortholog k-mer from each ortholog.
Parameters
----------
full_length_sequence_dict : dict[str, str]
input sequences in dictionary format with the key being the sequence id and
the value being the sequence as a string
idr_position_map : dict[str, list[int]]
a dictionary where the keys are the sequence ids in `full_length_sequence_dict`
and the values are the start and end positions of the IDR in the sequence
(using python indexing). This is used to slice out the IDR
embeddings/sequences from the full-length embeddings/sequences.
query_id : str
the id of the query sequence within the `full_length_sequence_dict` dictionary
and the `idr_position_map` dictionary. The query id must be present in
both dictionaries.
k : int
the length of the k-mers to use for the alignment
mod : esm_tools.ESM_Model
ESM2 model used to generate the embeddings
device : str, optional
whether to use cuda or cpu for pytorch, must be either "cpu" or "cuda",
by default "cuda". If "cuda" fails, it will default to "cpu". This
argument is passed to the `esm_tools.ESM_Model.encode` method.
precomputed_embeddings : None | dict[str, torch.Tensor], optional
a dictionary where the keys are the sequence ids in `full_length_sequence_dict`
and the values are the precomputed embeddings for each sequence. If this
is provided, the function will use these embeddings instead of computing
them. Allows you to pass any precomputed embeddings (from any LLM) for
the sequences. The provided embeddings must have an extra dimension for
the start and end tokens. The start token is at index 0 and the end
token is at index -1. These are stripped out before calculating the
pairwise distances.
Returns
-------
pairwise_tools.PairkAln
an object containing the alignment results. See the `pairk.PairkAln`
class for more information.
"""
_exceptions.check_queryid_in_idr_dict(full_length_sequence_dict, query_id)
_exceptions.check_queryid_in_idr_dict(idr_position_map, query_id)
assert set(full_length_sequence_dict.keys()) == set(
idr_position_map.keys()
), "Keys in full_length_dict and idr_position_map must be the same"
full_length_dict = copy.deepcopy(full_length_sequence_dict)
if precomputed_embeddings is None:
embedding_dict = get_idr_embedding_dict(
full_length_dict, idr_position_map, mod, device
)
else:
_exceptions.check_queryid_in_idr_dict(precomputed_embeddings, query_id)
assert set(full_length_sequence_dict.keys()) == set(
precomputed_embeddings.keys()
), "Keys in full_length_dict and precomputed_embeddings must be the same"
embedding_dict = slice_idr_embedding_dict(
full_length_sequence_dict=full_length_dict,
idr_position_map=idr_position_map,
precomputed_embeddings=precomputed_embeddings,
)
score_df, orthokmer_df, pos_df = run_pairwise_kmer_emb_aln(
query_id,
embedding_dict,
k,
)
return pairwise_tools.PairkAln(
orthokmer_df=orthokmer_df,
pos_df=pos_df,
score_df=score_df,
)