Knowledge-primed Neural Networks (KPNNs) for single cell data#

In this tutorial we will show how CORNETO can be used to build custom neural network architectures informed by prior knowledge. We will see how to implement a knowledge-primed neural network1. We will use the single cell data from the publication “Knowledge-primed neural networks enable biologically interpretable deep learning on single-cell sequencing data”, from Nikolaus Fortelny & Christoph Bock, where they used a single-cell RNA-seq dataset they previously generated2, which measures cellular responses to T cell receptor (TCR) stimulation in a standardized in vitro model. The dataset was chosen due to the TCR signaling pathway’s complexity and its well-characterized role in orchestrating transcriptional responses to antigen detection in T cells.

Why CORNETO?#

In the original publication, authors built a KPNN by searching on databases, building a Direct Acyclic Graph (DAG) by running shortest paths from TCR receptor to genes. However, this approach is not optimal. CORNETO, thanks to its advanced capabilities for modeling and optimization on networks, provides methods to automatically find DAG architectures in an optimal way.

In addition to this, CORNETO provides methods to build DAG NN architectures with ease using Keras +3, making KPNN implementation very flexible and interoperable with backends like Pytorch, Tensorflow and JAX.

How does it work?#

Thanks to CORNETO’s building blocks for optimization over networks, we can easily model optimization problems to find DAG architectures from a Prior Knowledge Network. After we have the backbone, we can convert it to a neural network using the utility functions included in CORNETO.

References#

  1. Fortelny, N., & Bock, C. (2020). Knowledge-primed neural networks enable biologically interpretable deep learning on single-cell sequencing data. Genome biology, 21, 1-36.

  2. Datlinger, P., Rendeiro, A. F., Schmidl, C., Krausgruber, T., Traxler, P., Klughammer, J., … & Bock, C. (2017). Pooled CRISPR screening with single-cell transcriptome readout. Nature methods, 14(3), 297-301.

Download and import the single cell dataset#

import os
import tempfile
import urllib.parse
import urllib.request

import numpy as np
import pandas as pd
import scanpy as sc

import corneto as cn

with urllib.request.urlopen("http://kpnn.computational-epigenetics.org/") as response:
    web_input = response.geturl()
print("Effective URL:", web_input)

files = ["TCR_Edgelist.csv", "TCR_ClassLabels.csv", "TCR_Data.h5"]

temp_dir = tempfile.mkdtemp()

# Download files
file_paths = []
for file in files:
    url = urllib.parse.urljoin(web_input, file)
    output_path = os.path.join(temp_dir, file)
    print(f"Downloading {url} to {output_path}")
    try:
        with urllib.request.urlopen(url) as response:
            with open(output_path, "wb") as f:
                f.write(response.read())
        file_paths.append(output_path)
    except Exception as e:
        print(f"Failed to download {url}: {e}")

print("Downloaded files:")
for path in file_paths:
    print(path)
Effective URL: https://medical-epigenomics.org/papers/fortelny2019/
Downloading https://medical-epigenomics.org/papers/fortelny2019/TCR_Edgelist.csv to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpgy9f05hv/TCR_Edgelist.csv
Downloading https://medical-epigenomics.org/papers/fortelny2019/TCR_ClassLabels.csv to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpgy9f05hv/TCR_ClassLabels.csv
Downloading https://medical-epigenomics.org/papers/fortelny2019/TCR_Data.h5 to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpgy9f05hv/TCR_Data.h5
Downloaded files:
/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpgy9f05hv/TCR_Edgelist.csv
/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpgy9f05hv/TCR_ClassLabels.csv
/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpgy9f05hv/TCR_Data.h5
# The data contains also the original network they built with shortest paths.
# We will use it to replicate the study
df_edges = pd.read_csv(file_paths[0])
df_labels = pd.read_csv(file_paths[1])
# Import the 10x data with Scanpy
adata = sc.read_10x_h5(file_paths[2])
df_labels
barcode TCR
0 AAACCTGCACACATGT-1 0
1 AAACCTGCACGTCTCT-1 0
2 AAACCTGTCAATACCG-1 0
3 AAACCTGTCGTGGTCG-1 0
4 AAACGGGTCTGAGTGT-1 0
... ... ...
1730 TTTCCTCGTCATGCCG-2 1
1731 TTTGCGCGTAGCCTCG-2 1
1732 TTTGGTTAGATACACA-2 1
1733 TTTGGTTGTATGAATG-2 1
1734 TTTGGTTTCCAAGTAC-2 1

