Source code for tint.datasets.biobank

import json
import multiprocessing as mp
import numpy as np
import os
import pandas as pd
import torch as th

from torch.nn.utils.rnn import pad_sequence
from typing import List, Union

from tint.utils import get_progress_bars
from .dataset import DataModule
from .utils import create_labels, Fasttext

try:
    from google.cloud import bigquery
except ImportError:
    bigquery = None

try:
    from pandarallel import pandarallel
except ImportError:
    pandarallel = None


tqdm = get_progress_bars()
tqdm.pandas(leave=False)
file_dir = os.path.dirname(__file__)


[docs]class BioBank(DataModule): """ BioBank dataset. Args: label (str): Condition to be used as label. If ``None``, it is set to type 2 diabetes. Default to ``None`` discretised (bool): Whether to return a discretised dataset or not. Default to ``False`` granularity (str, int): The time granularity. Default to a year. maximum_time (int): Maximum time to record. Default to 115 years fasttext (Fasttext): A Fasttext model to encode categorical features. Default to ``None`` time_to_task (float): Special arg for diabetes task. Stops the recording before diabetes happens. Default to .5 std_time_to_task (float): Add randomness into when to stop recording. Default to .2 data_dir (str): Where to download files. batch_size (int): Batch size. Default to 32 prop_val (float): Proportion of validation. Default to .2 n_folds (int): Number of folds for cross validation. If ``None``, the dataset is only split once between train and val using ``prop_val``. Default to ``None`` fold (int): Index of the fold to use with cross-validation. Ignored if n_folds is None. Default to ``None`` num_workers (int): Number of workers for the loaders. Default to 0 seed (int): For the random split. Default to 42 References: https://www.ukbiobank.ac.uk """ def __init__( self, label: str = None, discretised: bool = False, granularity: int = 1, maximum_time: int = 115, fasttext: Fasttext = None, time_to_task: float = 0.5, std_time_to_task: float = 0.2, data_dir: str = os.path.join( os.path.split(file_dir)[0], "data", "biobank", ), batch_size: int = 32, prop_val: float = 0.2, n_folds: int = None, fold: int = None, num_workers: int = 0, seed: int = 42, ): super().__init__( data_dir=data_dir, batch_size=batch_size, prop_val=prop_val, n_folds=n_folds, fold=fold, num_workers=num_workers, seed=seed, ) self.discretised = discretised self.granularity = granularity self.maximum_time = maximum_time self.fasttext = fasttext self.time_to_task = time_to_task self.std_time_to_task = std_time_to_task with open(os.path.join(file_dir, "utils", "read_3_2.json"), "r") as fp: self.read_3_2 = json.load(fp=fp) with open(os.path.join(file_dir, "utils", "labels.json"), "r") as fp: self.labels = json.load(fp=fp) # We drop term codes for all labels self.labels = {k: [x[:5] for x in v] for k, v in self.labels.items()} # Set the label or use the default type 2 diabetes self.label = label or self.labels["Type II diabetes mellitus (4)"] def download( self, split: str = "train", verbose: Union[int, bool] = False, ): # Set tqdm if necessary if verbose: assert tqdm is not None, "tqdm must be installed." pbar = tqdm(range(3), leave=True) if verbose else None pbar.set_description("Load Metadata") if verbose else None # Init pandarallel cpu_count = mp.cpu_count() assert pandarallel is not None, "pandarallel is not installed." pandarallel.initialize( nb_workers=max(1, cpu_count - 1), progress_bar=max(0, verbose - 1), verbose=0, use_memory_fs=False, ) # Query Metadata assert bigquery is not None, "google-cloud-bigquery must be installed." client = bigquery.Client() query = f""" SELECT * FROM `dsap-prod-uk-biobank-ml.bio_bank.ukb_core` as ukb """ metadata = client.query(query=query).to_dataframe() metadata = metadata.dropna(axis=1, thresh=1) # Convert eventual datetime object to datetime columns = metadata.select_dtypes(include="object").columns for column in columns: try: metadata[column] = metadata[column].astype("float") except: try: metadata[column] = pd.to_datetime(metadata[column]) except TypeError: continue except ValueError: continue # Convert dates in metadata to years columns = metadata.select_dtypes(include="datetime").columns columns = list(set(columns) - {"_34_0_0"}) for column in columns: metadata[column] = metadata[[column, "_34_0_0"]].parallel_apply( lambda x: np.nan if pd.isna(x[0]) or pd.isna(x[1]) else 1970 + x[0].timestamp() / 3600 / 24 / 365.25 - x[1], axis=1, ) # Update dob metadata["_34_0_0"] = metadata["_34_0_0"].parallel_apply( lambda x: np.nan if pd.isna(x) else x ) # Drop non-recognised columns metadata = metadata.select_dtypes(exclude="object") # Update tqdm pbar.update() if verbose else None pbar.set_description("Load GP data") if verbose else None # Query GP data query = f""" SELECT * FROM `dsap-prod-uk-biobank-ml.bio_bank.gp_clinical` """ df = client.query(query=query).to_dataframe() # Filter df df = df[ df.read_3.isin(list(self.read_3_2.keys())) + ~df.read_2.isnull() ] df = df[df["event_dt"].notnull()] # Convert read 3 codes to read 2 df["read"] = df[["read_2", "read_3"]].parallel_apply( lambda x: x[0] if x[0] is not None else self.read_3_2[x[1]], axis=1 ) # Remove None types and unknown values df = df[df["read"].notnull()] df.read = df.read.parallel_apply(lambda x: x[:5] if len(x) > 5 else x) # Convert dates df.event_dt = pd.to_datetime(df.event_dt).parallel_apply( lambda x: x.timestamp() ) df.event_dt = df.event_dt.parallel_apply( lambda x: x if pd.isna(x) else x / 3600 / 24 / 365.25 ) # Sort by timestamp df = df.sort_values(by="event_dt") # Add year of birth to df and subtract it, remove negative values df = pd.merge(df, metadata[["eid", "_34_0_0"]], how="inner", on="eid") df["event_dt"] = df[["event_dt", "_34_0_0"]].parallel_apply( lambda x: np.nan if pd.isna(x[0]) or pd.isna(x[1]) else 1970 + x[0] - x[1], axis=1, ) df = df[df["event_dt"] >= 0] # Create codes_to_idx data unique_codes = df.read.unique() codes_to_idx = {k: i + 1 for i, k in enumerate(unique_codes)} # Update tqdm pbar.update() if verbose else None pbar.set_description("Group patients") if verbose else None # Group per patient if verbose: df = ( df[["eid", "event_dt", "read"]] .groupby(["eid", "event_dt"]) .progress_aggregate(list) .reset_index() ) else: df = ( df[["eid", "event_dt", "read"]] .groupby(["eid", "event_dt"]) .agg(list) .reset_index() ) read = df.groupby("eid").read.apply(list).reset_index(name="read") times = ( df.groupby("eid").event_dt.apply(list).reset_index(name="times") ) # Merge with metadata and save dataframe df = pd.merge(read, times, how="inner", on="eid") df = pd.merge(df, metadata, how="inner", on="eid") df.to_csv(os.path.join(self.data_dir, "biobank_data.csv"), index=False) # Save codes_to_idx with open(os.path.join(self.data_dir, "codes_to_idx.json"), "w") as fp: json.dump(obj=codes_to_idx, fp=fp) # Save text file for fasttext training with open(os.path.join(self.data_dir, "codes_text.txt"), "w") as fp: events = df.read for patient in events: for record in patient: for label in record: label = codes_to_idx.get(label, None) if label is not None: fp.write(str(label) + " ") fp.write("\n") # Update tqdm pbar.update() if verbose else None def preprocess( self, split: str = "train", verbose: Union[bool, int] = False, ) -> dict: # Init pandarallel cpu_count = mp.cpu_count() assert pandarallel is not None, "pandarallel is not installed." pandarallel.initialize( nb_workers=max(1, cpu_count - 1), progress_bar=max(0, verbose - 1), verbose=0, use_memory_fs=False, ) # Load data df = pd.read_csv(os.path.join(self.data_dir, "biobank_data.csv")) df.times = df.times.parallel_apply(eval) df.read = df.read.parallel_apply(eval) # Load codes_to_idx with open(os.path.join(self.data_dir, "codes_to_idx.json"), "r") as fp: codes_to_idx = json.load(fp=fp) # Extract times and metadata from dataframe times = df.times.values times = [th.Tensor(x).type(th.float32) for x in times] metadata = df.drop(["eid", "times", "read"], axis=1).values metadata = [th.Tensor(x).type(th.float32) for x in metadata] # Replace codes with int events = df.read.apply( lambda x: [[codes_to_idx[z] for z in y] for y in x] ) if self.discretised: labels, mask = self.build_discretized_labels( events=events, times=times, ) events, times = self.build_discretized_features( events=events, times=times, verbose=verbose, ) events[mask.bool().squeeze(-1)] = 0.0 else: labels, mask = self.build_labels( events=events, times=times, ) events, times = self.build_features( events=events, times=times, verbose=verbose, ) events = [x[y.bool()] for x, y in zip(events, mask)] times = [x[y.bool()] for x, y in zip(times, mask)] return { "events": events, "times": times, "metadata": metadata, "labels": labels, "mask": mask, } def prepare_data(self): """""" if not os.path.exists(os.path.join(self.data_dir, "biobank_data.csv")): self.download() def collate_fn(self, batch: list) -> (th.Tensor, th.Tensor): # Get keys keys = set(batch[0].keys()) if self.discretised: keys -= {"times"} # Group data into a dict of tensors batch = {k: [b[k] for b in batch] for k in keys} for key in keys: batch[key] = pad_sequence( batch[key], batch_first=True, padding_value=0, ) # Transform metadata time_shape = batch["events"].shape[1] batch["metadata"] = batch["metadata"].repeat(1, time_shape, 1) # Group features x = th.cat([batch["events", batch["metadata"]]], dim=-1) if not self.discretised: x = th.cat([x, batch["times"]], dim=-1) return x, batch["labels"]
[docs] def build_features( self, events: List[th.Tensor], times: List[th.Tensor], verbose: Union[bool, int] = False, ): """ Build features. Args: events (list): The read codes. times (list): Times of each event. verbose (bool, int): Verbosity level. Default to ``False`` Returns: Preprocessed features. """ if verbose: events = tqdm(events, total=len(events), leave=False) if self.fasttext is not None: events = [ th.stack( [ self.fasttext.transform(th.Tensor(x).long()).type( th.float32 ) for x in r ] ) for r in events ] # [D,L,M] else: max_sim_events = max([max([len(x) for x in y]) for y in events]) events_ = [ pad_sequence([th.Tensor(x) for x in y], batch_first=True) for y in events ] events = [th.zeros((x.shape[0], max_sim_events)) for x in events_] for x, y in zip(events, events_): x[:, : y.shape[1]] = y return events, times
[docs] def build_discretized_features( self, events: List[th.Tensor], times: List[th.Tensor], verbose: Union[bool, int] = False, ): """ Build discretized features. Args: events (list): The read codes. times (list): Times of each event. verbose (bool, int): Verbosity level. Default to ``False`` Returns: Preprocessed features. """ if verbose: events = tqdm(events, total=len(events), leave=False) if self.fasttext is not None: codes = th.zeros( ( len(events), int(self.maximum_time / self.granularity) + 1, self.fasttext.emb_dim, ) ).type(th.float32) for i, r in enumerate(events): f = th.stack( [ self.fasttext.transform(th.Tensor(x).long()).type( th.float32 ) for x in r ] ) codes[i].index_add_(0, times[i].long(), f) else: raise NotImplementedError( "When using discretised data, " "categorical features must be encoded." ) return codes, None
[docs] def build_labels( self, events: list, times: list, ) -> (list, list): """ Build labels. Args: events (list): Dict of events. times (list): List of times. Returns: (list, list): Two lists of labels and tasks """ labels, end_of_records = create_labels( events=events, event_times=times, labels=self.label, time_to_task=self.time_to_task, std_time_to_task=self.std_time_to_task, maximum_time=self.maximum_time, seed=self.seed, ) mask_records = [ (x < y).type(th.float32) for x, y in zip(times, end_of_records) ] return [x for x in labels], mask_records
[docs] def build_discretized_labels( self, events: list, times: list, ) -> (th.Tensor, th.Tensor): """ Build discretized labels. Args: events (list): List of events. times (list): List of times. Returns: (th.Tensor, th.Tensor): Two tensors of labels and tasks """ labels_, end_of_records = create_labels( events=events, event_times=times, labels=self.label, time_to_task=self.time_to_task, std_time_to_task=self.std_time_to_task, maximum_time=self.maximum_time, seed=self.seed, ) idx = ( (end_of_records / self.granularity) .long() .clamp(max=self.maximum_time / self.granularity) ) mask_records = th.zeros( (len(events), int(self.maximum_time / self.granularity) + 1, 1) ).type(th.float32) for i in range(len(mask_records)): if idx[i] >= 0: mask_records[i][: idx[i]] = th.ones((idx[i], 1)) return labels_, mask_records