Multi-sample CARNIVAL#

import numpy as np
import pandas as pd

import corneto as cn

cn.info()
Installed version:v1.0.0.dev39+1266a7d
Available backends:CVXPY v1.6.6, PICOS v2.6.1
Default backend (corneto.opt):CVXPY
Installed solvers:CVXOPT, GLPK, GLPK_MI, HIGHS, SCIP, SCIPY
Graphviz version:v0.20.3
Installed path:/home/runner/work/corneto/corneto/corneto
Repository:https://github.com/saezlab/corneto

Toy example#

from corneto.graph import Graph

# A toy PKN signalling network
G = Graph.from_tuples(
    [
        ("rec1", 1, "a"),
        ("rec1", -1, "b"),
        ("rec1", 1, "f"),
        ("rec1", -1, "c"),
        ("rec2", 1, "b"),
        ("rec2", 1, "tf2"),
        ("b", 1, "g"),
        ("g", -1, "d"),
        ("rec2", -1, "d"),
        ("a", 1, "c"),
        ("a", -1, "d"),
        ("c", 1, "d"),
        ("c", -1, "e"),
        ("c", 1, "tf3"),
        ("e", 1, "a"),
        ("d", -1, "c"),
        ("e", 1, "tf1"),
        ("a", -1, "tf1"),
        ("d", 1, "tf2"),
        ("c", -1, "tf2"),
        ("tf1", 1, "tf2"),
        ("tf1", -1, "rec2"),
        ("tf2", 1, "rec1"),
        ("tf1", 1, "f"),
    ]
)
G.plot()
../../_images/4fd3b062c8fc55e25036e3fb65b0e7f7ab9ea9f69f11ac63c04901cda82874f4.svg
from corneto.methods.future.carnival import CarnivalFlow

samples = {
    "c1": {
        "rec2": {"value": 1, "mapping": "vertex", "role": "input"},
        "tf1": {"value": -2, "mapping": "vertex", "role": "output"},
        "tf2": {"value": 1, "mapping": "vertex", "role": "output"},
    },
    "c2": {
        "rec1": {"value": 1, "mapping": "vertex", "role": "input"},
        "rec2": {"value": -1, "mapping": "vertex", "role": "input"},
        "tf1": {"value": 1, "mapping": "vertex", "role": "output"},
        "tf2": {"value": -1, "mapping": "vertex", "role": "output"},
        "tf3": {"value": 3, "mapping": "vertex", "role": "output"},
    },
}

data = cn.Data.from_cdict(samples)
data
Data(n_samples=2, n_feats=[3 5])

Multi-sample CARNIVAL#

c = CarnivalFlow(lambda_reg=1e-3)
P = c.build(G, data)

# Solve the problem using the HIGHs solver included in Scipy
P.solve(verbosity=0, solver="scipy");
Unreachable vertices for sample: 1
Unreachable vertices for sample: 0
# The preprocessed G graph (flow graph)
c.processed_graph.plot()
../../_images/6625b3301d88983a5dece22b33532aba99311e2852fa795694c924035736763b.svg
# Show the sub-graph where the signal propagates for all samples (union graph)
c.processed_graph.edge_subgraph(np.flatnonzero(P.expr.flow.value)).plot()
../../_images/c15ba3441dc234e59c5231ea207c091214aec21bc89acef7199fcf19010834cc.svg
# Extract the values from the problem
pd.DataFrame(
    P.expr.edge_value.value,
    index=c.processed_graph.E,
    columns=["condition_1", "condition_2"],
).astype(int)
condition_1 condition_2
(rec1) (a) 1 1
(b) 0 0
(c) 0 0
(rec2) (b) 0 0
(tf2) 1 -1
(b) (g) 0 0
(g) (d) 0 0
(rec2) (d) 0 0
(a) (c) 0 1
(d) 0 0
(c) (d) 0 0
(e) 0 0
(tf3) 0 1
(e) (a) 0 0
(d) (c) 0 0
(e) (tf1) 0 0
(a) (tf1) -1 0
(d) (tf2) 0 0
(c) (tf2) 0 0
(tf1) (tf2) 0 0
(rec2) 0 0
(tf2) (rec1) 1 0
(tf3) () 0 0
(tf1) () 0 0
(tf2) () 0 0
() (rec2) 1 -1
(rec1) 0 1
# Solution for sample 1
c.processed_graph.edge_subgraph(np.flatnonzero(P.expr.edge_has_signal.value[:, 0] > 0.5)).plot()
../../_images/f241de182f9f6b73e736b744d931f928b2f4274204a15fe91d745cf9008fa499.svg
# Solution for sample 2
c.processed_graph.edge_subgraph(np.flatnonzero(P.expr.edge_has_signal.value[:, 1] > 0.5)).plot()
../../_images/1542ea34797a98a0b0329061800a8bd8dbe88e791e4af03a84896f7d12a8c02b.svg
for o in P.objectives:
    print(o.value)