1735 rows × 2 columns

df_edges
parent child
0 TCR ZAP70
1 ZAP70 MAPK14
2 MAPK14 FOXO3
3 MAPK14 STAT1
4 MAPK14 STAT3
... ... ...
27574 HMGA1 MTRNR2L9_gene
27575 MYB C12orf50_gene
27576 MYB TRPC5OS_gene
27577 SOX2 TRPC5OS_gene
27578 CRTC1 MTRNR2L9_gene

27579 rows × 2 columns

adata.var
gene_ids
DDX11L1 ENSG00000223972
WASH7P ENSG00000227232
MIR6859-2 ENSG00000278267
MIR1302-10 ENSG00000243485
MIR1302-11 ENSG00000274890
... ...
Tcrlibrary_RUNX2_3_gene Tcrlibrary_RUNX2_3_gene
Tcrlibrary_ZAP70_1_gene Tcrlibrary_ZAP70_1_gene
Tcrlibrary_ZAP70_2_gene Tcrlibrary_ZAP70_2_gene
Tcrlibrary_ZAP70_3_gene Tcrlibrary_ZAP70_3_gene
Cas9_blast_gene Cas9_blast_gene

64370 rows × 1 columns

# We can normalize the data, however, it is better to avoid
# preprocessing the whole dataset before splitting in training and test
# to avoid data leakage.
# NOTE: Normalization can be done inside the cross-val loop
# sc.pp.normalize_total(adata, target_sum=1e6)

# Log-transform the data does not leak data as it does not estimate anything
sc.pp.log1p(adata)
adata.obs
AAACCTGAGAAACCAT-1
AAACCTGAGAAACCGC-1
AAACCTGAGAAACCTA-1
AAACCTGAGAAACGAG-1
AAACCTGAGAAACGCC-1
...
TTTGTCATCTTTACAC-2
TTTGTCATCTTTACGT-2
TTTGTCATCTTTAGGG-2
TTTGTCATCTTTAGTC-2
TTTGTCATCTTTCCTC-2

1474560 rows × 0 columns

barcodes = adata.obs_names
barcodes
Index(['AAACCTGAGAAACCAT-1', 'AAACCTGAGAAACCGC-1', 'AAACCTGAGAAACCTA-1',
       'AAACCTGAGAAACGAG-1', 'AAACCTGAGAAACGCC-1', 'AAACCTGAGAAAGTGG-1',
       'AAACCTGAGAACAACT-1', 'AAACCTGAGAACAATC-1', 'AAACCTGAGAACTCGG-1',
       'AAACCTGAGAACTGTA-1',
       ...
       'TTTGTCATCTTGGGTA-2', 'TTTGTCATCTTGTACT-2', 'TTTGTCATCTTGTATC-2',
       'TTTGTCATCTTGTCAT-2', 'TTTGTCATCTTGTTTG-2', 'TTTGTCATCTTTACAC-2',
       'TTTGTCATCTTTACGT-2', 'TTTGTCATCTTTAGGG-2', 'TTTGTCATCTTTAGTC-2',
       'TTTGTCATCTTTCCTC-2'],
      dtype='object', length=1474560)
gene_names = adata.var.index
print(gene_names)
Index(['DDX11L1', 'WASH7P', 'MIR6859-2', 'MIR1302-10', 'MIR1302-11', 'FAM138A',
       'OR4G4P', 'OR4G11P', 'OR4F5', 'RP11-34P13.7',
       ...
       'Tcrlibrary_RUNX1_1_gene', 'Tcrlibrary_RUNX1_2_gene',
       'Tcrlibrary_RUNX1_3_gene', 'Tcrlibrary_RUNX2_1_gene',
       'Tcrlibrary_RUNX2_2_gene', 'Tcrlibrary_RUNX2_3_gene',
       'Tcrlibrary_ZAP70_1_gene', 'Tcrlibrary_ZAP70_2_gene',
       'Tcrlibrary_ZAP70_3_gene', 'Cas9_blast_gene'],
      dtype='object', length=64370)
