Multi-condition signalling network inference#

In this notebook we show how to infer signalling networks for a multicondition setting.
Here we will use mouse RNA-seq data of a multicondition study, where we will compare normal to spontaneous leukemia and Sleeping-Beauty (SB) mediated leukemia.

In the first part, we will show how to estimate Transcription Factor activities from gene expression data, following the Decoupler tutorial for functional analysis.
Then, we will infer 2 networks for each of the 2 conditions and analyse differences.

# --- Saezlab tools ---
# https://decoupler-py.readthedocs.io/
import gzip
import os
import shutil
import tempfile
import urllib.request

import decoupler as dc
import numpy as np

# https://omnipathdb.org/
import omnipath as op

# Additional packages
import pandas as pd

# --- Additional libs ---
# Pydeseq for differential expression analysis
from pydeseq2.dds import DefaultInference, DeseqDataSet
from pydeseq2.ds import DeseqStats

# https://saezlab.github.io/
import corneto as cn

cn.info()
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 13
     10 import numpy as np
     12 # https://omnipathdb.org/
---> 13 import omnipath as op
     15 # Additional packages
     16 import pandas as pd

ModuleNotFoundError: No module named 'omnipath'
max_time = 300
seed = 0
# loading GEO GSE148679 dataset
url = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE148679&format=file&file=GSE148679%5Fcounts%5Fgsea%5Fanalysis%2Etxt%2Egz"

adata = None
with tempfile.TemporaryDirectory() as tmpdirname:
    # Path for the gzipped file in the temp folder
    gz_file_path = os.path.join(tmpdirname, "counts.txt.gz")

    # Download the file
    with urllib.request.urlopen(url) as response:
        with open(gz_file_path, "wb") as out_file:
            shutil.copyfileobj(response, out_file)

    # Decompress the file
    decompressed_file_path = gz_file_path[:-3]  # Removing '.gz' extension
    with gzip.open(gz_file_path, "rb") as f_in:
        with open(decompressed_file_path, "wb") as f_out:
            shutil.copyfileobj(f_in, f_out)

    adata = pd.read_csv(decompressed_file_path, index_col=0, sep="\t").T

adata.head()
ID Itm2a Sergef Fam109a Dhx9 Fam71e2 Ssu72 Olfr1018 Eif2b2 Mks1 Hebp2 ... Olfr372 Gosr1 Ctsw Ryk Rhd Pxmp4 Gm25500 4930455C13Rik Prss39 Reg4
Ebf1.1 265 332 306 27823 0 2969 0 2614 231 6 ... 0 2967 674 150 708 830 29 9 7 0
Ebf1.2 175 471 335 35978 0 3108 0 2924 192 0 ... 0 2724 155 142 869 867 5 4 16 0
Ebf1.3 266 382 325 30367 0 3207 0 2940 211 7 ... 0 2547 87 54 291 767 7 7 6 0
Ebf1.4 212 388 331 30406 0 3168 0 2871 209 8 ... 0 2568 169 115 1244 829 5 1 4 0
PE.1 321 312 290 22272 0 2631 0 2236 109 10 ... 0 2466 837 191 921 888 2 2 8 0

5 rows × 24420 columns

from anndata import AnnData

adata = AnnData(adata, dtype=np.float32)
adata.var_names_make_unique()
adata
AnnData object with n_obs × n_vars = 53 × 24420
# renaming conditions for clarity
adata.obs["condition"] = np.select(
    [
        adata.obs.index.str.contains("WT"),
        adata.obs.index.str.contains("Pax5"),
        adata.obs.index.str.contains("Ebf1"),
        adata.obs.index.str.contains("Leuk"),
        adata.obs.index.str.contains("SB"),
    ],
    ["Normal", "Normal", "Normal", "Leukemic_spontaneous", "Leukemic_SB"],
    default="Normal",
)
# Visualize metadata
adata.obs["condition"].value_counts()
condition
Leukemic_SB             31
Normal                  15
Leukemic_spontaneous     7
Name: count, dtype: int64

Splitting data for network inference and validation#

We will validate the network for the comparison Leukemic_SB vs Normal (enough samples)

adata1 = adata[adata.obs["condition"].isin(["Normal", "Leukemic_spontaneous"])].copy()
adata2 = adata[adata.obs["condition"].isin(["Normal", "Leukemic_SB"])].copy()

Differential expression analysis per condition#

Contrast 1: Leukemic_spontaneous vs Normal#

# Obtain genes that pass the thresholds
genes1 = dc.filter_by_expr(
    adata1,
    group="condition",
    min_count=10,
    min_total_count=15,
    large_n=10,
    min_prop=0.8,
)

# Filter by these genes
adata1 = adata1[:, genes1].copy()
adata1
AnnData object with n_obs × n_vars = 22 × 13585
    obs: 'condition'
# Estimation of differential expression
inference1 = DefaultInference()
dds1 = DeseqDataSet(
    adata=adata1,
    design_factors="condition",
    refit_cooks=True,
    inference=inference1,
)
dds1.deseq2()
Using None as control genes, passed at DeseqDataSet initialization
stat_res_leu1 = DeseqStats(dds1, contrast=["condition", "Leukemic_spontaneous", "Normal"], inference=inference1)

stat_res_leu1.summary()
Log2 fold change & Wald test p-value: condition Leukemic_spontaneous vs Normal
             baseMean  log2FoldChange     lfcSE       stat        pvalue  \
ID                                                                         
Itm2a      319.854773        1.040954  0.173013   6.016633  1.780827e-09   
Sergef     518.491011        1.165972  0.102389  11.387652  4.818002e-30   
Fam109a    353.074249        0.270614  0.099378   2.723069  6.467861e-03   
Dhx9     28622.656109       -0.009588  0.058553  -0.163746  8.699308e-01   
Ssu72     3007.691599       -0.047053  0.044517  -1.056965  2.905276e-01   
...               ...             ...       ...        ...           ...   
Ctsw       204.692803       -1.015983  0.513345  -1.979142  4.780005e-02   
Ryk         82.285327       -2.689443  0.325347  -8.266384  1.380837e-16   
Rhd        519.107839       -4.933983  0.449259 -10.982487  4.639733e-28   
Pxmp4      791.774976       -0.367077  0.068277  -5.376332  7.601849e-08   
Gm25500      9.166514        0.879435  0.429185   2.049083  4.045397e-02   

                 padj  
