Source code for gseapy.base

#! python
# -*- coding: utf-8 -*-

import logging
import os
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from pandas.api.types import is_object_dtype, is_string_dtype

from gseapy.enrichr import EnrichrAPI
from gseapy.plot import GSEAPlot, TracePlot, gseaplot, heatmap
from gseapy.utils import DEFAULT_CACHE_PATH, log_init, mkdirs


[docs] class GMT: """A collection of gene set dictionaries with metadata. Attributes: _collections: Dict[str, Dict[str, Any]] - Stores gene set collections key: collection name value: { 'genes': Dict[str, List[str]] - Gene set mappings 'description': str - Collection description 'source': str - Source of the gene sets } """ def __init__( self, mapping: Optional[Dict[str, List[str]]] = None, description: Optional[str] = None, source: Optional[str] = None, name: Optional[str] = "default", ): """Initialize a GMT collection. Args: mapping: Initial gene set dictionary description: Description of the gene sets source: Source of the gene sets name: Name of this collection """ self._collections = {} if mapping is not None: self.add(mapping, description, source, name)
[docs] def add( self, mapping: Dict[str, List[str]], description: Optional[str] = None, source: Optional[str] = None, name: Optional[str] = "default", ): """Add a gene set collection with metadata. Args: mapping: Gene set dictionary to add description: Description of the gene sets source: Source of the gene sets name: Name for this collection """ self._collections[name] = { "genes": mapping, "description": description, "source": source, }
[docs] def get(self, name: str = "default") -> Dict[str, List[str]]: """Get gene sets by collection name.""" return self._collections[name]["genes"]
[docs] def get_metadata(self, name: str = "default") -> Dict[str, Any]: """Get metadata for a collection.""" collection = self._collections[name] return { "description": collection["description"], "source": collection["source"], }
[docs] def write(self, ofname: str): """Write GMT file to disk.""" with open(ofname, "w") as out: for name, collection in self._collections.items(): for key, genes in collection["genes"].items(): desc = collection["description"] or "" line = [key, desc] + genes out.write("\t".join(line) + "\n")
[docs] def filter( self, min_size: Optional[int] = None, max_size: Optional[int] = None, gene_list: Optional[List[str]] = None, collections: Optional[List[str]] = None, ) -> "GMT": """Filter gene sets based on size and gene membership. Args: min_size: Minimum number of genes in a set max_size: Maximum number of genes in a set gene_list: Only keep genes present in this list collections: Only keep these named collections Returns: A new filtered GMT object """ filtered = GMT() # Filter collections if specified colls = collections or self._collections.keys() for name in colls: if name not in self._collections: continue collection = self._collections[name] filtered_mapping = {} for term, genes in collection["genes"].items(): # Filter genes if gene_list provided if gene_list is not None: genes = [g for g in genes if g in gene_list] # Apply size filters if min_size is not None and len(genes) < min_size: continue if max_size is not None and len(genes) > max_size: continue # Only add if genes remain after filtering if genes: filtered_mapping[term] = genes # Only add collection if it has gene sets after filtering if filtered_mapping: filtered.add( filtered_mapping, description=collection["description"], source=collection["source"], name=name, ) return filtered
[docs] @classmethod def read(cls, paths: str, source: Optional[str] = None) -> "GMT": """Read GMT files into a collection. Args: paths: Comma-separated list of GMT file paths source: Source annotation for the files """ gmt = cls() for path in paths.strip().split(","): name = os.path.basename(path) mapping = {} with open(path) as f: for line in f: items = line.strip().split("\t") key = items[0] desc = items[1] genes = items[2:] mapping[key] = genes gmt.add(mapping, desc, source, name) return gmt
def __getitem__(self, name: str) -> Dict[str, List[str]]: """Dictionary-like access to gene sets.""" return self.get(name) def __iter__(self): """Iterate over collection names.""" return iter(self._collections)
[docs] def items(self): """Iterate over (name, gene_sets) pairs.""" return ((name, coll["genes"]) for name, coll in self._collections.items())
[docs] class GSEAbase(object): """base class of GSEA.""" def __init__( self, outdir: Optional[str] = None, gene_sets: Union[List[str], str, Dict[str, str]] = "KEGG_2016", module: str = "base", threads: int = 1, organism: str = "human", verbose: bool = False, ): self.outdir = outdir self.gene_sets = gene_sets self.fdr = 0.05 self.module = module self.res2d = None self.ranking = None self.ascending = False self.verbose = verbose self._threads = threads self.pheno_pos = "" self.pheno_neg = "" self.permutation_num = 0 self._LIBRARY_LIST_URL = "https://maayanlab.cloud/speedrichr/api/listlibs" self._gene_isupper = True self._gene_toupper = False self._set_cores() # init logger self.prepare_outdir() # self._enrichrapi = EnrichrAPI(organism) def __del__(self): if hasattr(self, "_logger"): handlers = self._logger.handlers[:] for handler in handlers: handler.close() # close file self._logger.removeHandler(handler)
[docs] def prepare_outdir(self): """create temp directory.""" self._outdir = self.outdir logfile = None if isinstance(self.outdir, str): mkdirs(self.outdir) logfile = os.path.join(self.outdir, "gseapy.%s.%s.log" % (self.module, id(self))) self._logfile = logfile self._logger = log_init( name=str(self.module) + str(id(self)), log_level=logging.INFO if self.verbose else logging.WARNING, filename=logfile, )
def _set_cores(self): """set cpu numbers to be used""" cpu_num = os.cpu_count() - 1 if self._threads > cpu_num: cores = cpu_num elif self._threads < 1: cores = 1 else: cores = self._threads # have to be int if user input is float self._threads = int(cores) def _read_file(self, path: str) -> pd.DataFrame: """ read file, and return dataframe (first column are gene IDs) """ # just txt file like input header, sep = "infer", "\t" # GCT input format? if path.endswith(".gct"): rank_metric = pd.read_csv(path, skiprows=1, comment="#", index_col=0, sep=sep) else: if path.endswith(".csv"): sep = "," if path.endswith(".rnk"): header = None rank_metric = pd.read_csv(path, comment="#", index_col=0, sep=sep, header=header) if rank_metric.shape[1] == 1: # rnk file like input rank_metric.columns = rank_metric.columns.astype(str) return rank_metric.select_dtypes(include=[np.number]).reset_index() def _reset_index(self, rank_metric: pd.DataFrame): """ helper function to reset index if index is already gene_names """ # handle index is already gene_names # pandas 3.0 compatibility: check for both object and string dtypes if is_object_dtype(rank_metric.index.dtype) or is_string_dtype(rank_metric.index.dtype): # Try to check if all elements can be converted to numbers try: # is_string_numbers = True, don't reset index pd.to_numeric(rank_metric.index) except (ValueError, TypeError): # Contains non-numeric strings, likely gene names # is_string_numbers = False rank_metric = rank_metric.reset_index() return rank_metric def _load_data(self, exprs: Union[str, pd.Series, pd.DataFrame]) -> pd.DataFrame: """ helper function to read data """ # load data if isinstance(exprs, pd.DataFrame): rank_metric = exprs.copy() # handle dataframe with gene_name as index. self._logger.debug("Input data is a DataFrame with gene names") # handle index is already gene_names rank_metric = self._reset_index(rank_metric) # pandas 3.0 compatibility: check for both object and string dtypes if not (is_object_dtype(rank_metric.columns.dtype) or is_string_dtype(rank_metric.columns.dtype)): rank_metric.columns = rank_metric.columns.astype(str) elif isinstance(exprs, pd.Series): # change to DataFrame self._logger.debug("Input data is a Series with gene names") if not isinstance(exprs.name, str): if exprs.name is None: # rename col if name attr is none exprs.name = "sample1" elif hasattr(exprs.name, "dtype"): # pandas 3.0 compatibility: check for both object and string dtypes if not (is_object_dtype(exprs.name.dtype) or is_string_dtype(exprs.name.dtype)): exprs.name = exprs.name.astype(str) else: exprs.name = str(exprs.name) rank_metric = exprs.reset_index() elif os.path.isfile(exprs): rank_metric = self._read_file(exprs) else: raise Exception("Error parsing expression values!") # select numbers # rank_metric = rank_metric.select_dtypes(include=[np.number]) return rank_metric def _check_data(self, exprs: pd.DataFrame) -> pd.DataFrame: """ check NAs, duplicates, and select number columns exprs: dataframe, the frist column must be gene identifiers return: dataframe, index is gene ids """ ## if gene names contain NA, drop them if exprs.iloc[:, 0].isnull().any(): exprs.dropna(subset=[exprs.columns[0]]) ## then fill na for numeric columns if exprs.isnull().any().sum() > 0: self._logger.warning("Input data contains NA, filled NA with 0") exprs.dropna(how="all", inplace=True) # drop rows with all NAs exprs = exprs.fillna(0) ## check duplicated IDs # set gene name as index exprs.set_index(keys=exprs.columns[0], inplace=True) # select numberic columns df = exprs.select_dtypes(include=[np.number]) # microarray data may contained multiple probs of same gene, average them if df.index.duplicated().sum() > 0: self._logger.warning("Found duplicated gene names, values averaged by gene names!") df = df.groupby(level=0).mean() # check whether contains infinity values if np.isinf(df).values.sum() > 0: self._logger.warning("Input gene rankings contains inf values!") col_min_max = { np.inf: df[np.isfinite(df)].max(), # column-wise max -np.inf: df[np.isfinite(df)].min(), # column-wise min } df = df.replace({col: col_min_max for col in df.columns}) return df def _is_entrez_id(self, idx: Union[int, str]) -> bool: """ Check if an index is Entrez ID. Parameters ---------- idx : str or int An index to be checked. Returns ------- bool Whether the index is Entrez ID. """ try: int(idx) return True except: return False
[docs] def check_uppercase(self, gene_list: List[Union[str, int]]) -> bool: """ Check whether a list of gene names are mostly in uppercase. Parameters ---------- gene_list : list, int A list of gene names or Entrez IDs Returns ------- bool Whether the list of gene names are mostly in uppercase """ # if all gene names are Entrez IDs, don't check uppercase if all([self._is_entrez_id(g) for g in gene_list]): return False is_upper = [str(s).isupper() for s in gene_list] if sum(is_upper) / len(is_upper) >= 0.9: return True return False
[docs] def make_unique(self, rank_metric: pd.DataFrame, col_idx: int) -> pd.DataFrame: """ make gene id column unique by adding a digit, similar to R's make.unique """ id_col = rank_metric.columns[col_idx] if rank_metric.duplicated(subset=id_col).sum() > 0: self._logger.info("Input gene rankings contains duplicated IDs") mask = rank_metric.duplicated(subset=id_col, keep=False) dups = ( rank_metric.loc[mask, id_col] .to_frame() .groupby(id_col) .cumcount() .map(lambda c: "_" + str(c) if c else "") ) rank_metric.loc[mask, id_col] = rank_metric.loc[mask, id_col] + dups return rank_metric
[docs] def load_gmt_only(self, gmt: Union[List[str], str, Dict[str, str]]) -> Dict[str, List[str]]: """parse gene_sets. gmt: List, Dict, Strings However,this function will merge different gene sets into one big dict to save computation time for later. """ genesets_dict = dict() if isinstance(gmt, dict): genesets_dict = gmt.copy() elif isinstance(gmt, str): gmts = gmt.split(",") if len(gmts) > 1: for gm in gmts: tdt = self.parse_gmt(gm) for k, v in tdt.items(): new_k = os.path.split(gm)[-1] + "__" + k genesets_dict[new_k] = v else: genesets_dict = self.parse_gmt(gmt) elif isinstance(gmt, list): for i, gm in enumerate(gmt): prefix = str(i) if isinstance(gm, dict): tdt = gm.copy() elif isinstance(gm, str): tdt = self.parse_gmt(gm) prefix = os.path.split(gm)[-1] else: continue for k, v in tdt.items(): new_k = prefix + "__" + k genesets_dict[new_k] = v else: raise Exception("Error parsing gmt parameter for gene sets") if len(genesets_dict) == 0: raise Exception("Error parsing gmt parameter for gene sets") return genesets_dict
[docs] def load_gmt(self, gene_list: Iterable[str], gmt: Union[List[str], str, Dict[str, str]]) -> Dict[str, List[str]]: """load gene set dict""" genesets_dict = self.load_gmt_only(gmt) subsets = list(genesets_dict.keys()) if not subsets: # Check if empty raise ValueError("Empty gene sets dictionary") entry1st = genesets_dict[subsets[0]] # gmt_is_entrez = all([self._is_entrez_id(x) for x in entry1st]) # gene_is_entrez = all([self._is_entrez_id(x) for x in gene_list]) gene_dict = {g: i for i, g in enumerate(gene_list)} # Check uppercase for up to 20 sets sample_size = min(20, len(subsets)) ups = [] for s in subsets[:sample_size]: ups.append(self.check_uppercase(genesets_dict[s])) if (not self._gene_isupper) and all(ups): # set flag to True, means use uppercase version of gene symbols self._gene_toupper = True gene_dict_upper = {str(g).upper(): i for i, g in enumerate(gene_list)} # filter gene sets for subset in subsets: subset_list = set(genesets_dict.get(subset)) # remove duplicates # drop genes not found in the gene list if (not self._gene_isupper) and all(ups): gene_overlap = [g for g in subset_list if g in gene_dict_upper] else: gene_overlap = [g for g in subset_list if g in gene_dict] tag_len = len(gene_overlap) if (self.min_size <= tag_len <= self.max_size) and tag_len < len(gene_list): # tag_len should < gene_list genesets_dict[subset] = gene_overlap continue del genesets_dict[subset] filsets_num = len(subsets) - len(genesets_dict) self._logger.info( "%04d gene_sets have been filtered out when max_size=%s and min_size=%s" % (filsets_num, self.max_size, self.min_size) ) if filsets_num == len(subsets): msg = ( "No gene sets passed through filtering condition !!! \n" + "Hint 1: Try to lower min_size or increase max_size !\n" + "Hint 2: Check gene symbols are identifiable to your gmt input.\n" + "Hint 3: Gene symbols curated in Enrichr web services are all upcases.\n" ) self._logger.error(msg) dict_head = "{ %s: [%s]}" % (subsets[0], ", ".join(entry1st)) self._logger.error("The first entry of your gene_sets (gmt) look like this : %s" % dict_head) self._logger.error( "The first 5 genes look like this : [ %s ]" % (", ".join([str(g) for g in list(gene_list)[:5]])) ) raise LookupError(msg) # self._gmtdct = genesets_dict return genesets_dict
[docs] def parse_gmt(self, gmt: str) -> Dict[str, List[str]]: """gmt parser when input is a string""" if gmt.lower().endswith(".gmt"): genesets_dict = {} with open(gmt) as genesets: for line in genesets: entries = line.strip().split("\t") key = entries[0] genes = [g.split(",")[0] for g in entries[2:]] genesets_dict[key] = genes return genesets_dict tmpname = "Enrichr." + gmt + ".gmt" tempath = os.path.join(DEFAULT_CACHE_PATH, tmpname) # if file already download if os.path.isfile(tempath): self._logger.info( "Enrichr library gene sets already downloaded in: %s, use local file" % DEFAULT_CACHE_PATH ) return self.parse_gmt(tempath) elif gmt in self.get_libraries(): return self._download_libraries(gmt) else: self._logger.error("No supported gene_sets: %s" % gmt) return dict()
[docs] def get_libraries(self) -> List[str]: """return active enrichr library name.Offical API""" libs = self._enrichrapi.get_libraries() return sorted(libs)
def _download_libraries(self, libname: str) -> Dict[str, List[str]]: """Download enrichr libraries. Only Support Enrichr libraries now""" self._logger.info("Downloading and generating Enrichr library gene sets......") return self._enrichrapi.download_libraries(libname) def _heatmat(self, df: pd.DataFrame, classes: List[str]): """only use for gsea heatmap""" cls_booA = list(map(lambda x: True if x == self.pheno_pos else False, classes)) cls_booB = list(map(lambda x: True if x == self.pheno_neg else False, classes)) datA = df.loc[:, cls_booA] datB = df.loc[:, cls_booB] datAB = pd.concat([datA, datB], axis=1) self.heatmat = datAB return def _plotting(self, metric: Dict[str, Union[pd.Series, pd.DataFrame]]): """Plotting API. :param metric: sorted pd.Series with rankings values. """ # no values need to be returned if self._outdir is None: return # indices = self.res2d["NES"].abs().sort_values(ascending=False).index indices = self.res2d.index # Plotting for i, idx in enumerate(indices): record = self.res2d.iloc[idx] if self.module != "ssgsea" and record["FDR q-val"] > 0.25: continue if i >= self.graph_num: # already sorted by abs(NES) in descending order break # if self.res2d["Name"].nunique() > 1 and hasattr( # self, "_metric_dict" # ): # self.module != "ssgsea": # key = record["Name"] # rank_metric = metric[key] # hit = self._results[key][record["Term"]]["hits"] # RES = self._results[key][record["Term"]]["RES"] # else: # rank_metric = metric[self.module] # hit = self._results[record["Term"]]["hits"] # RES = self._results[record["Term"]]["RES"] key = record["Name"] rank_metric = metric[key] hit = self._results[key][record["Term"]]["hits"] RES = self._results[key][record["Term"]]["RES"] outdir = os.path.join(self.outdir, record["Name"]) mkdirs(outdir) term = record["Term"].replace("/", "-").replace(":", "_") outfile = os.path.join(outdir, "{0}.{1}".format(term, self.format)) if self.module == "gsea": outfile2 = "{0}/{1}.heatmap.{2}".format(outdir, term, self.format) heatmat = self.heatmat.iloc[hit, :] width = np.clip(heatmat.shape[1], 4, 20) height = np.clip(heatmat.shape[0], 4, 20) heatmap( df=heatmat, title=record["Term"].split("__")[-1], ofname=outfile2, z_score=0, figsize=(width, height), xticklabels=True, yticklabels=True, ) if self.permutation_num > 0: # skip plotting when nperm=0 gseaplot( term=record["Term"].split("__")[-1], hits=hit, nes=record["NES"], pval=record["NOM p-val"], fdr=record["FDR q-val"], RES=RES, rank_metric=rank_metric, pheno_pos=self.pheno_pos, pheno_neg=self.pheno_neg, figsize=self.figsize, ofname=outfile, ) def _to_df( self, gsea_summary: List[Dict], gmt: Dict[str, List[str]], metric: Dict[str, pd.Series], ) -> pd.DataFrame: """Convernt GSEASummary to DataFrame rank_metric: Must be sorted in descending order already """ res_df = pd.DataFrame( index=range(len(gsea_summary)), columns=[ "name", "term", "es", "nes", "pval", "fdr", "fwerp", "tag %", "gene %", "lead_genes", "matched_genes", "hits", "RES", ], ) # res = OrderedDict() for i, gs in enumerate(gsea_summary): # reformat gene list. name = self._metric_dict[str(gs.index)] if (gs.index is not None) else self.module _genes = metric[name].index.values[gs.hits] genes = ";".join([str(g).strip() for g in _genes]) RES = np.array(gs.run_es) lead_genes = "" tag_frac = "" gene_frac = "" if len(RES) > 1: # extract leading edge genes if float(gs.es) >= 0: # RES -> ndarray, ind -> list es_i = RES.argmax() ldg_pos = list(filter(lambda x: x <= es_i, gs.hits)) gene_frac = (es_i + 1) / len(metric[name]) else: es_i = RES.argmin() ldg_pos = list(filter(lambda x: x >= es_i, gs.hits)) ldg_pos.reverse() gene_frac = (len(metric[name]) - es_i) / len(metric[name]) # tag_frac = len(ldg_pos) / len(gmt[gs.term]) gene_frac = "{0:.2%}".format(gene_frac) lead_genes = ";".join(list(map(str, metric[name].iloc[ldg_pos].index))) tag_frac = "%s/%s" % (len(ldg_pos), len(gmt[gs.term])) e = pd.Series( [ name, gs.term, gs.es, gs.nes, gs.pval, gs.fdr, gs.fwerp, tag_frac, gene_frac, lead_genes, genes, gs.hits, gs.run_es, ], index=res_df.columns, ) res_df.iloc[i, :] = e return res_df
[docs] def to_df( self, gsea_summary: List[Dict], gmt: Dict[str, List[str]], rank_metric: Union[pd.Series, pd.DataFrame], indices: Optional[List] = None, ): """Convernt GSEASummary to DataFrame rank_metric: if a Series, then it must be sorted in descending order already if a DataFrame, indices must not None. indices: Only works for DataFrame input. Stores the indices of sorted array """ if isinstance(rank_metric, pd.DataFrame) and (indices is not None): self._metric_dict = {str(c): n for c, n in enumerate(rank_metric.columns)} metric = { n: rank_metric.iloc[indices[i], i] # series with gene_name indexed for i, n in enumerate(rank_metric.columns) # indices is a 2d list } else: metric = {self.module: rank_metric} self._metric_dict = {self.module: self.module} res_df = self._to_df(gsea_summary, gmt, metric) self._results = {} # save dict # if res_df["name"].nunique() >= 2: # for name, dd in res_df.groupby(["name"]): # self._results[name] = dd.set_index("term").to_dict(orient="index") # else: # self._results = res_df.set_index("term").to_dict(orient="index") for name, dd in res_df.groupby("name"): self._results[name] = dd.set_index("term").to_dict(orient="index") # trim res_df.rename( columns={ "name": "Name", "term": "Term", "es": "ES", "nes": "NES", "pval": "NOM p-val", "fdr": "FDR q-val", "fwerp": "FWER p-val", "tag %": "Tag %", "gene %": "Gene %", "lead_genes": "Lead_genes", }, inplace=True, ) # res_df["Gene %"] = res_df["Gene %"].map(lambda x: "{0:.2%}".format(x) if x !="" else "") # Replace 0 p-values with the minimum detectable value (1/permutation_num) # This avoids exact zeros from permutation testing, which would cause issues # in downstream analysis and visualization (e.g. log-scale plots). if self.permutation_num > 0: min_pval = 1.0 / self.permutation_num for col in ["NOM p-val", "FDR q-val", "FWER p-val"]: if col in res_df.columns: res_df[col] = res_df[col].astype(float).clip(lower=min_pval) # trim dc = ["RES", "hits", "matched_genes"] if self.permutation_num == 0: dc += [ "NOM p-val", "FWER p-val", "FDR q-val", "Tag %", "Gene %", "Lead_genes", ] if self.module == "gsva": dc += ["NES"] # re-order by NES # for pandas > 1.1, use df.sort_values(by='B', key=abs) will sort by abs value self.res2d = res_df.reindex(res_df["NES"].abs().sort_values(ascending=False).index).reset_index(drop=True) self.res2d.drop(dc, axis=1, inplace=True) if self._outdir is not None: out = os.path.join( self.outdir, "gseapy.{b}.{c}.report.csv".format(b=self.permutation_type, c=self.module), ) self.res2d.to_csv(out, index=False, float_format="%.6e") with open(os.path.join(self.outdir, "gene_sets.gmt"), "w") as gout: for term, genes in gmt.items(): collection = "" if term.find("__") > -1: collections = term.split("__") collection = collections[0] term = collections[1] gg = "\t".join(genes) gout.write(f"{term}\t{collection}\t{gg}\n") # generate gseaplots if not self._noplot: self._plotting(metric) return
@property def results(self): """ compatible to old style """ keys = list(self._results.keys()) if len(keys) == 1: return self._results[keys[0]] return self._results
[docs] def enrichment_score( self, gene_list: Iterable[str], correl_vector: Iterable[float], gene_set: Dict[str, List[str]], weight: float = 1.0, nperm: int = 1000, seed: int = 123, single: bool = False, scale: bool = False, ): """This is the most important function of GSEApy. It has the same algorithm with GSEA and ssGSEA. :param gene_list: The ordered gene list gene_name_list, rank_metric.index.values :param gene_set: gene_sets in gmt file, please use gmt_parser to get gene_set. :param weight: It's the same with gsea's weighted_score method. Weighting by the correlation is a very reasonable choice that allows significant gene sets with less than perfect coherence. options: 0(classic),1,1.5,2. default:1. if one is interested in penalizing sets for lack of coherence or to discover sets with any type of nonrandom distribution of tags, a value p < 1 might be appropriate. On the other hand, if one uses sets with large number of genes and only a small subset of those is expected to be coherent, then one could consider using p > 1. Our recommendation is to use p = 1 and use other settings only if you are very experienced with the method and its behavior. :param correl_vector: A vector with the correlations (e.g. signal to noise scores) corresponding to the genes in the gene list. Or rankings, rank_metric.values :param nperm: Only use this parameter when computing esnull for statistical testing. Set the esnull value equal to the permutation number. :param seed: Random state for initializing gene list shuffling. Default: seed=None :return: ES: Enrichment score (real number between -1 and +1) ESNULL: Enrichment score calculated from random permutations. Hits_Indices: Index of a gene in gene_list, if gene is included in gene_set. RES: Numerical vector containing the running enrichment score for all locations in the gene list . """ N = len(gene_list) # Test whether each element of a 1-D array is also present in a second array # It's more intuitive here than original enrichment_score source code. # use .astype to covert bool to integer tag_indicator = np.isin(gene_list, gene_set, assume_unique=True).astype( int ) # notice that the sign is 0 (no tag) or 1 (tag) if weight == 0: correl_vector = np.repeat(1, N) else: correl_vector = np.abs(correl_vector) ** weight # get indices of tag_indicator hit_ind = np.flatnonzero(tag_indicator).tolist() # if used for compute esnull, set esnull equal to permutation number, e.g. 1000 # else just compute enrichment scores # set axis to 1, because we have 2D array axis = 1 tag_indicator = np.tile(tag_indicator, (nperm + 1, 1)) correl_vector = np.tile(correl_vector, (nperm + 1, 1)) # gene list permutation rs = np.random.RandomState(seed) for i in range(nperm): rs.shuffle(tag_indicator[i]) # np.apply_along_axis(rs.shuffle, 1, tag_indicator) Nhint = tag_indicator.sum(axis=axis, keepdims=True) sum_correl_tag = np.sum(correl_vector * tag_indicator, axis=axis, keepdims=True) # compute ES score, the code below is identical to gsea enrichment_score method. no_tag_indicator = 1 - tag_indicator Nmiss = N - Nhint norm_tag = 1.0 / sum_correl_tag norm_no_tag = 1.0 / Nmiss RES = np.cumsum( tag_indicator * correl_vector * norm_tag - no_tag_indicator * norm_no_tag, axis=axis, ) if scale: RES = RES / N if single: es_vec = RES.sum(axis=axis) else: max_ES, min_ES = RES.max(axis=axis), RES.min(axis=axis) es_vec = np.where(np.abs(max_ES) > np.abs(min_ES), max_ES, min_ES) # extract values es, esnull, RES = es_vec[-1], es_vec[:-1], RES[-1, :] return es, esnull, hit_ind, RES
[docs] def plot( self, terms: Union[str, List[str]], colors: Optional[Union[str, List[str]]] = None, legend_kws: Optional[Dict[str, Any]] = None, figsize: Tuple[float, float] = (4, 5), show_ranking: bool = True, ofname: Optional[str] = None, ): """ terms: str, list. terms/pathways to show colors: str, list. list of colors for each term/pathway legend_kws: kwargs to pass to ax.legend. e.g. `loc`, `bbox_to_achor`. ofname: savefig """ # if hasattr(self, "results"): if self.module in ["ssgsea", "gsva"]: raise NotImplementedError("not for ssgsea") keys = list(self._results.keys()) if len(keys) > 1: raise NotImplementedError("Multiple Dataset input No supported yet!") ranking = self.ranking if show_ranking else None if isinstance(terms, str): gsdict = self.results[terms] g = GSEAPlot( term=terms, tag=gsdict["hits"], rank_metric=ranking, runes=gsdict["RES"], nes=gsdict["nes"], pval=gsdict["pval"], fdr=gsdict["fdr"], ofname=ofname, pheno_pos=self.pheno_pos, pheno_neg=self.pheno_neg, color=colors, figsize=figsize, ) g.add_axes() g.savefig() return g.fig elif hasattr(terms, "__len__"): # means iterable terms = list(terms) tags = [self.results[t]["hits"] for t in terms] runes = [self.results[t]["RES"] for t in terms] t = TracePlot( terms=terms, tags=tags, runes=runes, rank_metric=ranking, colors=colors, legend_kws=legend_kws, ofname=ofname, figsize=figsize, ) t.add_axes() t.savefig(ofname) return t.fig else: print("not supported input: terms")