len(set(df_labels.barcode.tolist()))
1735
len(set(barcodes.tolist()))
1474560
matched_barcodes = sorted(set(barcodes.tolist()) & set(df_labels.barcode.tolist()))
len(matched_barcodes)
1735
# This is the InPathsY data in the original code of KPNNs
df_labels
barcode TCR
0 AAACCTGCACACATGT-1 0
1 AAACCTGCACGTCTCT-1 0
2 AAACCTGTCAATACCG-1 0
3 AAACCTGTCGTGGTCG-1 0
4 AAACGGGTCTGAGTGT-1 0
... ... ...
1730 TTTCCTCGTCATGCCG-2 1
1731 TTTGCGCGTAGCCTCG-2 1
1732 TTTGGTTAGATACACA-2 1
1733 TTTGGTTGTATGAATG-2 1
1734 TTTGGTTTCCAAGTAC-2 1

1735 rows × 2 columns

Import PKN with CORNETO#

cn.info()
Installed version:v1.0.0b2.post25.dev0+fe9ca1ee
Available backends:CVXPY v1.7.5
Default backend (corneto.opt):CVXPY
Installed solvers:CLARABEL, HIGHS, OSQP, SCIPY, SCS
Graphviz version:v0.21
Installed path:/Users/pablorodriguezmier/Documents/work/projects/corneto/corneto
Repository:https://github.com/saezlab/corneto
outputs_pkn = list(set(df_edges.parent.tolist()) - set(df_edges.child.tolist()))
inputs_pkn = set(df_edges.child.tolist()) - set(df_edges.parent.tolist())
input_pkn_genes = list(set(g.split("_")[0] for g in inputs_pkn))
len(inputs_pkn), len(outputs_pkn)
(13121, 1)
tuples = [(r.child, 1, r.parent) for _, r in df_edges.iterrows()]
G = cn.Graph.from_sif_tuples(tuples)
G = G.prune(inputs_pkn, outputs_pkn)

# Size of the original PKN provided by the authors
G.shape
(13439, 27579)

Select the single cell data for training#

adata_matched = adata[adata.obs_names.isin(matched_barcodes), adata.var_names.isin(input_pkn_genes)]
adata_matched.shape
(1735, 14229)
non_zero_genes = set(adata_matched.to_df().columns[adata_matched.to_df().sum(axis=0) >= 1e-6].values)
len(non_zero_genes)
12459
len(non_zero_genes.intersection(adata_matched.var_names))
12459
adata_matched = adata_matched[:, adata_matched.var_names.isin(non_zero_genes)]
# Many duplicates still 0 counts
adata_matched = adata_matched[:, adata_matched.to_df().sum(axis=0) != 0]
adata_matched.shape
(1735, 12487)
df_expr = adata_matched.to_df()
df_expr = df_expr.groupby(df_expr.columns, axis=1).max()
df_expr
A1BG A2ML1 AAAS AACS AADAT AAED1 AAGAB AAK1 AAMDC AAMP ... ZSWIM8 ZUFSP ZW10 ZWILCH ZXDC ZYG11A ZYG11B ZYX ZZEF1 ZZZ3
AAACCTGCACACATGT-1 0.0 0.0 0.000000 0.693147 0.0 0.000000 0.000000 0.000000 0.0 1.098612 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
AAACCTGCACGTCTCT-1 0.0 0.0 0.000000 0.000000 0.0 0.693147 0.000000 0.000000 0.0 0.693147 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
AAACCTGTCAATACCG-1 0.0 0.0 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 0.0 0.693147 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
AAACCTGTCGTGGTCG-1 0.0 0.0 1.098612 0.000000 0.0 0.000000 0.000000 0.000000 0.0 1.098612 ... 0.000000 0.693147 0.000000 0.693147 0.0 0.0 0.0 0.000000 0.000000 0.000000
AAACGGGTCTGAGTGT-1 0.0 0.0 0.000000 0.000000 0.0 0.693147 0.000000 0.000000 0.0 0.000000 ... 0.693147 0.000000 0.693147 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
TTTCCTCGTCATGCCG-2 0.0 0.0 0.000000 0.000000 0.0 0.000000 0.000000 0.693147 0.0 0.693147 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
TTTGCGCGTAGCCTCG-2 0.0 0.0 1.098612 0.000000 0.0 0.000000 0.000000 0.000000 0.0 1.386294 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
TTTGGTTAGATACACA-2 0.0 0.0 1.098612 0.000000 0.0 0.000000 0.000000 1.098612 0.0 0.693147 ... 0.693147 0.000000 0.693147 0.000000 0.0 0.0 0.0 0.693147 0.693147 0.000000
TTTGGTTGTATGAATG-2 0.0 0.0 0.000000 0.000000 0.0 0.693147 1.098612 0.000000 0.0 0.693147 ... 0.000000 0.693147 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.693147
TTTGGTTTCCAAGTAC-2 0.0 0.0 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 0.0 0.693147 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 1.098612 0.000000 0.000000