ID                     
Itm2a    5.850674e-09  
Sergef   9.569088e-29  
Fam109a  1.008099e-02  
Dhx9     8.900444e-01  
Ssu72    3.376523e-01  
...               ...  
Ctsw     6.505999e-02  
Ryk      8.512167e-16  
Rhd      7.603230e-27  
Pxmp4    2.122300e-07  
Gm25500  5.581062e-02  

[13585 rows x 6 columns]
results_df1 = stat_res_leu1.results_df
results_df1.sort_values(by="padj", ascending=True, inplace=False).head()
baseMean log2FoldChange lfcSE stat pvalue padj
ID
Egfl6 958.998149 -6.296339 0.244241 -25.779249 1.515461e-146 1.281368e-142
Polm 2883.081884 -7.255498 0.281540 -25.770766 1.886446e-146 1.281368e-142
Kdm5b 2202.514358 -5.510307 0.220789 -24.957387 1.775396e-137 8.039587e-134
Milr1 1647.365126 1.244237 0.050208 24.781432 1.421631e-135 4.828216e-132
Adgre5 10346.763067 -3.102891 0.125360 -24.751781 2.966356e-135 8.059589e-132

Contrast 2: Leukemic_SB vs Normal#

# Obtain genes that pass the thresholds
genes2 = dc.filter_by_expr(
    adata2,
    group="condition",
    min_count=10,
    min_total_count=15,
    large_n=10,
    min_prop=0.8,
)

# Filter by these genes
adata2 = adata2[:, genes2].copy()
adata2
AnnData object with n_obs × n_vars = 46 × 13404
    obs: 'condition'
# Estimation of differential expression
inference2 = DefaultInference()
dds2 = DeseqDataSet(
    adata=adata2,
    design_factors="condition",
    refit_cooks=True,
    inference=inference2,
)
dds2.deseq2()
Using None as control genes, passed at DeseqDataSet initialization
stat_res_leu2 = DeseqStats(dds2, contrast=["condition", "Leukemic_SB", "Normal"], inference=inference2)

stat_res_leu2.summary()
Log2 fold change & Wald test p-value: condition Leukemic_SB vs Normal
             baseMean  log2FoldChange     lfcSE       stat        pvalue  \
ID                                                                         
Itm2a      336.316323        0.710476  0.224228   3.168537  1.532085e-03   
Sergef     692.078769        1.218302  0.070644  17.245742  1.204559e-66   
Fam109a    382.268647        0.329745  0.097073   3.396873  6.816065e-04   
Dhx9     27794.304719       -0.031645  0.064081  -0.493828  6.214276e-01   
Ssu72     2994.457694        0.004410  0.061412   0.071816  9.427485e-01   
...               ...             ...       ...        ...           ...   
Ctsw       151.303280       -1.150780  0.374313  -3.074381  2.109397e-03   
Ryk         50.267450       -2.400835  0.309995  -7.744750  9.576982e-15   
Rhd        259.540718       -4.740190  0.424503 -11.166440  5.951875e-29   
Pxmp4      783.587183       -0.148554  0.095118  -1.561793  1.183368e-01   
Gm25500      8.618646        0.386326  0.302098   1.278810  2.009640e-01   

                 padj  
ID                     
Itm2a    2.906321e-03  
Sergef   2.409837e-64  
Fam109a  1.365862e-03  
Dhx9     6.749000e-01  
Ssu72    9.529865e-01  
...               ...  
Ctsw     3.916116e-03  
Ryk      6.703387e-14  
Rhd      1.325231e-27  
Pxmp4    1.573130e-01  
Gm25500  2.530504e-01  

[13404 rows x 6 columns]
results_df2 = stat_res_leu2.results_df
results_df2.sort_values(by="padj", ascending=True, inplace=False).head()
baseMean log2FoldChange lfcSE stat pvalue padj
ID
Mfsd2b 2252.744179 -9.412063 0.308793 -30.480194 4.769746e-204 6.393368e-200
Itga2b 8434.243135 -11.914356 0.401437 -29.679287 1.421232e-193 9.525096e-190
Adgrl4 148.349960 -9.117488 0.322229 -28.295027 3.978805e-176 1.777730e-172
Gp1ba 2396.209744 -8.226071 0.323265 -25.446857 7.648284e-143 2.562940e-139
Ppbp 17321.675345 -15.613716 0.621843 -25.108767 3.989417e-139 1.069483e-135

Prior knowledge with Decoupler and Omnipath#

# Retrieve CollecTRI gene regulatory network (through Omnipath)
collectri = dc.get_collectri(organism="mouse", split_complexes=False)
collectri.head()
source target weight PMID
0 Myc Tert 1 10022128;10491298;10606235;10637317;10723141;1...
1 Spi1 Bglap 1 10022617
2 Spi1 Bglap3 1 10022617
3 Spi1 Bglap2 1 10022617
4 Smad3 Jun 1 10022869;12374795

TF activity inference per condition#

Contrast 1: Leukemic_spontaneous vs Normal#

mat1 = results_df1[["stat"]].T.rename(index={"stat": "Leukemic_spontaneous.vs.Normal"})
mat1
ID Itm2a Sergef Fam109a Dhx9 Ssu72 Eif2b2 Mks1 Vps28 Setd6 Psma4 ... Cd37 Rag2 Itgb1bp2 Sec23ip Gosr1 Ctsw Ryk Rhd Pxmp4 Gm25500
Leukemic_spontaneous.vs.Normal 6.016633 11.387652 2.723069 -0.163746 -1.056965 7.738037 3.636643 2.671034 6.008164 3.939368 ... 3.891251 -2.875219 -0.371611 -11.321478 2.340759 -1.979142 -8.266384 -10.982487 -5.376332 2.049083

1 rows × 13585 columns

tf_acts1, tf_pvals1 = dc.run_ulm(mat=mat1, net=collectri, verbose=True)
tf_acts1
Running ulm on mat with 1 samples and 13585 targets for 582 sources.
Abl1 Ahr Aire Apex1 Ar Arid1a Arid3a Arid3b Arid4a Arid4b ... Zic1 Zic2 Zkscan7 Znf143 Znf148 Znf263 Znf354c Znf382 Znf436 Znf76
Leukemic_spontaneous.vs.Normal 1.045655 0.315712 -2.055446 0.340905 -1.094879 -0.176577 -1.167034 -1.821818 0.263631 2.369551 ... 0.530625 -0.030327 1.025575 -0.545273 -1.862967 -0.883297 0.466678 -1.014019 -1.812489 -0.305558

