This example is taken from the allele-specific binding manuscript that introduces PyProBound.
This example produces a single CTCF binding model by training on paired-end ChIP-seq data. Since the reads were 100bp long, the count table was created by directly pair-ending overlapping ChIP-seq reads, without any read mapping or peak calling steps. To improve the accuracy of the model, sequence-specific bias in fragmentation is modeled as an unobserved enrichment step predicted by scoring a bias model at the ends of each read.
import torch
import pyprobound
import pyprobound.plotting
Matplotlib is building the font cache; this may take a moment.
Data specification
alphabet = pyprobound.alphabets.DNA()
dataframe = pyprobound.get_dataframe(
"https://github.com/BussemakerLab/AlleleSpecificBinding/raw/main/in-vivo/"
"data/CTCF_ChIP-seq/Control-R1_CTCF-R1.1000000.tsv.gz"
)
# ChIP-seq count table generated by directly merging paired end data
# from ENCODE with accession codes ENCLB581JXH and ENCLB048DBS
count_table = pyprobound.CountTable(dataframe, alphabet, wildcard_pad=True)
Model specification
PSAMs
nonspecific = pyprobound.layers.NonSpecific(alphabet=alphabet, name="NS")
fragmentation_psams = [
pyprobound.layers.PSAM(
alphabet=alphabet, kernel_size=10, name=f"Fragmentation{i}"
)
for i in range(5)
]
ctcf_psam = pyprobound.layers.PSAM(
alphabet=alphabet, kernel_size=18, seed=["---CGCCMYCTAGTGG--"], name="CTCF"
)
psams = fragmentation_psams + [ctcf_psam]
Modes
# NS
mode_ns = pyprobound.Mode.from_nonspecific(nonspecific, count_table)
# LeftFragmentation
roll_left = pyprobound.layers.Roll.from_spec(
pyprobound.layers.RollSpec(
alphabet, direction="left", max_length=10, include_n=True
),
count_table,
)
conv1ds_left = [
pyprobound.layers.Conv1d.from_psam(
psam, roll_left, out_channel_indexing=[0]
)
for psam in fragmentation_psams
]
modes_left = [
pyprobound.Mode([roll_left, conv1d_left]) for conv1d_left in conv1ds_left
]
# RightFragmentation
roll_right = pyprobound.layers.Roll.from_spec(
pyprobound.layers.RollSpec(
alphabet, direction="right", max_length=10, include_n=True
),
count_table,
)
conv1ds_right = [
pyprobound.layers.Conv1d.from_psam(
psam, roll_right, out_channel_indexing=[1]
)
for psam in fragmentation_psams
]
modes_right = [
pyprobound.Mode([roll_right, conv1d_right])
for conv1d_right in conv1ds_right
]
# CTCF
roll_center = pyprobound.layers.Roll.from_spec(
pyprobound.layers.RollSpec(alphabet, direction="center", include_n=True),
count_table,
)
conv1d = pyprobound.layers.Conv1d.from_psam(
ctcf_psam,
roll_center,
train_posbias=True,
bias_bin=5,
length_specific_bias=False,
bias_mode="same",
)
mode_ctcf = pyprobound.Mode([roll_center, conv1d])
Rounds
round_initial = pyprobound.rounds.InitialRound()
reference_round = round_initial
for mode_left, mode_right in zip(modes_left, modes_right):
round_left = pyprobound.rounds.BoundUnsaturatedRound.from_binding(
[mode_ns, mode_left],
reference_round,
activity_heuristic=0.1,
train_depth=False,
)
round_right = pyprobound.rounds.BoundUnsaturatedRound.from_binding(
[mode_ns, mode_right],
round_left,
activity_heuristic=0.1,
train_depth=False,
)
reference_round = round_right
round_bound = pyprobound.rounds.BoundUnsaturatedRound.from_binding(
[mode_ns, mode_ctcf], reference_round, activity_heuristic=0.4
)
Experiment
experiment = pyprobound.Experiment(
[round_initial, round_bound], counts_per_round=count_table.counts_per_round
)
Model
model = pyprobound.MultiExperimentLoss([experiment], full_loss=True)
Fitting
optimizer = pyprobound.Optimizer(
model,
[count_table],
greedy_threshold=0,
device="cpu",
optim_args={"max_iter": 500, "tolerance_grad": 1e-9},
checkpoint="CTCF_ChIP-seq.pt",
output="CTCF_ChIP-seq.txt",
)
order = [
[mode_ns.key()],
*zip([i.key() for i in modes_left], [i.key() for i in modes_right]),
[mode_ctcf.key()],
]
optimizer.train_sequential(order=order)
optimizer.reload()
{'time': 'Wed Aug 14 16:29:00 2024',
'version': '1.4.0',
'flank_lengths': ((0, 0),)}
Loss
with torch.inference_mode():
loss, reg = model([count_table])
print(loss, reg, loss + reg)
tensor(1.6361) tensor(0.0005) tensor(1.6366)













