Source code for tint.datasets.mimic3

import numpy as np
import os
import pandas as pd
import pickle as pkl
import random
import torch as th
import warnings

from datetime import timedelta
from getpass import getpass
from sklearn.impute import SimpleImputer

from tint.utils import get_progress_bars

from .dataset import DataModule

try:
    import psycopg2
except ImportError:
    psycopg2 = None


warnings.filterwarnings("ignore")

file_dir = os.path.dirname(__file__)

vital_IDs = [
    "HeartRate",
    "SysBP",
    "DiasBP",
    "MeanBP",
    "RespRate",
    "SpO2",
    "Glucose",
    "Temp",
]
lab_IDs = [
    "ANION GAP",
    "ALBUMIN",
    "BICARBONATE",
    "BILIRUBIN",
    "CREATININE",
    "CHLORIDE",
    "GLUCOSE",
    "HEMATOCRIT",
    "HEMOGLOBIN" "LACTATE",
    "MAGNESIUM",
    "PHOSPHATE",
    "PLATELET",
    "POTASSIUM",
    "PTT",
    "INR",
    "PT",
    "SODIUM",
    "BUN",
    "WBC",
]
eth_list = ["white", "black", "hispanic", "asian", "other"]

EPS = 1e-5


[docs]class Mimic3(DataModule): r""" MIMIC-III dataset. Download is set up according to this repository: https://github.com/sanatonek/time_series_explainability. .. warning:: Using this dataset requires to have the MIMIC III data running on a local server. Please see https://mimic.mit.edu/docs/gettingstarted/local/install-mimic-locally-ubuntu/ for more information. Args: task (str): Name of the task to perform. Either ``'mortality'`` or ``'blood_pressure'``. Default to ``'mortality'`` data_dir (str): Where to download files. batch_size (int): Batch size. Default to 32 n_folds (int): Number of folds for cross validation. If ``None``, the dataset is only split once between train and val using ``prop_val``. Default to ``None`` fold (int): Index of the fold to use with cross-validation. Ignored if n_folds is None. Default to ``None`` prop_val (float): Proportion of validation. Default to .2 num_workers (int): Number of workers for the loaders. Default to 0 seed (int): For the random split. Default to 42 References: #. https://physionet.org/content/mimiciii/1.4/ #. https://github.com/sanatonek/time_series_explainability/blob/master/data_generator/icu_mortality.py Examples: >>> from tint.datasets import Mimic3 <BLANKLINE> >>> mimci3 = Mimic3() >>> mimci3.download(sqluser="your_username", split="train") >>> x_train = mimci3.preprocess(split="train")["x"] >>> y_train = mimci3.preprocess(split="train")["y"] """ def __init__( self, task: str = "mortality", data_dir: str = os.path.join( os.path.split(file_dir)[0], "data", "mimic3", ), batch_size: int = 32, prop_val: float = 0.2, n_folds: int = None, fold: int = None, num_workers: int = 0, seed: int = 42, ): super().__init__( data_dir=data_dir, batch_size=batch_size, prop_val=prop_val, n_folds=n_folds, fold=fold, num_workers=num_workers, seed=seed, ) assert task in [ "mortality", "blood_pressure", ], f"task must be either mortality or blood_pressure, got {task}." self.task = task # Init mean and std self._mean = None self._std = None def download( self, sqluser: str = "mimicuser", prop_train: float = 0.8, split: str = "train", ): assert psycopg2 is not None, "You need to install psycopg2." random.seed(22891) sqlpass = getpass(prompt="sqlpass: ") # create a database connection and connect to local postgres # version of mimic dbname = "mimic" schema_name = "mimiciii" con = psycopg2.connect( dbname=dbname, user=sqluser, host="127.0.0.1", password=sqlpass, ) cur = con.cursor() cur.execute("SET search_path to " + schema_name) # ========get the icu details # this query extracts the following: # Unique ids for the admission, patient and icu stay # Patient gender # diagnosis # age # ethnicity # admission type # first hospital stay # first icu stay? # mortality within a week denquery = """ --ie is the icustays table --adm is the admissions table SELECT ie.subject_id, ie.hadm_id, ie.icustay_id , pat.gender , adm.admittime, adm.dischtime, adm.diagnosis , ROUND( (CAST(adm.dischtime AS DATE) - CAST(adm.admittime AS DATE)) , 4) AS los_hospital , ROUND( (CAST(adm.admittime AS DATE) - CAST(pat.dob AS DATE)) / 365, 4) AS age , adm.ethnicity, adm.ADMISSION_TYPE --, adm.hospital_expire_flag , CASE when adm.deathtime between ie.intime and ie.outtime THEN 1 ELSE 0 END AS mort_icu , DENSE_RANK() OVER (PARTITION BY adm.subject_id ORDER BY adm.admittime) AS hospstay_seq , CASE WHEN DENSE_RANK() OVER (PARTITION BY adm.subject_id ORDER BY adm.admittime) = 1 THEN 1 ELSE 0 END AS first_hosp_stay -- icu level factors , ie.intime, ie.outtime , ie.FIRST_CAREUNIT , ROUND( (CAST(ie.outtime AS DATE) - CAST(ie.intime AS DATE)) , 4) AS los_icu , DENSE_RANK() OVER (PARTITION BY ie.hadm_id ORDER BY ie.intime) AS icustay_seq , CASE WHEN adm.deathtime between ie.intime and ie.intime + interval '168' hour THEN 1 ELSE 0 END AS mort_week -- first ICU stay *for the current hospitalization* , CASE WHEN DENSE_RANK() OVER (PARTITION BY ie.hadm_id ORDER BY ie.intime) = 1 THEN 1 ELSE 0 END AS first_icu_stay FROM icustays ie INNER JOIN admissions adm ON ie.hadm_id = adm.hadm_id INNER JOIN patients pat ON ie.subject_id = pat.subject_id WHERE adm.has_chartevents_data = 1 ORDER BY ie.subject_id, adm.admittime, ie.intime; """ den = pd.read_sql_query(denquery, con) # drop patients with less than 48 hour den["los_icu_hr"] = (den.outtime - den.intime).astype("timedelta64[h]") den = den[(den.los_icu_hr >= 48)] den = den[(den.age < 300)] den.drop("los_icu_hr", 1, inplace=True) # clean up den["adult_icu"] = np.where( den["first_careunit"].isin(["PICU", "NICU"]), 0, 1 ) den["gender"] = np.where(den["gender"] == "M", 1, 0) den.ethnicity = den.ethnicity.str.lower() den.ethnicity.loc[(den.ethnicity.str.contains("^white"))] = "white" den.ethnicity.loc[(den.ethnicity.str.contains("^black"))] = "black" den.ethnicity.loc[ (den.ethnicity.str.contains("^hisp")) | (den.ethnicity.str.contains("^latin")) ] = "hispanic" den.ethnicity.loc[(den.ethnicity.str.contains("^asia"))] = "asian" den.ethnicity.loc[ ~( den.ethnicity.str.contains( "|".join(["white", "black", "hispanic", "asian"]) ) ) ] = "other" den.drop( [ "hospstay_seq", "los_icu", "icustay_seq", "admittime", "dischtime", "los_hospital", "intime", "outtime", "first_careunit", ], 1, inplace=True, ) # ========= 48 hour vitals query # these are the normal ranges. useful to clean up the data vitquery = """ -- This query pivots the vital signs for the first 48 hours of a patient's stay -- Vital signs include heart rate, blood pressure, respiration rate, and temperature -- DROP MATERIALIZED VIEW IF EXISTS vitalsfirstday CASCADE; -- create materialized view vitalsfirstday as SELECT pvt.subject_id, pvt.hadm_id, pvt.icustay_id, pvt.VitalID, pvt.VitalValue, pvt.VitalChartTime FROM ( select ie.subject_id, ie.hadm_id, ie.icustay_id, ce.charttime as VitalChartTime , case when itemid in (211,220045) and valuenum > 0 and valuenum < 300 then 'HeartRate' when itemid in (51,442,455,6701,220179,220050) and valuenum > 0 and valuenum < 400 then 'SysBP' when itemid in (8368,8440,8441,8555,220180,220051) and valuenum > 0 and valuenum < 300 then 'DiasBP' when itemid in (456,52,6702,443,220052,220181,225312) and valuenum > 0 and valuenum < 300 then 'MeanBP' when itemid in (615,618,220210,224690) and valuenum > 0 and valuenum < 70 then 'RespRate' when itemid in (223761,678) and valuenum > 70 and valuenum < 120 then 'Temp' -- converted to degC in valuenum call when itemid in (223762,676) and valuenum > 10 and valuenum < 50 then 'Temp' when itemid in (646,220277) and valuenum > 0 and valuenum <= 100 then 'SpO2' when itemid in (807,811,1529,3745,3744,225664,220621,226537) and valuenum > 0 then 'Glucose' else null end as VitalID , case when itemid in (211,220045) and valuenum > 0 and valuenum < 300 then valuenum -- HeartRate when itemid in (51,442,455,6701,220179,220050) and valuenum > 0 and valuenum < 400 then valuenum -- SysBP when itemid in (8368,8440,8441,8555,220180,220051) and valuenum > 0 and valuenum < 300 then valuenum -- DiasBP when itemid in (456,52,6702,443,220052,220181,225312) and valuenum > 0 and valuenum < 300 then valuenum -- MeanBP when itemid in (615,618,220210,224690) and valuenum > 0 and valuenum < 70 then valuenum -- RespRate when itemid in (223761,678) and valuenum > 70 and valuenum < 120 then (valuenum-32)/1.8 -- TempF, convert to degC when itemid in (223762,676) and valuenum > 10 and valuenum < 50 then valuenum -- TempC when itemid in (646,220277) and valuenum > 0 and valuenum <= 100 then valuenum -- SpO2 when itemid in (807,811,1529,3745,3744,225664,220621,226537) and valuenum > 0 then valuenum -- Glucose else null end as VitalValue from icustays ie left join chartevents ce on ie.subject_id = ce.subject_id and ie.hadm_id = ce.hadm_id and ie.icustay_id = ce.icustay_id and ce.charttime between ie.intime and ie.intime + interval '48' hour -- exclude rows marked as error and ce.error IS DISTINCT FROM 1 where ce.itemid in ( -- HEART RATE 211, --"Heart Rate" 220045, --"Heart Rate" -- Systolic/diastolic 51, -- Arterial BP [Systolic] 442, -- Manual BP [Systolic] 455, -- NBP [Systolic] 6701, -- Arterial BP #2 [Systolic] 220179, -- Non Invasive Blood Pressure systolic 220050, -- Arterial Blood Pressure systolic 8368, -- Arterial BP [Diastolic] 8440, -- Manual BP [Diastolic] 8441, -- NBP [Diastolic] 8555, -- Arterial BP #2 [Diastolic] 220180, -- Non Invasive Blood Pressure diastolic 220051, -- Arterial Blood Pressure diastolic -- MEAN ARTERIAL PRESSURE 456, --"NBP Mean" 52, --"Arterial BP Mean" 6702, -- Arterial BP Mean #2 443, -- Manual BP Mean(calc) 220052, --"Arterial Blood Pressure mean" 220181, --"Non Invasive Blood Pressure mean" 225312, --"ART BP mean" -- RESPIRATORY RATE 618,-- Respiratory Rate 615,-- Resp Rate (Total) 220210,-- Respiratory Rate 224690, -- Respiratory Rate (Total) -- SPO2, peripheral 646, 220277, -- GLUCOSE, both lab and fingerstick 807,-- Fingerstick Glucose 811,-- Glucose (70-105) 1529,-- Glucose 3745,-- BloodGlucose 3744,-- Blood Glucose 225664,-- Glucose finger stick 220621,-- Glucose (serum) 226537,-- Glucose (whole blood) -- TEMPERATURE 223762, -- "Temperature Celsius" 676, -- "Temperature C" 223761, -- "Temperature Fahrenheit" 678 -- "Temperature F" ) ) pvt where VitalID is not null order by pvt.subject_id, pvt.hadm_id, pvt.icustay_id, pvt.VitalID, pvt.VitalChartTime; """ vit48 = pd.read_sql_query(vitquery, con) vit48.isnull().sum() # ===============48 hour labs query # This query extracts the lab events in the first 48 hours labquery = """ WITH pvt AS ( --- ie is the icu stay --- ad is the admissions table --- le is the lab events table SELECT ie.subject_id, ie.hadm_id, ie.icustay_id, le.charttime as LabChartTime , CASE when le.itemid = 50868 then 'ANION GAP' when le.itemid = 50862 then 'ALBUMIN' when le.itemid = 50882 then 'BICARBONATE' when le.itemid = 50885 then 'BILIRUBIN' when le.itemid = 50912 then 'CREATININE' when le.itemid = 50806 then 'CHLORIDE' when le.itemid = 50902 then 'CHLORIDE' when le.itemid = 50809 then 'GLUCOSE' when le.itemid = 50931 then 'GLUCOSE' when le.itemid = 50810 then 'HEMATOCRIT' when le.itemid = 51221 then 'HEMATOCRIT' when le.itemid = 50811 then 'HEMOGLOBIN' when le.itemid = 51222 then 'HEMOGLOBIN' when le.itemid = 50813 then 'LACTATE' when le.itemid = 50960 then 'MAGNESIUM' when le.itemid = 50970 then 'PHOSPHATE' when le.itemid = 51265 then 'PLATELET' when le.itemid = 50822 then 'POTASSIUM' when le.itemid = 50971 then 'POTASSIUM' when le.itemid = 51275 then 'PTT' when le.itemid = 51237 then 'INR' when le.itemid = 51274 then 'PT' when le.itemid = 50824 then 'SODIUM' when le.itemid = 50983 then 'SODIUM' when le.itemid = 51006 then 'BUN' when le.itemid = 51300 then 'WBC' when le.itemid = 51301 then 'WBC' ELSE null END AS label , -- add in some sanity checks on the values CASE when le.itemid = 50862 and le.valuenum > 10 then null -- g/dL 'ALBUMIN' when le.itemid = 50868 and le.valuenum > 10000 then null -- mEq/L 'ANION GAP' when le.itemid = 50882 and le.valuenum > 10000 then null -- mEq/L 'BICARBONATE' when le.itemid = 50885 and le.valuenum > 150 then null -- mg/dL 'BILIRUBIN' when le.itemid = 50806 and le.valuenum > 10000 then null -- mEq/L 'CHLORIDE' when le.itemid = 50902 and le.valuenum > 10000 then null -- mEq/L 'CHLORIDE' when le.itemid = 50912 and le.valuenum > 150 then null -- mg/dL 'CREATININE' when le.itemid = 50809 and le.valuenum > 10000 then null -- mg/dL 'GLUCOSE' when le.itemid = 50931 and le.valuenum > 10000 then null -- mg/dL 'GLUCOSE' when le.itemid = 50810 and le.valuenum > 100 then null -- % 'HEMATOCRIT' when le.itemid = 51221 and le.valuenum > 100 then null -- % 'HEMATOCRIT' when le.itemid = 50811 and le.valuenum > 50 then null -- g/dL 'HEMOGLOBIN' when le.itemid = 51222 and le.valuenum > 50 then null -- g/dL 'HEMOGLOBIN' when le.itemid = 50813 and le.valuenum > 50 then null -- mmol/L 'LACTATE' when le.itemid = 50960 and le.valuenum > 60 then null -- mmol/L 'MAGNESIUM' when le.itemid = 50970 and le.valuenum > 60 then null -- mg/dL 'PHOSPHATE' when le.itemid = 51265 and le.valuenum > 10000 then null -- K/uL 'PLATELET' when le.itemid = 50822 and le.valuenum > 30 then null -- mEq/L 'POTASSIUM' when le.itemid = 50971 and le.valuenum > 30 then null -- mEq/L 'POTASSIUM' when le.itemid = 51275 and le.valuenum > 150 then null -- sec 'PTT' when le.itemid = 51237 and le.valuenum > 50 then null -- 'INR' when le.itemid = 51274 and le.valuenum > 150 then null -- sec 'PT' when le.itemid = 50824 and le.valuenum > 200 then null -- mEq/L == mmol/L 'SODIUM' when le.itemid = 50983 and le.valuenum > 200 then null -- mEq/L == mmol/L 'SODIUM' when le.itemid = 51006 and le.valuenum > 300 then null -- 'BUN' when le.itemid = 51300 and le.valuenum > 1000 then null -- 'WBC' when le.itemid = 51301 and le.valuenum > 1000 then null -- 'WBC' ELSE le.valuenum END AS LabValue FROM icustays ie LEFT JOIN labevents le ON le.subject_id = ie.subject_id AND le.hadm_id = ie.hadm_id AND le.charttime between (ie.intime) AND (ie.intime + interval '48' hour) AND le.itemid IN ( -- comment is: LABEL | CATEGORY | FLUID | NUMBER OF ROWS IN LABEVENTS 50868, -- ANION GAP | CHEMISTRY | BLOOD | 769895 50862, -- ALBUMIN | CHEMISTRY | BLOOD | 146697 50882, -- BICARBONATE | CHEMISTRY | BLOOD | 780733 50885, -- BILIRUBIN, TOTAL | CHEMISTRY | BLOOD | 238277 50912, -- CREATININE | CHEMISTRY | BLOOD | 797476 50902, -- CHLORIDE | CHEMISTRY | BLOOD | 795568 50806, -- CHLORIDE, WHOLE BLOOD | BLOOD GAS | BLOOD | 48187 50931, -- GLUCOSE | CHEMISTRY | BLOOD | 748981 50809, -- GLUCOSE | BLOOD GAS | BLOOD | 196734 51221, -- HEMATOCRIT | HEMATOLOGY | BLOOD | 881846 50810, -- HEMATOCRIT, CALCULATED | BLOOD GAS | BLOOD | 89715 51222, -- HEMOGLOBIN | HEMATOLOGY | BLOOD | 752523 50811, -- HEMOGLOBIN | BLOOD GAS | BLOOD | 89712 50813, -- LACTATE | BLOOD GAS | BLOOD | 187124 50960, -- MAGNESIUM | CHEMISTRY | BLOOD | 664191 50970, -- PHOSPHATE | CHEMISTRY | BLOOD | 590524 51265, -- PLATELET COUNT | HEMATOLOGY | BLOOD | 778444 50971, -- POTASSIUM | CHEMISTRY | BLOOD | 845825 50822, -- POTASSIUM, WHOLE BLOOD | BLOOD GAS | BLOOD | 192946 51275, -- PTT | HEMATOLOGY | BLOOD | 474937 51237, -- INR(PT) | HEMATOLOGY | BLOOD | 471183 51274, -- PT | HEMATOLOGY | BLOOD | 469090 50983, -- SODIUM | CHEMISTRY | BLOOD | 808489 50824, -- SODIUM, WHOLE BLOOD | BLOOD GAS | BLOOD | 71503 51006, -- UREA NITROGEN | CHEMISTRY | BLOOD | 791925 51301, -- WHITE BLOOD CELLS | HEMATOLOGY | BLOOD | 753301 51300 -- WBC COUNT | HEMATOLOGY | BLOOD | 2371 ) AND le.valuenum IS NOT null AND le.valuenum > 0 -- lab values cannot be 0 and cannot be negative LEFT JOIN admissions ad ON ie.subject_id = ad.subject_id AND ie.hadm_id = ad.hadm_id ) SELECT pvt.subject_id, pvt.hadm_id, pvt.icustay_id, pvt.LabChartTime, pvt.label, pvt.LabValue From pvt where pvt.label is not NULL ORDER BY pvt.subject_id, pvt.hadm_id, pvt.icustay_id, pvt.label, pvt.LabChartTime; """ lab48 = pd.read_sql_query(labquery, con) # =====combine all variables mort_vital = den.merge( vit48, how="left", on=["subject_id", "hadm_id", "icustay_id"] ) mort_lab = den.merge( lab48, how="left", on=["subject_id", "hadm_id", "icustay_id"] ) # create means by age group and gender mort_vital["age_group"] = pd.cut( mort_vital["age"], [-1, 5, 10, 15, 20, 25, 40, 60, 80, 200], labels=[ "l5", "5_10", "10_15", "15_20", "20_25", "25_40", "40_60", "60_80", "80p", ], ) mort_lab["age_group"] = pd.cut( mort_lab["age"], [-1, 5, 10, 15, 20, 25, 40, 60, 80, 200], labels=[ "l5", "5_10", "10_15", "15_20", "20_25", "25_40", "40_60", "60_80", "80p", ], ) # one missing variable adult_vital = mort_vital[(mort_vital.adult_icu == 1)] adult_lab = mort_lab[(mort_lab.adult_icu == 1)] adult_vital.drop(columns=["adult_icu"], inplace=True) adult_lab.drop(columns=["adult_icu"], inplace=True) # Save files adult_vital.to_csv( os.path.join(self.data_dir, "adult_icu_vital.gz"), compression="gzip", index=False, ) mort_lab.to_csv( os.path.join(self.data_dir, "adult_icu_lab.gz"), compression="gzip", index=False, ) # Drop NAs adult_vital = adult_vital.dropna(subset=["vitalid"]) mort_lab = mort_lab.dropna(subset=["label"]) # Get unique ids icu_ids = adult_vital.icustay_id.unique() # Create arrays x = np.zeros((len(icu_ids), 12, 48)) x_lab = np.zeros((len(icu_ids), len(lab_IDs), 48)) x_impute = np.zeros((len(icu_ids), 12, 48)) y = np.zeros((len(icu_ids),)) imp_mean = SimpleImputer(strategy="mean") missing_ids = [] missing_map = np.zeros((len(icu_ids), 12)) missing_map_lab = np.zeros((len(icu_ids), len(lab_IDs))) nan_map = np.zeros((len(icu_ids), len(lab_IDs) + 12)) # Create ethnicity encoding # Populate data pbar = get_progress_bars()(enumerate(icu_ids), total=len(icu_ids)) for i, icu_id in pbar: patient_data = adult_vital.loc[adult_vital["icustay_id"] == icu_id] patient_data["vitalcharttime"] = patient_data[ "vitalcharttime" ].astype("datetime64[s]") patient_lab_data = mort_lab.loc[mort_lab["icustay_id"] == icu_id] patient_lab_data["labcharttime"] = patient_lab_data[ "labcharttime" ].astype("datetime64[s]") admit_time = patient_data["vitalcharttime"].min() n_missing_vitals = 0 # Extract demographics and repeat them over time x[i, -4, :] = int(patient_data["gender"].iloc[0]) x[i, -3, :] = int(patient_data["age"].iloc[0]) x[i, -2, :] = ethnicity_encoder( patient_data["ethnicity"].iloc[0], patient_data ) x[i, -1, :] = int(patient_data["first_icu_stay"].iloc[0]) y[i] = int(patient_data["mort_icu"].iloc[0]) # Extract vital measurement information vitals = patient_data.vitalid.unique() for vital in vitals: try: vital_IDs.index(vital) signal = patient_data[patient_data["vitalid"] == vital] quantized_signal, _ = quantize_signal( signal, start=admit_time, step_size=1, n_steps=48, value_column="vitalvalue", charttime_column="vitalcharttime", ) nan_arr, nan_count = check_nan(quantized_signal) x[i, vital_IDs.index(vital), :] = np.array( quantized_signal ) nan_map[ i, len(lab_IDs) + vital_IDs.index(vital) ] = nan_count if nan_count == 48: n_missing_vitals = +1 missing_map[i, vital_IDs.index(vital)] = 1 else: x_impute[i, :, :] = imp_mean.fit_transform( x[i, :, :].T ).T except: # noqa: E722 pass # Extract lab measurement informations labs = patient_lab_data.label.unique() for lab in labs: try: lab_IDs.index(lab) lab_measures = patient_lab_data[ patient_lab_data["label"] == lab ] quantized_lab, quantized_measures = quantize_signal( lab_measures, start=admit_time, step_size=1, n_steps=48, value_column="labvalue", charttime_column="labcharttime", ) nan_arr, nan_count = check_nan(quantized_lab) x_lab[i, lab_IDs.index(lab), :] = np.array(quantized_lab) nan_map[i, lab_IDs.index(lab)] = nan_count if nan_count == 48: missing_map_lab[i, lab_IDs.index(lab)] = 1 except: # noqa: E722 pass # Remove a patient that is missing a measurement for the entire 48 hours if n_missing_vitals > 0: missing_ids.append(i) # Record statistics of the dataset, remove missing samples and save the signals f = open(os.path.join(self.data_dir, "stats.txt"), "a") f.write( "\n ******************* Before removing missing *********************" ) f.write( "\n Number of patients: " + str(len(y)) + "\n Number of patients who died within their stay: " + str(np.count_nonzero(y)) ) f.write("\nMissingness report for Vital signals") for i, vital in enumerate(vital_IDs): f.write( "\nMissingness for %s: %.2f" % (vital, np.count_nonzero(missing_map[:, i]) / len(icu_ids)) ) f.write("\n") f.write("\nMissingness report for Vital signals") for i, lab in enumerate(lab_IDs): f.write( "\nMissingness for %s: %.2f" % (lab, np.count_nonzero(missing_map_lab[:, i]) / len(icu_ids)) ) f.write("\n") x_lab = np.delete(x_lab, missing_ids, axis=0) x_impute = np.delete(x_impute, missing_ids, axis=0) y = np.delete(y, missing_ids, axis=0) nan_map = np.delete(nan_map, missing_ids, axis=0) x_lab_impute = impute_lab(x_lab) missing_map = np.delete(missing_map, missing_ids, axis=0) missing_map_lab = np.delete(missing_map_lab, missing_ids, axis=0) all_data = np.concatenate((x_lab_impute, x_impute), axis=1) f.write( "\n ******************* After removing missing *********************" ) f.write( "\n Final number of patients: " + str(len(y)) + "\n Number of patients who died within their stay: " + str(np.count_nonzero(y)) ) f.write("\nMissingness report for Vital signals") for i, vital in enumerate(vital_IDs): f.write( "\nMissingness for %s: %.2f" % (vital, np.count_nonzero(missing_map[:, i]) / len(icu_ids)) ) f.write("\n") f.write("\nMissingness report for Vital signals") for i, lab in enumerate(lab_IDs): f.write( "\nMissingness for %s: %.2f" % (lab, np.count_nonzero(missing_map_lab[:, i]) / len(icu_ids)) ) f.write("\n") f.close() samples = [ (all_data[i, :, :], y[i], nan_map[i, :]) for i in range(len(y)) ] # Split train and test train_size = int(len(samples) * prop_train) train_samples = samples[:train_size] test_samples = samples[train_size:] # Save preprocessed data with open( os.path.join( self.data_dir, "train_patient_vital_preprocessed.pkl" ), "wb", ) as f: pkl.dump(train_samples, f) with open( os.path.join(self.data_dir, "test_patient_vital_preprocessed.pkl"), "wb", ) as f: pkl.dump(test_samples, f) def prepare_data(self): """""" if not os.path.exists( os.path.join(self.data_dir, "train_patient_vital_preprocessed.pkl") ) or not os.path.join( self.data_dir, "test_patient_vital_preprocessed.pkl" ): sqluser = input("sqluser: ") self.download(sqluser=sqluser) def preprocess(self, split: str = "train") -> dict: # Load data file = os.path.join(self.data_dir, f"{split}_") with open(file + "patient_vital_preprocessed.pkl", "rb") as fp: data = pkl.load(fp) features = th.Tensor([x for (x, y, z) in data]).transpose(1, 2) if self.task == "mortality": labels = th.Tensor([y for (x, y, z) in data]) else: labels = features[..., 22] features = th.cat([features[..., :20], features[..., 23:]], dim=-1) # Compute mean and std if split == "train": self._mean = features.mean(dim=(0, 1), keepdim=True) self._std = features.std(dim=(0, 1), keepdim=True) else: assert split == "test", "split must be train or test" assert ( self._mean is not None ), "You must call preprocess('train') first" # Normalise features = (features - self._mean) / (self._std + EPS) return { "x": features.float(), "y": labels.long() if self.task == "mortality" else labels.float(), }
def quantize_signal( signal, start, step_size, n_steps, value_column, charttime_column ): quantized_signal = [] quantized_counts = np.zeros((n_steps,)) s = start u = start + timedelta(hours=step_size) for i in range(n_steps): signal_window = signal[value_column][ (signal[charttime_column] > s) & (signal[charttime_column] < u) ] quantized_signal.append(signal_window.mean()) quantized_counts[i] = len(signal_window) s = u u = s + timedelta(hours=step_size) return quantized_signal, quantized_counts def check_nan(a): a = np.array(a) nan_arr = np.isnan(a).astype(int) nan_count = np.count_nonzero(nan_arr) return nan_arr, nan_count def forward_impute(x, nan_arr): x_impute = x.copy() first_value = 0 while first_value < len(x) and nan_arr[first_value] == 1: first_value += 1 last = x_impute[first_value] for i, measurement in enumerate(x): if nan_arr[i] == 1: x_impute[i] = last else: last = measurement return x_impute def impute_lab(lab_data): imputer = SimpleImputer(strategy="mean") lab_data_impute = lab_data.copy() imputer.fit(lab_data.reshape((-1, lab_data.shape[1]))) for i, patient in enumerate(lab_data): for j, signal in enumerate(patient): nan_arr, nan_count = check_nan(signal) if nan_count != len(signal): lab_data_impute[i, j, :] = forward_impute(signal, nan_arr) lab_data_impute = np.array( [imputer.transform(sample.T).T for sample in lab_data_impute] ) return lab_data_impute def ethnicity_encoder(eth, patient_data): return ( 0 if eth == "0" else eth_list.index(patient_data["ethnicity"].iloc[0]) + 1 )