1 rows × 582 columns

dc.plot_barplot(
    acts=tf_acts1,
    contrast="Leukemic_spontaneous.vs.Normal",
    top=25,
    vertical=True,
    figsize=(3, 6),
)
../../../_images/68d972b0bdb7053bbd25ff9e3cee5141cb35bfe8ff6164d922569cf2835bf7e9.png

Contrast 2: Leukemic_SB vs Normal#

mat2 = results_df2[["stat"]].T.rename(index={"stat": "Leukemic_SB.vs.Normal"})
mat2
ID Itm2a Sergef Fam109a Dhx9 Ssu72 Eif2b2 Mks1 Vps28 Setd6 Psma4 ... Cd37 Rag2 Itgb1bp2 Sec23ip Gosr1 Ctsw Ryk Rhd Pxmp4 Gm25500
Leukemic_SB.vs.Normal 3.168537 17.245742 3.396873 -0.493828 0.071816 6.81284 6.314823 3.452622 5.558161 5.259825 ... 1.184029 -2.941684 -2.355502 -7.619793 2.099912 -3.074381 -7.74475 -11.16644 -1.561793 1.27881

1 rows × 13404 columns

tf_acts2, tf_pvals2 = dc.run_ulm(mat=mat2, net=collectri, verbose=True)
tf_acts2
Running ulm on mat with 1 samples and 13404 targets for 571 sources.
Abl1 Ahr Aire Apex1 Ar Arid1a Arid3a Arid3b Arid4a Arid4b ... Zic1 Zic2 Zkscan7 Znf143 Znf148 Znf263 Znf354c Znf382 Znf436 Znf76
Leukemic_SB.vs.Normal 1.053067 0.998199 -0.641643 0.976317 -0.111084 0.676713 -0.908069 -1.489229 0.929773 2.253084 ... 0.658906 -0.049635 1.316132 -0.616619 -1.524429 -0.361584 1.869067 -1.975475 -1.940144 0.405888

1 rows × 571 columns

dc.plot_barplot(
    acts=tf_acts2,
    contrast="Leukemic_SB.vs.Normal",
    top=25,
    vertical=True,
    figsize=(3, 6),
)
../../../_images/d2c825bfedd2dd1cc756f26c836b912d21b093ab11a8ade41b4df0f727a271a8.png

Retrieving potential receptors per condition#

# We obtain ligand-receptor interactions from Omnipath, and we keep only the receptors
# This is our list of a prior potential receptors from which we will infer the network
unique_receptors = set(
    op.interactions.LigRecExtra.get(organisms="mouse", genesymbols=True)["target_genesymbol"].values.tolist()
)
len(unique_receptors)
849

Contrast 1: Leukemic_spontaneous vs Normal#

df_de_receptors1 = results_df1.loc[results_df1.index.intersection(unique_receptors)]
df_de_receptors1 = df_de_receptors1.sort_values(by="stat", ascending=False)
# We will take the top 20 receptors that increased the expression after treatment
df_top_receptors1 = df_de_receptors1.head(30)
df_top_receptors1.head()
baseMean log2FoldChange lfcSE stat pvalue padj
ID
Mrc1 314.324326 3.312083 0.179665 18.434750 6.913703e-76 2.935083e-73
St14 1073.819725 1.904158 0.116095 16.401768 1.857530e-60 3.077383e-58
Cd48 5533.470898 0.664655 0.054604 12.172257 4.368342e-34 1.216064e-32
Ifngr1 6551.268902 1.418280 0.120961 11.725091 9.479917e-32 2.182791e-30
Lrrc4c 82.859784 10.019258 0.895645 11.186639 4.740624e-29 8.575416e-28

Contrast 2: Leukemic_SB vs Normal#

df_de_receptors2 = results_df2.loc[results_df2.index.intersection(unique_receptors)]
df_de_receptors2 = df_de_receptors2.sort_values(by="stat", ascending=False)
# We will take the top 20 receptors that increased the expression after treatment
df_top_receptors2 = df_de_receptors2.head(30)
df_top_receptors2.head()
baseMean log2FoldChange lfcSE stat pvalue padj
ID
Lsr 770.417182 4.360408 0.346097 12.598814 2.143428e-36 8.425371e-35
Ptprs 5164.424536 1.303309 0.109463 11.906392 1.096190e-32 3.250738e-31
Nlgn2 1309.701665 1.950302 0.165263 11.801208 3.847485e-32 1.104319e-30
Cd244 2529.473998 2.992844 0.255033 11.735143 8.418360e-32 2.341073e-30
Lrrc4c 267.125842 10.650499 0.916744 11.617744 3.348637e-31 8.783783e-30

Inferring intracellular signalling network with CORNETO#

cn.info()
Installed version:v1.0.0.dev37+db611743
Available backends:CVXPY v1.6.5
Default backend (corneto.opt):CVXPY
Installed solvers:GUROBI, HIGHS, SCIP, SCIPY
Graphviz version:v0.20.3
Installed path:/Users/pablorodriguezmier/Documents/work/repos/pablormier/corneto/corneto
Repository:https://github.com/saezlab/corneto
from corneto.methods.future import CarnivalFlow

# CarnivalFlow.show_citations()

Setting prior knowledge graph#

