Code by TomMakesThings
!pip install pytorch-lightning
!pip install kneed
!pip install kmapper
import gzip
import rpy2.robjects as robjects
import tarfile
import matplotlib.pyplot as plot
import torch
import torch.nn as nn
import torch.optim as optim
import random
import string
import math
import numpy as np
import pandas
import warnings
import urllib
import pytorch_lightning as pl
import plotly.graph_objects as go
import scipy.cluster.hierarchy
import seaborn as sns
import kmapper as km
from rpy2.robjects import pandas2ri
from io import BytesIO
from zipfile import ZipFile
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from termcolor import colored
from google.colab import files, output, drive
from sklearn.model_selection import KFold
from sklearn.cluster import KMeans, AgglomerativeClustering, Birch, MiniBatchKMeans, SpectralClustering
from sklearn.cluster import SpectralBiclustering
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA, FastICA, NMF
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, accuracy_score
from sklearn.metrics.cluster import adjusted_rand_score
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.optimize import linear_sum_assignment
from plotly.graph_objs import *
from kneed import KneeLocator
from kmapper import jupyter
%load_ext tensorboard
Although the notebook was designed to run in Colab, it may also be run in a Jupyter notebook. The code below will set the plotly graphing library to render graphs for Jupyter if detected.
if get_ipython().__class__.__name__ == 'ZMQInteractiveShell':
print("Setting plotly for Jupyter notebook")
import plotly
import plotly.offline as pyo
pyo.init_notebook_mode()
Optionally connect to Google Drive. This is not required if you wish to use one of the pre-set datasets: benchmark_dataset
, splat_dataset
or cortex_dataset
. These datasets are opened in the cells below by downloading them from URL. If you want to use a different dataset, this will have to be manually set.
connect_to_drive = False # Allow reading files from Google Drive
if connect_to_drive:
# Connect to Google Drive
drive.mount('/content/gdrive')
Open and view the first dataset, referred to in this project as benchmark_dataset
. This is a dataset is known as sc_10x by Luyi Tian composed of human lung adenocarcinoma cells from three cell lines. Further information about this dataset is avaliable in the paper scPipe: A flexible R/Bioconductor preprocessing pipeline for single-cell RNA-sequencing data
try:
# Try open the original dataset from Luyi Tian's GitHub as a DataFrame
benchmark_dataset = pandas.read_csv("https://github.com/LuyiTian/sc_mixology/raw/master/data/csv/sc_10x.count.csv.gz", index_col=0)
except:
try:
# If this fails, try open the dataset from my GitHub
benchmark_dataset = pandas.read_csv("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Datasets/benchmark_counts.csv.gz", index_col=0)
except IOError as exc:
raise RuntimeError("File not found: failed to open benchmarking dataset") from exc
print(colored("Benchmarking Dataset:\n", attrs=['bold']))
benchmark_dataset
Benchmarking Dataset:
CELL_000001 | CELL_000002 | CELL_000003 | CELL_000004 | CELL_000005 | CELL_000006 | CELL_000007 | CELL_000008 | CELL_000009 | CELL_000010 | ... | CELL_000931 | CELL_000932 | CELL_000933 | CELL_000934 | CELL_000935 | CELL_000939 | CELL_000943 | CELL_000946 | CELL_000955 | CELL_000965 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
ENSG00000272758 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 1 | 1 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000154678 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 1 | 2 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000148737 | 0 | 0 | 0 | 1 | 3 | 2 | 0 | 2 | 1 | 0 | ... | 0 | 0 | 0 | 2 | 0 | 4 | 0 | 0 | 2 | 0 |
ENSG00000196968 | 0 | 0 | 0 | 1 | 2 | 2 | 5 | 0 | 4 | 1 | ... | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 |
ENSG00000134297 | 0 | 0 | 0 | 1 | 1 | 1 | 2 | 1 | 3 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
ENSG00000237289 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000238098 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 |
ENSG00000133433 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000054219 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000137691 | 0 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | ... | 0 | 0 | 0 | 2 | 0 | 1 | 0 | 0 | 0 | 1 |
16468 rows × 902 columns
View the three cell lines of the samples. These are used as labels to evaluate accuracy during clustering.
# View metadata for benchmarking dataset
try:
benchmark_metadata = pandas.read_csv("https://github.com/LuyiTian/sc_mixology/raw/master/data/csv/sc_10x.metadata.csv.gz", index_col=0).transpose()
except:
try:
benchmark_metadata = pandas.read_csv("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Datasets/benchmark_metadata.csv.gz", index_col=0).transpose()
except IOError as exc:
raise RuntimeError("File not found: failed to open benchmarking metadata") from exc
# Get cell lines
benchmark_metadata = pandas.DataFrame(benchmark_metadata.loc['cell_line'])
print(colored("Cell Lines in Benchmarking Dataset:\n", attrs=['bold']))
benchmark_metadata
Cell Lines in Benchmarking Dataset:
cell_line | |
---|---|
CELL_000001 | HCC827 |
CELL_000002 | H1975 |
CELL_000003 | HCC827 |
CELL_000004 | HCC827 |
CELL_000005 | HCC827 |
... | ... |
CELL_000939 | H1975 |
CELL_000943 | H1975 |
CELL_000946 | H2228 |
CELL_000955 | HCC827 |
CELL_000965 | HCC827 |
902 rows × 1 columns
View the adjusted Rand index (ARI) from experiments using the same dataset. This provides a comparision against experiments in this notebook.
These are from the paper: scRNA-seq mixology: towards better benchmarking of single cell RNA-seq protocols and analysis methods
Note downloading this data is not essential to run the rest of the notebook.
pandas2ri.activate()
readRDS = robjects.r['readRDS']
# Try open benchmark clustering results
try:
# Download file from URL
urllib.request.urlretrieve("https://github.com/LuyiTian/sc_mixology/raw/master/data/benchmark_results/clustering/cluster_evaluation_result.Rds",
"cluster_evaluation_result.Rds")
except:
try:
urllib.request.urlretrieve("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Datasets/benchmark_comparison_result.Rds",
"cluster_evaluation_result.Rds")
except IOError as exc:
raise RuntimeError("File not found: failed to open benchmark results") from exc
# Open R object as a DataFrame
benchmark_cluster_metadata = readRDS("cluster_evaluation_result.Rds")
# Uncomment to show all rows
#pandas.set_option('display.max_rows', None)
# Filter to find adjusted Rand index (ARI) for sc_10x dataset
benchmark_cluster_metadata = benchmark_cluster_metadata.loc[(benchmark_cluster_metadata['data'] == 'sc_10x') & (benchmark_cluster_metadata['clustering_evaluation'] == "ARI")]
# Drop any rows with missing values
benchmark_cluster_metadata = benchmark_cluster_metadata.dropna()
# Sort from best ARI to worst ARI
benchmark_cluster_metadata = benchmark_cluster_metadata.sort_values(by=['result'], ascending=False)
# Drop all columns in which all values are the same
benchmark_cluster_metadata = benchmark_cluster_metadata[benchmark_cluster_metadata.columns[benchmark_cluster_metadata.nunique() > 1]]
print(colored("Best ARI: ", color="magenta", attrs=['bold']) + str(round(benchmark_cluster_metadata["result"].max(), 3)))
print(colored("Worst ARI: ", color="cyan", attrs=['bold']) + str(round(benchmark_cluster_metadata["result"].min(), 3)))
print(colored("Mean: ", color="green", attrs=['bold']) + str(round(benchmark_cluster_metadata["result"].mean(), 3)) + "\n")
print(colored("ARI for Clustering Experiments on Benchmarking Dataset:\n", attrs=['bold']))
benchmark_cluster_metadata
Best ARI: 0.742 Worst ARI: 0.095 Mean: 0.436 ARI for Clustering Experiments on Benchmarking Dataset:
norm_method | impute_method | clustering_method | result | |
---|---|---|---|---|
9701 | TMM | DrImpute | SC3 | 0.741936 |
10373 | scone | DrImpute | SC3 | 0.741936 |
9477 | DESeq2 | DrImpute | SC3 | 0.741291 |
9253 | scran | DrImpute | SC3 | 0.741291 |
10157 | Linnorm | SAVER | SC3 | 0.741144 |
... | ... | ... | ... | ... |
4293 | scone | knn_smooth2 | Seurat_1.6 | 0.228908 |
4357 | SCnorm | knn_smooth2 | Seurat_1.6 | 0.226527 |
4165 | logCPM | knn_smooth2 | Seurat_1.6 | 0.225011 |
4229 | Linnorm | knn_smooth2 | Seurat_1.6 | 0.216451 |
10381 | scone | SAVER | SC3 | 0.095372 |
127 rows × 4 columns
Open the second dataset, referred to as splat_dataset
. This data was created to mimic the gene expression of the benchmarking dataset by using the Splat simulator which is part of Splatter. The creation of this dataset is explained in the report Clustering and Topological Data Analysis of Single-Cell RNA Sequencing Data and is avaliable to download from my GitHub.
# Open Splat simulated dataset
try:
splat_dataset = pandas.read_csv("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Datasets/simulated_counts.csv.gz", index_col=0)
except IOError as exc:
raise RuntimeError("File not found: failed to open simulated dataset") from exc
print(colored("Splat Simulated Dataset:\n", attrs=['bold']))
splat_dataset
Splat Simulated Dataset:
Cell1 | Cell2 | Cell3 | Cell4 | Cell5 | Cell6 | Cell7 | Cell8 | Cell9 | Cell10 | Cell11 | Cell12 | Cell13 | Cell14 | Cell15 | Cell16 | Cell17 | Cell18 | Cell19 | Cell20 | Cell21 | Cell22 | Cell23 | Cell24 | Cell25 | Cell26 | Cell27 | Cell28 | Cell29 | Cell30 | Cell31 | Cell32 | Cell33 | Cell34 | Cell35 | Cell36 | Cell37 | Cell38 | Cell39 | Cell40 | ... | Cell1961 | Cell1962 | Cell1963 | Cell1964 | Cell1965 | Cell1966 | Cell1967 | Cell1968 | Cell1969 | Cell1970 | Cell1971 | Cell1972 | Cell1973 | Cell1974 | Cell1975 | Cell1976 | Cell1977 | Cell1978 | Cell1979 | Cell1980 | Cell1981 | Cell1982 | Cell1983 | Cell1984 | Cell1985 | Cell1986 | Cell1987 | Cell1988 | Cell1989 | Cell1990 | Cell1991 | Cell1992 | Cell1993 | Cell1994 | Cell1995 | Cell1996 | Cell1997 | Cell1998 | Cell1999 | Cell2000 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Gene1 | 2 | 6 | 1 | 7 | 4 | 2 | 0 | 1 | 1 | 2 | 0 | 4 | 6 | 3 | 0 | 4 | 0 | 1 | 0 | 4 | 1 | 6 | 7 | 1 | 6 | 8 | 10 | 0 | 1 | 4 | 0 | 1 | 5 | 3 | 0 | 5 | 1 | 6 | 4 | 3 | ... | 1 | 0 | 1 | 5 | 2 | 2 | 2 | 2 | 3 | 4 | 1 | 3 | 3 | 14 | 1 | 1 | 4 | 2 | 0 | 0 | 12 | 13 | 1 | 7 | 1 | 0 | 1 | 0 | 0 | 8 | 2 | 4 | 4 | 1 | 0 | 2 | 0 | 2 | 0 | 1 |
Gene2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 2 | 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 1 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
Gene3 | 2 | 4 | 2 | 5 | 0 | 7 | 0 | 2 | 0 | 6 | 0 | 1 | 2 | 1 | 1 | 0 | 0 | 1 | 0 | 1 | 0 | 9 | 4 | 4 | 0 | 1 | 1 | 4 | 1 | 2 | 1 | 7 | 2 | 3 | 5 | 3 | 0 | 4 | 6 | 4 | ... | 5 | 1 | 1 | 1 | 0 | 8 | 0 | 0 | 3 | 0 | 0 | 1 | 8 | 3 | 8 | 5 | 1 | 3 | 2 | 1 | 2 | 3 | 0 | 6 | 0 | 1 | 0 | 2 | 0 | 5 | 2 | 1 | 5 | 9 | 1 | 0 | 2 | 0 | 7 | 1 |
Gene4 | 0 | 0 | 7 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 3 | 0 | 2 | 1 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | 8 | 1 | 1 | 0 | 1 | 0 | 3 | 1 | 0 | 3 | 2 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 14 | 0 | 3 | 4 | 0 | 0 | 5 | 1 | 2 | 0 | 2 | 0 | 0 | 1 | 1 | 0 | 2 | 4 | 0 | 1 | 0 | 3 | 4 | 0 | 0 | 0 | 0 |
Gene5 | 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 3 | 0 | 0 | 0 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 2 | 0 | 0 | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
Gene16464 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
Gene16465 | 1 | 4 | 0 | 0 | 2 | 6 | 2 | 9 | 0 | 8 | 1 | 14 | 6 | 4 | 7 | 4 | 1 | 9 | 4 | 0 | 2 | 7 | 14 | 0 | 0 | 4 | 5 | 3 | 5 | 5 | 4 | 10 | 5 | 7 | 3 | 6 | 0 | 0 | 11 | 4 | ... | 8 | 0 | 2 | 2 | 3 | 1 | 2 | 3 | 2 | 8 | 0 | 1 | 3 | 16 | 4 | 1 | 3 | 2 | 3 | 2 | 1 | 23 | 0 | 5 | 2 | 5 | 13 | 1 | 3 | 10 | 5 | 1 | 8 | 3 | 10 | 2 | 11 | 3 | 22 | 4 |
Gene16466 | 5 | 12 | 1 | 8 | 5 | 7 | 0 | 7 | 1 | 7 | 2 | 3 | 12 | 5 | 7 | 8 | 1 | 2 | 0 | 0 | 2 | 14 | 11 | 3 | 9 | 5 | 0 | 4 | 11 | 1 | 0 | 0 | 1 | 35 | 1 | 0 | 6 | 14 | 2 | 5 | ... | 13 | 0 | 6 | 0 | 10 | 8 | 1 | 3 | 5 | 11 | 0 | 4 | 4 | 9 | 0 | 3 | 3 | 7 | 16 | 9 | 1 | 21 | 1 | 2 | 2 | 10 | 0 | 2 | 7 | 1 | 3 | 4 | 7 | 7 | 1 | 2 | 13 | 10 | 5 | 2 |
Gene16467 | 0 | 5 | 0 | 0 | 0 | 1 | 0 | 0 | 5 | 2 | 5 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 2 | 2 | 1 | 0 | 0 | 0 | 7 | 0 | 2 | 5 | 2 | 3 | 0 | 6 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 5 | 0 | 1 | 1 | 0 | 0 | 0 | 3 | 4 | 1 | 7 | 2 | 5 | 0 | 0 | 0 | 3 | 0 | 0 | 0 | 0 | 1 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 1 |
Gene16468 | 3 | 0 | 0 | 0 | 1 | 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 2 | 0 | 4 | 0 | 0 | 0 | 0 | ... | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 3 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | 1 | 0 | 0 | 0 |
16468 rows × 2000 columns
View the four groups of the cell samples. These are used as labels to evaluate accuracy during clustering.
try:
splat_metadata = pandas.read_csv("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Datasets/simulated_metadata.csv.gz", index_col=0)
except IOError as exc:
raise RuntimeError("File not found: failed to open simulated metadata") from exc
# Rename cell names (rows) to match
splat_metadata.index = splat_dataset.columns.values
splat_metadata.columns = ['Group']
print(colored("Groups in Simulated Dataset:\n", attrs=['bold']))
splat_metadata
Groups in Simulated Dataset:
Group | |
---|---|
Cell1 | Group4 |
Cell2 | Group3 |
Cell3 | Group1 |
Cell4 | Group3 |
Cell5 | Group2 |
... | ... |
Cell1996 | Group3 |
Cell1997 | Group1 |
Cell1998 | Group1 |
Cell1999 | Group3 |
Cell2000 | Group4 |
2000 rows × 1 columns
Open the third dataset, referred to as cortex_dataset
. This is a dataset by by Zeisel et al. containing of mRNA reads from mouse cortex and hippocampus cells. It is avaliable to download from Linnarsson Lab and is discussed in the paper Cell types in the mouse cortex and hippocampus revealed by single-cell RNA-seq.
try:
# Try open the dataset from my GitHub as this is faster to download
cortex_data = pandas.read_csv("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Datasets/evaluation_counts.txt",
sep='\t', index_col=1, header=None, low_memory=False)
except:
try:
# If this fails, try open the original dataset from Linnarsson Lab
cortex_data = pandas.read_csv("https://storage.googleapis.com/linnarsson-lab-www-blobs/blobs/cortex/expression_mRNA_17-Aug-2014.txt",
sep='\t', index_col=1, header=None, low_memory=False)
except IOError as exc:
raise RuntimeError("File not found: failed to open evaluation dataset") from exc
# Drop first column as NaN
cortex_data = cortex_data.drop(cortex_data.columns[0], axis=1)
cortex_data = cortex_data[cortex_data.index.notnull()]
# Get ids of cells
cortex_cell_names = cortex_data.loc['cell_id'].values
# Rename cells
cortex_data.columns = cortex_cell_names
cortex_data.index.name = None
try:
cortex_dataset = pandas.read_csv("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Datasets/evaluation_counts.txt",
sep='\t', index_col=0, header=None, low_memory=False)
except:
try:
cortex_dataset = pandas.read_csv("https://storage.googleapis.com/linnarsson-lab-www-blobs/blobs/cortex/expression_mRNA_17-Aug-2014.txt",
sep='\t', index_col=0, header=None, low_memory=False)
except IOError as exc:
raise RuntimeError("File not found: failed to open evaluation dataset") from exc
print(colored("Evaluation Dataset:\n", attrs=['bold']))
# Get gene reads, filter out metadata
# cortex_dataset = cortex_dataset[pandas.notnull(cortex_dataset.index)]
cortex_dataset = pandas.DataFrame(cortex_dataset.loc['r_HY1':]) # r_HY1 is the first gene
# Drop first column as NaN
cortex_dataset = cortex_dataset.drop(cortex_dataset.columns[0], axis=1)
cortex_dataset.columns = cortex_cell_names
cortex_dataset.index.name = None
# Convert gene reads from string to integer
cortex_dataset = cortex_dataset.astype(int)
cortex_dataset
Evaluation Dataset:
1772071015_C02 | 1772071017_G12 | 1772071017_A05 | 1772071014_B06 | 1772067065_H06 | 1772071017_E02 | 1772067065_B07 | 1772067060_B09 | 1772071014_E04 | 1772071015_D04 | 1772071015_C11 | 1772071017_D04 | 1772071017_D06 | 1772067082_D07 | 1772071017_F09 | 1772071017_A09 | 1772067094_C05 | 1772067059_B06 | 1772067096_E05 | 1772066089_C05 | 1772067094_F04 | 1772071045_A01 | 1772071015_C08 | 1772071045_D06 | 1772071017_A03 | 1772071017_F07 | 1772071017_E06 | 1772067066_C10 | 1772071017_B05 | 1772071014_E06 | 1772067058_D11 | 1772071014_B04 | 1772067066_B09 | 1772071017_E10 | 1772071015_B08 | 1772071014_C11 | 1772067066_E10 | 1772067065_F11 | 1772071014_H11 | 1772071017_B11 | ... | 1772063061_C07 | 1772062111_D02 | 1772063077_A05 | 1772063079_F03 | 1772063077_B05 | 1772058171_E08 | 1772062113_H10 | 1772067074_G05 | 1772062113_D02 | 1772066076_C01 | 1772066080_E09 | 1772067057_E04 | 1772066080_D11 | 1772067074_C09 | 1772063077_B06 | 1772066076_H04 | 1772062116_E06 | 1772067064_D03 | 1772063063_G07 | 1772067063_E08 | 1772058148_G02 | 1772066080_G05 | 1772063074_F11 | 1772063074_C04 | 1772062111_E09 | 1772062111_F01 | 1772063074_H03 | 1772063062_H10 | 1772058148_D02 | 1772063078_G10 | 1772066110_D12 | 1772071017_A07 | 1772063071_G10 | 1772058148_C03 | 1772063061_D09 | 1772067059_B04 | 1772066097_D04 | 1772063068_D01 | 1772066098_A12 | 1772058148_F03 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
r_HY1 | 0 | 1 | 0 | 1 | 2 | 1 | 1 | 1 | 4 | 0 | 0 | 0 | 2 | 1 | 0 | 0 | 14 | 0 | 6 | 1 | 3 | 1 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 1 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 4 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 3 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
r_HY3 | 4 | 0 | 0 | 1 | 5 | 0 | 0 | 0 | 13 | 1 | 2 | 0 | 1 | 4 | 0 | 1 | 11 | 0 | 5 | 0 | 5 | 2 | 2 | 0 | 2 | 1 | 2 | 2 | 1 | 3 | 1 | 2 | 4 | 5 | 3 | 0 | 0 | 1 | 3 | 0 | ... | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 4 | 0 | 0 | 5 | 0 | 2 | 6 | 0 | 1 | 0 | 0 | 0 | 1 | 3 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 2 | 0 | 5 | 0 | 0 | 0 | 1 | 0 | 2 | 0 |
r_LTR33A_ | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
r_RLTR10A | 0 | 1 | 0 | 3 | 0 | 5 | 0 | 2 | 0 | 0 | 3 | 0 | 13 | 0 | 0 | 0 | 2 | 2 | 6 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 3 | 0 | 0 | 0 | 2 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 14 | 0 | 0 | 0 | 0 | 0 | 0 | 3 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 3 | 0 | 0 | 2 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 4 | 0 | 6 |
r_IAPLTR1_Mm | 50 | 25 | 45 | 30 | 5 | 62 | 22 | 27 | 27 | 9 | 31 | 42 | 16 | 0 | 41 | 59 | 6 | 15 | 38 | 16 | 10 | 59 | 47 | 5 | 114 | 36 | 45 | 6 | 23 | 89 | 5 | 13 | 16 | 24 | 22 | 11 | 15 | 42 | 31 | 9 | ... | 35 | 14 | 137 | 60 | 10 | 128 | 27 | 16 | 101 | 83 | 156 | 70 | 87 | 15 | 91 | 45 | 31 | 24 | 62 | 29 | 86 | 71 | 87 | 44 | 28 | 5 | 1 | 49 | 17 | 33 | 10 | 5 | 9 | 38 | 48 | 41 | 27 | 25 | 34 | 59 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
r_U3 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 3 | 8 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 6 | 4 | 0 | 0 | 0 | 6 | 1 | 1 | 0 | 4 | 0 | 8 | 0 | 0 | 0 | 0 | ... | 66 | 0 | 0 | 0 | 0 | 0 | 0 | 27 | 0 | 3 | 49 | 4 | 0 | 2 | 0 | 8 | 1 | 13 | 25 | 62 | 21 | 3 | 0 | 55 | 0 | 4 | 18 | 0 | 0 | 3 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 2 | 0 | 9 |
r_tRNA-Arg-CGY_ | 1 | 6 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 3 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 4 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 7 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
r_tRNA-Ala-GCY | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
r_U4 | 0 | 7 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 1 | 1 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | ... | 7 | 0 | 0 | 0 | 0 | 3 | 0 | 23 | 0 | 0 | 0 | 0 | 5 | 4 | 0 | 0 | 4 | 9 | 0 | 9 | 0 | 0 | 0 | 4 | 0 | 4 | 18 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 4 | 1 | 0 | 0 | 0 |
r_tRNA-Ser-TCG | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
1072 rows × 3005 columns
View the assigned types or tissue of the cell samples. This dataset provides four possible options for a group truth:
group #
: nine numerical groups determined from previous experiments by Zeisel et al. using the BackSPIN.level1class
: the labelled version of group #
in which groups 1 and 2 are assigned class joint “pyramidal CA1”, and 7 and 8 “astrocytes_ependymal”.level2class
: the 47 labelled sub-classes of the groups.tissue
: binary labels determining whether the data is from the cortex or hippocampus.cortex_metadata = pandas.DataFrame(cortex_data.loc['level1class'])
print(colored("Groups in Evaluation Dataset:\n", attrs=['bold']))
cortex_metadata
Groups in Evaluation Dataset:
level1class | |
---|---|
1772071015_C02 | interneurons |
1772071017_G12 | interneurons |
1772071017_A05 | interneurons |
1772071014_B06 | interneurons |
1772067065_H06 | interneurons |
... | ... |
1772067059_B04 | endothelial-mural |
1772066097_D04 | endothelial-mural |
1772063068_D01 | endothelial-mural |
1772066098_A12 | endothelial-mural |
1772058148_F03 | endothelial-mural |
3005 rows × 1 columns
If not using one of the three given datasets, open a new dataset here as a DataFrame.
# Replace with the new dataset in the format shown below
custom_dataset = pandas.DataFrame(data=[[1, 2], [3, 4], [5, 6]], index=['Gene 1', 'Gene 2', 'Gene 3'], columns=['Cell 1', 'Cell 2'])
display(custom_dataset)
Cell 1 | Cell 2 | |
---|---|---|
Gene 1 | 1 | 2 |
Gene 2 | 3 | 4 |
Gene 3 | 5 | 6 |
Optionally, also create a DataFrame matching samples to target labels for the clusters. If these are unknown, set custom_metadata = None
.
# Set as labels if known as a DataFrame in format shown below, otherwise set as None
custom_metadata = pandas.DataFrame(data=['Label 1', 'Label 2'], index=['Cell 1', 'Cell 2'], columns=['Target Cluster'])
display(custom_metadata)
Target Cluster | |
---|---|
Cell 1 | Label 1 |
Cell 2 | Label 2 |
Select which dataset you want to use by setting dataset
as either benchmark_dataset
, splat_dataset
, cortex_dataset
, or custom_dataset
if you set a new dataset as a DataFrame above.
Leave metadata = None
if using one of the four options above.
# Set dataset and metadata to use
dataset = benchmark_dataset
metadata = None
# Override metadata if using one of the given datasets
if dataset.equals(benchmark_dataset):
metadata = benchmark_metadata
if dataset.equals(splat_dataset):
metadata = splat_metadata
if dataset.equals(cortex_dataset):
metadata = cortex_metadata
if dataset.equals(custom_dataset):
metadata = custom_metadata
print(colored("Selected Dataset:\n", attrs=['bold']))
display(dataset.reindex(sorted(dataset.columns), axis=1))
Selected Dataset:
CELL_000001 | CELL_000002 | CELL_000003 | CELL_000004 | CELL_000005 | CELL_000006 | CELL_000007 | CELL_000008 | CELL_000009 | CELL_000010 | ... | CELL_000931 | CELL_000932 | CELL_000933 | CELL_000934 | CELL_000935 | CELL_000939 | CELL_000943 | CELL_000946 | CELL_000955 | CELL_000965 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
ENSG00000272758 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 1 | 1 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000154678 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 1 | 2 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000148737 | 0 | 0 | 0 | 1 | 3 | 2 | 0 | 2 | 1 | 0 | ... | 0 | 0 | 0 | 2 | 0 | 4 | 0 | 0 | 2 | 0 |
ENSG00000196968 | 0 | 0 | 0 | 1 | 2 | 2 | 5 | 0 | 4 | 1 | ... | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 |
ENSG00000134297 | 0 | 0 | 0 | 1 | 1 | 1 | 2 | 1 | 3 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
ENSG00000237289 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000238098 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 |
ENSG00000133433 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000054219 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
ENSG00000137691 | 0 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | ... | 0 | 0 | 0 | 2 | 0 | 1 | 0 | 0 | 0 | 1 |
16468 rows × 902 columns
An autoencoder is a neural network that learns a compressed version of the input through encoding and decoding. For scRNA-seq data, this is useful for dimensionality reduction as these datasets contain many features. For example, the benchmarking dataset has 16,468 genes, but the autoencoder can capture the underlying trend of the data in far fewer features, such as 16. Although some of the information is lost, this can be beneficial for removing noise.
DatasetRNASeq
and LitAutoencoder
classes.
class DatasetRNASeq(Dataset):
""" Custom map-style Dataset class for RNA-seq dataset """
def __init__(self, data):
self.dataset = data
def __len__(self):
# Return the number of cells
return len(self.dataset.columns)
def __getitem__(self, index):
# Return the ith sample
column = self.dataset.columns[index]
item = self.dataset[column]
# Convert to tensor
tensor = torch.tensor(item, dtype=torch.float32)
return tensor
class LitAutoencoder(pl.LightningModule):
""" PyTorch Lightning Autoencoder
Parameters:
train_dataloader: the batched training data
validation_dataloader: the batched validation data
test_dataloader: the batched testing data
input_features: set as the number of genes
extra_hidden_features: number of nodes for additional hidden layer if hidden_layer=5
hidden_features: number of nodes for hidden layer
encoded_features: number of dimensions to compress data
loss_function: function to calculate loss
learning_rate: learning rate of optimizer
optimizer: algorithm to update network weights through backpropagation
amsgrad: whether to use AMSGrad if using Adam or AdamW optimizer
activation_function: the function applied between layers
hidden_layers: the number of hidden layers of the network, including the encoding layer
"""
def __init__(self, train_dataloader, validation_dataloader, test_dataloader,
input_features, extra_hidden_features=1024, hidden_features=128,
encoded_features=16, loss_function=nn.MSELoss(),
learning_rate=1e-3, optimizer="Adam", amsgrad=False,
activation_function=nn.ReLU(), hidden_layers=3):
super().__init__()
# Save hyperparameters to hparams.yaml
self.save_hyperparameters('input_features', 'extra_hidden_features',
'hidden_features', 'encoded_features', 'learning_rate',
'loss_function', 'activation_function', 'optimizer',
'amsgrad', 'hidden_layers')
# Data not saved so must be initialised
self.training_dataloader = train_dataloader
self.validation_dataloader = validation_dataloader
self.testing_dataloader = test_dataloader
# Input -> Encoded -> Output
if hidden_layers == 1:
# Encoder layers
self.encoder = nn.Sequential(nn.Linear(input_features, encoded_features))
# Decoder layers
self.decoder = nn.Sequential(nn.Linear(encoded_features, input_features))
# Input -> Hidden -> Hidden2 -> Encoded - > Hidden2 -> Hidden -> Output
elif hidden_layers == 5:
self.encoder = nn.Sequential(
nn.Linear(input_features, extra_hidden_features),
self.hparams.activation_function,
nn.Linear(extra_hidden_features, hidden_features),
self.hparams.activation_function,
nn.Linear(hidden_features, encoded_features)
)
self.decoder = nn.Sequential(
nn.Linear(encoded_features, hidden_features),
self.hparams.activation_function,
nn.Linear(hidden_features, extra_hidden_features),
self.hparams.activation_function,
nn.Linear(extra_hidden_features, input_features)
)
# Input -> Hidden -> Encoded -> Hidden -> Output
else:
# Default to 3
self.encoder = nn.Sequential(
nn.Linear(input_features, hidden_features),
self.hparams.activation_function,
nn.Linear(hidden_features, encoded_features)
)
self.decoder = nn.Sequential(
nn.Linear(encoded_features, hidden_features),
# Uncomment to use non-linear activation function in decoder
#self.hparams.activation_function,
nn.Linear(hidden_features, input_features),
)
def shared_step(self, batch):
# Repeated code from step methods
encoded = self.encoder(batch)
decoded = self.decoder(encoded)
loss = self.hparams.loss_function(batch, decoded)
return loss
def training_step(self, batch, batch_idx):
# Training loop
loss = self.shared_step(batch)
# Current batch id + current epoch * number of batches per epoch
batch_number = batch_idx + (self.current_epoch * len(self.training_dataloader))
if self.logger is not None:
self.logger.experiment.add_scalar("Training Loss / Mini-Batch", loss,
batch_number)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
# Validation loop
loss = self.shared_step(batch)
self.log('val_loss', loss)
return {'validation_loss': loss}
def test_step(self, batch, batch_idx):
# Testing loop
loss = self.shared_step(batch)
if self.logger is not None:
self.logger.experiment.add_scalar("Testing Loss / Mini-Batch", loss, batch_idx)
return {'test_loss': loss}
def training_epoch_end(self, outputs):
# Record average training loss
average_loss = torch.stack([x['loss'] for x in outputs]).mean()
if self.logger is not None:
self.logger.experiment.add_scalars("Loss / Epoch",
{"Train_Loss" : average_loss},
self.current_epoch)
def validation_epoch_end(self, outputs):
# Record average validation loss
average_loss = torch.stack([x['validation_loss'] for x in outputs]).mean()
if self.logger is not None:
self.logger.experiment.add_scalars("Loss / Epoch",
{"Validation_Loss" : average_loss},
self.current_epoch)
def test_epoch_end(self, outputs):
# Record average testing loss
average_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
print("Average test loss: " + str(average_loss.item()))
return {'Average test loss': average_loss}
def train_dataloader(model):
return model.training_dataloader
def val_dataloader(model):
return model.validation_dataloader
def test_dataloader(model):
return model.testing_dataloader
def configure_optimizers(self):
# Set the optimizer
if (self.hparams.optimizer.lower() == "rprop"):
opt = torch.optim.Rprop(self.parameters(), lr=self.hparams.learning_rate)
elif (self.hparams.optimizer.lower() == "sgd"):
opt = torch.optim.SGD(self.parameters(), lr=self.hparams.learning_rate)
elif (self.hparams.optimizer.lower() == "adadelta"):
opt = torch.optim.Adadelta(self.parameters(), lr=self.hparams.learning_rate)
elif (self.hparams.optimizer.lower() == "adagrad"):
opt = torch.optim.Adadelta(self.parameters(), lr=self.hparams.learning_rate)
elif (self.hparams.optimizer.lower() == "adamw"):
opt = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate,
amsgrad=self.hparams.amsgrad)
else:
# Default to Adam
opt = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate,
amsgrad=self.hparams.amsgrad)
return opt
Functions convert_to_weights_only
and read_autoencoder_from_url
.
The function below is not run anywhere in this notebook, though can be called to reduce the file size of the autoencoder's .ckpt (checkpoint) file if you trained a new model with weights_only = False
.
def convert_to_weights_only(checkpoint_file, new_file_name="checkpoint_weights.ckpt"):
key_to_remove = {'optimizer_states', 'lr_schedulers', 'callbacks'}
# Load the checkpoint from file
checkpoint_dict = torch.load(checkpoint_file, map_location=torch.device('cpu'))
# Remove the keys not found in save_weights_only
for key in key_to_remove:
checkpoint_dict.pop(key, None)
# Save the checkpoint
if not new_file_name.endswith('.ckpt'):
new_file_name = new_file_name + '.ckpt'
torch.save(checkpoint_dict, new_file_name)
print("Saved checkpoint to " + str(new_file_name))
This function is used to find the file paths required to construct the training, validation and testing dataloaders if using a pre-trained model.
def read_autoencoder_from_url(url):
# Open the zip file from URL
with ZipFile(BytesIO(url.read())) as zipped_model:
# Check each file in the zip
for file_name in zipped_model.namelist():
if file_name.endswith('.ckpt'):
# Set the file path of the checkpoint
pre_trained_file = zipped_model.extract(file_name, 'pre-trained_autoencoder')
elif 'train' in file_name:
# Set the file path of the training data
train_file = zipped_model.extract(file_name, 'pre-trained_autoencoder')
elif 'val' in file_name:
# Set the file path of the validation data
val_file = zipped_model.extract(file_name, 'pre-trained_autoencoder')
elif 'test' in file_name:
# Set the file path of the testing data
test_file = zipped_model.extract(file_name, 'pre-trained_autoencoder')
try:
# Attempt to return all variables
return pre_trained_file, train_file, val_file, test_file
except IOError as exc:
# Throw error if any variable missing
raise RuntimeError("Error! Could not load autoencoder state from file") from exc
To train a new model, set train_new_model = True
.
Or alternatively, to load a pre-trained model, set train_new_model = False
.
continue_train = True
will continue training the pre-trained modelload_test_data = False
if using different test data from when the model was createdpre_trained_file_path
should be set as the pre-trained model's .ckpt filetrain_new_model = False # Set as False to load a model from a file
# Options if loading a pre-trained model
continue_training = False # Set as True to continue training the pre-loaded model
load_test_data = True # Setting as True means clustering experiments are performed on test data assigned when the model was trained
Override some parameters if loading a pre-trained model for one of the three given datasets.
if dataset.equals(benchmark_dataset):
# Read the model state from a zip file
model_weights_url = urllib.request.urlopen("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Benchmark_Autoencoder/Trained_Benchmark_Autoencoder.zip")
pre_trained_file_path, train_path, val_path, test_path = read_autoencoder_from_url(model_weights_url)
load_test_data = True
elif dataset.equals(splat_dataset):
model_weights_url = urllib.request.urlopen("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Simulated_Autoencoder/Trained_Simulated_Autoencoder.zip")
pre_trained_file_path, train_path, val_path, test_path = read_autoencoder_from_url(model_weights_url)
load_test_data = True
elif dataset.equals(cortex_dataset):
model_weights_url = urllib.request.urlopen("https://github.com/TomMakesThings/Clustering-and-TDA-of-scRNA-seq-Data/raw/main/Data/Evaluation_Autoencoder/Trained_Evaluation_Autoencoder.zip")
pre_trained_file_path, train_path, val_path, test_path = read_autoencoder_from_url(model_weights_url)
load_test_data = True
If training a new model:
test_size
. All remaining data is used for training.batch_size
as the number of samples to include in each training batch.test_size = 150 # Number of samples to use for testing
batch_size = 4 # Number of samples in a training batch
# Get cell names and their gene counts
cell_names = list(set(dataset.columns.values)) # Converting to set then list shuffles column order from original
cell_data = np.transpose(dataset[cell_names].values) # Transpose switches the row and column indices
# Number of genes for each cell
input_feature_number = len(dataset.index)
# Total number of samples
sample_size = len(cell_data)
if train_new_model:
# Split into training and testing datasets
x_train = dataset[cell_names[:sample_size - test_size]] # Last n cells in dataset
x_test = dataset[cell_names[sample_size - test_size:]] # Remaining cells
# Convert training and testing data to custom Dataset class, then DataLoader
train_data = DatasetRNASeq(x_train)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
print("Created new train dataloader")
test_data = DatasetRNASeq(x_test)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
print("Created new test dataloader")
else:
# If using a pre-trained model, try create dataloaders from file
try:
with open(train_path) as train_file:
train_samples = train_file.read().splitlines()
train_data = DatasetRNASeq(dataset[train_samples])
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
except IOError as exc:
raise RuntimeError("Error reading train data from file") from exc
try:
with open(test_path) as test_file:
test_samples = test_file.read().splitlines()
test_data = DatasetRNASeq(dataset[test_samples])
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
except IOError as exc:
raise RuntimeError("Error reading test data from file") from exc
try:
with open(val_path) as validation_file:
validation_samples = validation_file.read().splitlines()
validation_data = DatasetRNASeq(dataset[validation_samples])
validation_dataloader = DataLoader(validation_data, batch_size=batch_size, shuffle=False)
except IOError as exc:
raise RuntimeError("Error reading test data from file") from exc
If training a new autoencoder:
autoencoder_kwargs
.save_weights_only = True
will save only the autoencoder's weights when creating a checkpoint. This can massively reduce the file size of the model, though is not recommended if you wish to load a model from file and continuing training.save_weights_only = True # Whether to only save the model's weights to the checkpoint file
# Autoencoder arguments
autoencoder_kwargs = {"test_dataloader": test_dataloader, "input_features": input_feature_number,
"hidden_features": 128, "encoded_features": 16, "hidden_layers": 3,
"learning_rate": 1e-3, "optimizer": "Adam", "amsgrad": True,
"activation_function": nn.PReLU()}
# Callback to save state
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', save_weights_only = save_weights_only,
filename='autoencoder-{epoch:02d}-{val_loss:.2f}')
Load the weights of a pre-trained autoencoder. This will only run if train_new_model = False
.
if not train_new_model:
try:
autoencoder = LitAutoencoder.load_from_checkpoint(pre_trained_file_path,
train_dataloader=train_dataloader,
validation_dataloader=validation_dataloader,
test_dataloader=test_dataloader,
input_features=input_feature_number)
autoencoder.freeze()
print("Loaded pre-trained model from \"" + str(pre_trained_file_path) + "\"")
except IOError as exc:
raise RuntimeError("Failed to load pre-trained autoencoder from file! To train a new model set train_new_model = True") from exc
Loaded pre-trained model from "pre-trained_autoencoder\benchmark_autoencoder_weights.ckpt"
Split training data into 80% train and 20% validation using k-fold cross validation to find the arrangement that performs best.
k_folds
as the number of splits to test.epochs
as the number of epochs to test each fold over.This part does not run if using a pre-trained model.
k_folds = 5
epochs = 20
if train_new_model:
# Split into k-folds
kfold = KFold(n_splits=k_folds, shuffle=True)
split = kfold.split(train_data)
# Record indexes of training / validation samples in a fold
fold_train_indexes = []
fold_val_indexes = []
# Record final validation loss of each fold
fold_loss = []
# Test each fold
for fold, (train_indexes, val_indexes) in enumerate(split):
fold_train_indexes.append(train_indexes)
fold_val_indexes.append(val_indexes)
# Get the data specified by the indexes returned by the KFold function
train_subsampler = SubsetRandomSampler(train_indexes)
validation_subsampler = SubsetRandomSampler(val_indexes)
train_dataloader = DataLoader(train_data, batch_size=batch_size, sampler=train_subsampler)
validation_dataloader = DataLoader(train_data, batch_size=batch_size, sampler=validation_subsampler)
# Create the autoencoder
k_fold_autoencoder = LitAutoencoder(train_dataloader, validation_dataloader, **autoencoder_kwargs)
logger = TensorBoardLogger('tb_logs', name='k_fold_logger') # Log if seperate folder to final model training
# Create a trainer
k_fold_trainer = pl.Trainer(gpus=torch.cuda.device_count(),
max_epochs=epochs,
logger=logger,
checkpoint_callback=False,
progress_bar_refresh_rate=50) # Reduced refresh rate
k_fold_trainer.fit(k_fold_autoencoder, train_dataloader)
# Record loss for the fold
fold_loss.append(k_fold_trainer.callback_metrics.get('val_loss').item())
# Find the best fold
best_fold_index = fold_loss.index(min(fold_loss))
print(colored('Number of Training Samples:', 'cyan', attrs=['bold']), len(fold_train_indexes[0]))
print(colored('Number of Validation Samples:', 'green', attrs=['bold']), len(fold_val_indexes[0]))
print(colored('Loss per fold:', 'blue', attrs=['bold']), fold_loss)
print(colored('Best fold:', 'red', attrs=['bold']), best_fold_index)
# Set training and validation dataset to use best fold
train_subsampler = SubsetRandomSampler(fold_train_indexes[best_fold_index])
validation_subsampler = SubsetRandomSampler(fold_val_indexes[best_fold_index])
train_dataloader = DataLoader(train_data, batch_size=batch_size, sampler=train_subsampler)
validation_dataloader = DataLoader(train_data, batch_size=batch_size, sampler=validation_subsampler)
Runs if either:
train_new_model = True
continue_training = True
The number of epochs
is the number of training iterations.
Setting use_lr_finder = True
will use PyTorch Lightning's learning rate finder to set learning rate automatically. This has by default been switched off as I found it did not improve performance.
epochs = 250
use_lr_finder = False # Set True to use learning rate finder for training
# Create a logger and specify the folder to save the results to
logger = TensorBoardLogger('tb_logs', name='autoencoder_logger')
# Create a trainer
trainer = pl.Trainer(gpus=torch.cuda.device_count(), # Use avaliable GPUs
max_epochs=epochs,
auto_lr_find=use_lr_finder, # Set whether to use learning rate finder
logger=logger,
callbacks=[checkpoint_callback], # Callback allows saving best model state
progress_bar_refresh_rate=50) # Reduced refresh rate
if train_new_model:
# Create the autoencoder
autoencoder = LitAutoencoder(train_dataloader, validation_dataloader, **autoencoder_kwargs)
if use_lr_finder:
# Create learning rate finder
lr_finder = trainer.tuner.lr_find(autoencoder, min_lr=1e-4, max_lr=1e-2)
# Plot graph of best learning rate
lr_graph = lr_finder.plot(suggest=True)
plot.title('Learning Rate Range Test')
lr_graph.show()
new_lr = lr_finder.suggestion()
# Set model's learning rate as new learning rate
autoencoder.hparams.lr = new_lr
print("New learning rate is " + str(new_lr))
if train_new_model or continue_training:
# Unfreeze layers
autoencoder.unfreeze()
# Train the model with the training dataset
trainer.fit(autoencoder, train_dataloader)
GPU available: False, used: False TPU available: None, using: 0 TPU cores
Optionally train the autoencoder over more epochs by setting continue_training = True
and running this cell. This may be used to if training is not complete.
train_further = False
additional_epochs = 10
if train_further:
autoencoder.unfreeze()
trainer.max_epochs = additional_epochs
trainer.fit(autoencoder, train_dataloader)
For a new autoencoder's state to be saved to file, download_model
must be set to True
. This downloads:
download_model = False # Set True to download the files of trained autoencoder
if train_new_model or continue_training:
best_model = checkpoint_callback.best_model_path
print(best_model)
if download_model:
train_file_name = "train_data.txt"
val_file_name = "validation_data.txt"
test_file_name = "test_data.txt"
# Save to file
test_textfile = open(test_file_name, "w")
for sample in test_dataloader.dataset.dataset.columns:
test_textfile.write(sample + "\n")
test_textfile.close()
# Find all the cell names for the training data
train_sample_names = []
for idx in train_dataloader.sampler.indices:
train_sample_names.append(train_dataloader.dataset.dataset.columns[idx])
# Save training data cells to text file
train_textfile = open(train_file_name, "w")
for sample in train_sample_names:
train_textfile.write(sample + "\n")
train_textfile.close()
# Find all the cell names for the validation data
validation_sample_names = []
for idx in validation_dataloader.sampler.indices:
validation_sample_names.append(validation_dataloader.dataset.dataset.columns[idx])
val_textfile = open(val_file_name, "w")
for sample in validation_sample_names:
val_textfile.write(sample + "\n")
val_textfile.close()
# Download the files
files.download(best_model)
files.download(train_file_name)
files.download(val_file_name)
files.download(test_file_name)
Test the trained autoencoder's performance by returning the average loss on unseen test data.
# Print average test loss
trainer.test(autoencoder, test_dataloaders=test_dataloader)
View graphs of the autoencoder's training and testing performance using TensorBoard.
# Load the autoencoder logger's results
%tensorboard --logdir tb_logs/autoencoder_logger/
Now that the autoencoder has been trained, it is used to produce an encoding that will reduce the dimensionality of the dataset.
print_batches = True
to print each testing batch's input, output and accuracy.print_batches = True
enc = autoencoder.encoder
dec = autoencoder.decoder
encoded_data = []
batch_accuracies = []
for batch in test_dataloader:
if print_batches:
print(colored('Input:', 'cyan', attrs=['bold']))
print(batch)
print(colored('Encoded:', 'magenta', attrs=['bold']))
try:
encoded = enc(batch.cuda() if torch.cuda.is_available() else batch)
except:
encoded = enc(batch)
for sample in encoded:
# Convert to numpy array for clustering
encoded_data.append(sample.cpu().detach().numpy())
if print_batches:
print(encoded)
print(colored('Decoded:', 'green', attrs=['bold']))
decoded = dec(encoded)
decoded[decoded < 0] = 0 # Set all negatives to 0
decoded = torch.round(decoded) # Round all floats
if print_batches:
print(decoded)
# Calculate accuracy between the original gene counts and decoded output for each sample
total_batch_accuracy = 0
for i in range(len(batch)):
total_batch_accuracy += accuracy_score(batch[i].detach().numpy(), decoded[i].detach().numpy())
batch_accuracy = total_batch_accuracy / len(batch)
batch_accuracies.append(batch_accuracy)
if print_batches:
print(colored("Batch average accuracy: ", 'blue', attrs=['bold']) + str(round(batch_accuracy * 100, 2)) + "%" + "\n")
print(colored("Testing average accuracy: ", 'red', attrs=['bold']) + str(round(sum(batch_accuracies)/len(batch_accuracies) * 100, 2)) + "%")
Input: tensor([[0., 0., 1., ..., 1., 0., 0.], [1., 0., 4., ..., 0., 0., 0.], [1., 0., 4., ..., 1., 0., 2.], [0., 2., 1., ..., 0., 0., 3.]]) Encoded: tensor([[-3.0994e+01, -3.2482e+01, -1.4423e+01, 4.9324e+01, -1.3693e+02, 3.9318e+01, -4.8515e+01, -9.8878e+01, -1.5166e+01, -5.6354e+01, 6.8232e+01, 5.9742e+01, -1.3862e+01, 2.3978e+02, -9.0629e+00, -1.0410e+00], [ 3.7525e-01, 2.7726e+01, -1.5853e+01, 8.0319e+01, -2.7313e+02, -8.9801e+01, 3.3336e+01, -1.2845e+02, 2.7614e+01, -6.9375e+01, 1.0661e+02, 3.7762e+00, -8.7107e+00, 3.5847e+02, 1.9205e+01, -1.8302e+01], [-2.1119e+01, -1.4323e+02, 1.7389e+01, 1.7998e+01, -2.5508e+02, 3.5625e+01, -7.4107e+01, -1.5386e+02, -1.8553e+01, -1.0496e+02, 1.0653e+02, 2.6962e+01, 3.4770e+01, 3.9025e+02, -9.5394e-02, -4.5079e+01], [-1.0034e+02, 6.3732e+01, -4.5656e+01, -6.1508e+01, -1.5451e+02, 8.3894e+01, -1.6131e+01, -2.7142e+02, -3.3708e+01, 1.6080e+02, 2.7337e+01, 6.7642e+01, -5.7335e+01, 3.3015e+02, 6.2503e+01, 4.0152e+01]]) Decoded: tensor([[0., 0., 1., ..., 0., 1., 0.], [0., 0., 2., ..., 0., 0., 1.], [0., 0., 2., ..., 0., 0., 1.], [0., 1., 4., ..., 0., 0., 1.]]) Batch average accuracy: 36.78% Input: tensor([[1., 0., 2., ..., 1., 0., 0.], [0., 2., 4., ..., 0., 0., 2.], [2., 2., 2., ..., 0., 0., 6.], [1., 0., 3., ..., 0., 0., 0.]]) Encoded: tensor([[ 9.7791, -7.6208, -1.1892, 63.5273, -151.0350, -25.6433, 27.5999, -88.7464, 11.2517, -33.8007, 39.3804, 38.8026, -0.9239, 223.1910, 30.7383, 11.3617], [ -39.5862, -33.0210, -47.8409, 139.9470, -178.2782, 48.3359, -13.9197, -131.3504, 25.0904, 97.0889, 84.2029, -27.0123, 32.8361, 355.3621, 120.7151, 70.7619], [ -68.1642, -29.0632, -55.6574, 67.4719, -145.0308, 38.7798, -26.6443, -144.9100, 35.3830, 109.7968, 66.6640, -13.2037, 18.5094, 311.8013, 86.1862, 57.4970], [ -35.7887, 92.7150, 35.2633, -159.5737, -103.8245, -9.0223, -31.2519, 25.2292, 85.6539, -7.9943, -7.0383, 14.7428, 47.8822, 439.4055, 41.0245, -39.2605]]) Decoded: tensor([[0., 0., 1., ..., 0., 0., 1.], [0., 2., 2., ..., 0., 0., 2.], [0., 2., 3., ..., 0., 0., 2.], [0., 1., 1., ..., 0., 0., 0.]]) Batch average accuracy: 39.75% Input: tensor([[0., 0., 0., ..., 0., 0., 0.], [2., 0., 0., ..., 0., 0., 0.], [0., 0., 1., ..., 0., 0., 0.], [0., 2., 2., ..., 0., 1., 1.]]) Encoded: tensor([[ 8.9053e+00, 8.3170e+01, 8.2415e+01, -4.4482e+01, -1.0325e+02, 7.7243e+00, -4.9876e+01, 9.2439e+00, 2.0837e+01, -1.6196e+01, -5.0828e+01, -4.5349e+01, 8.6863e+01, 3.5154e+02, 4.9343e+00, 3.5649e+01], [-1.2057e+01, 2.7214e+01, -1.1884e+01, 3.8449e+01, -1.1911e+02, -3.3671e+01, -1.1157e+01, -8.5212e+01, 2.0503e+01, -2.1842e+01, 7.0391e+01, 4.4281e+01, -1.2317e+01, 1.9183e+02, 7.2989e+00, -3.3741e+00], [-5.8283e-02, -2.5054e+01, 1.5481e+01, 2.9118e+01, -1.1727e+02, 3.6777e+01, -6.6018e+01, -7.0021e+01, -2.2936e+01, -3.7551e+01, 6.0384e+01, 9.9904e+00, 2.9804e+00, 1.8527e+02, -2.3773e+01, -3.3716e+01], [-1.0066e+02, 6.1030e+01, -8.8662e+01, 7.4248e+00, -1.1248e+02, 1.6228e+01, 4.0040e+01, -1.9785e+02, -4.9380e+01, 1.1928e+02, 5.9847e+01, -7.8613e+00, -7.8185e+01, 2.9815e+02, 7.5267e+00, 1.0144e+01]]) Decoded: tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 1., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 1., 0.], [0., 1., 2., ..., 0., 0., 1.]]) Batch average accuracy: 43.47% Output omitted... Input: tensor([[0., 0., 3., ..., 0., 0., 1.], [1., 0., 1., ..., 0., 0., 1.], [0., 0., 0., ..., 0., 0., 0.], [1., 0., 2., ..., 0., 0., 1.]]) Encoded: tensor([[-119.2607, -52.6912, -61.5833, 34.5369, -283.4898, 65.4566, -127.4109, -263.4840, -76.5547, -32.7067, 137.5651, 71.9857, -7.7636, 505.8143, -14.4719, 11.6506], [ -28.4978, 25.1791, -9.7694, 54.7468, -232.3482, -88.8151, 9.2038, -146.1536, 23.2263, -47.2812, 86.8774, 53.8991, -30.6187, 331.8239, 37.8359, -43.8041], [ 19.9082, 66.6941, -17.3319, -55.1068, -56.3407, -13.1293, -55.4012, -34.2282, 68.8391, -19.3155, -7.1700, -33.4250, 51.1531, 448.4697, 6.1075, -9.5957], [ -21.2346, -105.8277, 30.8435, 24.9150, -277.0805, 47.6661, -35.6607, -196.0434, -77.7742, -59.7684, 27.6061, 52.9808, 26.9755, 377.2787, 3.0401, -81.2074]]) Decoded: tensor([[0., 1., 3., ..., 0., 1., 1.], [0., 0., 2., ..., 0., 0., 1.], [0., 0., 0., ..., 1., 0., 0.], [1., 0., 2., ..., 0., 0., 1.]]) Batch average accuracy: 35.01% Input: tensor([[1., 0., 0., ..., 0., 0., 0.], [0., 1., 0., ..., 0., 0., 0.], [1., 1., 2., ..., 0., 0., 2.], [0., 0., 0., ..., 0., 0., 1.]]) Encoded: tensor([[ 40.6085, 39.9126, -29.5988, -18.3418, -61.2797, -18.7595, -30.8322, -26.2547, 56.1965, -32.8186, 8.3026, -31.4386, 13.5161, 258.4942, 23.2640, -2.7048], [-123.1021, 26.0830, -92.5752, 138.6674, -97.3086, -83.5439, 28.9378, -82.0378, -77.3154, 96.5343, 100.5889, -120.1304, -5.7743, 252.8478, -22.6112, 18.3558], [ 17.6993, -97.9791, 28.9530, 50.1693, -161.5858, 67.3367, -66.0552, -61.9287, -58.1357, -64.7236, 92.1746, 34.8906, -4.6904, 278.9036, -8.4625, -39.4232], [ 6.5479, 20.1505, -8.7921, 53.3130, -198.2090, -51.1953, 55.4308, -132.7181, -11.1066, -45.9261, 47.8415, 45.4972, -27.8748, 280.4395, 47.6055, -4.0002]]) Decoded: tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 1., 0., ..., 0., 0., 2.], [0., 0., 0., ..., 0., 1., 0.], [0., 0., 2., ..., 0., 0., 1.]]) Batch average accuracy: 42.26% Input: tensor([[0., 0., 1., ..., 0., 0., 0.], [0., 0., 4., ..., 0., 0., 3.]]) Encoded: tensor([[ -70.1456, -0.5271, 2.2225, -61.6991, -70.3221, -1.5425, -26.1585, -188.2877, 34.5224, 38.2465, 80.6506, 63.2564, -22.8477, 233.7037, -105.3474, -47.9143], [ -38.9478, 134.2433, -22.5568, -42.3777, -271.6851, -85.6587, 51.5674, -378.4795, -20.4906, -1.7884, 17.0424, 145.6059, -89.4490, 449.1552, 24.4117, -81.9819]]) Decoded: tensor([[0., 0., 1., ..., 0., 0., 0.], [1., 0., 4., ..., 1., 0., 1.]]) Batch average accuracy: 37.25% Testing average accuracy: 39.95%
Custom made functions used during clustering experiments.
Functions elbow_method
and silhouette_coefficient
.
def elbow_method(max_k, features, **k_means_kwargs):
""" Elbow method to find the recommended number of clusters for k-means """
sse = []
# Test different numbers of clusters up to max_k
for k in range(1, max_k+1):
k_means = KMeans(n_clusters=k, **k_means_kwards)
k_means.fit(features)
# Record the Sum of Squared Error (SSE)
sse.append(k_means.inertia_)
# Find the elbow point
knee_locator = KneeLocator(range(1, max_k+1), sse, curve="convex", direction="decreasing")
best_k = knee_locator.elbow
# Plot the graph
fig = plot.figure()
plot.plot(range(1, max_k+1), sse, "orange")
plot.plot(best_k, sse[best_k-1], marker="*", markersize=12, color="blueviolet")
plot.title("K-Means Elbow Method")
plot.xlabel('Number of Clusters')
plot.ylabel('SSE')
plot.show()
return sse, best_k
def silhouette_coefficient(max_k, features, **k_means_kwargs):
""" Silhouette coefficient to find the recommended number of clusters for k-means """
silhouette_coefficients = []
for k in range(2, max_k+1):
k_means = KMeans(n_clusters=k, **k_means_kwards)
k_means.fit(features)
silhouette_coefficients.append(silhouette_score(features, k_means.labels_))
max_value = max(silhouette_coefficients)
best_k = silhouette_coefficients.index(max_value) + 2
fig = plot.figure()
plot.plot(range(2, max_k+1), silhouette_coefficients, "deepskyblue")
plot.plot(best_k, max_value, marker="*", markersize=12, color="blueviolet")
plot.title("K-Means Silhouette Coefficient")
plot.xlabel('Number of Clusters')
plot.ylabel('Silhouette Coefficient')
plot.show()
return silhouette_coefficients, best_k
Function bicluster
.
def bicluster(data, sample_names, features_names, sample_labels=None, cluster_sizes=(6,6),
display_graphs=True, using_genes=False):
# Perform biclustering
spectral_biclustering = SpectralBiclustering(n_clusters=cluster_sizes, method='log').fit(data)
try:
biclustered_data = data[np.argsort(spectral_biclustering.row_labels_)][:, np.argsort(spectral_biclustering.column_labels_)]
except:
# Convert data to correct format
data = np.array([list(array) for array in selected_data])
biclustered_data = data[np.argsort(spectral_biclustering.row_labels_)][:, np.argsort(spectral_biclustering.column_labels_)]
# Identify the biclusters
biclustered_regions = np.outer(np.sort(spectral_biclustering.row_labels_) + 1, np.sort(spectral_biclustering.column_labels_) + 1)
# Update the order of cell names, their labels and genes/features as biclustering rearranges the matrix order
updated_names = np.array(sample_names)[np.argsort(spectral_biclustering.row_labels_)]
updated_labels = list(np.array(sample_labels)[np.argsort(spectral_biclustering.row_labels_)]) if sample_labels else None
updated_features = np.array(features_names)[np.argsort(spectral_biclustering.column_labels_)]
if display_graphs:
# Plot heatmaps of original data and biclusters
fig, ax = plot.subplots(2, 2, figsize=(22, 18))
plot.subplots_adjust(wspace=0.3, hspace=0.4)
sns.heatmap(pandas.DataFrame(data, index=sample_names, columns=features_names), ax=ax[0, 0], cmap="magma")
sns.heatmap(pandas.DataFrame(biclustered_data, index=updated_names, columns=updated_features), ax=ax[1, 0], cmap="magma")
sns.heatmap(pandas.DataFrame(biclustered_regions, index=updated_names, columns=updated_features), ax=ax[1, 1], cmap="magma")
if using_genes:
ax[0, 0].set_title('Gene Expression of Cells', fontsize = 17)
else:
ax[0, 0].set_title('Feature Expression of Cells', fontsize = 17)
ax[1, 0].set_title('Biclustered', fontsize = 17)
ax[1, 1].set_title('Bicluster Regions', fontsize = 17)
for axis in ax.flatten():
axis.set_xlabel('Genes', fontsize = 15) if using_genes else axis.set_xlabel('Features', fontsize = 15)
axis.set_ylabel('Cells', fontsize = 15)
# Hide second figure as not plotted
ax[0, 1].axis('off')
plot.show()
return biclustered_data, updated_names, updated_labels
Function plot_cells
to plot either a 2D or 3D graph of all the cells and the cells selected for testing.
def plot_cells(data, sample_names, sample_labels,
test_sample_names=None, test_sample_labels=None,
reduction="PCA", components=2, standardize=True,
apply_tsne=False, tsne_components=2, tsne_perplexity=30,
graph_title="All Cells and Test Cells", graph_type="2D", graph_text="", test_graph_text="",
colours=None, test_colours=None, save_to_file=False, file_name="cluster_groups.html"):
# Prevent error if not using enough components
if components < 3 and graph_type.lower() == "3d":
components = 3
reducer, data = reduce_data(reduction, components, data, apply_tsne=apply_tsne,
tsne_components=tsne_components, standardize=standardize,
graph_type=graph_type)
if test_sample_labels:
# Find all the testing samples in the processed data
test_indexes = [np.where(sample_names == sample) for sample in test_sample_names]
testing_data = np.array([data[idx[0][0]] for idx in test_indexes])
# Figure parmeters
figure_kwargs = {'paper_bgcolor': 'rgba(0,0,0,0)', 'plot_bgcolor': 'rgba(0,0,0,0)', 'height': 600, 'width': 800}
trace_kwargs = {'mode': "markers"}
# Set graph colours
colorscale = None
test_colorscale = None
if colours != None:
try:
# Set custom colours of true clusters
if sample_labels != None:
colorscale = list(colours.values())[0:len(np.unique(sample_labels))]
else:
colorscale = list(colours.values())[0:2]
except:
print("ERROR: Colours for the cells must be hex\nFor example: {0: '#E52592', 1: '#84E51F', 2: '#12B5CB'}")
if test_colours != None:
try:
# Set custom colours of true clusters
if test_cell_labels != None:
test_colorscale = list(test_colours.values())[0:len(np.unique(test_cell_labels))]
else:
test_colorscale = list(test_colours.values())[0:2]
except:
print("ERROR: Colours for the test cells must be hex\nFor example: {0: '#E52592', 1: '#84E51F', 2: '#12B5CB'}")
# Check metadata has been set
if True:
fig = go.Figure(layout=Layout(title=graph_title, **figure_kwargs))
# Add graph of actual clusters to plot
if graph_type.lower() != "3d":
marker_size = 7
fig.add_trace(go.Scatter(x=data[:, 0], y=data[:, 1], opacity=0.65, marker_symbol="diamond-open",
**trace_kwargs, name="All Cells", text=graph_text,
marker=dict(color=sample_labels, colorscale=colorscale,
size=marker_size)))
if isinstance(testing_data, np.ndarray):
fig.add_trace(go.Scatter(x=testing_data[:, 0], y=testing_data[:, 1],
**trace_kwargs, name="Test Cells", visible="legendonly",
text=test_graph_text,
marker=dict(color=test_sample_labels, colorscale=test_colorscale,
size=marker_size)))
else:
marker_size = 5
fig.add_trace(go.Scatter3d(x=data[:, 0], y=data[:, 1], z=data[:, 2], opacity=0.55, marker_symbol="diamond-open",
**trace_kwargs, name="All Cells", text=graph_text,
marker=dict(color=sample_labels, colorscale=colorscale, size=marker_size)))
if isinstance(testing_data, np.ndarray):
fig.add_trace(go.Scatter3d(x=testing_data[:, 0], y=testing_data[:, 1], z=testing_data[:, 2],
**trace_kwargs, name="Test Cells", visible="legendonly", text=test_graph_text,
marker=dict(color=test_sample_labels, colorscale=test_colorscale, size=marker_size)))
else:
print("ERROR: Labels not set")
# Display graph
axis_names = {PCA: 'PC', FastICA: 'IC', NMF: 'NMF Basis Component', TSNE: 't-SNE'}
fig.update_xaxes(showline=True, linecolor='black', mirror=True, title=str(axis_names[type(reducer)]) + " 0")
fig.update_yaxes(showline=True, linecolor='black', mirror=True, title=str(axis_names[type(reducer)]) + " 1")
if save_to_file:
# Save as HTML file
fig.write_html(file_name)
# Check if using Jupyter notebook
if get_ipython().__class__.__name__ == 'ZMQInteractiveShell':
pyo.iplot(fig)
# Otherwise assume using Colab
else:
fig.show(renderer="colab")
Functions cluster_accuracy
, cluster_metrics
, reduce_data
and plot_clusters
.
cluster_accuracy
will match the expected labels and predicted clusters using the Hungarian matching algorithm, then calculate accuracy between them. This will be returned, along with the mapping between the labels and clusters. cluster_metrics
will return this accuracy, along with ARI and difference between labels and clusters.
def cluster_accuracy(actual_clusters, pred_clusters):
"""
Calculate the accuracy and the best match to the true labels of clusters
Input: actual and predicted cluster labels
Return: highest accuracy and mapping between equivalent labels
"""
n_clusters = max(actual_clusters) + 1
n_samples = len(pred_clusters)
# Construct a bipartite graph (represented by an adjacency matrix)
bipartite = np.zeros((n_clusters, n_clusters))
for i in range(n_samples):
# Add 1 for every intersection between rows and columns
bipartite[pred_clusters[i], actual_clusters[i]] += 1
# Find the best mapping between cluster numbers and label numbers using the Hungarian matching algorithm
label_map = linear_sum_assignment(bipartite.max() - bipartite)
label_map = np.transpose(np.asarray(label_map))
accuracy = sum([bipartite[i, j] for i, j in label_map]) / n_samples
label_dict = {label[0]: label[1:][0] for label in label_map}
return accuracy, label_dict
def cluster_metrics(actual_labels, predicted_labels):
"""
Calculate accuracy, ARI and label difference between ideal and actual clusters
Input: labels of the actual and predicted clusters
Return: accuracy, ARI and binary list of correct/incorrect clusters
"""
accuracy, label_dict = cluster_accuracy(actual_labels, predicted_labels)
ARI = adjusted_rand_score(actual_labels, predicted_labels)
# Turn label map into a dictionary
# Map the predicted labels so their representation now match the actual labels
updated_labels = list(map(label_dict.get, predicted_labels))
label_difference = []
for i in range(len(updated_labels)):
# Add 0 if actual and predicted labels match
if updated_labels[i] == actual_labels[i]:
label_difference.append(0)
# Add 1 if labels do not match
else:
label_difference.append(1)
return accuracy, ARI, label_difference
reduce_data
will perform dimensionality reduction, optionally apply t-SNE and standardization.
def reduce_data(reduction, components, data, standardize=True, apply_tsne=False,
tsne_components=2, tsne_perplexity=30, graph_type="2D"):
if reduction:
# Independent Component Analysis (ICA)
if reduction[0].lower() == "i":
reducer_algorithm = FastICA(n_components=components)
# Non-Negative Matrix Factorization (NMF)
elif reduction[0].lower() == "n":
reducer_algorithm = NMF(n_components=components)
# Find the lowest value
smallest_value = min([min(x) for x in data])
# If negative values present, make all values in data at least 0
if smallest_value < 0:
data = [[x - smallest_value for x in sample] for sample in data]
else:
# Default to PCA
reducer_algorithm = PCA(n_components=components)
# Reduce dimensions of the data to top components
data = reducer_algorithm.fit_transform(data)
else:
# Otherwise use PCA for plotting in 2D / 3D but do not change the data
reducer_algorithm = PCA(n_components=components).fit(data)
# If less components were given than features of the data, select the first n features
if components < len(data[0]):
data = [sample[:components] for sample in data]
# Optionally apply t-SNE
if apply_tsne:
if graph_type.lower() != "3d":
if tsne_components < 2 or tsne_components > 3:
# Default as 2 to prevent errors
tsne_components = 2
tsne = TSNE(n_components=tsne_components, perplexity=tsne_perplexity)
else:
# Use three for 3D graphs to prevent errors
tsne = TSNE(n_components=3)
# Use t-SNE
data = tsne.fit_transform(data)
# Scale the features
if standardize:
data = scaler.fit_transform(data)
return reducer_algorithm, data
plot_clusters
performs dimensionality reduction, applies clustering, then plots either a 2D or 3D scatter graph of the clusters.
def plot_clusters(data, sample_names, clusterer, reduction="PCA", components=2,
labels=None, apply_tsne=False, tsne_components=2, tsne_perplexity=30,
standardize=True, graph_title="", graph_type="2D", graph_text="",
colours=None, correct_colours=None, wrong_colours=None,
save_to_file=False, file_name="cluster_graph.html"):
""" Plot a 2D or 3D graph of clusters
Parameters:
data: encoded or unencoded samples
sample_names: the names, e.g. [cell_01, cell_02, ...]
reduction: reducuction technique to reduce dimensions, can be either PCA, ICA or NMF
clusterer: the clustering algorithm such as KMeans, AgglomerativeClustering
or GaussianMixture
components: number of components to use for the clustering algorithm
apply_tsne: whether to use t-SNE in addition to another reduction technique,
such as PCA, before clustering
tsne_components: components if using t-SNE, this can be either 2 or 3
standardize: whether to apply standardization before clustering
graph_title: title to display above graph
graph_type: 2D or 3D
graph_text: custom text displayed on graph when cursor hovering over sample
labels: true labels of the clusters
colours: dictionary mapping label indexes to hex colours for predicted clusters
correct_colours: dictionary mapping label indexes to hex colours for true
clusters
wrong_colours: dictionary mapping 0 for false predicted colour in hex and 1
for correct
save_to_file: set as True to save the graph as an HTML file
file_name: the name of the HTML file to save to
"""
# Prevent error if not using enough components
if components < 3 and graph_type.lower() == "3d":
components = 3
reducer, data = reduce_data(reduction, components, data, apply_tsne=apply_tsne,
tsne_components=tsne_components, tsne_perplexity=tsne_perplexity,
standardize=standardize, graph_type=graph_type)
# Apply clustering
try:
clusterer.fit(data)
clusterer.fit_predict(data)
if isinstance(clusterer, GaussianMixture):
predicted_labels = clusterer.predict(data)
else:
predicted_labels = clusterer.labels_
except:
# Throw error if clustering algorithm is unexpected
raise Exception("Only sci-kit learn clustering algorithms are supported!\nFor example, set clusterer=KMeans(n_clusters=5)")
# Figure parmeters
figure_kwargs = {'paper_bgcolor': 'rgba(0,0,0,0)', 'plot_bgcolor': 'rgba(0,0,0,0)', 'height': 600, 'width': 800}
trace_kwargs = {'mode': "markers", 'text': graph_text if graph_text != "" else sample_names}
# Set graph colours
colorscale = None
correct_colorscale = None
wrong_colorscale = list({0: "#b3b3b3", 1: "#ff0000"}.values())[0:cluster_number] # Red and grey default
if colours != None:
try:
# Set custom colours for predicted clusters
colorscale = list(colours.values())[0:cluster_number]
except:
print("ERROR: Predicted colours must be hex\nFor example: {0: '#E52592', 1: '#84E51F', 2: '#12B5CB'}")
if correct_colours != None:
try:
# Set custom colours of true clusters
correct_colorscale = list(correct_colours.values())[0:cluster_number]
except:
print("ERROR: Colours for true clusters must be hex\nFor example: {0: '#E52592', 1: '#84E51F', 2: '#12B5CB'}")
if wrong_colours != None:
try:
# Set custom colours of incorrectly labelled clusters
wrong_colorscale = list(wrong_colours.values())[0:cluster_number]
except:
print("ERROR: Colours for incorrect clusters must be hex\nFor example: {0: '#E52592', 1: '#84E51F'}")
# Create graph of predicted clusters
if graph_title == "":
clusterer_names = {KMeans: 'K-Means', AgglomerativeClustering: 'Hierarchical',
Birch: 'Birch', MiniBatchKMeans: 'Mini-Batch K-Means',
SpectralClustering: 'Spectral',
GaussianMixture: 'Gaussian Mixture'}
reducer_names = {PCA: 'PCA', FastICA: 'ICA', NMF: 'NMF', TSNE: 't-SNE'}
if reduction:
graph_title = str(clusterer_names[type(clusterer)]) + " Clustering on " + str(components) + " Components Using " + str(reducer_names[type(reducer)])
else:
graph_title = str(clusterer_names[type(clusterer)]) + " Clustering"
fig = go.Figure(layout=Layout(title=graph_title, **figure_kwargs))
if graph_type.lower() != "3d": # Create 2D graph
marker_size = 7
fig.add_trace(go.Scatter(x=data[:, 0], y=data[:, 1], **trace_kwargs, name="Predicted",
marker=dict(color=predicted_labels, colorscale=colorscale, size=marker_size)))
else: # Create 3D graph
marker_size = 5
fig.add_trace(go.Scatter3d(x=data[:, 0], y=data[:, 1], z=data[:, 2],
**trace_kwargs, name="Predicted",
marker=dict(color=predicted_labels, colorscale=colorscale, size=marker_size)))
# Check metadata has been set
if labels != None:
# Add graph of actual clusters to plot
if graph_type.lower() != "3d":
fig.add_trace(go.Scatter(x=data[:, 0], y=data[:, 1],
**trace_kwargs, name="Actual", visible="legendonly",
marker=dict(color=labels, colorscale=correct_colorscale,
size=marker_size)))
else:
fig.add_trace(go.Scatter3d(x=data[:, 0], y=data[:, 1], z=data[:, 2],
**trace_kwargs, name="Actual", visible="legendonly",
marker=dict(color=labels, colorscale=correct_colorscale, size=marker_size)))
# Calculate metrics
accuracy, ARI, label_difference = cluster_metrics(labels, predicted_labels)
# Set marker colours for correct / incorrect clustered cells
if accuracy == 1:
marker_colour = wrong_colorscale[0]
else:
marker_colour = label_difference
# Add graph showing difference between actual and predicted clusters
if graph_type.lower() != "3d":
fig.add_trace(go.Scatter(x=data[:, 0], y=data[:, 1],
**trace_kwargs, name="Difference", visible="legendonly",
marker=dict(color=marker_colour, colorscale=wrong_colorscale, size=marker_size)))
else:
fig.add_trace(go.Scatter3d(x=data[:, 0], y=data[:, 1], z=data[:, 2],
**trace_kwargs, name="Difference", visible="legendonly",
marker=dict(color=marker_colour, colorscale=wrong_colorscale, size=5)))
# Display graph
axis_names = {PCA: 'PC', FastICA: 'IC', NMF: 'NMF Basis Component', TSNE: 't-SNE'}
fig.update_xaxes(showline=True, linecolor='black', mirror=True, title=str(axis_names[type(reducer)]) + " 0")
fig.update_yaxes(showline=True, linecolor='black', mirror=True, title=str(axis_names[type(reducer)]) + " 1")
if save_to_file:
# Save as HTML file
fig.write_html(file_name)
# Check if using Jupyter notebook
if get_ipython().__class__.__name__ == 'ZMQInteractiveShell':
pyo.iplot(fig)
# Otherwise assume using Colab
else:
fig.show(renderer="colab")
# Print metrics
if labels != None:
print(colored("Accuracy: ", 'magenta', attrs=['bold']) + str(np.round(100 * accuracy, 3)) + "%")
print(colored("Adjusted Rand Index: ", 'cyan', attrs=['bold']) + str(np.round(ARI, 3)))
silhouette = silhouette_score(data, predicted_labels)
print(colored("Silhouette Coefficient: ", 'green', attrs=['bold']) + str(np.round(silhouette, 3)))
Functions find_best_components
, plot_components
and plot_variance
.
def find_best_components(max_components, data, labels, clusterer, reduction="PCA",
iterations=3, measure="accuracy", apply_tsne=False, tsne_perplexity=30):
"""
Find the best number of principal components (n) based on a given measure
Parameters:
max_components: the maximum number to check
data: samples and their features
labels: list of labels for corresponding samples
clusterer: the clustering model, e.g. KMeans
reduction: reduce dimensions using either PCA, ICA, NMF or t-SNE
iterations: number of times to run clustering for each component estimate, after which the mean is calculated
measure: either accuracy, ARI or silhouette coefficient
Return:
best_n: the best number of principal components
component_scores: average score for each n
varience:
components:
"""
if max_components > 150:
# PCA has a maximum of 150
max_components = 150
if reduction[0].lower() == "t" and max_components > 4:
# t-SNE has a maximum of 3
max_components = 3
# Record average score (accuracy or ARI), for each n
component_scores = []
for n in range(2, max_components + 1):
# Record scores for each iteration
n_scores = []
# Record varience
n_varience = []
for j in range(iterations):
# Independent Component Analysis (ICA)
if reduction[0].lower() == "i":
reducer = FastICA(n_components=n)
# Non-Negative Matrix Factorization (NMF)
elif reduction[0].lower() == "n":
reducer = NMF(n_components=n)
# Find the lowest value
smallest_value = min([min(x) for x in data])
# If negative values present, make all values in data at least 0
if smallest_value < 0:
data = [[x - smallest_value for x in sample] for sample in data]
# t-distributed Stochastic Neighbor Embedding (t-SNE)
elif reduction[0].lower() == "t":
reducer = TSNE(n_components=n, perplexity=tsne_perplexity)
# Principal Component Analysis (PCA)
else:
reducer = PCA(n_components=n)
# Perform reduction to find the n components
components = reducer.fit_transform(data)
# Optionally apply t-SNE before clustering
if apply_tsne:
tsne = TSNE(n_components=2)
# Use t-SNE
components = tsne.fit_transform(components)
scaled_components = scaler.fit_transform(components)
clusterer.fit(scaled_components)
# Record accuracy or ARI
if isinstance(clusterer, GaussianMixture):
predicted_labels = clusterer.predict(scaled_components)
else:
predicted_labels = clusterer.labels_
if measure.lower()[0] != "s":
accuracy, ARI, label_difference = cluster_metrics(labels, predicted_labels)
if measure.lower() == "ari":
n_scores.append(ARI)
else:
n_scores.append(accuracy)
else:
n_scores.append(silhouette_score(scaled_components, predicted_labels))
component_scores.append(np.mean(n_scores))
# Find number of components with highest score
best_n = np.argmax(component_scores) + 2
results = {'best n': best_n, 'scores': component_scores}
if isinstance(reducer, PCA):
# Get the varience contributions from each principal component
results['variance'] = reducer.explained_variance_ratio_
if not isinstance(reducer, TSNE):
# Return the components if using PCA, ICA or NMF
results['components'] = reducer.components_
return results
def plot_components(component_scores, measure="accuracy"):
""" Plot graph of accuracy / ARI / silhouette coefficient over number of components """
fig = plot.figure()
plot.plot([i for i in range(2, len(component_scores)+2)], component_scores, "lime")
best_n_index = np.argmax(component_scores)
plot.plot(best_n_index + 2, component_scores[best_n_index], marker="*", markersize=12, color="dodgerblue")
plot.title("Best Number of Components")
plot.xlabel('Number of Components')
if measure.lower() == "ari":
plot.ylabel('Adjusted Rand Index')
elif measure.lower()[0] == "s":
plot.ylabel('Silhouette Coefficient')
else:
plot.ylabel('Accuracy')
plot.xticks(np.arange(2,len(component_scores)+2,1))
plot.show()
def plot_variance(variance):
""" Plot variance contribution of each principal component """
fig, (ax1, ax2) = plot.subplots(1, 2, figsize=(12,4))
fig.suptitle("Contribution of Principal Component to Variance")
# Plot each component against the amount of variance it contributes
ax1.plot([i for i in range(0, len(variance))], variance, "slateblue")
ax1.set(xlabel='Principal Component', ylabel='Variance Ratio')
ax1.set_xticks(np.arange(0,len(variance),1))
# Plot the cumulative variance
ax2.plot([i for i in range(0, len(variance))], np.cumsum(variance), "mediumorchid")
ax2.set(xlabel='Principal Component', ylabel='Cumulative Variance Ratio')
ax2.set_xticks(np.arange(0,len(variance),1))
plot.show()
Read the original and encoded data, then create a pre-standardization versions to be used during experiments. This provides four different variations to test during the following experiments.
# Get cell names
test_cell_names = test_dataloader.dataset.dataset.columns
# Get gene names
gene_names = test_dataloader.dataset.dataset.index.values
# Get unencoded data
unstandardized_unencoded = np.transpose(test_dataloader.dataset.dataset.values)
# Scale features so that they have a mean of 0 and standard deviation 1
scaler = StandardScaler()
standardized_encoded = scaler.fit_transform(encoded_data)
standardized_unencoded = scaler.fit_transform(unstandardized_unencoded)
# Convert encoded values to dataframe
encoded_dataframe = pandas.DataFrame(np.transpose(standardized_encoded), columns=test_cell_names, index=["Feature_" + str(i) for i in range(0, len(standardized_encoded[0]))])
print(colored("Encoded Data:\n", attrs=['bold']))
encoded_dataframe.reindex(sorted(encoded_dataframe.columns), axis=1)
Encoded Data:
CELL_000010 | CELL_000020 | CELL_000025 | CELL_000027 | CELL_000030 | CELL_000032 | CELL_000033 | CELL_000060 | CELL_000061 | CELL_000071 | ... | CELL_000850 | CELL_000852 | CELL_000854 | CELL_000871 | CELL_000875 | CELL_000879 | CELL_000887 | CELL_000897 | CELL_000900 | CELL_000918 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Feature_0 | 0.234398 | 0.812075 | -1.667359 | 1.722621 | 1.300376 | 0.782870 | 1.243788 | 0.291369 | 0.608002 | 0.714236 | ... | 0.507372 | 0.436038 | 0.034836 | -0.730599 | -0.345844 | -0.486102 | 0.097766 | -0.194457 | -0.471935 | 0.462136 |
Feature_1 | 2.740618 | 2.563481 | -1.199272 | 1.463323 | 1.783733 | 1.355311 | 0.801223 | 1.603983 | -0.591048 | 0.309632 | ... | -0.683954 | 0.030278 | 0.902614 | 0.417851 | 0.346201 | -0.705821 | -0.317017 | -0.375020 | -0.511888 | 0.485226 |
Feature_2 | -1.423052 | -0.844285 | -1.220447 | 1.508093 | 2.178617 | 2.178892 | 0.946271 | 1.996534 | 0.610028 | 1.640319 | ... | 0.311965 | -0.100322 | -0.234494 | -0.690127 | -0.878440 | 0.374460 | -0.748471 | -0.499533 | -0.161664 | 0.179707 |
Feature_3 | -3.182918 | -1.565403 | 0.347536 | -2.309087 | -2.190185 | -2.364824 | -1.521195 | -1.030488 | -0.028608 | -1.579000 | ... | 0.272805 | 1.004740 | 0.670810 | -0.584705 | 0.318752 | -0.053887 | 0.561160 | 0.075975 | 0.325324 | -1.058655 |
Feature_4 | 0.357589 | 0.735483 | -2.464719 | -0.982008 | -1.412255 | -1.090757 | -1.123284 | -0.116053 | -0.964981 | -0.121380 | ... | 0.137528 | 0.133983 | 0.281143 | 1.171156 | 0.788480 | 1.036110 | 0.980019 | 1.166814 | 0.980964 | 1.702771 |
Feature_5 | -0.437209 | -0.588417 | 1.266392 | -0.719493 | 0.621131 | 0.467957 | 0.489806 | 0.771173 | -1.859422 | -0.392233 | ... | 0.683937 | -0.768002 | -1.180558 | 0.548222 | -0.269256 | 0.222907 | 0.693668 | 0.577967 | 0.497945 | -0.565712 |
Feature_6 | -2.570539 | -0.948594 | -1.999611 | 0.190231 | -0.743908 | -0.537724 | -0.419772 | -0.829934 | 1.427763 | -0.727194 | ... | -0.876794 | 0.310663 | 0.315140 | 0.944080 | 0.309068 | 0.365828 | 0.186795 | 0.829618 | 0.739259 | 0.614755 |
Feature_7 | -0.584635 | 0.843677 | -1.908625 | -1.020752 | 1.094430 | 1.684888 | 0.163533 | 1.132958 | -1.326178 | 0.798913 | ... | 0.284485 | 0.227868 | -0.384249 | -0.401133 | -0.125269 | -0.257755 | 0.333989 | 0.347562 | 0.500798 | 1.065853 |
Feature_8 | 0.446670 | 1.116654 | -1.749891 | -1.343971 | -0.363419 | 0.505157 | -1.098744 | 0.691356 | -0.310842 | 2.217893 | ... | -0.663388 | -0.076282 | 0.610590 | 0.504271 | -0.746510 | -0.038530 | 0.028291 | 0.396468 | 0.390674 | 0.195753 |
Feature_9 | -0.815687 | -0.656062 | -0.708197 | -0.163578 | -0.536713 | -0.617265 | -0.705187 | -0.829711 | 1.393768 | -0.291901 | ... | -0.770000 | -0.789063 | -0.192432 | 0.739492 | 0.674174 | -0.068702 | 0.394031 | 0.477576 | 0.449476 | -0.518045 |
Feature_10 | -2.168331 | -1.599092 | 1.501471 | -1.489532 | -1.710723 | -1.964833 | -2.981893 | -1.445425 | 1.433631 | -2.909092 | ... | 0.419420 | 0.258892 | 0.310924 | 0.315592 | 0.396661 | 0.307815 | 0.409331 | 0.421551 | 0.314963 | 0.061838 |
Feature_11 | 0.034021 | 1.067638 | 1.241142 | -0.922199 | -0.225340 | 0.432108 | -0.089165 | -1.380536 | -1.846933 | -0.642088 | ... | 0.152886 | 0.352023 | 0.650518 | 0.399274 | -0.307777 | 0.794110 | 0.009613 | 0.151319 | -0.390006 | -0.224341 |
Feature_12 | 0.109958 | 0.987362 | -0.266863 | -1.824618 | 2.476899 | 0.319647 | 0.392270 | 1.468212 | 0.487338 | 1.610619 | ... | -0.044362 | -0.424481 | -0.379947 | -0.909291 | 0.126146 | -1.518282 | 0.028395 | -0.927453 | -0.806546 | 0.175452 |
Feature_13 | 3.100341 | 3.079734 | 1.597680 | 2.131201 | 2.827287 | 2.695341 | 2.146547 | 2.767436 | 1.667244 | 2.146611 | ... | -1.158078 | -1.090725 | -0.947854 | -1.376621 | -1.180690 | -1.018069 | -1.327804 | -1.391233 | -1.438953 | -1.231664 |
Feature_14 | 1.010896 | 1.369850 | -0.533248 | 1.320653 | 0.425064 | 1.625886 | 2.003527 | -1.310808 | 0.316317 | 0.681729 | ... | -0.702407 | -0.116137 | -0.418362 | -0.424331 | -0.078280 | -2.160142 | 0.302622 | 0.174060 | -0.024346 | 0.120345 |
Feature_15 | -1.007947 | -1.010822 | 0.146622 | -0.575521 | 1.526624 | 0.449822 | 0.709557 | 1.855954 | -0.141362 | -0.478407 | ... | -0.884952 | 0.293472 | -0.647150 | -0.003909 | 0.393356 | -0.894180 | 1.087355 | 0.331456 | 0.553880 | -0.436264 |
16 rows × 150 columns
Get the true labels for the cell lines / groups and convert to numerical representations. This will only run if metadata
has been set.
test_cell_labels = None
cell_graph_text = None
# Check metadata has been set
if isinstance(metadata, pandas.DataFrame):
# Find the true cluster labels for test cells
test_cell_lines = metadata.loc[test_cell_names]
# Create a map between labels and a numeric value
unique_labels = np.unique(test_cell_lines.values)
cluster_label_map = {unique_labels[i]: [j for j in range(len(unique_labels))][i] for i in range(len(unique_labels))}
print(colored('Mapping between cell lines and numerical value:', 'cyan', attrs=['bold']))
print(cluster_label_map)
print("\n")
test_cell_labels = [cluster_label_map.get(key[0]) for key in test_cell_lines.values]
print(colored('Numerical values of target clusters:', 'green', attrs=['bold']))
print(test_cell_labels)
print("\n")
display(test_cell_lines)
# Number of clusters
true_cluster_num = len(unique_labels)
# Create custom text for labelling cells on the graphs so that cell lines/groups are also displayed
cell_graph_text = [str(name) + ", " + str(label[0]) for name, label in zip(test_cell_names, test_cell_lines.values)]
Mapping between cell lines and numerical value: {'H1975': 0, 'H2228': 1, 'HCC827': 2} Numerical values of target clusters: [0, 0, 0, 1, 0, 1, 1, 2, 2, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 2, 1, 0, 0, 2, 1, 2, 1, 0, 2, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 1, 0, 2, 1, 0, 0, 0, 0, 1, 1, 1, 2, 1, 1, 0, 1, 0, 0, 2, 0, 2, 2, 0, 2, 0, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 0, 0, 2, 2, 1, 1, 2, 0, 2, 2, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 2, 0, 0, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 0, 0, 2, 0, 1, 0, 0, 0, 1, 1, 1, 0, 2, 2, 0, 0, 1, 1, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 2, 1, 0, 0, 0, 0]
cell_line | |
---|---|
CELL_000572 | H1975 |
CELL_000233 | H1975 |
CELL_000098 | H1975 |
CELL_000200 | H2228 |
CELL_000677 | H1975 |
... | ... |
CELL_000765 | H2228 |
CELL_000404 | H1975 |
CELL_000305 | H1975 |
CELL_000730 | H1975 |
CELL_000082 | H1975 |
150 rows × 1 columns
Produce an encoded version of every sample in the dataset
# Create a DataLoader containing every sample
all_data = DatasetRNASeq(dataset)
all_dataloader = DataLoader(all_data)
all_encoded = []
all_unencoded = dataset.transpose().values
# Produce an encoding
for batch in all_dataloader:
encoded = enc(batch.cuda() if torch.cuda.is_available() else batch)
for sample in encoded:
# Convert to numpy array for clustering
all_encoded.append(sample.cpu().detach().numpy())
# Create standardized version of data
all_standardized_encoded = scaler.fit_transform(all_encoded)
all_standardized_unencoded = scaler.fit_transform(all_unencoded)
all_cell_labels = None
if test_cell_labels != None:
all_labels = metadata.values
# Encode the cell line / group labels
all_unique_labels = np.unique(all_labels)
all_label_map = {all_unique_labels[i]: [j for j in range(len(all_unique_labels))][i] for i in range(len(all_unique_labels))}
all_cell_labels = [all_label_map.get(key[0]) for key in all_labels]
all_cell_names = dataset.columns.values
all_cell_graph_text = [str(name) + ", " + str(label[0]) for name, label in zip(all_cell_names, all_labels)]
Specify the colours to use for the clusters.
# Standard cluster colours
cluster_colours = {0: "#E52592", 1: "#84E51F", 2: "#12B5CB", 3: "#9334E6",
4: "#F9AB00", 5: "#0500FF", 6: "#FFE300", 7: "#0E8602",
8: "#fd91ff", 9: "#ff4d00"}
# Alternative colours
cluster_colours1 = {0: "#2655ff", 1: "#16d900", 2: "#7729ff", 3: "#51ede8"} # Blue, green, purple, orange
cluster_colours2 = {0: "#d102e3", 1: "#94e30b", 2: "#4a07a3", 4: "#7affad", 5: "#c37aff"} # pink, light green, dark purple, light blue, violet
cluster_colours3 = {0: "#ff0000", 1: "#ffb300", 2: "#95ff00", 3: "#00ff2f",
4: "#00ffbb", 5: "#0091ff", 6: "#001aff", 7: "#8000ff",
8: "#ea00ff"} # Rainbow
actual_cluster_colours = cluster_colours1
if dataset.equals(cortex_dataset):
actual_cluster_colours = cluster_colours3
# Correct and incorrectly labelled clusters colours
false_colours = {0: "#b3b3b3", 1: "#ff0000"}
Plot the cells selected for testing against all cells. If labels were given for the cell lines / groups, these are coloured.
if test_cell_labels:
plot_cells(all_unencoded, all_cell_names, all_cell_labels,
test_cell_names, test_cell_labels,
graph_text=all_cell_graph_text, test_graph_text=cell_graph_text,
reduction="pca", components=2, apply_tsne=False,
colours=actual_cluster_colours, test_colours=actual_cluster_colours,
graph_type="3D", graph_title="Plot of All and Test Cells",
save_to_file=True, file_name="groups_plot.html")
For clustering algorithms such as k-means and hierarchical clustering, the number of clusters must be specified beforehand. For k-means, elbow method and the silhouette coefficient are two methods that can be used to find the optimal number. While for agglomerative hierarchical clustering, the best number can be determined by plotting a dendrogram.
Run the elbow method and silhouette coefficient for k-means:
max_k
determines the maximum number of clusters to check.max_k = 10
# K-means arguments
k_means_kwards = {"init": "k-means++", "n_init": 20, "max_iter": 300}
# Find the best value for k
sse, elbow_k = elbow_method(max_k, standardized_encoded, **k_means_kwards)
silhouette_coefficients, silhouette_k = silhouette_coefficient(max_k, standardized_encoded, **k_means_kwards)
print(colored("Elbow Method Best k: ", 'green', attrs=['bold']) + str(elbow_k))
print(colored("Silhouette Coefficient Best k: ", 'cyan', attrs=['bold']) + str(silhouette_k))
# Set number of clusters
if isinstance(metadata, pandas.DataFrame):
# Use true number of labels for known datasets
cluster_number = true_cluster_num
else:
# Alteratively, use elbow method k or silhouette k
cluster_number = elbow_k
Elbow Method Best k: 3 Silhouette Coefficient Best k: 5
Plot a dendrogram for agglomerative hierarchical clustering:
# Plot dendrogram
plot.figure(figsize=(10, 6), dpi=80)
plot.title("Dendrogram")
plot.xlabel('Cell')
plot.ylabel('Distance')
scipy.cluster.hierarchy.set_link_color_palette(['cyan', 'magenta', 'lime', 'purple', 'orange'])
dend = dendrogram(linkage(standardized_encoded, method='ward'))
Set cluster_number
as either:
true_cluster_num
if labels are givenelbow_k
to use the elbow method's recommendationsilhouette_k
to use the silhouette coefficient's recommendationcluster_number = 5
# Set number of clusters
if isinstance(metadata, pandas.DataFrame):
# Use true number of labels for known datasets
cluster_number = true_cluster_num
else:
# Alteratively, use elbow method k, silhouette k or number from dendrogram
cluster_number = elbow_k
Set algorithm
as the clustering algorithm to run. This can be either:
kmeans
: K-meanshierarchical
: Agglomerative hierarchical clusteringbirch
: BIRCHminibatch
: Mini batch k-meansspectral
: Spectral clusteringgaussian
: Gaussian MixtureFour variations of the data can be tested by changing these parameters:
use_standardized
specifies whether to use the pre-standardized data. This is the version in which standardization has been applied to the data before dimensionality reduction specified by one of the techniques below.
use_encoded
determines whether to perform the experiment on the encoded version of the data produced by the autoencoder, or on the original data.
A dimensionality reduction technique can be applied to extract features that best represent variation in the data in a lower-dimensional space. Variable reduction_technique
can be set as one of the following:
PCA
: Principal Component AnalysisICA
: Independent Component AnalysisNMF
: Non-Negative Matrix FactorizationNone
or ""
: Don't apply a dimensionality reduction techniquet-distributed Stochastic Neighbor Embedding can be applied before clustering by setting use_tsne = True
. The parameter perplexity
can be changed, though ideally this value should be between 5 - 50 and less than the number of cells.
To perform spectral biclustering on the data before running standard clustering, set run_biclustering = True
. This is a form of clustering that will cluster both the genes/encoded features and the cell samples.
# Set the clustering algorithm as a key from clustering_algorithms, e.g. "birch"
algorithm = "kmeans"
# Set parameters for data pre-processing before applying clustering
use_standardized = False
use_encoded = False
reduction_technique = "pca"
use_tsne = False
perplexity = 30
run_biclustering = False
# Dictionary of supported clustering algorithms
clustering_algorithms = {"kmeans": KMeans(n_clusters=cluster_number, **k_means_kwards),
"hierarchical": AgglomerativeClustering(n_clusters=cluster_number, affinity='euclidean', linkage='ward'),
"birch": Birch(threshold=0.01, n_clusters=cluster_number),
"minibatch": MiniBatchKMeans(n_clusters=cluster_number),
"spectral": SpectralClustering(n_clusters=cluster_number),
"gaussian": GaussianMixture(n_components=cluster_number)}
# These parameters are set automatically
if use_encoded:
selected_data = standardized_encoded if use_standardized else encoded_data
else:
selected_data = standardized_unencoded if use_standardized else unstandardized_unencoded
algorithm = algorithm.translate(str.maketrans('', '', string.punctuation)).lower()
# Set the clustering algorithm from the dictionary
selected_algorithm = clustering_algorithms.get(algorithm)
if not selected_algorithm:
# Default to k-means
selected_algorithm = clustering_algorithms.get("kmeans")
print("Desired algorithm not found, defaulting to k-means")
else:
print("Algorithm set as " + str(type(selected_algorithm).__name__))
selected_labels = test_cell_labels
selected_cell_names = test_cell_names
selected_feature_names = [i for i in range(len(selected_data[0]))] if use_encoded else gene_names
if run_biclustering:
print("Running Spectral Biclustering\n")
selected_data, selected_cell_names, selected_labels = bicluster(selected_data, selected_cell_names,
selected_feature_names, selected_labels,
cluster_sizes=(6, 6), using_genes=(not use_encoded),
display_graphs=True)
Algorithm set as KMeans
Now create 2D and 3D interative graphs plotted against 2 / 3 components extracted by the dimensionality reduction technique.
Click on the legend labels to toggle which clusters to display:
# Plot 2 principal components
plot_clusters(selected_data, selected_cell_names, selected_algorithm, reduction=reduction_technique, labels=selected_labels,
graph_text=cell_graph_text, colours=cluster_colours, correct_colours=actual_cluster_colours,
graph_type="2D", components=2, apply_tsne=use_tsne, tsne_perplexity=perplexity,
save_to_file=True, file_name="cluster_2d.html")
# Print 3 principal components
plot_clusters(selected_data, selected_cell_names, selected_algorithm, reduction=reduction_technique, labels=selected_labels,
graph_text=cell_graph_text, colours=cluster_colours, correct_colours=actual_cluster_colours,
graph_type="3D", components=3, apply_tsne=use_tsne, tsne_perplexity=perplexity,
save_to_file=True, file_name="cluster_3d.html")
Accuracy: 64.0% Adjusted Rand Index: 0.387 Silhouette Coefficient: 0.489
Accuracy: 99.333% Adjusted Rand Index: 0.98 Silhouette Coefficient: 0.478
Experiment using a higher numbers of components for the dimensionality reduction technique specified as reduction_technique
to see the effect on either accuracy, ARI or silhouette coefficient.
test_measure
as the one of the three metrics: accuracy
, ARI
or silhouette coefficient.
test_measure = "accuracy"
if selected_labels != None:
maximum_components = len(selected_data[0]) if use_encoded else min(20, len(selected_data[0]))
# Find number of components with the highest accuracy
component_results = find_best_components(maximum_components, selected_data, selected_labels, selected_algorithm,
apply_tsne=use_tsne, tsne_perplexity=perplexity, reduction=reduction_technique,
iterations=5, measure=test_measure)
component_accuracy = component_results.get('scores')
plot_components(component_accuracy, measure=test_measure)
print("\n")
display(pandas.DataFrame(component_accuracy, columns=[test_measure[0].upper() + test_measure[1:]], index=[i for i in range(2, len(component_accuracy)+2)]).transpose())
2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Accuracy | 0.64 | 0.993333 | 0.993333 | 0.993333 | 0.993333 | 0.990667 | 0.990667 | 0.992 | 0.985333 | 0.992 | 0.974667 | 0.908 | 0.892 | 0.814667 | 0.786667 | 0.84 | 0.641333 | 0.741333 | 0.564 |
Plot a 2D and 3D graph to view the clustering predictions when using the best number of components according to the metric set above. Note results may differ between the two plots as k-means is non-deterministic and is run twice.
if selected_labels != None:
best_n_components = component_results.get('best n')
# Plot onto 2D graph
plot_clusters(selected_data, test_cell_names, selected_algorithm, labels=selected_labels, reduction=reduction_technique,
graph_text=cell_graph_text, apply_tsne=use_tsne, tsne_perplexity=perplexity,
colours=cluster_colours, correct_colours=actual_cluster_colours,
graph_type="2D", components=best_n_components,
save_to_file=True, file_name="cluster_best_2d.html")
if best_n_components > 2:
# Plot onto 3D graph if more than 2 components
plot_clusters(selected_data, test_cell_names, selected_algorithm, labels=selected_labels, reduction=reduction_technique,
graph_text=cell_graph_text, apply_tsne=use_tsne, tsne_perplexity=perplexity,
colours=cluster_colours, correct_colours=actual_cluster_colours,
graph_type="3D", components=best_n_components,
save_to_file=True, file_name="cluster_best_3d.html")
Accuracy: 99.333% Adjusted Rand Index: 0.98 Silhouette Coefficient: 0.478
Accuracy: 99.333% Adjusted Rand Index: 0.98 Silhouette Coefficient: 0.478
If using PCA, view how much the top principal components contribute to the variance of the data.
if selected_labels != None and "variance" in component_results:
# Plot ratio of variance contribution for each principal component
component_variance = component_results.get('variance')
plot_variance(component_variance)
print("\n")
display(pandas.DataFrame(component_variance, columns=['Variance Contribution'], index=[i for i in range(2, len(component_variance)+2)]).transpose())
2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Variance Contribution | 0.412039 | 0.158785 | 0.101232 | 0.052688 | 0.029578 | 0.027386 | 0.022385 | 0.019398 | 0.017894 | 0.013528 | 0.013175 | 0.010506 | 0.008801 | 0.006881 | 0.006653 | 0.005996 | 0.005814 | 0.00486 | 0.004518 | 0.004198 |
If run on the original data, visualise the contribution genes to each component.
use_encoded
must be set to True
for this to run.use_standardizedected_data = False
before running this cell.n_heatmap_genes = 20
n_top_genes = 10
if not use_encoded and selected_labels != None and "components" in component_results:
# Get the eigen vectors
top_components = component_results.get('components')
# Create a heat map graph
heat_map = pandas.DataFrame(top_components, columns=gene_names)
heat_map = heat_map.drop(columns=gene_names[n_heatmap_genes:])
plot.figure(figsize=(15,10))
sns.heatmap(heat_map, cmap="mako")
plot.title('Contribution of Genes to Components', fontsize = 20)
plot.xlabel('Gene', fontsize = 15)
plot.ylabel('Component', fontsize = 15)
plot.show()
print("\n")
# Create a DataFrame containing the eigenvectors of the first top three components
eigenvectors = pandas.DataFrame(data={'Gene': gene_names,
'Component 1': list(map(abs, top_components[0])),
'Component 2': list(map(abs, top_components[1])),
'Component 3': list(map(abs, top_components[2]))})
# Sort by loading score of first component
eigenvectors = eigenvectors.sort_values(by=['Component 1'], ascending=False)
# Make DataFrame in suitable form for plotting a bar chart with multiple columns
vector_melt = pandas.melt(eigenvectors[:n_top_genes], id_vars="Gene", var_name="Components", value_name="Loading Score (Magnitude)")
# Plot the bar chart
fig = plot.figure(figsize=(19, 5))
plot.title("Highest Gene Contribution of Top 3 Components")
#sns.barplot(x=eigenvectors['Gene'].values[:n_top_genes], y=eigenvectors['Component 1'].values[:n_top_genes], palette="magma")
sns.barplot(x="Gene", y="Loading Score (Magnitude)", hue="Components", data=vector_melt, palette="viridis")
plot.show()
print("\n")
print(colored('Percentage genes contributioning to component 1:', 'cyan', attrs=['bold']), str(np.round((eigenvectors[eigenvectors['Component 1'] > 0].count().values[0] / len(gene_names) * 100), 3)) + "%")
print(colored('Percentage genes contributioning to component 2:', 'cyan', attrs=['bold']), str(np.round((eigenvectors[eigenvectors['Component 2'] > 0].count().values[0] / len(gene_names) * 100), 3)) + "%")
print(colored('Percentage genes contributioning to component 3:', 'cyan', attrs=['bold']), str(np.round((eigenvectors[eigenvectors['Component 3'] > 0].count().values[0] / len(gene_names) * 100), 3)) + "%")
Percentage genes contributioning to component 1: 99.484% Percentage genes contributioning to component 2: 99.484% Percentage genes contributioning to component 3: 99.484%
Experiment with six different clustering algorithms and parameter sets ups to determine the best overall performance.
run_alternatives
is set as True
as this can be slow.metric
can be set as either accuracy
or silhouette coefficient
.
Set dimension_reduction
as the dimensionality reduction technique to use:
PCA
: Principal Component AnalysisICA
: Independent Component AnalysisNMF
: Non-Negative Matrix Factorizationtest_iterations
is the number of times to test each experiment set up.
run_tsne
specifies whether to apply t-SNE after dimensionality reduction and run_tsne_perplexity
is the parameter perplexity that can be set if using t-SNE.
use_all_data
determines whether to run the experiments on all cells in the dataset.
False
will run the experiments on just the test samples.run_alternatives = True
metric = "accuracy"
dimension_reduction = "pca"
test_iterations = 3
run_tsne = False
run_tsne_perplexity = 30
use_all_data = True
Run experiments:
# List of scikit learn clustering algorithms to try
clustering_algorithms = [KMeans(n_clusters=cluster_number, **k_means_kwards),
AgglomerativeClustering(n_clusters=cluster_number, affinity='euclidean', linkage='ward'),
Birch(threshold=0.01, n_clusters=cluster_number),
MiniBatchKMeans(n_clusters=cluster_number),
SpectralClustering(n_clusters=cluster_number),
GaussianMixture(n_components=cluster_number)]
# Clustering algorithm names
algorithm_names = ["K-Means", "Hierarchical", "Birch", "Mini Batch K-Means",
"Spectral Clustering", "Gaussian Mixture"]
# Use every sample in the dataset
if use_all_data:
# Test data before and after encoding, with and without standardization
data_types = [all_unencoded, all_standardized_unencoded, all_encoded, all_standardized_encoded]
# Use only testing samples
else:
data_types = [unstandardized_unencoded, standardized_unencoded, encoded_data, standardized_encoded]
if test_cell_labels == None:
# Prevent accuracy being used if no labels avaliable
metric = "silhouette"
# Set proper name for metric
metric = "Silhouette Coefficient" if metric[0].lower() == "s" else "Accuracy"
cell_labels = all_cell_labels if use_all_data else test_cell_labels
if run_alternatives:
# Record the highest accuracy or silhouette coefficient of each clustering algorithm
algorithms_best_score = [0 for i in range(len(clustering_algorithms))]
# Record the best number of components for each algorithm
algorithms_best_n = [2 for i in range(len(clustering_algorithms))]
# Record if the encoded data outpeformed the unencoded data
algorithms_use_encoded = [False for i in range(len(clustering_algorithms))]
# Record if standardization outperformed unstandardized
algorithms_use_standardized = [False for i in range(len(clustering_algorithms))]
with warnings.catch_warnings():
# Prevent printing convergence warnings
warnings.simplefilter("ignore")
# Iterate through each combination of data
for data_idx, data_combo in enumerate(data_types):
# Set the maximum number of components to check
max_n = len(data_combo[0]) if data_idx >= 2 else 20
# Test different numbers of components
for n in range(2, max_n):
# Perform dimensionality reduction
reducer, scaled_components = reduce_data(dimension_reduction, n, data_combo, apply_tsne=run_tsne,
tsne_perplexity=run_tsne_perplexity)
# Test each algorithm over several attempts
for algo_idx, algorithm in enumerate(clustering_algorithms):
total_score = 0
for i in range(test_iterations):
algorithm.fit(scaled_components)
if isinstance(algorithm, GaussianMixture):
predicted_labels = algorithm.predict(scaled_components)
else:
predicted_labels = algorithm.labels_
if metric[0].lower() == "s":
# Calculate the silhouette coefficient
score = silhouette_score(scaled_components, predicted_labels)
else:
# Calculate accuracy
score, ARI, label_difference = cluster_metrics(cell_labels, predicted_labels)
total_score += score
# Find the average over all attempts
average_score = total_score / test_iterations
# Check if found best accuracy / silhouette coefficient
if average_score > algorithms_best_score[algo_idx]:
# Record the results
algorithms_best_score[algo_idx] = average_score
algorithms_best_n[algo_idx] = n
algorithms_use_encoded[algo_idx] = True if data_idx >= 2 else False
algorithms_use_standardized[algo_idx] = True if data_idx % 2 == 1 else False
# Calculate how many cells were not clustered correctly
incorrect_cells = [int(len(data_types[0]) - np.round(algorithms_best_score[i] * len(data_types[0]))) for i in range(len(clustering_algorithms))]
# Create DataFrame of results
algorithm_results = pandas.DataFrame(data={metric: algorithms_best_score,
'Number Incorrect': incorrect_cells,
'Components': algorithms_best_n,
'Encoded': algorithms_use_encoded,
'Standardized': algorithms_use_standardized},
index=algorithm_names)
if test_cell_labels != None and run_alternatives:
display(algorithm_results)
Accuracy | Number Incorrect | Components | Encoded | Standardized | |
---|---|---|---|---|---|
K-Means | 0.997783 | 2 | 3 | False | True |
Hierarchical | 0.998891 | 1 | 3 | False | True |
Birch | 0.998891 | 1 | 3 | False | True |
Mini Batch K-Means | 0.997783 | 2 | 3 | False | True |
Spectral Clustering | 0.998891 | 1 | 4 | False | True |
Gaussian Mixture | 0.997783 | 2 | 3 | False | True |
if test_cell_labels != None and run_alternatives:
# Plot highest accuracy / silhouette coefficient of each algorithm
fig = plot.subplots(figsize=(12, 5))
plot.bar(algorithm_names, algorithms_best_score, width=0.6,
color=["red", "darkorange", "gold", "lawngreen", "deepskyblue", "blueviolet"])
plot.xlabel('Clustering Algorithm')
plot.ylabel(metric)
for i, comp in enumerate(algorithms_best_n):
# Add number of components
plot.annotate(str(comp) + " " + dimension_reduction.upper() + " - " + ("encoded" if algorithms_use_encoded[i] else "unencoded"),
(algorithm_names[i], algorithms_best_score[i]),
textcoords="offset points", xytext=(0,2), ha='center')
if test_cell_labels != None and run_alternatives:
# Get the best clustering algorithm, optimal number of principal components and data to use
best_index = np.argmax(algorithms_best_score)
best_components = algorithms_best_n[best_index]
best_algorithm = clustering_algorithms[best_index]
if algorithms_use_standardized[best_index]:
if use_all_data:
best_data = all_standardized_encoded if algorithms_use_encoded[best_index] else all_standardized_unencoded
else:
best_data = standardized_encoded if algorithms_use_encoded[best_index] else standardized_unencoded
else:
if use_all_data:
best_data = all_encoded if algorithms_use_encoded[best_index] else all_unencoded
else:
best_data = encoded_data if algorithms_use_encoded[best_index] else unstandardized_unencoded
graph_text = all_cell_graph_text if use_all_data else cell_graph_text
cell_labels = all_cell_labels if use_all_data else test_cell_labels
# Plot 2D graph
plot_clusters(best_data, test_cell_names, best_algorithm, labels=cell_labels,
reduction=dimension_reduction, apply_tsne=run_tsne, tsne_perplexity=run_tsne_perplexity,
graph_text=graph_text, colours=cluster_colours,
correct_colours=actual_cluster_colours,
graph_type="2D", components=best_components,
save_to_file=True, file_name="cluster_best_result_2d.html")
if best_components > 2:
# Plot 3D graph
plot_clusters(best_data, test_cell_names, best_algorithm, labels=cell_labels,
reduction=dimension_reduction, apply_tsne=run_tsne, tsne_perplexity=run_tsne_perplexity,
graph_text=graph_text, colours=cluster_colours,
correct_colours=actual_cluster_colours,
graph_type="3D", components=best_components,
save_to_file=True, file_name="cluster_best_result_3d.html")
Accuracy: 99.889% Adjusted Rand Index: 0.997 Silhouette Coefficient: 0.538
Accuracy: 99.889% Adjusted Rand Index: 0.997 Silhouette Coefficient: 0.538
Run Kepler Mapper to create a visualisation of the higher-dimensional shape of the dataset.
tda_data = cell_data
# Create a mapper
mapper = km.KeplerMapper(verbose=2) # Set verbose to 2 to show logging, or 0 to show none
# Project data onto a lower-dimensional space
projected_data = mapper.project(tda_data, projection=PCA(n_components=2), distance_matrix="euclidean")
KeplerMapper(verbose=2) ..Projecting on data shaped (902, 16468) Created distance matrix, shape: (902, 902), with distance metric `euclidean` ..Projecting data using: PCA(copy=True, iterated_power='auto', n_components=2, random_state=None, svd_solver='auto', tol=0.0, whiten=False) ..Scaling with: MinMaxScaler(copy=True, feature_range=(0, 1))
graph_colors = [[0.0, '#fc0303'], [0.1, '#fc6b03'], [0.2, '#fcf003'], [0.3, '#88fc03'],
[0.4, '#03fc39'], [0.5, '#03fcba'], [0.6, '#0380fc'], [0.7, '#0324fc'],
[0.8, '#862bfc'], [0.9, '#fc03fc'], [1.0, '#fc0377']]
Apply the Mapper algorithm to build the simplicial complex.
n_cubes
will increase the granularity of the complex.perc_overlap
will result in more connections forming.# Build a simplicial complex
graph = mapper.map(projected_data, tda_data,
cover=km.Cover(n_cubes=10, perc_overlap=0.15),
clusterer=KMeans(n_clusters=cluster_number, **k_means_kwards))
Mapping on data shaped (902, 16468) using lens shaped (902, 2) Minimal points in hypercube before clustering: 3 Creating 100 hypercubes. > Found 3 clusters in hypercube 0. > Found 3 clusters in hypercube 1. > Found 3 clusters in hypercube 2. > Found 3 clusters in hypercube 3. > Found 3 clusters in hypercube 4. > Found 3 clusters in hypercube 5. > Found 3 clusters in hypercube 6. > Found 3 clusters in hypercube 7. > Found 3 clusters in hypercube 8. > Found 3 clusters in hypercube 9. > Found 3 clusters in hypercube 10. > Found 3 clusters in hypercube 11. > Found 3 clusters in hypercube 12. > Found 3 clusters in hypercube 13. Cube_14 is empty. > Found 3 clusters in hypercube 15. > Found 3 clusters in hypercube 16. > Found 3 clusters in hypercube 17. > Found 3 clusters in hypercube 18. > Found 3 clusters in hypercube 19. Cube_20 is empty. > Found 3 clusters in hypercube 21. > Found 3 clusters in hypercube 22. Cube_23 is empty. Cube_24 is empty. Cube_25 is empty. > Found 3 clusters in hypercube 26. > Found 3 clusters in hypercube 27. Cube_28 is empty. Cube_29 is empty. Cube_30 is empty. Cube_31 is empty. Created 95 edges and 69 nodes in 0:00:19.215567.
Download graph as a HTML file so that it can be visualised
# Visualize it
graph_file = "scRNAseq_graph.html"
graph_title = "scRNA-seq Simplicial Complex"
html = mapper.visualize(graph, colorscale = graph_colors, path_html=graph_file, title=graph_title)
#files.download(graph_file)
Wrote visualization to: scRNAseq_graph.html