1735 rows × 12459 columns

Building and training the KPNN#

Now we will use the provided PKN by the authors and the utility functions in CORNETO to build a KPNN similar to the one used in the original manuscript

import os

os.environ["KERAS_BACKEND"] = "jax"
import keras
from keras.callbacks import EarlyStopping
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import StratifiedKFold
# Use the data from the experiment
X = df_expr.values
y = df_labels.set_index("barcode").loc[df_expr.index, "TCR"].values
X.shape, y.shape
((1735, 12459), (1735,))
# We can prefilter on top N genes to make this faster
top_n = None

# From the given PKN
outputs_pkn = list(set(df_edges.parent.tolist()) - set(df_edges.child.tolist()))
inputs_pkn = set(df_edges.child.tolist()) - set(df_edges.parent.tolist())
input_pkn_genes = list(set(g.split("_")[0] for g in inputs_pkn))

if top_n is not None and top_n > 0:
    input_pkn_genes = list(
        set(input_pkn_genes).intersection(df_expr.var(axis=0).sort_values(ascending=False).head(top_n).index)
    )
    inputs_pkn = list(g + "_gene" for g in input_pkn_genes)

len(inputs_pkn), len(outputs_pkn)
(13121, 1)
input_nn_genes = list(set(input_pkn_genes).intersection(df_expr.columns))
input_nn = [g + "_gene" for g in input_nn_genes]
len(input_nn)
12459
# Build corneto graph
tuples = [(r.child, 1, r.parent) for _, r in df_edges.iterrows()]
G = cn.Graph.from_sif_tuples(tuples)
G = G.prune(input_nn, outputs_pkn)
G.shape
(12767, 25928)
len(input_nn), len(input_nn_genes)
(12459, 12459)
len(set(input_nn).intersection(G.V))
12459
X = df_expr.loc[:, input_nn_genes].values
y = df_labels.set_index("barcode").loc[df_expr.index, "TCR"].values
X.shape, y.shape
((1735, 12459), (1735,))
from corneto._ml import build_dagnn


def stratified_kfold(
    G,
    inputs,
    outputs,
    n_splits=5,
    shuffle=True,
    random_state=42,
    lr=0.001,
    patience=10,
    file_weights="weights",
    dagnn_config=dict(
        batch_norm_input=True,
        batch_norm_center=False,
        batch_norm_scale=False,
        bias_reg_l1=1e-3,
        bias_reg_l2=1e-2,
        dropout=0.20,
        default_hidden_activation="sigmoid",
        default_output_activation="sigmoid",
        force_sign=False,
        verbose=False,
    ),
):
    kfold = StratifiedKFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
    models = []
    metrics = {m: [] for m in ["accuracy", "precision", "recall", "f1", "roc_auc"]}
    for i, (train_idx, val_idx) in enumerate(kfold.split(X, y)):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]

        print("Building DAG NN model with CORNETO using Keras with JAX...")
        print(f" > N. inputs: {len(input_nn)}")
        print(f" > N. outputs: {len(outputs_pkn)}")
        model = build_dagnn(G, input_nn, outputs_pkn, **dagnn_config)
        print(f" > N. parameters: {model.count_params()}")

        # Train the model with Adam
        opt = keras.optimizers.Adam(learning_rate=lr)
        early_stopping = EarlyStopping(monitor="val_loss", patience=patience, restore_best_weights=True)
        print("Compiling...")
        model.compile(optimizer=opt, loss="binary_crossentropy", metrics=["accuracy"])
        print("Fitting...")
        model.fit(
            X_train,
            y_train,
            validation_data=(X_val, y_val),
            epochs=200,
            batch_size=64,
            verbose=0,
            callbacks=[early_stopping],
        )

        if file_weights is not None:
            filename = f"{file_weights}_{i}.keras"
            model.save(filename)
            print(f"Weights saved to {filename}")

        # Predictions and metrics calculation
        y_pred_proba = model.predict(X_val).flatten()
        y_pred = (y_pred_proba > 0.5).astype(int)
        acc = accuracy_score(y_val, y_pred)
        precision = precision_score(y_val, y_pred)
        recall = recall_score(y_val, y_pred)
        f1 = f1_score(y_val, y_pred)
        roc_auc = roc_auc_score(y_val, y_pred_proba)
        metrics["accuracy"].append(acc)
        metrics["precision"].append(precision)
        metrics["recall"].append(recall)
        metrics["f1"].append(f1)
        metrics["roc_auc"].append(roc_auc)
        print(f" > Fold {i} validation ROC-AUC={roc_auc:.3f}")
        models.append(model)
    return models, metrics


