Source code for sklift.datasets.datasets

import os
import shutil

import pandas as pd
import requests
from sklearn.utils import Bunch


[docs]def get_data_dir(): """Return the path of the scikit-uplift data dir. This folder is used by some large dataset loaders to avoid downloading the data several times. By default the data dir is set to a folder named ‘scikit_learn_data’ in the user home folder. Returns: string: The path to scikit-uplift data dir. """ return os.path.join(os.path.expanduser("~"), "scikit-uplift-data")
def _create_data_dir(path): """Creates a directory, which stores the datasets. Args: path (str): The path to scikit-uplift data dir. """ if not os.path.isdir(path): os.makedirs(path) def _download(url, dest_path): """Download the file from url and save it locally. Args: url (str): URL address, must be a string. dest_path (str): Destination of the file. """ if isinstance(url, str): req = requests.get(url, stream=True) req.raise_for_status() with open(dest_path, "wb") as fd: for chunk in req.iter_content(chunk_size=2 ** 20): fd.write(chunk) else: raise TypeError("URL must be a string") def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing): """Return the path to the dataset. Args: data_home (str): The path to scikit-uplift data dir. url (str): The URL to the dataset. dest_subdir (str): The name of the folder in which the dataset is stored. dest_filename (str): The name of the dataset. download_if_missing (bool): If False, raise a IOError if the data is not locally available instead of trying to download the data from the source site. Returns: string: The path to the dataset. """ if data_home is None: if dest_subdir is None: data_dir = get_data_dir() else: data_dir = os.path.join(get_data_dir(), dest_subdir) else: if dest_subdir is None: data_dir = os.path.abspath(data_home) else: data_dir = os.path.join(os.path.abspath(data_home), dest_subdir) _create_data_dir(data_dir) dest_path = os.path.join(data_dir, dest_filename) if not os.path.isfile(dest_path): if download_if_missing: _download(url, dest_path) else: raise IOError("Dataset missing") return dest_path
[docs]def clear_data_dir(path=None): """Delete all the content of the data home cache. Args: path (str): The path to scikit-uplift data dir """ if path is None: path = get_data_dir() if os.path.isdir(path): shutil.rmtree(path, ignore_errors=True)
[docs]def fetch_lenta(data_home=None, dest_subdir=None, download_if_missing=True, return_X_y_t=False): """Load and return the Lenta dataset (classification). An uplift modeling dataset containing data about Lenta's customers grociery shopping and related marketing campaigns. Major columns: - ``group`` (str): treatment/control group flag - ``response_att`` (binary): target - ``gender`` (str): customer gender - ``age`` (float): customer age - ``main_format`` (int): store type (1 - grociery store, 0 - superstore) Read more in the :ref:`docs <Lenta>`. Args: data_home (str): The path to the folder where datasets are stored. dest_subdir (str): The name of the folder in which the dataset is stored. download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing. return_X_y_t (bool): If True, returns (data, target, treatment) instead of a Bunch object. Returns: Bunch or tuple: dataset. Bunch: By default dictionary-like object, with the following attributes: * ``data`` (DataFrame object): Dataset without target and treatment. * ``target`` (Series object): Column target by values. * ``treatment`` (Series object): Column treatment by values. * ``DESCR`` (str): Description of the Lenta dataset. * ``feature_names`` (list): Names of the features. * ``target_name`` (str): Name of the target. * ``treatment_name`` (str): Name of the treatment. Tuple: tuple (data, target, treatment) if `return_X_y` is True """ url = 'https://winterschool123.s3.eu-north-1.amazonaws.com/lentadataset.csv.gz' filename = url.split('/')[-1] csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir, dest_filename=filename, download_if_missing=download_if_missing) target_col = 'response_att' treatment_col = 'group' data = pd.read_csv(csv_path) treatment, target = data[treatment_col], data[target_col] data = data.drop([target_col, treatment_col], axis=1) feature_names = list(data.columns) if return_X_y_t: return data, target, treatment module_path = os.path.dirname(__file__) with open(os.path.join(module_path, 'descr', 'lenta.rst')) as rst_file: fdescr = rst_file.read() return Bunch(data=data, target=target, treatment=treatment, DESCR=fdescr, feature_names=feature_names, target_name=target_col, treatment_name=treatment_col)
[docs]def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True): """Load and return the X5 RetailHero dataset (classification). The dataset contains raw retail customer purchases, raw information about products and general info about customers. Major columns: - ``treatment_flg`` (binary): treatment/control group flag - ``target`` (binary): target - ``customer_id`` (str): customer id - primary key for joining Read more in the :ref:`docs <X5>`. Args: data_home (str, unicode): The path to the folder where datasets are stored. dest_subdir (str, unicode): The name of the folder in which the dataset is stored. download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing Returns: Bunch: dataset. Dictionary-like object, with the following attributes. * ``data`` (Bunch object): dictionary-like object without target and treatment: * ``clients`` (ndarray or DataFrame object): General info about clients. * ``train`` (ndarray or DataFrame object): A subset of clients for training. * ``purchases`` (ndarray or DataFrame object): clients’ purchase history prior to communication. * ``target`` (Series object): Column target by values. * ``treatment`` (Series object): Column treatment by values. * ``DESCR`` (str): Description of the Lenta dataset. * ``feature_names`` (Bunch object): Names of the features. * ``target_name`` (str): Name of the target. * ``treatment_name`` (str): Name of the treatment. References: https://ods.ai/competitions/x5-retailhero-uplift-modeling/data """ url_train = 'https://timds.s3.eu-central-1.amazonaws.com/uplift_train.csv.gz' file_train = url_train.split('/')[-1] csv_train_path = _get_data(data_home=data_home, url=url_train, dest_subdir=dest_subdir, dest_filename=file_train, download_if_missing=download_if_missing) train = pd.read_csv(csv_train_path) train_features = list(train.columns) target_col = 'target' treatment_col = 'treatment_flg' treatment, target = train[treatment_col], train[target_col] train = train.drop([target_col, treatment_col], axis=1) url_clients = 'https://timds.s3.eu-central-1.amazonaws.com/clients.csv.gz' file_clients = url_clients.split('/')[-1] csv_clients_path = _get_data(data_home=data_home, url=url_clients, dest_subdir=dest_subdir, dest_filename=file_clients, download_if_missing=download_if_missing) clients = pd.read_csv(csv_clients_path) clients_features = list(clients.columns) url_purchases = 'https://timds.s3.eu-central-1.amazonaws.com/purchases.csv.gz' file_purchases = url_purchases.split('/')[-1] csv_purchases_path = _get_data(data_home=data_home, url=url_purchases, dest_subdir=dest_subdir, dest_filename=file_purchases, download_if_missing=download_if_missing) purchases = pd.read_csv(csv_purchases_path) purchases_features = list(purchases.columns) data = Bunch(clients=clients, train=train, purchases=purchases) feature_names = Bunch(train_features=train_features, clients_features=clients_features, purchases_features=purchases_features) module_path = os.path.dirname(__file__) with open(os.path.join(module_path, 'descr', 'x5.rst')) as rst_file: fdescr = rst_file.read() return Bunch(data=data, target=target, treatment=treatment, DESCR=fdescr, feature_names=feature_names, target_name='target', treatment_name='treatment_flg')
[docs]def fetch_criteo(target_col='visit', treatment_col='treatment', data_home=None, dest_subdir=None, download_if_missing=True, percent10=False, return_X_y_t=False): """Load and return the Criteo Uplift Prediction Dataset (classification). This dataset is constructed by assembling data resulting from several incrementality tests, a particular randomized trial procedure where a random part of the population is prevented from being targeted by advertising. Major columns: * ``treatment`` (binary): treatment * ``exposure`` (binary): treatment * ``visit`` (binary): target * ``conversion`` (binary): target * ``f0, ... , f11`` (float): feature values Read more in the :ref:`docs <Criteo>`. Args: target_col (string, 'visit', 'conversion' or 'all', default='visit'): Selects which column from dataset will be target. If 'all', return a DataFrame with all targets cols. treatment_col (string,'treatment', 'exposure' or 'all', default='treatment'): Selects which column from dataset will be treatment. If 'all', return a DataFrame with all treatment cols. data_home (string): Specify a download and cache folder for the datasets. dest_subdir (string): The name of the folder in which the dataset is stored. download_if_missing (bool, default=True): If False, raise an IOError if the data is not locally available instead of trying to download the data from the source site. percent10 (bool, default=False): Whether to load only 10 percent of the data. return_X_y_t (bool, default=False): If True, returns (data, target, treatment) instead of a Bunch object. Returns: Bunch or tuple: dataset. Bunch: By default dictionary-like object, with the following attributes: * ``data`` (DataFrame object): Dataset without target and treatment. * ``target`` (Series or DataFrame object): Column target by values. * ``treatment`` (Series or DataFrame object): Column treatment by values. * ``DESCR`` (str): Description of the Lenta dataset. * ``feature_names`` (list): Names of the features. * ``target_name`` (str list): Name of the target. * ``treatment_name`` (str or list): Name of the treatment. Tuple: tuple (data, target, treatment) if `return_X_y` is True References: “A Large Scale Benchmark for Uplift Modeling” Eustache Diemert, Artem Betlei, Christophe Renaudin; (Criteo AI Lab), Massih-Reza Amini (LIG, Grenoble INP) """ treatment_cols = ['exposure', 'treatment'] if treatment_col == 'all': treatment_col = treatment_cols elif treatment_col not in treatment_cols: raise ValueError(f"treatment_col value must be in {treatment_cols + ['all']}. " f"Got value {treatment_col}.") target_cols = ['visit', 'conversion'] if target_col == 'all': target_col = target_cols elif target_col not in target_cols: raise ValueError(f"target_col value must be from {target_cols + ['all']}. " f"Got value {target_col}.") if percent10: url = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz' else: url = "https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz" filename = url.split('/')[-1] csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir, dest_filename=filename, download_if_missing=download_if_missing) dtypes = { 'exposure': 'Int8', 'treatment': 'Int8', 'conversion': 'Int8', 'visit': 'Int8' } data = pd.read_csv(csv_path, dtype=dtypes) treatment, target = data[treatment_col], data[target_col] data = data.drop(target_cols + treatment_cols, axis=1) if return_X_y_t: return data, target, treatment feature_names = list(data.columns) module_path = os.path.dirname(__file__) with open(os.path.join(module_path, 'descr', 'criteo.rst')) as rst_file: fdescr = rst_file.read() return Bunch(data=data, target=target, treatment=treatment, DESCR=fdescr, feature_names=feature_names, target_name=target_col, treatment_name=treatment_col)
[docs]def fetch_hillstrom(target_col='visit', data_home=None, dest_subdir=None, download_if_missing=True, return_X_y_t=False): """Load and return Kevin Hillstrom Dataset MineThatData (classification or regression). This dataset contains 64,000 customers who last purchased within twelve months. The customers were involved in an e-mail test. Major columns: * ``visit`` (binary): target. 1/0 indicator, 1 = Customer visited website in the following two weeks. * ``conversion`` (binary): target. 1/0 indicator, 1 = Customer purchased merchandise in the following two weeks. * ``spend`` (float): target. Actual dollars spent in the following two weeks. * ``segment`` (str): treatment. The e-mail campaign the customer received Read more in the :ref:`docs <Hillstrom>`. Args: target_col (string, 'visit' or 'conversion', 'spend' or 'all', default='visit'): Selects which column from dataset will be target data_home (str): The path to the folder where datasets are stored. dest_subdir (str): The name of the folder in which the dataset is stored. download_if_missing (bool): Download the data if not present. Raises an IOError if False and data is missing. return_X_y_t (bool, default=False): If True, returns (data, target, treatment) instead of a Bunch object. Returns: Bunch or tuple: dataset. Bunch: By default dictionary-like object, with the following attributes: * ``data`` (DataFrame object): Dataset without target and treatment. * ``target`` (Series or DataFrame object): Column target by values. * ``treatment`` (Series object): Column treatment by values. * ``DESCR`` (str): Description of the Lenta dataset. * ``feature_names`` (list): Names of the features. * ``target_name`` (str or list): Name of the target. * ``treatment_name`` (str): Name of the treatment. Tuple: tuple (data, target, treatment) if `return_X_y` is True References: https://blog.minethatdata.com/2008/03/minethatdata-e-mail-analytics-and-data.html """ target_cols = ['visit', 'conversion', 'spend'] if target_col == 'all': target_col = target_cols elif target_col not in target_cols: raise ValueError(f"target_col value must be from {target_cols + ['all']}. " f"Got value {target_col + ['all']}.") url = 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz' filename = url.split('/')[-1] csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir, dest_filename=filename, download_if_missing=download_if_missing) treatment_col = 'segment' data = pd.read_csv(csv_path) treatment, target = data[treatment_col], data[target_col] data = data.drop(target_cols + [treatment_col], axis=1) if return_X_y_t: return data, target, treatment feature_names = list(data.columns) module_path = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(module_path, 'descr', 'hillstrom.rst')) as rst_file: fdescr = rst_file.read() return Bunch(data=data, target=target, treatment=treatment, DESCR=fdescr, feature_names=feature_names, target_name=target_col, treatment_name=treatment_col)