pkn = op.interactions.OmniPath.get(organisms="mouse", databases=["SIGNOR"], genesymbols=True)
pkn = pkn[pkn.consensus_direction == True]
pkn.head()
source target source_genesymbol target_genesymbol is_directed is_stimulation is_inhibition consensus_direction consensus_stimulation consensus_inhibition curation_effort references sources n_sources n_primary_sources n_references references_stripped
0 P0C605 Q9QZC1 Prkg1 Trpc3 True False True True False True 9 HPRD:14983059;KEA:14983059;ProtMapper:14983059... HPRD;HPRD_KEA;HPRD_MIMP;KEA;MIMP;PhosphoPoint;... 15 8 2 14983059;16331690
1 P0C605 Q9WVC5 Prkg1 Trpc7 True True False True True False 3 SIGNOR:21402151;TRIP:21402151;iPTMnet:21402151 SIGNOR;TRIP;iPTMnet 3 3 1 21402151
2 Q8K2C7 Q9EPK8 Os9 Trpv4 True True True True True True 3 HPRD:17932042;SIGNOR:17932042;TRIP:17932042 HPRD;SIGNOR;TRIP 3 3 1 17932042
3 P35821 Q91WD2 Ptpn1 Trpv6 True False True True False True 11 DEPOD:15894168;DEPOD:17197020;HPRD:15894168;In... DEPOD;HPRD;IntAct;Lit-BM-17;SIGNOR;SPIKE_LC;TRIP 7 6 2 15894168;17197020
4 P68040 Q8CIR4 Rack1 Trpm6 True False True True False True 2 SIGNOR:18258429;TRIP:18258429 SIGNOR;TRIP 2 2 1 18258429
pkn["interaction"] = pkn["is_stimulation"].astype(int) - pkn["is_inhibition"].astype(int)
sel_pkn = pkn[["source_genesymbol", "interaction", "target_genesymbol"]]
sel_pkn.head()
source_genesymbol interaction target_genesymbol
0 Prkg1 -1 Trpc3
1 Prkg1 1 Trpc7
2 Os9 0 Trpv4
3 Ptpn1 -1 Trpv6
4 Rack1 -1 Trpm6
# We create the CORNETO graph by importing the edges and interaction
G = cn.Graph.from_sif_tuples([(r[0], r[1], r[2]) for _, r in sel_pkn.iterrows() if r[1] != 0])
G.shape  # nodes, edges
(4304, 9505)

Identifying target TFs per condition#

max_pval = 0.01

Contrast 1: Leukemic_spontaneous vs Normal#

# As measurements, we take the estimated TFs, we will filter out TFs with p-val > 0.001
significant_tfs1 = (
    tf_acts1[tf_pvals1 <= max_pval].T.dropna().sort_values(by="Leukemic_spontaneous.vs.Normal", ascending=False)
)
significant_tfs1.head()
Leukemic_spontaneous.vs.Normal
Myc 5.768653
Ubtf 3.480495
Mtf2 2.688996
Runx3 -2.588453
Elf3 -2.588482
# We keep only the ones in the PKN graph
measurements1 = significant_tfs1.loc[significant_tfs1.index.intersection(G.V)].to_dict()[
    "Leukemic_spontaneous.vs.Normal"
]
measurements1
{'Myc': 5.768653392791748,
 'Ubtf': 3.4804954528808594,
 'Runx3': -2.5884528160095215,
 'Elf3': -2.58848237991333,
 'Mecom': -2.608494997024536,
 'Cebpa': -2.615208387374878,
 'Nkx3-1': -2.6257927417755127,
 'Esr2': -2.6459813117980957,
 'Pax2': -2.677741527557373,
 'Maf': -2.685983419418335,
 'Mef2d': -2.7207248210906982,
 'Phox2a': -2.782597303390503,
 'Jun': -2.8008198738098145,
 'Egr1': -2.8238978385925293,
 'Pgr': -2.8435566425323486,
 'Hmga1': -2.887751817703247,
 'Yy1': -2.9141457080841064,
 'Tcf3': -2.942206621170044,
 'Ppara': -2.94732928276062,
 'Mef2a': -2.970165491104126,
 'Smad3': -2.971665382385254,
 'Myod1': -2.983170986175537,
 'Smad2': -3.014216899871826,
 'Nfic': -3.015226125717163,
 'Plag1': -3.0289695262908936,
 'Foxo1': -3.139951229095459,
 'Cebpg': -3.163689374923706,
 'Ppard': -3.187347412109375,
 'Rara': -3.2031116485595703,
 'Fos': -3.221236228942871,
 'Esr1': -3.246117115020752,
 'Spi1': -3.3310794830322266,
 'Nr3c1': -3.350456953048706,
 'Pax6': -3.3512790203094482,
 'Nfatc1': -3.4140117168426514,
 'Vdr': -3.468764066696167,
 'Pax3': -3.476954936981201,
 'Srf': -3.7604994773864746,
 'Pknox1': -3.8543660640716553,
 'Cebpb': -4.067713737487793,
 'Ets1': -4.081283092498779,
 'Foxo3': -4.1075849533081055,
 'Fli1': -4.756601810455322,
 'Stat6': -4.842164516448975,
 'Sp3': -5.319258689880371,
 'Sp1': -5.941048622131348,
 'Runx1': -5.983950138092041,
 'Gata1': -6.134815216064453}

Contrast 2: Leukemic_SB vs Normal#

# As measurements, we take the estimated TFs, we will filter out TFs with p-val > 0.001
significant_tfs2 = tf_acts2[tf_pvals2 <= max_pval].T.dropna().sort_values(by="Leukemic_SB.vs.Normal", ascending=False)
significant_tfs2.head()
Leukemic_SB.vs.Normal
Myc 5.395146
Asxl1 3.256416
Bcor 2.956437
Srebf2 2.764174
Ubtf 2.616793
# We keep only the ones in the PKN graph
measurements2 = significant_tfs2.loc[significant_tfs2.index.intersection(G.V)].to_dict()["Leukemic_SB.vs.Normal"]
measurements2
{'Myc': 5.395146369934082,
 'Asxl1': 3.256416082382202,
 'Bcor': 2.9564368724823,
 'Srebf2': 2.764174461364746,
 'Ubtf': 2.616792917251587,
 'Hoxa9': -2.6575746536254883,
 'Pax5': -2.813767671585083,
 'Sp3': -2.854238510131836,
 'Foxo3': -2.857527256011963,
 'Mecom': -2.883071184158325,
 'Cebpb': -2.8957369327545166,
 'Ets1': -3.012101888656616,
 'Dot1l': -3.0354621410369873,
 'Irf9': -3.061843156814575,
 'Pax2': -3.125239133834839,
 'Pax6': -3.176431655883789,
 'Nkx3-1': -3.185379981994629,
 'Stat6': -3.2588818073272705,
 'Phox2a': -4.103046894073486,
 'Runx1': -5.518087387084961,
 'Pknox1': -5.823720455169678,
 'Fli1': -6.548583507537842,
 'Gata1': -8.030933380126953}
def balance_dicts_by_top_abs(dict1, dict2, discretize=False):
    # Decide which is smaller and which is larger
    if len(dict1) <= len(dict2):
        small, large = dict1, dict2
    else:
        small, large = dict2, dict1

    n = len(small)
    # Take the n keys from large with the largest abs(values)
    top_large_items = sorted(large.items(), key=lambda kv: abs(kv[1]), reverse=True)[:n]
    
    # Prepare output dicts
    small_balanced = dict(small)  # already of size n
    large_balanced = dict(top_large_items)
    
    if discretize:
        small_balanced = {k: np.sign(v) for k, v in small_balanced.items()}
        large_balanced = {k: np.sign(v) for k, v in large_balanced.items()}
    
    return small_balanced, large_balanced