temp_weights = tempfile.mkdtemp()
models, metrics = stratified_kfold(G, input_nn, outputs_pkn, file_weights=os.path.join(temp_weights, "weights"))

print("Validation metrics:")
for k, v in metrics.items():
    print(f" - {k}: {np.mean(v):.3f}")
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/weights_0.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 10s 1s/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 113ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 2s 113ms/step
 > Fold 0 validation ROC-AUC=0.995
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/weights_1.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 10s 1s/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 111ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 2s 111ms/step
 > Fold 1 validation ROC-AUC=0.978
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/weights_2.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 10s 1s/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 110ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 2s 110ms/step
 > Fold 2 validation ROC-AUC=0.988
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/weights_3.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 10s 1s/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 112ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 2s 112ms/step
 > Fold 3 validation ROC-AUC=0.993
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/weights_4.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 11s 1s/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 112ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 2s 113ms/step
 > Fold 4 validation ROC-AUC=0.993
Validation metrics:
 - accuracy: 0.961
 - precision: 0.967
 - recall: 0.954
 - f1: 0.960
 - roc_auc: 0.989

Now, we will analyze the learned biases for each of the nodes in the graph. Note that authors in the KPNN paper explain a way to extract weights for the nodes, based on the learned interactions and accounting for biases in the structure of the NN. Here we just show the learned biases of the nodes of the NN across 5 folds. Please be careful interpreting these weights.

# We collect the weights obtained in each fold


def load_biases(file="weights", folds=5):
    biases = []
    mean_inputs = []
    for i in range(5):
        model = keras.models.load_model(f"{file}_{i}.keras")
        for layer in model.layers:
            weights = layer.get_weights()
            if weights:
                biases.append((i, layer.name, weights[1][0]))
                mean_inputs.append((i, layer.name, weights[0].mean()))
    df_biases = pd.DataFrame(biases, columns=["fold", "gene", "bias"])
    df_biases["abs_bias"] = df_biases.bias.abs()
    df_biases["pow2_bias"] = df_biases.bias.pow(2)
    df_biases = df_biases.set_index(["fold", "gene"])
    return df_biases


df_biases = load_biases(file=os.path.join(temp_weights, "weights"), folds=5)
df_biases.sort_values(by="abs_bias", ascending=False).head(10)
bias abs_bias pow2_bias
fold gene
2 SUZ12.EZH2 -2.319227 2.319227 5.378813
3 TCR -2.314526 2.314526 5.357029
1 ZAP70 -2.168838 2.168838 4.703859
2 TCR 2.163686 2.163686 4.681535
3 NfKb.p65.p50 -2.100779 2.100779 4.413271
4 SUZ12.EZH2 -1.978842 1.978842 3.915815
3 ZAP70 1.965354 1.965354 3.862615
4 PRC2 -1.936076 1.936076 3.748389
0 TCR 1.794367 1.794367 3.219754
1 NfKb.p65.p50 -1.788792 1.788792 3.199777
df_biases_full = df_biases.copy().reset_index()
gene_biases_score = df_biases_full.groupby("gene")["pow2_bias"].mean().sort_values(ascending=False)
gene_biases_score
gene
TCR             3.567696
ZAP70           2.803867
SUZ12.EZH2      2.255738
NfKb.p65.p50    2.109870
PRKCD           1.688119
                  ...   