0.0
1.0
8.0

Plotting solutions on top of the PKN#

c.processed_graph.plot_values(
    vertex_values=P.expr.vertex_value.value[:, 0],
    edge_values=P.expr.edge_value.value[:, 0],
)
../../_images/eacc2b197e26b1408aa1b4b70e44827a30e0489f576366cbd48fbc38a1c06094.svg
c.processed_graph.plot_values(
    vertex_values=P.expr.vertex_value.value[:, 1],
    edge_values=P.expr.edge_value.value[:, 1],
)
../../_images/01794b89c73ccde6d5d4561cee5c1b9c70e97424cd01f7c6c51ac1942aabc58c.svg

Penalizing indirect signaling rules#

The original version of CARNIVAL treats indirect rules (A -> B, A inhibited => B inhibited, and A -| B, A inhibited => B activated) similarly that more direct rules (…). The reverse or inactive-state implications rely, in general, on additional assumptions, such as A being the sole regulator of B or the existence of efficient off-switch mechanisms, and these might be in general less reliable. To control for this, we added a tunable parameter called indirect_rule_penalty that penalizes the number of indirect rules selected in each sample

c = CarnivalFlow(lambda_reg=1e-3, indirect_rule_penalty=1)
P = c.build(G, data)
P.solve(verbosity=0, solver="scipy")

for o in P.objectives:
    print(o.name, o.value)
Unreachable vertices for sample: 1
Unreachable vertices for sample: 0
error_c1_0 0.0
penalty_indirect_rules_0 [0.]
error_c2_1 1.0
penalty_indirect_rules_1 [1.]
regularization_edge_has_signal_OR 9.0
c.processed_graph.plot_values(
    vertex_values=P.expr.vertex_value.value[:, 0],
    edge_values=P.expr.edge_value.value[:, 0],
)
../../_images/eacc2b197e26b1408aa1b4b70e44827a30e0489f576366cbd48fbc38a1c06094.svg
c.processed_graph.plot_values(
    vertex_values=P.expr.vertex_value.value[:, 1],
    edge_values=P.expr.edge_value.value[:, 1],
)
../../_images/24d2f30c4e5323e99d2ca8588dcc24b805e561d0eeff2d9e91eff4678b55488e.svg

Single sample using the same multi-sample method#

# Only c1

subset = cn.Data.from_cdict({"c1": samples["c1"]})
c = CarnivalFlow(lambda_reg=1e-3)
P = c.build(G, subset)
P.solve(verbosity=0, solver="scipy")
for o in P.objectives:
    print(o.value)

c.processed_graph.edge_subgraph(np.flatnonzero(P.expr.edge_has_signal.value)).plot()
Unreachable vertices for sample: 0
0.0
5.0
../../_images/f241de182f9f6b73e736b744d931f928b2f4274204a15fe91d745cf9008fa499.svg
# Only c2

subset = cn.Data.from_cdict({"c2": samples["c2"]})
c = CarnivalFlow(lambda_reg=1e-3)
P = c.build(G, subset)
P.solve(verbosity=0, solver="scipy")
for o in P.objectives:
    print(o.value)

c.processed_graph.edge_subgraph(np.flatnonzero(P.expr.edge_has_signal.value)).plot()
Unreachable vertices for sample: 0
1.0
6.0
../../_images/1542ea34797a98a0b0329061800a8bd8dbe88e791e4af03a84896f7d12a8c02b.svg