# Discretize not needed, but simplifies multi-cond analysis since both conditions have the same cost
d_measurements1, d_measurements2 = balance_dicts_by_top_abs(measurements1, measurements2, discretize=True)
d_measurements1
{'Myc': 1.0,
 'Asxl1': 1.0,
 'Bcor': 1.0,
 'Srebf2': 1.0,
 'Ubtf': 1.0,
 'Hoxa9': -1.0,
 'Pax5': -1.0,
 'Sp3': -1.0,
 'Foxo3': -1.0,
 'Mecom': -1.0,
 'Cebpb': -1.0,
 'Ets1': -1.0,
 'Dot1l': -1.0,
 'Irf9': -1.0,
 'Pax2': -1.0,
 'Pax6': -1.0,
 'Nkx3-1': -1.0,
 'Stat6': -1.0,
 'Phox2a': -1.0,
 'Runx1': -1.0,
 'Pknox1': -1.0,
 'Fli1': -1.0,
 'Gata1': -1.0}
d_measurements2
{'Gata1': -1.0,
 'Runx1': -1.0,
 'Sp1': -1.0,
 'Myc': 1.0,
 'Sp3': -1.0,
 'Stat6': -1.0,
 'Fli1': -1.0,
 'Foxo3': -1.0,
 'Ets1': -1.0,
 'Cebpb': -1.0,
 'Pknox1': -1.0,
 'Srf': -1.0,
 'Ubtf': 1.0,
 'Pax3': -1.0,
 'Vdr': -1.0,
 'Nfatc1': -1.0,
 'Pax6': -1.0,
 'Nr3c1': -1.0,
 'Spi1': -1.0,
 'Esr1': -1.0,
 'Fos': -1.0,
 'Rara': -1.0,
 'Ppard': -1.0}

Creating a CARNIVAL problem per condition#

Contrast 1: Leukemic_spontaneous vs Normal#

# We will infer the direction, so for the inputs, we use a value of 0 (=unknown direction)
inputs1 = {k: 0 for k in df_top_receptors1.index.intersection(G.V).values}
inputs1
{'Ifngr1': 0,
 'Lrrc4c': 0,
 'Tnfrsf11a': 0,
 'Cdon': 0,
 'Notch1': 0,
 'Ptprs': 0,
 'Nlgn2': 0,
 'Flt3': 0,
 'Ephb4': 0,
 'Il2rg': 0,
 'Marco': 0,
 'Axl': 0,
 'Tlr4': 0,
 'Ifngr2': 0,
 'Znrf3': 0}
# Create the dataset in standard format
carnival_data1 = dict()
for inp, v in inputs1.items():
    carnival_data1[inp] = dict(value=v, role="input", mapping="vertex")
for out, v in d_measurements1.items():
    carnival_data1[out] = dict(value=v, role="output", mapping="vertex")
data1 = cn.Data.from_cdict({"sample1": carnival_data1})
data1
Data(n_samples=1, n_feats=[38])

Contrast 2: Leukemic_SB vs Normal#

# We will infer the direction, so for the inputs, we use a value of 0 (=unknown direction)
inputs2 = {k: 0 for k in df_top_receptors2.index.intersection(G.V).values}
inputs2
{'Ptprs': 0,
 'Nlgn2': 0,
 'Lrrc4c': 0,
 'Tnfrsf11a': 0,
 'Ifngr1': 0,
 'Marco': 0,
 'Ifngr2': 0,
 'Ctla4': 0,
 'Znrf3': 0,
 'Flt3': 0,
 'Ephb4': 0,
 'Il9r': 0,
 'Icam1': 0,
 'Axl': 0}
# Create the dataset in standard format
carnival_data2 = dict()
for inp, v in inputs2.items():
    carnival_data2[inp] = dict(value=v, role="input", mapping="vertex")
for out, v in d_measurements2.items():
    carnival_data2[out] = dict(value=v, role="output", mapping="vertex")
data2 = cn.Data.from_cdict({"sample2": carnival_data2})
data2
Data(n_samples=1, n_feats=[37])

Solving multi-condition CARNIVAL problem with CORNETO#

data = cn.Data.from_cdict({"sample1": carnival_data1, "sample2": carnival_data2})
data
Data(n_samples=2, n_feats=[38 37])
from corneto.utils import check_gurobi

check_gurobi()
Gurobipy successfully imported.
Gurobi environment started successfully.
Starting optimization of the test model...
Test optimization was successful.
Gurobi environment disposed.
Gurobi is correctly installed and working.
True
def subgraph(G, inp, out):
    inp_s = set(G.V).intersection(inp)
    out_s = set(G.V).intersection(out)
    Gs = G.prune(inp_s, out_s)
    tot_inp = set(Gs.V).intersection(inp)
    tot_out = set(Gs.V).intersection(out)
    print(f"Inputs: ({len(tot_inp)}/{len(inp)}), Outputs: ({len(tot_out)}/{len(out)})")
    return Gs
    

Gs1 = subgraph(G, df_top_receptors1.index.tolist(), list(d_measurements1.keys()))
Gs2 = subgraph(G, df_top_receptors2.index.tolist(), list(d_measurements2.keys()))
Inputs: (7/30), Outputs: (19/23)
Inputs: (4/30), Outputs: (22/23)
<corneto.graph._graph.Graph at 0x31cfdfd90>
c = CarnivalFlow(lambda_reg=0, indirect_rule_penalty=1)
P = c.build(G, data)
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
len(set(c.processed_graph.V).intersection(d_measurements1.keys()))
19
len(set(c.processed_graph.V).intersection(d_measurements2.keys()))
22
# How many TFs are in common between conditions?
s1 = set(c.processed_graph.V).intersection(d_measurements1.keys())
s2 = set(c.processed_graph.V).intersection(d_measurements2.keys())
for g in s1.intersection(s2):
    print(g, d_measurements1[g], d_measurements2[g])
Gata1 -1.0 -1.0
Foxo3 -1.0 -1.0
Sp3 -1.0 -1.0
Myc 1.0 1.0
Runx1 -1.0 -1.0
Fli1 -1.0 -1.0
Cebpb -1.0 -1.0
Stat6 -1.0 -1.0
Pax6 -1.0 -1.0
Ets1 -1.0 -1.0
Ubtf 1.0 1.0