CBX3            0.000156
KMT2D           0.000054
HMT             0.000016
GATA3           0.000015
SRY             0.000011
Name: pow2_bias, Length: 308, dtype: float32
gene_biases_score.head(30).plot.bar()
<Axes: xlabel='gene'>
../../_images/d99d6f9f027a9b51f1b8676ea3d743d970c29026dd5505cd4689aa505c4af220.png
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# Get the top genes sorted by mean bias
top_genes = gene_biases_score.head(30).index

# Filter the DataFrame to include only rows with these top genes
filtered_df = df_biases_full[df_biases_full["gene"].isin(top_genes)]

plt.figure(figsize=(10, 6))
sns.violinplot(data=filtered_df, x="gene", y="bias", order=top_genes, dodge=True)
sns.stripplot(data=filtered_df, x="gene", y="bias", order=top_genes, dodge=True)
plt.xticks(rotation=90)
plt.title("Biases of the trained neurons")
plt.xlabel("Gene")
plt.ylabel("Bias")
plt.axhline(0, linestyle="--", color="k")
plt.tight_layout()
../../_images/c7ace1adacab23df9f91806b836523f04c8a637dfa1ad263b7c969b711cfa675.png

Use CORNETO for NN pruning#

Now we will show how CORNETO can be used to extract a smaller, yet complete DAG from the original PKN provided by the authors. We will add input edges to each input node and an output edge through TCR to indicate which nodes are the inputs and which one the output. We will use then Acyclic Flow to find the smallest DAG comprising these nodes

G_dag = G.copy()
new_edges = []
for g in input_nn:
    new_edges.append(G_dag.add_edge((), g))
new_edges.append(G_dag.add_edge("TCR", ()))
print(G_dag.shape)

# Find small DAG. We use Acyclic Flow to find over the space of DAGs
P = cn.opt.AcyclicFlow(G_dag)
# We enforce that the input genes and the output gene are part of the solution
P += P.expr.with_flow[new_edges] == 1
# Minimize the number of active edges
P.add_objectives(sum(P.expr.with_flow), weights=1)
P.solve(solver="HIGHS", verbosity=1, max_seconds=300);
(12767, 38388)
===============================================================================
                                     CVXPY                                     
                                     v1.7.5                                    
===============================================================================
-------------------------------------------------------------------------------
                                  Compilation                                  
-------------------------------------------------------------------------------
-------------------------------------------------------------------------------
                                Numerical solver                               
-------------------------------------------------------------------------------
Running HiGHS 1.13.0 (git hash: 1bce6d5): Copyright (c) 2026 under MIT licence terms
MIP has 332945 rows; 127931 cols; 628290 nonzeros; 76776 integer variables (76776 binary)
Coefficient ranges:
  Matrix  [1e-04, 1e+04]
  Cost    [1e+00, 1e+00]
  Bound   [1e+00, 1e+00]
  RHS     [1e+00, 1e+04]
Presolving model
61423 rows, 50344 cols, 162081 nonzeros  0s
41020 rows, 42186 cols, 129115 nonzeros  0s
39333 rows, 38152 cols, 131120 nonzeros  0s
Presolve reductions: rows 39333(-293612); columns 38152(-89779); nonzeros 131120(-497170) 
Objective function is integral with scale 1

Solving MIP model with:
   39333 rows
   38152 cols (18820 binary, 0 integer, 0 implied int., 19332 continuous, 0 domain fixed)
   131120 nonzeros

Src: B => Branching; C => Central rounding; F => Feasibility pump; H => Heuristic;
     I => Shifting; J => Feasibility jump; L => Sub-MIP; P => Empty MIP; R => Randomized rounding;
     S => Solve LP; T => Evaluate node; U => Unbounded; X => User solution; Y => HiGHS solution;
     Z => ZI Round; l => Trivial lower; p => Trivial point; u => Trivial upper; z => Trivial zero

        Nodes      |    B&B Tree     |            Objective Bounds              |  Dynamic Constraints |       Work      
