tutorial.md
May 7, 2024 · View on GitHub
CeLEry Tutorial
Author: Qihuang Zhang*, Shunzhou Jiang, Amelia Schroeder, Jian Hu, Kejie Li, Baohong Zhang, David Dai, Edward B. Lee, Rui Xiao, Mingyao Li*
Outline
- Installation
- Import modules
- Data Loading
- Run CeLEry
1. Installation
To install CeLEry package you must make sure that your python version is over 3.5 (3.8 recommended). If you don't know the version of python you can check it by:
import platform
platform.python_version()
Note: Because CeLery depends on pytorch, you should make sure torch is correctly installed. Now you can install the current release of CeLEry by the following three ways:
1.1 PyPI: Directly install the package from PyPI.
pip3 install CeLEryPy
#Note: you need to make sure that the pip is for python3, or we should install CeLEry by
python3 -m pip install CeLEryPy
pip3 install CeLEryPy
#If you do not have permission (when you get a permission denied error), you should install CeLEry by
pip3 install --user CeLEryPy
1.2 Github
Download the package from Github and install it locally:
git clone https://github.com/QihuangZhang/CeLEry
cd CeLEry/CeLEry_package/
python3 setup.py install --user
1.3 Anaconda
If you do not have Python3.5 or Python3.6 installed, consider installing Anaconda (see Installing Anaconda). After installing Anaconda, you can create a new environment, for example, CeLEry (you can change to any name you like).
#create an environment called CeLEry
conda create -n CeLEry python=3.8
#activate your environment
conda activate CeLEry
git clone https://github.com/QihuangZhang/CeLEry
cd CeLEry/CeLEry_package/
python3 setup.py build
python3 setup.py install
conda deactivate
2. Import python modules
import scanpy as sc
import torch
import CeLEry as cel
import os,csv,re
import pandas as pd
import numpy as np
import math
from skimage import io, color
from scipy.sparse import issparse
import random, torch
import warnings
warnings.filterwarnings("ignore")
import pickle
import json
cel.__version__
3. Load-in data
The current version of CeLEry takes two input data, the reference data and the query data. The reference data is used to trained the model and the query data is the dataset where predictions (or classifications) are made.
- The Reference Data (the spatial transcriptomics data): AnnData format including:
- the gene expression matrix spot by gene ( by );
- the spot-specific information (e.g., coordinates, layer, etc.)
- The Query Data (the scRNA-seq data): AnnData format including:
- the gene expression matrix cell by gene ( by );
- the demographic information for each cell (e.g., cell type, layer, etc.)
AnnData stores a data matrix .X together with annotations of
observations .obs, variables .var and unstructured annotations
.uns.
"""
#Read original 10x_h5 data and save it to h5ad
from scanpy import read_10x_h5
adata = read_10x_h5("../tutorial/data/151673/expression_matrix.h5")
spatial = pd.read_csv("../tutorial/data/151673/positions.txt",sep=",",header=None,na_filter=False,index_col=0)
adata.obs["x1"] = spatial[1]
adata.obs["x2"] = spatial[2]
adata.obs["x3"] = spatial[3]
adata.obs["x4"] = spatial[4]
adata.obs["x5"] = spatial[5]
adata.obs["x_array"] = adata.obs["x2"]
adata.obs["y_array"] = adata.obs["x3"]
adata.obs["x_pixel"] = adata.obs["x4"]
adata.obs["y_pixel"] = adata.obs["x5"]
#Select captured samples
adata = adata[adata.obs["x1"]==1]
adata.var_names = [i.upper() for i in list(adata.var_names)]
adata.var["genename"] = adata.var.index.astype("str")
adata.write_h5ad("../tutorial/data/151673/sample_data.h5ad")
"""
#Read in gene expression and spatial location
Qdata = sc.read("data/tutorial/MouseSCToy.h5ad")
Rdata = sc.read("data/tutorial/MousePosteriorToy.h5ad")
Rdata
``$
\text{Rdata} \text{AnnData} \text{object} \text{with} \text{n_obs} \times \text{n_vars} = 5824 \times 356 \text{obs}: '\text{x}', '\text{y}', '\text{inner}' \text{var}: '\text{genename}' $``
Here, Qdata stores the annodated query data (scRNA-seq/snRNA-seq data) and Rdata is the annoated reference data collected from spatial transcriptomics.
Before inplementing our methods, we often normalize both the reference data and the query data:
cel.get_zscore(Qdata)
cel.get_zscore(Rdata)
4. Run CeLEry
We demonstrate the implemnetation of CeLEry in two tasks. In the first task, CeLEry is implemented to predict the 2D coordinates for the cells. In the second task, we classify the cells into different layers.
4.1 Analysis Task 1: Coordinates Recovery
In the first task, we train a deep neural network using the reference data, and then apply the trained model to predict the location of the cells (or spots) in the query data.
Training
First, we train the model using spatial transcriptomic data. The trained
model will also automately save as an .obj file in the specified path. This step can take an hour depending on the structure of the neural network.
model_train = cel.Fit_cord (data_train = Rdata, hidden_dims = [30, 25, 15], num_epochs_max = 500, path = "output/tutorial", filename = "Org_Mousesc")
The fitting function Fit_cord involves the following parameters:
-
data_train (an annotated matrix): the input data
-
hidden_dims (a list of length three): the width of the neural network in each layer. In total, three layers are considered in the neural network.
-
num_epochs_max: maximum number of epochs considered in the training procedure.
-
path: the directory that saving the model object
-
filename: the name of the model object to be saved to the path.
Training Tips: Tuning matters a lot!!
Optimizing the parameters of our training algorithm is crucial for achieving the best performance. Let’s explore the key settings that you’ll need to tweak to improve the predictivity:
Choosing the Right Batch Size
Batch size is a critical determinant of your training process's success. While beginning with a batch size of 32 is conventional, we encourage experimenting with this figure. Should you encounter 'NA' errors, indicative of numerical instability, reducing the batch size to 8 may prove beneficial. Ultimately, the objective is to identify the optimal batch size that best suits the unique requirements of your specific scenario.
Try adding an option ' batch_size = 8 ' in your
Fit_layer()function.
Setting the Initial Learning Rate
The initial learning rate effectively sets the pace of progress during your training regimen. Initiating the process with a learning rate of 0.1 enables the algorithm to undertake significant adjustments at the outset. Importantly, the algorithm is designed to intelligently reduce the learning rate if it detects excessive increases in loss, preventing overfitting and ensuring smoother convergence. This adaptive modification is regulated by the 'number error try' option (see below), which monitors the frequency of deteriorating loss metrics, adjusting the rate accordingly to optimize performance.
Try adding an option ' initial_learning_rate = 0.01 ' in your
Fit_layer()function.
Adjusting for Error Tries
Speaking of error tries, this setting is your safeguard against the instable loss changes. Align this with the number of epochs you plan to run (given by option num_epochs_max). It helps to adjust our learning rate dynamically.
Try adding an option ' number_error_try = 60 ' in your
Fit_layer()function.
Prediction
Then, we apply the trained model to the query data to predict the coordinates of the cells.
The prediction function Predict_cord contains three arguments:
-
data_test (an annotated matrix): the input query dat
-
path: the directory that saving the model object
-
filename: the name of the model object to be saved
The method implementation outputs the 2D coordinates in pred_cord. A
.csv file will also saved with the name predmatrix. The prediction results is also saved in Qdata.obs.
Example code:
pred_cord = cel.Predict_cord (data_test = Qdata, path = "output/tutorial", filename = "Org_Mousesc")
pred_cord
Output:
array([[0.67726576, 0.49435037],
[0.42489582, 0.51810944],
[0.07367212, 0.4977431 ],
...,
[0.72734278, 0.43093637],
[0.63597023, 0.10852443],
[0.3674576 , 0.50103331]])
Each row of the output matrix represents a 2D coordinates of the predicted cells. The results is also appearing in the updaded Qdata.obs.
Qdata.obs
| exp_component_name | platform_label | cluster_color | cluster_order | cluster_label | class_color | class_order | class_label | subclass_color | subclass_order | ... | injection_roi_label | injection_type_color | injection_type_id | injection_type_label | cortical_layer_label | outlier_call | outlier_type | n_counts | x_cord_pred | y_cord_pred | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| sample_name | |||||||||||||||||||||
| CAGCGACAGAGACTTA-L8TX_180115_01_F11 | CAGCGACAGAGACTTA-17L8TX_180115_01_F11 | 10x | #07C6D9 | 196.0 | 196_L4/5 IT CTX | #00ADEE | 2.0 | Glutamatergic | #00E5E5 | 17.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 602.0 | 0.700048 | 0.534017 |
| AGCTCTCCATCGACGC-L8TX_180115_01_F09 | AGCTCTCCATCGACGC-23L8TX_180115_01_F09 | 10x | #00FFFF | 188.0 | 188_L4/5 IT CTX | #00ADEE | 2.0 | Glutamatergic | #00E5E5 | 17.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 1597.0 | 0.427784 | 0.516523 |
| AGAGCGACACCTCGTT-L8TX_180406_01_C08 | AGAGCGACACCTCGTT-7L8TX_180406_01_C08 | 10x | #297F98 | 331.0 | 331_L6 CT ENTm | #00ADEE | 2.0 | Glutamatergic | #174596 | 34.0 | ... | NaN | #FF7373 | 1 | NaN | 0 | False | nan | 869.0 | 0.071109 | 0.467619 |
| TGGCGCAAGTACACCT-L8TX_180115_01_C11 | TGGCGCAAGTACACCT-11L8TX_180115_01_C11 | 10x | #28758B | 329.0 | 329_L6 CT CTX | #00ADEE | 2.0 | Glutamatergic | #2D8CB8 | 33.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 1876.0 | 0.900394 | 0.176353 |
| ACCTTTAGTTATCACG-L8TX_180115_01_D11 | ACCTTTAGTTATCACG-12L8TX_180115_01_D11 | 10x | #00FFFF | 188.0 | 188_L4/5 IT CTX | #00ADEE | 2.0 | Glutamatergic | #00E5E5 | 17.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 2068.0 | 0.345863 | 0.418795 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| TTTGTCATCGTACCGG-L8TX_180221_01_B11 | TTTGTCATCGTACCGG-1L8TX_180221_01_B11 | 10x | #B36C76 | 13.0 | 13_Lamp5 | #F05A28 | 1.0 | GABAergic | #DA808C | 3.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 966.0 | 0.222131 | 0.285191 |
| ACCTTTAGTGATGATA-L8TX_180221_01_D11 | ACCTTTAGTGATGATA-3L8TX_180221_01_D11 | 10x | #30E6BA | 131.0 | 131_L2 IT RSPv | #00ADEE | 2.0 | Glutamatergic | #2DB38A | 10.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 833.0 | 0.248778 | 0.135483 |
| AGCTCTCCATAGAAAC-L8TX_180115_01_D11 | AGCTCTCCATAGAAAC-12L8TX_180115_01_D11 | 10x | #02F970 | 183.0 | 183_L2/3 IT CTX | #00ADEE | 2.0 | Glutamatergic | #0BE652 | 15.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 1423.0 | 0.726116 | 0.442133 |
| AGAGCGACACAACGTT-L8TX_180406_01_H01 | AGAGCGACACAACGTT-1L8TX_180406_01_H01 | 10x | #299337 | 141.0 | 141_L3 IT ENTm | #00ADEE | 2.0 | Glutamatergic | #65CA2F | 12.0 | ... | NaN | #FF7373 | 1 | NaN | 0 | False | nan | 911.0 | 0.632229 | 0.108702 |
| AGCTCTCCATATACGC-L8TX_180406_01_B06 | AGCTCTCCATATACGC-5L8TX_180406_01_B06 | 10x | #297F98 | 331.0 | 331_L6 CT ENTm | #00ADEE | 2.0 | Glutamatergic | #174596 | 34.0 | ... | NaN | #FF7373 | 1 | NaN | 0 | False | nan | 708.0 | 0.330912 | 0.482509 |
3000 rows × 59 columns
Confidence Score
To quantify the uncertainty of the prediction, we can produce confidence score for each predicted subject. We first train the deep neural network using cel.Fit_region.
The usage of Fit_region is similar to Fit_cord. An extra parmameter alpha is needed to indicate the confidence level.
model_train = cel.Fit_region (data_train = Rdata, alpha = 0.95, hidden_dims = [30, 25, 15], num_epochs_max = 500, path = "output/example", filename = "ConfRegion_Mousesc")
model_train
Then, we use Predict_region to evaluate the confidence score for each prediction subject. This function produce two new columns in the object of query data: area and conf_score.
The area records the area of the predicted circle, which will cover the truth with probability alpha. The confidence score measures the uncertainty of the prediction, which is defined as 1 - area. Higher confidence level represents a lower uncertainty in the prediction.
pred_region = cel.Predict_region (data_test = Qdata, path = "output/example", filename = "ConfRegion_Mousesc")
Qdata.obs
Output:
| exp_component_name | platform_label | cluster_color | cluster_order | cluster_label | class_color | class_order | class_label | subclass_color | subclass_order | ... | injection_roi_label | injection_type_color | injection_type_id | injection_type_label | cortical_layer_label | outlier_call | outlier_type | n_counts | area_record | conf_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| sample_name | |||||||||||||||||||||
| CAGCGACAGAGACTTA-L8TX_180115_01_F11 | CAGCGACAGAGACTTA-17L8TX_180115_01_F11 | 10x | #07C6D9 | 196.0 | 196_L4/5 IT CTX | #00ADEE | 2.0 | Glutamatergic | #00E5E5 | 17.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 602.0 | 0.521508 | 0.478492 |
| AGCTCTCCATCGACGC-L8TX_180115_01_F09 | AGCTCTCCATCGACGC-23L8TX_180115_01_F09 | 10x | #00FFFF | 188.0 | 188_L4/5 IT CTX | #00ADEE | 2.0 | Glutamatergic | #00E5E5 | 17.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 1597.0 | 0.464650 | 0.535350 |
| AGAGCGACACCTCGTT-L8TX_180406_01_C08 | AGAGCGACACCTCGTT-7L8TX_180406_01_C08 | 10x | #297F98 | 331.0 | 331_L6 CT ENTm | #00ADEE | 2.0 | Glutamatergic | #174596 | 34.0 | ... | NaN | #FF7373 | 1 | NaN | 0 | False | nan | 869.0 | 0.478806 | 0.521194 |
| TGGCGCAAGTACACCT-L8TX_180115_01_C11 | TGGCGCAAGTACACCT-11L8TX_180115_01_C11 | 10x | #28758B | 329.0 | 329_L6 CT CTX | #00ADEE | 2.0 | Glutamatergic | #2D8CB8 | 33.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 1876.0 | 0.533365 | 0.466635 |
| ACCTTTAGTTATCACG-L8TX_180115_01_D11 | ACCTTTAGTTATCACG-12L8TX_180115_01_D11 | 10x | #00FFFF | 188.0 | 188_L4/5 IT CTX | #00ADEE | 2.0 | Glutamatergic | #00E5E5 | 17.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 2068.0 | 0.419726 | 0.580274 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| TTTGTCATCGTACCGG-L8TX_180221_01_B11 | TTTGTCATCGTACCGG-1L8TX_180221_01_B11 | 10x | #B36C76 | 13.0 | 13_Lamp5 | #F05A28 | 1.0 | GABAergic | #DA808C | 3.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 966.0 | 0.599244 | 0.400756 |
| ACCTTTAGTGATGATA-L8TX_180221_01_D11 | ACCTTTAGTGATGATA-3L8TX_180221_01_D11 | 10x | #30E6BA | 131.0 | 131_L2 IT RSPv | #00ADEE | 2.0 | Glutamatergic | #2DB38A | 10.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 833.0 | 0.490772 | 0.509228 |
| AGCTCTCCATAGAAAC-L8TX_180115_01_D11 | AGCTCTCCATAGAAAC-12L8TX_180115_01_D11 | 10x | #02F970 | 183.0 | 183_L2/3 IT CTX | #00ADEE | 2.0 | Glutamatergic | #0BE652 | 15.0 | ... | NaN | #FF7373 | 1 | NaN | All | False | nan | 1423.0 | 0.445069 | 0.554931 |
| AGAGCGACACAACGTT-L8TX_180406_01_H01 | AGAGCGACACAACGTT-1L8TX_180406_01_H01 | 10x | #299337 | 141.0 | 141_L3 IT ENTm | #00ADEE | 2.0 | Glutamatergic | #65CA2F | 12.0 | ... | NaN | #FF7373 | 1 | NaN | 0 | False | nan | 911.0 | 0.674069 | 0.325931 |
| AGCTCTCCATATACGC-L8TX_180406_01_B06 | AGCTCTCCATATACGC-5L8TX_180406_01_B06 | 10x | #297F98 | 331.0 | 331_L6 CT ENTm | #00ADEE | 2.0 | Glutamatergic | #174596 | 34.0 | ... | NaN | #FF7373 | 1 | NaN | 0 | False | nan | 708.0 | 0.442143 | 0.557857 |
3000 rows × 59 columns
4.2 Analysis Task 2: Layer Recovery
In the second task, we use CeLEry to classify the cells into different layers. First, we load the spatial transcriptomics data with annotation for layers together with a single cell RNA data collected from an Alzheimer's study.
Qdata = sc.read("data/tutorial/AlzheimerToy.h5ad")
Rdata = sc.read("data/tutorial/DataLayerToy.h5ad")
cel.get_zscore(Qdata)
cel.get_zscore(Rdata)
Qdata
Output:
$ \text{AnnData} \text{object} \text{with} \text{n\_obs} \times \text{n\_vars} = 3000 \times 26423 \text{obs}: '\text{cellname}', '\text{sample}', '\text{groupid}', '\text{final\_celltype}', '\text{maxprob}', '\text{imaxprob}', '\text{trem2}', '\text{atscore}', '\text{apoe}', '\text{sampleID}', '\text{n\_counts}' \text{var}: '\text{Ensembl}', '\text{genename}', '\text{n\_cells}' $
Rdata
Output:
$ \text{AnnData} \text{object} \text{with} \text{n\_obs} \times \text{n\_vars} = 3611 \times 1134 \text{obs}: '\text{x2}', '\text{x3}', '\text{Layer}', '\text{Layer\_character}', '\text{n\_counts}' \text{var}: '\text{gene\_ids}', '\text{feature\_types}', '\text{genome}', '\text{genename}' \text{uns}: '\text{wilcoxon}' $
One common error that can happen is that after normalization, the gene expression can have NAs due to the lack of variability. (Usually because the data is not well preprocessed.) This may cause the problem when training the data or evaluate the query. It is a good practice to check whether NA exists in the data:
np.isnan(Rdata.X).any()
Output:
ArrayView(False)
The training data is good
np.isnan(Qdata.X).any()
Output:
ArrayView(True)
The testing data has NaN.
Now we use drop_NaN function in CeLEry package to remove the genes that has NaN (due to no variation in the data). This step is usually not needed if preprocessing is properly made.
Qdata = cel.drop_NaN(Qdata)
Qdata
Output:
$ \text{AnnData} \text{object} \text{with} \text{n\_obs} \times \text{n\_vars} = 3000 \times 23796 \text{obs}: '\text{cellname}', '\text{sample}', '\text{groupid}', '\text{final\_celltype}', '\text{maxprob}', '\text{imaxprob}', '\text{trem2}', '\text{atscore}', '\text{apoe}', '\text{sampleID}', '\text{n\_counts}' \text{var}: '\text{Ensembl}', '\text{genename}', '\text{n\_cells}' $
For many times, the gene set in query data (Qdata) is different from the reference data (Rdata). To ensure the model trained by reference data is applicable to the query data, it is essential to the genes sets are identical for both datasets.
common_gene = list(set(Qdata.var_names) & set(Rdata.var_names))
#
Query_select = Qdata[:,common_gene]
Reference_select = Rdata[:,common_gene]
Output of comparison after gene filtering:
>>> Qdata
AnnData object with n_obs × n_vars = 2452 × 23796
obs: 'cellname', 'sample', 'groupid', 'final_celltype', 'maxprob', 'imaxprob', 'trem2', 'atscore', 'apoe', 'sampleID', 'n_counts'
var: 'Ensembl', 'genename', 'n_cells'
>>> Query_select
View of AnnData object with n_obs × n_vars = 3611 × 1120
obs: 'x2', 'x3', 'Layer', 'Layer_character', 'n_counts'
var: 'gene_ids', 'feature_types', 'genome', 'genename'
uns: 'wilcoxon'
The sample size of the spots in each layer could be very different, leading to the poor performance of the classification in some layers. We consider weighting the sample from each layer. A typical way to choose weight is to use $1/sample size$.
layer_count = Reference_select.obs["Layer"].value_counts().sort_index()
layer_weight = layer_count[7]/layer_count[0:7]
layer_weight
Output:
[1.8791, 2.0277, 0.5187, 2.3532, 0.7623, 0.7413, 1.0000]
We train the model using the function Fit_layer. The model will
returned and also save as an .obj object to be loaded later. This step can take an hour according to the structure of the neural network.
model_train = cel.Fit_layer (data_train = Reference_select, layer_weights = layer_weight, layerkey = "Layer",
hidden_dims = [30, 25, 15], num_epochs_max = 500, path = "output/tutorial", filename = "Org_layer")
Then, we apply the trained model to the scRNA-seq/snRNA-seq data:
pred_layer = cel.Predict_layer(data_test = Query_select, class_num = 7, path = "output/tutorial", filename = "Org_layer", predtype = "deterministic")
pred_layer
Output:
array([4., 4., 6., ..., 6., 5., 4.])
probability_each_layer = cel.Predict_layer(data_test = Query_select, class_num = 7, path = "output/tutorial", filename = "Org_layer", predtype = "probabilistic")
probability_each_layer
Output:
array([[ 2.27034092e-04, 1.87861919e-02, 3.39182794e-01, ...,
6.73830733e-02, 1.59902696e-03, 5.36697644e-06],
[ 3.27110291e-04, 2.68441439e-02, 4.18585479e-01, ...,
4.77917679e-02, 1.11017306e-03, 3.72435829e-06],
[-0.00000000e+00, 4.52995300e-06, 1.26361847e-04, ...,
1.24190569e-01, 8.50280046e-01, 2.23746095e-02],
...,
[ 1.19209290e-07, 9.77516174e-06, 2.75790691e-04, ...,
2.34772027e-01, 7.47992218e-01, 1.03732459e-02],
[ 2.26497650e-06, 1.93536282e-04, 5.41210175e-03, ...,
7.42786407e-01, 1.36679530e-01, 5.30853984e-04],
[ 7.78675079e-04, 6.15803003e-02, 5.94597340e-01, ...,
2.06700731e-02, 4.66533442e-04, 1.56409408e-06]])
4.3 Analysis Task 3: Domain Recovery
In the third task, we use CeLEry to classify the cells into different domains without layer structures. i.e., we don't assume the domain having an ordinal relationship.
Qdata = sc.read("tutorial/data/AlzheimerToy.h5ad")
Rdata = sc.read("tutorial/data/DataLayerToy.h5ad")
cel.get_zscore(Qdata)
cel.get_zscore(Rdata)
common_gene = list(set(Qdata.var_names) & set(Rdata.var_names))
#
Query_select = Qdata[:,common_gene]
Reference_select = Rdata[:,common_gene]
Note that the classes need to span from 0 to N-1. #!# Important
Reference_select.obs["domain_id"] = Rdata.obs["Layer"] -1
domain_count = Reference_select.obs["domain_id"].value_counts().sort_index()
domain_weight = domain_count[len(domain_count)-1]/domain_count[0:len(domain_count)]
domain_weights = torch.tensor(domain_weight.to_numpy(), dtype=torch.float32)
domain_weights
We train the model using the function Fit_domain. The fitted model will be returned and also save as an .obj object to be loaded later.
model_train = cel.Fit_domain (data_train = Reference_select, domain_weights = domain_weight, domainkey = "domain_id",
hidden_dims = [30, 25, 15], num_epochs_max = 500, path = "output/example", filename = "PreOrg_domain")
model_train
To predict the results, we implement Predict_domain function. A probabilitic classificatoin matrix will be returned and the deterministic domain prediction will be attached to .obs.
pred_domain = cel.Predict_domain (data_test = Query_select, class_num = 7, path = "output/example", filename = "PreOrg_domain")
Query_select.obs
Output:
| cellname | sample | groupid | final_celltype | maxprob | imaxprob | trem2 | atscore | apoe | sampleID | n_counts | pred_domain | pred_domain_str | domain_cel_pred | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| index | ||||||||||||||
| GGGAATGGTTATTCTC-1-C2 | GGGAATGGTTATTCTC-1 | C2 | C | Oli | 0.827 | 827 | WT | A-T- | E3/E3 | 1 | 1277.0 | 6 | 6 | 6.0 |
| ATCATGGCAGTGAGTG-1-C3 | ATCATGGCAGTGAGTG-1 | C3 | C | Ast | 0.822 | 822 | WT | A-T- | E3/E3 | 2 | 1961.0 | 6 | 6 | 6.0 |
| CGACCTTAGTGTCCCG-1-I4 | CGACCTTAGTGTCCCG-1 | I4 | I | Oli | 0.827 | 827 | WT | A+T- | E3/E3 | 3 | 706.0 | 6 | 6 | 6.0 |
| ATCACGAGTTGGAGGT-1-I1 | ATCACGAGTTGGAGGT-1 | I1 | I | Ex | 0.779 | 779 | WT | A+T- | E3/E4 | 6 | 2541.0 | 6 | 6 | 6.0 |
| CGACCTTCACAGATTC-1-I4 | CGACCTTCACAGATTC-1 | I4 | I | In | 0.775 | 775 | WT | A+T- | E3/E3 | 3 | 4185.0 | 6 | 6 | 6.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| GGAATAACATACGCTA-1-T3 | GGAATAACATACGCTA-1 | T3 | T | Ast | 0.811 | 811 | R47H | A+T+ | E3/E3 | 9 | 583.0 | 6 | 6 | 6.0 |
| TTTCCTCGTCGGGTCT-1-I2 | TTTCCTCGTCGGGTCT-1 | I2 | I | Ex | 0.545 | 545 | WT | A+T- | E3/E4 | 5 | 4217.0 | 6 | 6 | 6.0 |
| TCTGAGAAGAATAGGG-T1COMB | TCTGAGAAGAATAGGG | T1 | T | Ast | 0.821 | 821 | R47H | A+T+ | E4/E4 | 15 | 1455.0 | 6 | 6 | 6.0 |
| TTTCCTCTCGTCCAGG-1-I2 | TTTCCTCTCGTCCAGG-1 | I2 | I | End | 0.721 | 721 | WT | A+T- | E3/E4 | 5 | 1166.0 | 6 | 6 | 6.0 |
| GCAGCCATCTGGCGTG-1-T2 | GCAGCCATCTGGCGTG-1 | T2 | T | Ex | 0.355 | 355 | R47H | A+T+ | E3/E4 | 11 | 2434.0 | 6 | 6 | 6.0 |
3000 rows × 14 columns