Which is the best lambda to choose?#

To choose a robust value of \(\lambda\) we sample multiple solutions, as in the tutorial https://saezlab.github.io/corneto/dev/tutorials/network-sampler.html#vertex-based-perturbation.
We define 3 metrics to choose \(\lambda\):

\(sign\_agreement\_ratio = \frac{n.\:of\:nodes\:matching\:sign\:with\:stat\:value}{total\:n.\:of\:nodes}\)

\(de\_overlap\_ratio = \frac{n.\:of\:nodes\:also\:differentially\:expressed}{total\:n.\:of\:nodes}\)

\(overall\_agreement = \frac{sign\_agreement\_ratio\:+\:de\_overlap\_ratio}{2}\)

# retrieving alternative network solutions. This snippet will take a while to run
from corneto.methods.sampler import sample_alternative_solutions

lambda_val = [0, 0.01, 0.1, 0.2, 0.3, 0.5, 0.7, 0.9, 0.95]
other_optimal_v = dict()  # collecting the vertex sampled solutions for each lambda
index = dict()
for i in lambda_val:
    print("lambda: ", i)
    c = CarnivalFlow(lambda_reg=i, indirect_rule_penalty=1)
    P = c.build(G, data)
    vertex_results = sample_alternative_solutions(
        P,
        "vertex_value",
        percentage=0.03,
        scale=0.03,
        rel_opt_tol=0.05,
        max_samples=30,  # number of alternative solutions to sample
        solver_kwargs=dict(solver="gurobi", max_seconds=max_time, mip_gap=0.01, seed=seed),
    )
    other_optimal_v["lambda" + str(i)] = vertex_results["vertex_value"]
    index["lambda" + str(i)] = c.processed_graph.V
lambda:  0
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
Set parameter Username
Set parameter LicenseID to value 2593994
Academic license - for non-commercial use only - expires 2025-12-02
lambda:  0.01
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
lambda:  0.1
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
lambda:  0.2
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
lambda:  0.3
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
lambda:  0.5
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
lambda:  0.7
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
lambda:  0.9
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
lambda:  0.95
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
def compute_nodes_metrics(nodes, inputs, measurements, results_df, column):
    """Computes the metrics for the nodes in the network based on the provided inputs, measurements, and results dataframe.

    Args:
        nodes (pd.DataFrame): data frame containing the node scores per condition.
        inputs (dict): Dictionary of input nodes with their values.
        measurements (dict): Dictionary of output nodes with their values.
        results_df (pd.DataFrame): DataFrame containing the results of differential expression analysis.
        column (str): Column name to be used for filtering conditions.

    Returns:
        sign_agreement_ratio (float): Ratio of nodes with the same sign as the DE analysis.
        de_overlap_ratio (float): Ratio of DE genes that overlap with the nodes.
        overall_agreement (float): Overall agreement between DE genes and nodes.
    """
    nodes = nodes.loc[(nodes[column] != 0.0)][column]  # nodes with a score different from 0
    de = results_df.loc[results_df["padj"] < 0.05].index.tolist()  # de genes
    in_out_nodes = set(measurements.keys()).union(
        set(nodes.index[nodes.index.isin(inputs.keys())])
    )  # collecting input and output nodes which we don´t want to evaluate
    middle_nodes = set(nodes.index).difference(in_out_nodes)
    commmon_genes = list(
        set(results_df.index).intersection(middle_nodes)
    )  # common genes between network and DE analysis genes
    stat = results_df.loc[commmon_genes, "stat"]
    stat_sign = np.sign(stat)
    middle_nodes_sign = np.sign(nodes.loc[commmon_genes])
    sign_agreement_ratio = 0
    if len(middle_nodes_sign) > 0:
        sign_agreement_ratio = (middle_nodes_sign == stat_sign).sum() / len(middle_nodes_sign)
    de_overlap_ratio = 0
    if len(middle_nodes) > 0:
        de_overlap_ratio = len(set(de).intersection(middle_nodes)) / len(middle_nodes)
    overall_agreement = (np.array(de_overlap_ratio) + np.array(sign_agreement_ratio)) / 2

    return sign_agreement_ratio, de_overlap_ratio, overall_agreement
# Initialize metric containers
metrics_avg_1_per_lambda = []
metrics_avg_2_per_lambda = []
metrics_sd_1_per_lambda = []
metrics_sd_2_per_lambda = []

for k in other_optimal_v.keys():
    metrics1_per_sample = []
    metrics2_per_sample = []

    for i in range(len(other_optimal_v[k])):
        n = pd.DataFrame(
            other_optimal_v[k][i],
            index=index[k],
            columns=["vertex_activity1", "vertex_activity2"],
        )  # retrieving node scores

        sign_agreement1, de_overlap1, overall_agreement1 = compute_nodes_metrics(
            n, inputs1, d_measurements1, results_df1, "vertex_activity1"
        )
        sign_agreement2, de_overlap2, overall_agreement2 = compute_nodes_metrics(
            n, inputs2, d_measurements2, results_df2, "vertex_activity2"
        )

        metrics1_per_sample.append(
            {
                "sign_agreement": sign_agreement1,
                "de_overlap": de_overlap1,
                "overall": overall_agreement1,
            }
        )
        metrics2_per_sample.append(
            {
                "sign_agreement": sign_agreement2,
                "de_overlap": de_overlap2,
                "overall": overall_agreement2,
            }
        )

    metrics1_per_sample_df = pd.DataFrame(metrics1_per_sample)
    metrics2_per_sample_df = pd.DataFrame(metrics2_per_sample)
    metrics_avg_1_per_lambda.append(metrics1_per_sample_df.mean().to_dict())
    metrics_avg_2_per_lambda.append(metrics2_per_sample_df.mean().to_dict())
    metrics_sd_1_per_lambda.append(metrics1_per_sample_df.std().to_dict())
    metrics_sd_2_per_lambda.append(metrics2_per_sample_df.std().to_dict())

# Convert lists of dictionaries to DataFrames
m_avg_df1 = pd.DataFrame(metrics_avg_1_per_lambda)
m_avg_df2 = pd.DataFrame(metrics_avg_2_per_lambda)
m_sd_df1 = pd.DataFrame(metrics_sd_1_per_lambda)
m_sd_df2 = pd.DataFrame(metrics_sd_2_per_lambda)