Src  Proc. InQueue |  Leaves   Expl. | BestBound       BestSol              Gap |   Cuts   InLp Confl. | LpIters     Time

         0       0         0   0.00%   19719           inf                  inf        0      0      0         0     0.9s
         0       0         0   0.00%   19719           inf                  inf        0      0      5     16346     1.8s
         0       0         0   0.00%   24886.380366    inf                  inf     6492   5461     58     75101    22.7s
 L       0       0         0   0.00%   24924.373044    25087              0.65%     8224   5729    113     77722    36.5s

0.2% inactive integer columns, restarting
Model after restart has 39257 rows, 38106 cols (18777 bin., 0 int., 0 impl., 19329 cont., 0 dom.fix.), and 130968 nonzeros

         0       0         0   0.00%   24924.373044    25087              0.65%     5499      0      0    142175    36.9s
         0       0         0   0.00%   24924.373059    25087              0.65%     5499   5433      4    254650    69.2s

Symmetry detection completed in 4.6s
Found 2475 generator(s)
        48       0         1   0.00%   24927.248967    25087              0.64%     8027   5631     99    322948   100.7s
       192     188         1   0.00%   24996.054495    25087              0.36%     8054   5631    170    337853   113.6s
       232     187         2   0.00%   24996.054495    25087              0.36%     8482   5712    180    410758   151.7s
       370     309         6   0.00%   24996.054495    25087              0.36%     8831   5745    211    429962   162.5s
 T     401     239         7   0.00%   24996.054495    25079              0.33%     8831   5745    216    430098   166.5s
       503     359         9   0.00%   24996.054495    25079              0.33%     9205   5774    225    451338   176.4s
 T     503     311         9   0.00%   24996.054495    25073              0.31%     9205   5774    229    451338   178.6s
 L     577     331        10   0.00%   24996.054495    25067              0.28%     9409   5474    231    453184   185.6s
       630     331        11   0.00%   24996.054495    25067              0.28%     9411   5474    239    479529   192.5s
       766     430        13   0.00%   24997.054475    25067              0.28%     9497   5488    261    501683   205.1s
       853     595        13   0.00%   24997.054485    25067              0.28%     9676   5515    283    529168   220.8s
       905     594        14   0.00%   24997.054485    25067              0.28%    10056   5493    293    561558   236.5s
 L     935     632        14   0.00%   24997.054485    25066              0.28%    10295   5510    299    563727   244.6s
       984     632        15   0.00%   24997.054485    25066              0.28%    10295   5510    308    589295   249.6s

        Nodes      |    B&B Tree     |            Objective Bounds              |  Dynamic Constraints |       Work      
Src  Proc. InQueue |  Leaves   Expl. | BestBound       BestSol              Gap |   Cuts   InLp Confl. | LpIters     Time

      1068     692        16   0.00%   24997.054485    25066              0.28%    10386   5524    323    649063   279.6s
      1139     776        17   0.00%   24997.054485    25066              0.28%    10392   5538    336    685024   299.3s
      1151     838        18   0.00%   25006.0518      25066              0.24%    10392   5538    341    687194   300.0s

Solving report
  Status            Time limit reached
  Primal bound      25066
  Dual bound        25007
  Gap               0.235% (tolerance: 0.01%)
  P-D integral      1.06568195332
  Solution status   feasible
                    25066 (objective)
                    0 (bound viol.)
                    0 (int. viol.)
                    0 (row viol.)
  Timing            300.02
  Max sub-MIP depth 10
  Nodes             1151
  Repair LPs        0
  LP iterations     687194
                    146288 (strong br.)
                    95062 (separation)
                    193529 (heuristics)
-------------------------------------------------------------------------------
                                    Summary                                    
-------------------------------------------------------------------------------
G_subdag = G_dag.edge_subgraph(P.expr.with_flow.value > 0.5)
G_dag.shape, G_subdag.shape
((12767, 38388), (12607, 25066))
rel_dag_compression = (1 - (G_subdag.num_edges / G_dag.num_edges)) * 100
print(f"KPNN edge compression (0-100%): {rel_dag_compression:.2f}%")
KPNN edge compression (0-100%): 34.70%
pruned_models, pruned_metrics = stratified_kfold(
    G_subdag,
    input_nn,
    outputs_pkn,
    file_weights=os.path.join(temp_weights, "pruned_weights"),
)

