Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 71 additions & 32 deletions binchicken/workflow/scripts/cluster_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,45 @@ def pipeline(
is_pooled = any(elusive_edges["style"] == "pool")

with pl.StringCache():
elusive_edges = (
elusive_edges
.with_columns(
pl.col("samples")
.str.split(",")
.cast(pl.List(pl.Categorical)),
pl.col("target_ids")
.str.split(",")
.cast(pl.List(pl.UInt64)),
)
.with_columns(
samples_hash = pl.col("samples").list.sort().hash(),
length = pl.col("samples").list.len()
)
)
if MAX_COASSEMBLY_SAMPLES == 1:
elusive_edges = (
elusive_edges
.with_columns(
pl.col("samples")
.str.split(",")
.cast(pl.List(pl.Categorical)),
pl.col("target_ids")
.str.split(",")
.cast(pl.List(pl.UInt64)),
)
.with_columns(
source_sample = pl.col("samples").list.first(),
dest_sample = pl.col("samples").list.get(1, null_on_oob=True),
samples_hash = pl.col("samples").list.sort().hash(),
length = pl.col("samples").list.len()
)
)

directional_rows = elusive_edges.filter(pl.col("style") == "directional").height
singleton_rows = elusive_edges.filter(pl.col("style") == "singleton").height
if directional_rows == 0 and singleton_rows == 0:
raise ValueError("Directional or singleton edges required for single-sample clusters. Please rerun with a clean output folder.")
else:
elusive_edges = (
elusive_edges
.with_columns(
pl.col("samples")
.str.split(",")
.cast(pl.List(pl.Categorical)),
pl.col("target_ids")
.str.split(",")
.cast(pl.List(pl.UInt64)),
)
.with_columns(
samples_hash = pl.col("samples").list.sort().hash(),
length = pl.col("samples").list.len()
)
)

if weightings is not None:
if weightings.height == 0:
Expand All @@ -180,7 +204,12 @@ def pipeline(
else:
weightings_dict = {}

if COASSEMBLY_SAMPLES:
if COASSEMBLY_SAMPLES and MAX_COASSEMBLY_SAMPLES == 1:
coassembly_edges = (
elusive_edges
.filter(pl.col("source_sample").is_in(COASSEMBLY_SAMPLES))
)
elif COASSEMBLY_SAMPLES:
coassembly_edges = (
elusive_edges
.with_columns(
Expand Down Expand Up @@ -217,14 +246,13 @@ def pipeline(
logging.info("Skipping clustering, using single-sample clusters")
clusters = [
elusive_edges
.explode("samples")
.filter((not COASSEMBLY_SAMPLES) | pl.col("samples").is_in(COASSEMBLY_SAMPLES))
.group_by("samples")
.filter((not COASSEMBLY_SAMPLES) | pl.col("source_sample").is_in(COASSEMBLY_SAMPLES))
.group_by("source_sample")
.agg(pl.col("target_ids").flatten())
.with_columns(
pl.concat_list(pl.col("samples")),
pl.col("target_ids").list.sort().list.unique(),
samples_hash = pl.concat_list(pl.col("samples")).list.sort().hash(),
samples = pl.concat_list(pl.col("source_sample")),
target_ids = pl.col("target_ids").list.sort().list.unique(),
samples_hash = pl.concat_list(pl.col("source_sample")).list.sort().hash(),
)
]
else:
Expand Down Expand Up @@ -264,15 +292,26 @@ def pipeline(
.select("samples", "target_ids", "samples_hash")
)

sample_targets = (
elusive_edges
.select("target_ids", recover_candidates = pl.col("samples"))
.explode("recover_candidates")
.explode("target_ids")
.unique()
.group_by("recover_candidates")
.agg("target_ids")
)
if MAX_COASSEMBLY_SAMPLES == 1:
sample_targets = (
elusive_edges
.select("target_ids", recover_candidates = pl.col("dest_sample"))
.filter(pl.col("recover_candidates").is_not_null())
.explode("target_ids")
.unique()
.group_by("recover_candidates")
.agg("target_ids")
)
else:
sample_targets = (
elusive_edges
.select("target_ids", recover_candidates = pl.col("samples"))
.explode("recover_candidates")
.explode("target_ids")
.unique()
.group_by("recover_candidates")
.agg("target_ids")
)

def filter_max_coassembly_size(df, MAX_COASSEMBLY_SIZE):
if MAX_COASSEMBLY_SIZE is None:
Expand Down Expand Up @@ -337,7 +376,7 @@ def filter_max_coassembly_size(df, MAX_COASSEMBLY_SIZE):
.then(pl.col("total_targets"))
.otherwise(pl.col("target_ids").list.len()),
)
.sort("total_targets", "total_size", descending=[True, False])
.sort("total_targets", "total_size", "samples", descending=[True, False, True])
.with_row_index("coassembly")
.select(
"samples", "length", "total_targets", "total_size", "recover_samples",
Expand Down
154 changes: 139 additions & 15 deletions binchicken/workflow/scripts/target_elusive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Author: Samuel Aroney

import os
import glob
import polars as pl
import logging
import numpy as np
Expand Down Expand Up @@ -150,10 +151,6 @@ def streaming_pipeline(
pl.DataFrame(schema=EDGES_COLUMNS).write_csv(edges_path, separator="\t")
return

if MAX_COASSEMBLY_SAMPLES < 2:
# Set to 2 to produce paired edges
MAX_COASSEMBLY_SAMPLES = 2

# Filter TAXA_OF_INTEREST
if TAXA_OF_INTEREST:
logging.info(f"Filtering for taxa of interest: {TAXA_OF_INTEREST}")
Expand Down Expand Up @@ -192,6 +189,92 @@ def streaming_pipeline(
pl.DataFrame(schema=EDGES_COLUMNS).write_csv(edges_path, separator="\t")
return

# Remove any existing chunk files so stale data is not merged into output
chunk_paths = glob.glob(edges_path + "_*")
if chunk_paths:
logging.info(f"Removing existing files: {edges_path}_*")
for chunk_path in chunk_paths:
os.remove(chunk_path)

if MAX_COASSEMBLY_SAMPLES == 1:
sample_list = sorted(list(samples))
filtered_targets = (
targets
.filter(pl.col("coverage") > MIN_COASSEMBLY_COVERAGE)
.select(sample = pl.col("sample"), target = pl.col("target"))
.unique()
)

sample_preclusters = (
sample_preclusters
.with_columns(pl.col("samples").str.split(",").cast(pl.List(pl.Categorical)))
.select(
source_sample = pl.col("samples").list.first(),
dest_sample = pl.col("samples").list.get(1, null_on_oob=True)
)
)

SINGLE_CHUNK_SIZE = CHUNK_SIZE // 5 + 1
num_chunks = (len(sample_list) + SINGLE_CHUNK_SIZE - 1) // SINGLE_CHUNK_SIZE

logging.info("Processing samples in chunks")
with pl.StringCache():
for i in range(num_chunks):
logging.info(f"Processing cluster {str(i+1)} of {str(num_chunks)}")
start_row = i * SINGLE_CHUNK_SIZE
chunk_samples = (
pl.DataFrame({"source_sample": sample_list[start_row:start_row + SINGLE_CHUNK_SIZE]})
.with_columns(pl.col("source_sample").cast(pl.Categorical))
)

chunk_targets = (
pl.concat([
sample_preclusters,
sample_preclusters.select(source_sample = "dest_sample", dest_sample = "source_sample")
])
.join(chunk_samples, on="source_sample")
.lazy()
.join(filtered_targets.select("target", source_sample=pl.col("sample").cast(pl.Categorical)), on="source_sample")
.join(targets.select("target", dest_sample=pl.col("sample").cast(pl.Categorical)), on=["dest_sample", "target"])
)

directional_edges = (
chunk_targets
.group_by("source_sample", "dest_sample")
.agg(target_ids = pl.col("target").cast(pl.Utf8).unique().sort().str.join(","))
.with_columns(
style = pl.lit("directional"),
cluster_size = pl.lit(2),
samples = pl.concat_str(["source_sample", "dest_sample"], separator=","),
)
.select("style", "cluster_size", "samples", "target_ids")
)

singleton_edges = (
filtered_targets
.select("target", source_sample=pl.col("sample").cast(pl.Categorical))
.join(chunk_samples.lazy(), on="source_sample")
.join(chunk_targets, on=["source_sample", "target"], how="anti")
.group_by("source_sample")
.agg(target_ids = pl.col("target").cast(pl.Utf8).unique().sort().str.join(","))
.with_columns(
style = pl.lit("singleton"),
cluster_size = pl.lit(1),
samples = pl.col("source_sample").cast(pl.Utf8),
)
.select("style", "cluster_size", "samples", "target_ids")
)

pl.concat([directional_edges, singleton_edges]).sink_csv(edges_path + f"_{i}", separator="\t")

(
pl.scan_csv(edges_path + "_*", separator="\t", schema_overrides=EDGES_COLUMNS)
.sink_csv(edges_path, separator="\t")
)

logging.info("Done")
return

logging.info("Using chosen clusters to find appropriate targets")
def process_chunk(df):
sparse_edges = (
Expand All @@ -214,16 +297,10 @@ def process_chunk(df):
return(sparse_edges)

num_chunks = (sample_preclusters.height + CHUNK_SIZE - 1) // CHUNK_SIZE # Ceiling division to include all rows
# Check if any edges_path_* files exist and remove them
if os.path.exists(edges_path + "_*"):
logging.info(f"Removing existing files: {edges_path}_*")
os.system(f"rm {edges_path}_*")

with pl.StringCache():
logging.info("Processing clusters in chunks")
for i in range(num_chunks):
if True: #i % 100 == 0:
logging.info(f"Processing cluster {str(i+1)} of {str(num_chunks)}")
logging.info(f"Processing cluster {str(i+1)} of {str(num_chunks)}")
start_row = i * CHUNK_SIZE
chunk = sample_preclusters.slice(start_row, CHUNK_SIZE)
processed_chunk = process_chunk(chunk)
Expand Down Expand Up @@ -251,10 +328,6 @@ def pipeline(
logging.warning("No unbinned sequences found")
return unbinned.rename({"found_in": "target"}), pl.DataFrame(schema=EDGES_COLUMNS)

if MAX_COASSEMBLY_SAMPLES < 2:
# Set to 2 to produce paired edges
MAX_COASSEMBLY_SAMPLES = 2

# Filter TAXA_OF_INTEREST
if TAXA_OF_INTEREST:
logging.info(f"Filtering for taxa of interest: {TAXA_OF_INTEREST}")
Expand All @@ -278,6 +351,57 @@ def pipeline(
logging.warning("No SingleM sequences found for the given samples")
return unbinned.with_columns(pl.col("target").cast(pl.Utf8)), pl.DataFrame(schema=EDGES_COLUMNS)

if MAX_COASSEMBLY_SAMPLES == 1:
filtered_targets = (
unbinned
.filter(pl.col("coverage") > MIN_COASSEMBLY_COVERAGE)
.select(
sample = pl.col("sample"),
target = pl.col("target"),
)
.unique()
)
source = filtered_targets.select(pl.col("target"), source_sample = pl.col("sample"))
dest = unbinned.select(pl.col("target"), dest_sample = pl.col("sample")).unique()

directional_edges = (
source
.join(dest, on="target")
.filter(pl.col("source_sample") != pl.col("dest_sample"))
.group_by("source_sample", "dest_sample")
.agg(target_ids = pl.col("target").cast(pl.Utf8).unique().sort().str.join(","))
.with_columns(
style = pl.lit("directional"),
cluster_size = pl.lit(2),
samples = pl.concat_str(["source_sample", "dest_sample"], separator=","),
)
.select("style", "cluster_size", "samples", "target_ids")
)

singleton_edges = (
filtered_targets
.join(
dest
.group_by("target")
.agg(sample_count = pl.n_unique("dest_sample"))
.filter(pl.col("sample_count") == 1)
.select("target"),
on="target",
)
.group_by("sample")
.agg(target_ids = pl.col("target").cast(pl.Utf8).unique().sort().str.join(","))
.with_columns(
style = pl.lit("singleton"),
cluster_size = pl.lit(1),
samples = pl.col("sample"),
)
.select("style", "cluster_size", "samples", "target_ids")
)

sparse_edges = pl.concat([directional_edges, singleton_edges])

return unbinned.with_columns(pl.col("target").cast(pl.Utf8)), sparse_edges

def process_groups(df):
if df.height == 1:
return pl.DataFrame(schema={"style": str, "cluster_size": pl.Int64, "samples": str, "target": pl.UInt32})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
gene sample sequence num_hits coverage taxonomy
S3.7.ribosomal_protein_S7 sample_1 ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA 5 10 Root; d__Bacteria; p__Pseudomonadota
S3.7.ribosomal_protein_S7 sample_1 TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC 5 10 Root; d__Bacteria; p__Pseudomonadota
S3.7.ribosomal_protein_S7 sample_1 ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA 5 11 Root; d__Bacteria; p__Pseudomonadota
S3.7.ribosomal_protein_S7 sample_1 TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC 5 11 Root; d__Bacteria; p__Pseudomonadota
S3.7.ribosomal_protein_S7 sample_1 ATGACTAGTCATAGCTAGATTTGAGGTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC 5 11 Root; d__Bacteria; p__Pseudomonadota
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
gene sample sequence num_hits coverage taxonomy
S3.7.ribosomal_protein_S7 sample_2 ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA 5 10 Root; d__Bacteria; p__Abyssobacteria
S3.7.ribosomal_protein_S7 sample_2 TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC 5 10 Root; d__Bacteria; p__Abyssobacteria
S3.7.ribosomal_protein_S7 sample_2 ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA 5 11 Root; d__Bacteria; p__Abyssobacteria
S3.7.ribosomal_protein_S7 sample_2 TGACTAGCTGGGCTAGCTATATTCTTTTTACGAGCGCGAGGAAAGCGACAGCGGCCAGGC 5 11 Root; d__Bacteria; p__Abyssobacteria
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
gene sample sequence num_hits coverage taxonomy
S3.7.ribosomal_protein_S7 sample_3 ATCGACTGACTTGATCGATCTTTGACGACGAGAGAGAGAGCGACGCGCCGAGAGGTTTCA 5 10 Root; d__Bacteria; p__Pseudomonadota
S3.7.ribosomal_protein_S7 sample_3 TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG 5 10 Root; d__Bacteria; p__Pseudomonadota
S3.7.ribosomal_protein_S7 sample_3 TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG 5 10 Root; d__Bacteria; p__Pseudomonadota
S3.7.ribosomal_protein_S7 sample_3 ATCGACTGACTTGATCGATCTTTGACGACGAGAGAGAGAGCGACGCGCCGAGAGGTTTCA 5 11 Root; d__Bacteria; p__Pseudomonadota
S3.7.ribosomal_protein_S7 sample_3 TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG 5 11 Root; d__Bacteria; p__Pseudomonadota
S3.7.ribosomal_protein_S7 sample_3 TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG 5 11 Root; d__Bacteria; p__Pseudomonadota
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
gene sample sequence num_hits coverage taxonomy
S3.7.ribosomal_protein_S7 sample_5 ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA 5 10 Root; d__Bacteria; p__Abyssobacteria
S3.7.ribosomal_protein_S7 sample_5 TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG 5 10 Root; d__Bacteria; p__Abyssobacteria
S3.7.ribosomal_protein_S7 sample_5 TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG 5 10 Root; d__Bacteria; p__Abyssobacteria
S3.7.ribosomal_protein_S7 sample_5 ATGACTAGTCATAGCTAGATTTGAGGCAGCAGGAGTTAGGAAAGCCCCCGGAGTTAGCTA 5 11 Root; d__Bacteria; p__Abyssobacteria
S3.7.ribosomal_protein_S7 sample_5 TACGAGCGGATCGTGCACGTAGTCAGTCGTTATATATCGAAAGCTCATGCGGCCATATCG 5 11 Root; d__Bacteria; p__Abyssobacteria
S3.7.ribosomal_protein_S7 sample_5 TACGAGCGGATCG---------------GTTATATATCGAAAGCTCATGCGGCCATATCG 5 11 Root; d__Bacteria; p__Abyssobacteria
Loading