##plotting metrics
import matplotlib.pyplot as plt

metric_name = ["sign_agreement", "de_overlap", "overall"]
metric_label = ["Sign agreement ratio", "DE overlap ratio", "Overall agreement"]

fig, axes = plt.subplots(1, 3, figsize=(25, 5), sharex=True)
for i in range(len(metric_name)):
    axes[i].errorbar(
        lambda_val,
        m_avg_df1[metric_name[i]],
        yerr=m_sd_df1[metric_name[i]],
        label="Leukemic_spontaneous",
        marker="o",
    )
    axes[i].errorbar(
        lambda_val,
        m_avg_df2[metric_name[i]],
        yerr=m_sd_df2[metric_name[i]],
        label="Leukemic_SB",
        marker="o",
    )
    axes[i].grid()
    axes[i].set_xscale("linear")  # adjust scale according to your results
    axes[i].set_xlabel("Lambda")
    axes[i].set_ylabel(metric_label[i])
    axes[i].set_title(metric_label[i] + " vs Lambda")
    if i == 0:
        axes[i].legend(loc="upper left")
../../../_images/3072a818f9fa19a67a381114749a071902af04700a0b8bbb3570b53c55d6ec87.png

We choose the lambda value basing on the Overall agreement metric. In this case a reasonable value of \(\lambda\) is 2

c = CarnivalFlow(lambda_reg=0.5,indirect_rule_penalty=1)
P = c.build(G, data)

# Optimal solutions are found when Gap is close to 0
P.solve(solver="GUROBI", verbosity=1)
Unreachable vertices for sample: 25
Unreachable vertices for sample: 14
===============================================================================
                                     CVXPY                                     
                                     v1.6.5                                    
===============================================================================
(CVXPY) Jun 06 04:25:39 PM: Your problem has 16902 variables, 47640 constraints, and 1 parameters.
(CVXPY) Jun 06 04:25:39 PM: It is compliant with the following grammars: DCP, DQCP
(CVXPY) Jun 06 04:25:39 PM: CVXPY will first compile your problem; then, it will invoke a numerical solver to obtain a solution.
(CVXPY) Jun 06 04:25:39 PM: Your problem is compiled with the CPP canonicalization backend.
-------------------------------------------------------------------------------
                                  Compilation                                  
-------------------------------------------------------------------------------
(CVXPY) Jun 06 04:25:39 PM: Compiling problem (target solver=GUROBI).
(CVXPY) Jun 06 04:25:39 PM: Reduction chain: CvxAttr2Constr -> Qp2SymbolicQp -> QpMatrixStuffing -> GUROBI
(CVXPY) Jun 06 04:25:39 PM: Applying reduction CvxAttr2Constr
(CVXPY) Jun 06 04:25:39 PM: Applying reduction Qp2SymbolicQp
(CVXPY) Jun 06 04:25:39 PM: Applying reduction QpMatrixStuffing
(CVXPY) Jun 06 04:25:39 PM: Applying reduction GUROBI
(CVXPY) Jun 06 04:25:39 PM: Finished problem compilation (took 7.523e-02 seconds).
(CVXPY) Jun 06 04:25:39 PM: (Subsequent compilations of this problem, using the same arguments, should take less time.)
-------------------------------------------------------------------------------
                                Numerical solver                               
-------------------------------------------------------------------------------
(CVXPY) Jun 06 04:25:39 PM: Invoking solver GUROBI  to obtain a solution.
Set parameter OutputFlag to value 1
Set parameter QCPDual to value 1
Gurobi Optimizer version 12.0.2 build v12.0.2rc0 (mac64[arm] - Darwin 24.5.0 24F74)

CPU model: Apple M4 Pro
Thread count: 12 physical cores, 12 logical processors, using up to 12 threads

Non-default parameters:
QCPDual  1

Optimize a model with 47640 rows, 16902 columns and 169821 nonzeros
Model fingerprint: 0x5a0a2a3f
Variable types: 4062 continuous, 12840 integer (12840 binary)
Coefficient statistics:
  Matrix range     [5e-01, 7e+02]
  Objective range  [5e-01, 2e+00]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 1e+03]
Found heuristic solution: objective 0.0000000
Presolve removed 23196 rows and 1817 columns
Presolve time: 0.36s
Presolved: 24444 rows, 15085 columns, 140655 nonzeros
Variable types: 3564 continuous, 11521 integer (11505 binary)
Deterministic concurrent LP optimizer: primal and dual simplex
Showing primal log only...

Concurrent spin time: 0.00s

Solved with dual simplex

Root relaxation: objective -1.056250e+01, 2374 iterations, 0.11 seconds (0.28 work units)

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time

     0     0  -10.56250    0   73    0.00000  -10.56250      -     -    0s
     0     0  -10.56250    0   73    0.00000  -10.56250      -     -    0s
     0     0   -9.82812    0  140    0.00000   -9.82812      -     -    0s
     0     0   -9.75000    0   49    0.00000   -9.75000      -     -    0s
     0     0   -9.35547    0  165    0.00000   -9.35547      -     -    0s
     0     0   -9.33594    0  161    0.00000   -9.33594      -     -    0s
     0     0   -9.17204    0  178    0.00000   -9.17204      -     -    0s
     0     0   -9.12500    0  109    0.00000   -9.12500      -     -    1s
     0     0   -9.12500    0   97    0.00000   -9.12500      -     -    1s
     0     0   -9.12500    0   98    0.00000   -9.12500      -     -    1s
H    0     0                      -0.5000000   -9.12500  1725%     -    1s
     0     0   -9.12500    0   98   -0.50000   -9.12500  1725%     -    1s
     0     0   -9.00000    0   62   -0.50000   -9.00000  1700%     -    1s
     0     0   -9.00000    0   79   -0.50000   -9.00000  1700%     -    1s
     0     0   -9.00000    0  110   -0.50000   -9.00000  1700%     -    1s
     0     0   -9.00000    0  110   -0.50000   -9.00000  1700%     -    1s
     0     0   -9.00000    0   13   -0.50000   -9.00000  1700%     -    1s
     0     0   -9.00000    0   13   -0.50000   -9.00000  1700%     -    1s
     0     0   -9.00000    0   13   -0.50000   -9.00000  1700%     -    1s
     0     0   -9.00000    0   13   -0.50000   -9.00000  1700%     -    2s
     0     0   -9.00000    0   64   -0.50000   -9.00000  1700%     -    2s
     0     0   -9.00000    0   64   -0.50000   -9.00000  1700%     -    2s
     0     0   -9.00000    0   13   -0.50000   -9.00000  1700%     -    2s
     0     0   -9.00000    0   13   -0.50000   -9.00000  1700%     -    2s
     0     2   -8.93817    0   12   -0.50000   -8.93817  1688%     -    2s