print("Validation metrics:")
for k, v in pruned_metrics.items():
    print(f" - {k}: {np.mean(v):.3f}")
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12754
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/pruned_weights_0.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 5s 539ms/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step
 > Fold 0 validation ROC-AUC=0.988
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12754
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/pruned_weights_1.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 5s 525ms/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step
 > Fold 1 validation ROC-AUC=0.980
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12754
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/pruned_weights_2.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 5s 544ms/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step
 > Fold 2 validation ROC-AUC=0.987
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12754
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/pruned_weights_3.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 5s 513ms/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step
 > Fold 3 validation ROC-AUC=0.992
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12754
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpflwx1756/pruned_weights_4.keras
 1/11 ━━━━━━━━━━━━━━━━━━━ 5s 526ms/step
10/11 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step
11/11 ━━━━━━━━━━━━━━━━━━━━ 1s 54ms/step
 > Fold 4 validation ROC-AUC=0.994
Validation metrics:
 - accuracy: 0.963
 - precision: 0.969
 - recall: 0.954
 - f1: 0.962
 - roc_auc: 0.988
df_biases_pruned = load_biases(file=os.path.join(temp_weights, "pruned_weights"), folds=5)
df_biases_pruned.sort_values(by="abs_bias", ascending=False).head(10)
bias abs_bias pow2_bias
fold gene
2 ZAP70 -2.833718 2.833718 8.029957
3 SUZ12.EZH2 -2.608386 2.608386 6.803678
4 PRKCD -2.220839 2.220839 4.932124
3 TCR 2.169856 2.169856 4.708274
4 TCR -2.168031 2.168031 4.700360
ZAP70 2.140123 2.140123 4.580125
3 ZAP70 2.107303 2.107303 4.440725
2 NFYA 2.093171 2.093171 4.381363
1 NFYA 2.036027 2.036027 4.145407
2 MAX -2.015005 2.015005 4.060244
df_biases_prunedr = df_biases_pruned.copy().reset_index()
gene_biases_score = df_biases_prunedr.groupby("gene")["pow2_bias"].mean().sort_values(ascending=False)
gene_biases_score
gene
ZAP70                 4.356112
TCR                   3.455697
NFYA                  3.271365
SUZ12.EZH2            2.938754
PRKCD                 2.340390
                        ...   
ZNF217                0.003361
MLL.SET.subcomplex    0.002986
REST                  0.000634
ASH2L                 0.000447
GATA3                 0.000002
Name: pow2_bias, Length: 148, dtype: float32
gene_biases_score.head(30).plot.bar()
<Axes: xlabel='gene'>
../../_images/600960ec608cd3ea496d9774553bba61ddfad4a551285da9aeab2cc69f6856b3.png
# Get the top genes sorted by mean bias
top_genes = gene_biases_score.head(30).index

# Filter the DataFrame to include only rows with these top genes
filtered_df = df_biases_prunedr[df_biases_prunedr["gene"].isin(top_genes)]

plt.figure(figsize=(10, 6))
sns.violinplot(data=filtered_df, x="gene", y="bias", order=top_genes, dodge=True)
sns.stripplot(data=filtered_df, x="gene", y="bias", order=top_genes, dodge=True)
plt.xticks(rotation=90)
plt.title("Biases of the trained neurons")
plt.xlabel("Gene")
plt.ylabel("Bias")
plt.axhline(0, linestyle="--", color="k")
plt.tight_layout()
../../_images/3989de8bc7abc0bb3e938210528e7aae633fa4f06df82f4debfd6249ec1e67b5.png
param_compression = (1 - (pruned_models[0].count_params() / models[0].count_params())) * 100
print(f"Parameter compression: {param_compression:.2f}%")
Parameter compression: 51.39%
perf_degradation = (
    (np.mean(metrics["roc_auc"]) - np.mean(pruned_metrics["roc_auc"])) / np.mean(metrics["roc_auc"])
) * 100
print(
    f"Degradation in ROC-AUC after compression (positive = decrease in performance, negative = increase in performance): {perf_degradation:.2f}%"
)
Degradation in ROC-AUC after compression (positive = decrease in performance, negative = increase in performance): 0.11%