H  231   239                      -5.0000000   -8.93817  78.8%   248    3s
H  271   273                      -5.5000000   -8.93817  62.5%   234    3s
H  408   381                      -6.0000000   -8.93817  49.0%   199    4s
H  805   676                      -8.0000000   -8.93817  11.7%   156    4s
  1229   609   -8.50000   10   12   -8.00000   -8.93817  11.7%   127    5s

Cutting planes:
  Gomory: 17
  Cover: 105
  Implied bound: 57
  Clique: 42
  MIR: 64
  Inf proof: 4
  Zero half: 13
  RLT: 11
  Relax-and-lift: 10

Explored 3276 nodes (288205 simplex iterations) in 6.08 seconds (14.72 work units)
Thread count was 12 (of 12 available processors)

Solution count 6: -8 -6 -5.5 ... 0
No other solutions better than -8

Optimal solution found (tolerance 1.00e-04)
Best objective -8.000000000000e+00, best bound -8.000000000000e+00, gap 0.0000%
-------------------------------------------------------------------------------
                                    Summary                                    
-------------------------------------------------------------------------------
(CVXPY) Jun 06 04:25:45 PM: Problem status: optimal
(CVXPY) Jun 06 04:25:45 PM: Optimal value: 3.300e+01
(CVXPY) Jun 06 04:25:45 PM: Compilation took 7.523e-02 seconds
(CVXPY) Jun 06 04:25:45 PM: Solver (including time spent in interface) took 6.276e+00 seconds
Problem(Minimize(Expression(AFFINE, UNKNOWN, (1,))), [Inequality(Constant(CONSTANT, ZERO, (2568,))), Inequality(Variable((2568,), _flow)), Inequality(Constant(CONSTANT, ZERO, (747, 2))), Inequality(Variable((747, 2), _dag_layer)), Equality(Expression(AFFINE, UNKNOWN, (747,)), Constant(CONSTANT, ZERO, ())), Inequality(Expression(AFFINE, NONNEGATIVE, (2568, 2))), Equality(Expression(AFFINE, UNKNOWN, (25,)), Constant(CONSTANT, ZERO, ())), Equality(Expression(AFFINE, NONNEGATIVE, (25,)), Constant(CONSTANT, ZERO, ())), Equality(Expression(AFFINE, NONNEGATIVE, (25,)), Constant(CONSTANT, ZERO, ())), Equality(Expression(AFFINE, UNKNOWN, (14,)), Constant(CONSTANT, ZERO, ())), Equality(Expression(AFFINE, NONNEGATIVE, (14,)), Constant(CONSTANT, ZERO, ())), Equality(Expression(AFFINE, NONNEGATIVE, (14,)), Constant(CONSTANT, ZERO, ())), Inequality(Expression(AFFINE, NONNEGATIVE, (747, 2))), Inequality(Expression(AFFINE, UNKNOWN, (2531,))), Inequality(Expression(AFFINE, UNKNOWN, (2531,))), Inequality(Constant(CONSTANT, NONNEGATIVE, (688,))), Inequality(Expression(AFFINE, UNKNOWN, (2531,))), Inequality(Expression(AFFINE, UNKNOWN, (2531,))), Inequality(Constant(CONSTANT, NONNEGATIVE, (688,))), Inequality(Expression(AFFINE, NONNEGATIVE, (2568, 2))), Inequality(Expression(AFFINE, NONNEGATIVE, (2561, 2))), Inequality(Expression(AFFINE, NONNEGATIVE, (2561, 2))), Equality(Expression(AFFINE, NONNEGATIVE, (3,)), Constant(CONSTANT, ZERO, ())), Equality(Expression(AFFINE, NONNEGATIVE, (3,)), Constant(CONSTANT, ZERO, ())), Inequality(Expression(AFFINE, NONNEGATIVE, (2568,))), Inequality(Variable((2568,), edge_has_signal_OR, boolean=True))])
# optimization metrics
for o in P.objectives:
    print(o.name, o.value)
error_sample1_0 9.0
penalty_indirect_rules_0 [0.]
error_sample2_1 10.0
penalty_indirect_rules_1 [0.]
regularization_edge_has_signal_OR 28.0
# having a look to edge values per condition
pd.DataFrame(
    P.expr.edge_value.value,
    index=c.processed_graph.E,
    columns=["edge_activity1", "edge_activity2"],
).head(5)
edge_activity1 edge_activity2
(Mapk14) (Mapkapk2) 0.0 0.0
(Map2k3) (Dyrk1b) 0.0 0.0
(Akt1) (Chuk) 0.0 0.0
(Cdkn1a) (Cdk1) 0.0 0.0
(Mdc1) (Rnf8) 0.0 0.0
# having a look to node values per condition
pd.DataFrame(
    P.expr.vertex_value.value,
    index=c.processed_graph.V,
    columns=["vertex_activity1", "vertex_activity2"],
).head(5)
vertex_activity1 vertex_activity2
Ulk1 0.0 0.0
Siah1a 0.0 0.0
Sp1 0.0 -1.0
Hck 0.0 0.0
Ghr 0.0 0.0

Inferred network for contrast 1: Leukemic_spontaneous vs Normal#

sol_edges1 = np.flatnonzero(np.abs(P.expr.edge_value.value[:, 0]) > 0.5)
c.processed_graph.plot_values(
    vertex_values=P.expr.vertex_value.value[:, 0],
    edge_values=P.expr.edge_value.value[:, 0],
    edge_indexes=sol_edges1,
)
../../../_images/f46c4c596d790444564eda9a85329e968dac94d3ff99368d5f487ac428489860.svg

Inferred network for contrast 2: Leukemic_SB vs Normal#

sol_edges2 = np.flatnonzero(np.abs(P.expr.edge_value.value[:, 1]) > 0.5)
c.processed_graph.plot_values(
    vertex_values=P.expr.vertex_value.value[:, 1],
    edge_values=P.expr.edge_value.value[:, 1],
    edge_indexes=sol_edges2,
)
../../../_images/bac9d77d0452be602088078503e115db4d55242df1064f3b8a97bd57940563a8.svg